diff --git a/.gitignore b/.gitignore index 3624d1226961..debad77ec2ad 100644 --- a/.gitignore +++ b/.gitignore @@ -66,6 +66,7 @@ scalastyle-output.xml R-unit-tests.log R/unit-tests.out python/lib/pyspark.zip +lint-r-report.log # For Hive metastore_db/ diff --git a/.rat-excludes b/.rat-excludes index c0f81b57fe09..0240e81c45ea 100644 --- a/.rat-excludes +++ b/.rat-excludes @@ -28,6 +28,7 @@ spark-env.sh spark-env.cmd spark-env.sh.template log4j-defaults.properties +log4j-defaults-repl.properties bootstrap-tooltip.js jquery-1.11.1.min.js d3.min.js @@ -80,5 +81,13 @@ local-1425081759269/* local-1426533911241/* local-1426633911242/* local-1430917381534/* +local-1430917381535_1 +local-1430917381535_2 DESCRIPTION NAMESPACE +test_support/* +.*Rd +help/* +html/* +INDEX +.lintr diff --git a/LICENSE b/LICENSE index d6b9ccf07d99..f9e412cade34 100644 --- a/LICENSE +++ b/LICENSE @@ -853,6 +853,52 @@ and Vis.js may be distributed under either license. +======================================================================== +For dagre-d3 (core/src/main/resources/org/apache/spark/ui/static/dagre-d3.min.js): +======================================================================== +Copyright (c) 2013 Chris Pettitt + +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in +all copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +THE SOFTWARE. + +======================================================================== +For graphlib-dot (core/src/main/resources/org/apache/spark/ui/static/graphlib-dot.min.js): +======================================================================== +Copyright (c) 2012-2013 Chris Pettitt + +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in +all copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +THE SOFTWARE. + ======================================================================== BSD-style licenses ======================================================================== @@ -861,7 +907,7 @@ The following components are provided under a BSD-style license. See project lin (BSD 3 Clause) core (com.github.fommil.netlib:core:1.1.2 - https://github.com/fommil/netlib-java/core) (BSD 3 Clause) JPMML-Model (org.jpmml:pmml-model:1.1.15 - https://github.com/jpmml/jpmml-model) - (BSD 3-clause style license) jblas (org.jblas:jblas:1.2.3 - http://jblas.org/) + (BSD 3-clause style license) jblas (org.jblas:jblas:1.2.4 - http://jblas.org/) (BSD License) AntLR Parser Generator (antlr:antlr:2.7.7 - http://www.antlr.org/) (BSD License) Javolution (javolution:javolution:5.5.1 - http://javolution.org) (BSD licence) ANTLR ST4 4.0.4 (org.antlr:ST4:4.0.4 - http://www.stringtemplate.org) @@ -902,5 +948,6 @@ The following components are provided under the MIT License. See project link fo (MIT License) SLF4J LOG4J-12 Binding (org.slf4j:slf4j-log4j12:1.7.5 - http://www.slf4j.org) (MIT License) pyrolite (org.spark-project:pyrolite:2.0.1 - http://pythonhosted.org/Pyro4/) (MIT License) scopt (com.github.scopt:scopt_2.10:3.2.0 - https://github.com/scopt/scopt) - (The MIT License) Mockito (org.mockito:mockito-all:1.8.5 - http://www.mockito.org) + (The MIT License) Mockito (org.mockito:mockito-core:1.9.5 - http://www.mockito.org) (MIT License) jquery (https://jquery.org/license/) + (MIT License) AnchorJS (https://github.com/bryanbraun/anchorjs) diff --git a/R/README.md b/R/README.md index a6970e39b55f..005f56da1670 100644 --- a/R/README.md +++ b/R/README.md @@ -6,7 +6,7 @@ SparkR is an R package that provides a light-weight frontend to use Spark from R #### Build Spark -Build Spark with [Maven](http://spark.apache.org/docs/latest/building-spark.html#building-with-buildmvn) and include the `-PsparkR` profile to build the R package. For example to use the default Hadoop versions you can run +Build Spark with [Maven](http://spark.apache.org/docs/latest/building-spark.html#building-with-buildmvn) and include the `-Psparkr` profile to build the R package. For example to use the default Hadoop versions you can run ``` build/mvn -DskipTests -Psparkr package ``` @@ -52,7 +52,7 @@ The SparkR documentation (Rd files and HTML files) are not a part of the source SparkR comes with several sample programs in the `examples/src/main/r` directory. To run one of them, use `./bin/sparkR `. For example: - ./bin/sparkR examples/src/main/r/pi.R local[2] + ./bin/sparkR examples/src/main/r/dataframe.R You can also run the unit-tests for SparkR by running (you need to install the [testthat](http://cran.r-project.org/web/packages/testthat/index.html) package first): @@ -63,5 +63,5 @@ You can also run the unit-tests for SparkR by running (you need to install the [ The `./bin/spark-submit` and `./bin/sparkR` can also be used to submit jobs to YARN clusters. You will need to set YARN conf dir before doing so. For example on CDH you can run ``` export YARN_CONF_DIR=/etc/hadoop/conf -./bin/spark-submit --master yarn examples/src/main/r/pi.R 4 +./bin/spark-submit --master yarn examples/src/main/r/dataframe.R ``` diff --git a/R/create-docs.sh b/R/create-docs.sh index 4194172a2e11..6a4687b06ecb 100755 --- a/R/create-docs.sh +++ b/R/create-docs.sh @@ -23,14 +23,14 @@ # After running this script the html docs can be found in # $SPARK_HOME/R/pkg/html +set -o pipefail +set -e + # Figure out where the script is export FWDIR="$(cd "`dirname "$0"`"; pwd)" pushd $FWDIR -# Generate Rd file -Rscript -e 'library(devtools); devtools::document(pkg="./pkg", roclets=c("rd"))' - -# Install the package +# Install the package (this will also generate the Rd files) ./install-dev.sh # Now create HTML files diff --git a/R/install-dev.sh b/R/install-dev.sh index 55ed6f4be1a4..1edd551f8d24 100755 --- a/R/install-dev.sh +++ b/R/install-dev.sh @@ -26,11 +26,20 @@ # NOTE(shivaram): Right now we use $SPARK_HOME/R/lib to be the installation directory # to load the SparkR package on the worker nodes. +set -o pipefail +set -e FWDIR="$(cd `dirname $0`; pwd)" LIB_DIR="$FWDIR/lib" mkdir -p $LIB_DIR -# Install R +pushd $FWDIR + +# Generate Rd files if devtools is installed +Rscript -e ' if("devtools" %in% rownames(installed.packages())) { library(devtools); devtools::document(pkg="./pkg", roclets=c("rd")) }' + +# Install SparkR to $LIB_DIR R CMD INSTALL --library=$LIB_DIR $FWDIR/pkg/ + +popd diff --git a/R/log4j.properties b/R/log4j.properties index 701adb2a3da1..cce8d9152d32 100644 --- a/R/log4j.properties +++ b/R/log4j.properties @@ -19,7 +19,7 @@ log4j.rootCategory=INFO, file log4j.appender.file=org.apache.log4j.FileAppender log4j.appender.file.append=true -log4j.appender.file.file=R-unit-tests.log +log4j.appender.file.file=R/target/unit-tests.log log4j.appender.file.layout=org.apache.log4j.PatternLayout log4j.appender.file.layout.ConversionPattern=%d{yy/MM/dd HH:mm:ss.SSS} %t %p %c{1}: %m%n diff --git a/R/pkg/.lintr b/R/pkg/.lintr new file mode 100644 index 000000000000..038236fc149e --- /dev/null +++ b/R/pkg/.lintr @@ -0,0 +1,2 @@ +linters: with_defaults(line_length_linter(100), camel_case_linter = NULL, open_curly_linter(allow_single_line = TRUE), closed_curly_linter(allow_single_line = TRUE)) +exclusions: list("inst/profile/general.R" = 1, "inst/profile/shell.R") diff --git a/R/pkg/NAMESPACE b/R/pkg/NAMESPACE index 64ffdcffc9ca..7f857222452d 100644 --- a/R/pkg/NAMESPACE +++ b/R/pkg/NAMESPACE @@ -1,12 +1,20 @@ # Imports from base R importFrom(methods, setGeneric, setMethod, setOldClass) -useDynLib(SparkR, stringHashCode) + +# Disable native libraries till we figure out how to package it +# See SPARKR-7839 +#useDynLib(SparkR, stringHashCode) # S3 methods exported export("sparkR.init") export("sparkR.stop") export("print.jobj") +# Job group lifecycle management methods +export("setJobGroup", + "clearJobGroup", + "cancelJobGroup") + exportClasses("DataFrame") exportMethods("arrange", @@ -16,9 +24,11 @@ exportMethods("arrange", "count", "describe", "distinct", + "dropna", "dtypes", "except", "explain", + "fillna", "filter", "first", "group_by", diff --git a/R/pkg/R/DataFrame.R b/R/pkg/R/DataFrame.R index a7fa32e291fb..60702824acb4 100644 --- a/R/pkg/R/DataFrame.R +++ b/R/pkg/R/DataFrame.R @@ -38,7 +38,7 @@ setClass("DataFrame", setMethod("initialize", "DataFrame", function(.Object, sdf, isCached) { .Object@env <- new.env() .Object@env$isCached <- isCached - + .Object@sdf <- sdf .Object }) @@ -55,19 +55,19 @@ dataFrame <- function(sdf, isCached = FALSE) { ############################ DataFrame Methods ############################################## #' Print Schema of a DataFrame -#' +#' #' Prints out the schema in tree format -#' +#' #' @param x A SparkSQL DataFrame -#' +#' #' @rdname printSchema #' @export #' @examples #'\dontrun{ #' sc <- sparkR.init() -#' sqlCtx <- sparkRSQL.init(sc) +#' sqlContext <- sparkRSQL.init(sc) #' path <- "path/to/file.json" -#' df <- jsonFile(sqlCtx, path) +#' df <- jsonFile(sqlContext, path) #' printSchema(df) #'} setMethod("printSchema", @@ -78,19 +78,19 @@ setMethod("printSchema", }) #' Get schema object -#' +#' #' Returns the schema of this DataFrame as a structType object. -#' +#' #' @param x A SparkSQL DataFrame -#' +#' #' @rdname schema #' @export #' @examples #'\dontrun{ #' sc <- sparkR.init() -#' sqlCtx <- sparkRSQL.init(sc) +#' sqlContext <- sparkRSQL.init(sc) #' path <- "path/to/file.json" -#' df <- jsonFile(sqlCtx, path) +#' df <- jsonFile(sqlContext, path) #' dfSchema <- schema(df) #'} setMethod("schema", @@ -100,9 +100,9 @@ setMethod("schema", }) #' Explain -#' +#' #' Print the logical and physical Catalyst plans to the console for debugging. -#' +#' #' @param x A SparkSQL DataFrame #' @param extended Logical. If extended is False, explain() only prints the physical plan. #' @rdname explain @@ -110,9 +110,9 @@ setMethod("schema", #' @examples #'\dontrun{ #' sc <- sparkR.init() -#' sqlCtx <- sparkRSQL.init(sc) +#' sqlContext <- sparkRSQL.init(sc) #' path <- "path/to/file.json" -#' df <- jsonFile(sqlCtx, path) +#' df <- jsonFile(sqlContext, path) #' explain(df, TRUE) #'} setMethod("explain", @@ -139,9 +139,9 @@ setMethod("explain", #' @examples #'\dontrun{ #' sc <- sparkR.init() -#' sqlCtx <- sparkRSQL.init(sc) +#' sqlContext <- sparkRSQL.init(sc) #' path <- "path/to/file.json" -#' df <- jsonFile(sqlCtx, path) +#' df <- jsonFile(sqlContext, path) #' isLocal(df) #'} setMethod("isLocal", @@ -162,15 +162,15 @@ setMethod("isLocal", #' @examples #'\dontrun{ #' sc <- sparkR.init() -#' sqlCtx <- sparkRSQL.init(sc) +#' sqlContext <- sparkRSQL.init(sc) #' path <- "path/to/file.json" -#' df <- jsonFile(sqlCtx, path) +#' df <- jsonFile(sqlContext, path) #' showDF(df) #'} setMethod("showDF", signature(x = "DataFrame"), - function(x, numRows = 20) { - s <- callJMethod(x@sdf, "showString", numToInt(numRows)) + function(x, numRows = 20, truncate = TRUE) { + s <- callJMethod(x@sdf, "showString", numToInt(numRows), truncate) cat(s) }) @@ -185,9 +185,9 @@ setMethod("showDF", #' @examples #'\dontrun{ #' sc <- sparkR.init() -#' sqlCtx <- sparkRSQL.init(sc) +#' sqlContext <- sparkRSQL.init(sc) #' path <- "path/to/file.json" -#' df <- jsonFile(sqlCtx, path) +#' df <- jsonFile(sqlContext, path) #' df #'} setMethod("show", "DataFrame", @@ -200,19 +200,19 @@ setMethod("show", "DataFrame", }) #' DataTypes -#' +#' #' Return all column names and their data types as a list -#' +#' #' @param x A SparkSQL DataFrame -#' +#' #' @rdname dtypes #' @export #' @examples #'\dontrun{ #' sc <- sparkR.init() -#' sqlCtx <- sparkRSQL.init(sc) +#' sqlContext <- sparkRSQL.init(sc) #' path <- "path/to/file.json" -#' df <- jsonFile(sqlCtx, path) +#' df <- jsonFile(sqlContext, path) #' dtypes(df) #'} setMethod("dtypes", @@ -224,19 +224,19 @@ setMethod("dtypes", }) #' Column names -#' +#' #' Return all column names as a list -#' +#' #' @param x A SparkSQL DataFrame -#' +#' #' @rdname columns #' @export #' @examples #'\dontrun{ #' sc <- sparkR.init() -#' sqlCtx <- sparkRSQL.init(sc) +#' sqlContext <- sparkRSQL.init(sc) #' path <- "path/to/file.json" -#' df <- jsonFile(sqlCtx, path) +#' df <- jsonFile(sqlContext, path) #' columns(df) #'} setMethod("columns", @@ -256,22 +256,22 @@ setMethod("names", }) #' Register Temporary Table -#' +#' #' Registers a DataFrame as a Temporary Table in the SQLContext -#' +#' #' @param x A SparkSQL DataFrame #' @param tableName A character vector containing the name of the table -#' +#' #' @rdname registerTempTable #' @export #' @examples #'\dontrun{ #' sc <- sparkR.init() -#' sqlCtx <- sparkRSQL.init(sc) +#' sqlContext <- sparkRSQL.init(sc) #' path <- "path/to/file.json" -#' df <- jsonFile(sqlCtx, path) +#' df <- jsonFile(sqlContext, path) #' registerTempTable(df, "json_df") -#' new_df <- sql(sqlCtx, "SELECT * FROM json_df") +#' new_df <- sql(sqlContext, "SELECT * FROM json_df") #'} setMethod("registerTempTable", signature(x = "DataFrame", tableName = "character"), @@ -293,9 +293,9 @@ setMethod("registerTempTable", #' @examples #'\dontrun{ #' sc <- sparkR.init() -#' sqlCtx <- sparkRSQL.init(sc) -#' df <- read.df(sqlCtx, path, "parquet") -#' df2 <- read.df(sqlCtx, path2, "parquet") +#' sqlContext <- sparkRSQL.init(sc) +#' df <- read.df(sqlContext, path, "parquet") +#' df2 <- read.df(sqlContext, path2, "parquet") #' registerTempTable(df, "table1") #' insertInto(df2, "table1", overwrite = TRUE) #'} @@ -306,19 +306,19 @@ setMethod("insertInto", }) #' Cache -#' +#' #' Persist with the default storage level (MEMORY_ONLY). -#' +#' #' @param x A SparkSQL DataFrame -#' +#' #' @rdname cache-methods #' @export #' @examples #'\dontrun{ #' sc <- sparkR.init() -#' sqlCtx <- sparkRSQL.init(sc) +#' sqlContext <- sparkRSQL.init(sc) #' path <- "path/to/file.json" -#' df <- jsonFile(sqlCtx, path) +#' df <- jsonFile(sqlContext, path) #' cache(df) #'} setMethod("cache", @@ -341,9 +341,9 @@ setMethod("cache", #' @examples #'\dontrun{ #' sc <- sparkR.init() -#' sqlCtx <- sparkRSQL.init(sc) +#' sqlContext <- sparkRSQL.init(sc) #' path <- "path/to/file.json" -#' df <- jsonFile(sqlCtx, path) +#' df <- jsonFile(sqlContext, path) #' persist(df, "MEMORY_AND_DISK") #'} setMethod("persist", @@ -366,9 +366,9 @@ setMethod("persist", #' @examples #'\dontrun{ #' sc <- sparkR.init() -#' sqlCtx <- sparkRSQL.init(sc) +#' sqlContext <- sparkRSQL.init(sc) #' path <- "path/to/file.json" -#' df <- jsonFile(sqlCtx, path) +#' df <- jsonFile(sqlContext, path) #' persist(df, "MEMORY_AND_DISK") #' unpersist(df) #'} @@ -391,16 +391,16 @@ setMethod("unpersist", #' @examples #'\dontrun{ #' sc <- sparkR.init() -#' sqlCtx <- sparkRSQL.init(sc) +#' sqlContext <- sparkRSQL.init(sc) #' path <- "path/to/file.json" -#' df <- jsonFile(sqlCtx, path) +#' df <- jsonFile(sqlContext, path) #' newDF <- repartition(df, 2L) #'} setMethod("repartition", signature(x = "DataFrame", numPartitions = "numeric"), function(x, numPartitions) { sdf <- callJMethod(x@sdf, "repartition", numToInt(numPartitions)) - dataFrame(sdf) + dataFrame(sdf) }) # toJSON @@ -415,9 +415,9 @@ setMethod("repartition", # @examples #\dontrun{ # sc <- sparkR.init() -# sqlCtx <- sparkRSQL.init(sc) +# sqlContext <- sparkRSQL.init(sc) # path <- "path/to/file.json" -# df <- jsonFile(sqlCtx, path) +# df <- jsonFile(sqlContext, path) # newRDD <- toJSON(df) #} setMethod("toJSON", @@ -440,9 +440,9 @@ setMethod("toJSON", #' @examples #'\dontrun{ #' sc <- sparkR.init() -#' sqlCtx <- sparkRSQL.init(sc) +#' sqlContext <- sparkRSQL.init(sc) #' path <- "path/to/file.json" -#' df <- jsonFile(sqlCtx, path) +#' df <- jsonFile(sqlContext, path) #' saveAsParquetFile(df, "/tmp/sparkr-tmp/") #'} setMethod("saveAsParquetFile", @@ -461,9 +461,9 @@ setMethod("saveAsParquetFile", #' @examples #'\dontrun{ #' sc <- sparkR.init() -#' sqlCtx <- sparkRSQL.init(sc) +#' sqlContext <- sparkRSQL.init(sc) #' path <- "path/to/file.json" -#' df <- jsonFile(sqlCtx, path) +#' df <- jsonFile(sqlContext, path) #' distinctDF <- distinct(df) #'} setMethod("distinct", @@ -486,10 +486,10 @@ setMethod("distinct", #' @examples #'\dontrun{ #' sc <- sparkR.init() -#' sqlCtx <- sparkRSQL.init(sc) +#' sqlContext <- sparkRSQL.init(sc) #' path <- "path/to/file.json" -#' df <- jsonFile(sqlCtx, path) -#' collect(sample(df, FALSE, 0.5)) +#' df <- jsonFile(sqlContext, path) +#' collect(sample(df, FALSE, 0.5)) #' collect(sample(df, TRUE, 0.5)) #'} setMethod("sample", @@ -513,19 +513,19 @@ setMethod("sample_frac", }) #' Count -#' +#' #' Returns the number of rows in a DataFrame -#' +#' #' @param x A SparkSQL DataFrame -#' +#' #' @rdname count #' @export #' @examples #'\dontrun{ #' sc <- sparkR.init() -#' sqlCtx <- sparkRSQL.init(sc) +#' sqlContext <- sparkRSQL.init(sc) #' path <- "path/to/file.json" -#' df <- jsonFile(sqlCtx, path) +#' df <- jsonFile(sqlContext, path) #' count(df) #' } setMethod("count", @@ -545,9 +545,9 @@ setMethod("count", #' @examples #'\dontrun{ #' sc <- sparkR.init() -#' sqlCtx <- sparkRSQL.init(sc) +#' sqlContext <- sparkRSQL.init(sc) #' path <- "path/to/file.json" -#' df <- jsonFile(sqlCtx, path) +#' df <- jsonFile(sqlContext, path) #' collected <- collect(df) #' firstName <- collected[[1]]$name #' } @@ -568,21 +568,21 @@ setMethod("collect", }) #' Limit -#' +#' #' Limit the resulting DataFrame to the number of rows specified. -#' +#' #' @param x A SparkSQL DataFrame #' @param num The number of rows to return #' @return A new DataFrame containing the number of rows specified. -#' +#' #' @rdname limit #' @export #' @examples #' \dontrun{ #' sc <- sparkR.init() -#' sqlCtx <- sparkRSQL.init(sc) +#' sqlContext <- sparkRSQL.init(sc) #' path <- "path/to/file.json" -#' df <- jsonFile(sqlCtx, path) +#' df <- jsonFile(sqlContext, path) #' limitedDF <- limit(df, 10) #' } setMethod("limit", @@ -593,15 +593,15 @@ setMethod("limit", }) #' Take the first NUM rows of a DataFrame and return a the results as a data.frame -#' +#' #' @rdname take #' @export #' @examples #'\dontrun{ #' sc <- sparkR.init() -#' sqlCtx <- sparkRSQL.init(sc) +#' sqlContext <- sparkRSQL.init(sc) #' path <- "path/to/file.json" -#' df <- jsonFile(sqlCtx, path) +#' df <- jsonFile(sqlContext, path) #' take(df, 2) #' } setMethod("take", @@ -613,8 +613,8 @@ setMethod("take", #' Head #' -#' Return the first NUM rows of a DataFrame as a data.frame. If NUM is NULL, -#' then head() returns the first 6 rows in keeping with the current data.frame +#' Return the first NUM rows of a DataFrame as a data.frame. If NUM is NULL, +#' then head() returns the first 6 rows in keeping with the current data.frame #' convention in R. #' #' @param x A SparkSQL DataFrame @@ -626,9 +626,9 @@ setMethod("take", #' @examples #'\dontrun{ #' sc <- sparkR.init() -#' sqlCtx <- sparkRSQL.init(sc) +#' sqlContext <- sparkRSQL.init(sc) #' path <- "path/to/file.json" -#' df <- jsonFile(sqlCtx, path) +#' df <- jsonFile(sqlContext, path) #' head(df) #' } setMethod("head", @@ -647,9 +647,9 @@ setMethod("head", #' @examples #'\dontrun{ #' sc <- sparkR.init() -#' sqlCtx <- sparkRSQL.init(sc) +#' sqlContext <- sparkRSQL.init(sc) #' path <- "path/to/file.json" -#' df <- jsonFile(sqlCtx, path) +#' df <- jsonFile(sqlContext, path) #' first(df) #' } setMethod("first", @@ -659,19 +659,19 @@ setMethod("first", }) # toRDD() -# +# # Converts a Spark DataFrame to an RDD while preserving column names. -# +# # @param x A Spark DataFrame -# +# # @rdname DataFrame # @export # @examples #\dontrun{ # sc <- sparkR.init() -# sqlCtx <- sparkRSQL.init(sc) +# sqlContext <- sparkRSQL.init(sc) # path <- "path/to/file.json" -# df <- jsonFile(sqlCtx, path) +# df <- jsonFile(sqlContext, path) # rdd <- toRDD(df) # } setMethod("toRDD", @@ -938,9 +938,9 @@ setMethod("select", #' @examples #'\dontrun{ #' sc <- sparkR.init() -#' sqlCtx <- sparkRSQL.init(sc) +#' sqlContext <- sparkRSQL.init(sc) #' path <- "path/to/file.json" -#' df <- jsonFile(sqlCtx, path) +#' df <- jsonFile(sqlContext, path) #' selectExpr(df, "col1", "(col2 * 5) as newCol") #' } setMethod("selectExpr", @@ -964,9 +964,9 @@ setMethod("selectExpr", #' @examples #'\dontrun{ #' sc <- sparkR.init() -#' sqlCtx <- sparkRSQL.init(sc) +#' sqlContext <- sparkRSQL.init(sc) #' path <- "path/to/file.json" -#' df <- jsonFile(sqlCtx, path) +#' df <- jsonFile(sqlContext, path) #' newDF <- withColumn(df, "newCol", df$col1 * 5) #' } setMethod("withColumn", @@ -988,9 +988,9 @@ setMethod("withColumn", #' @examples #'\dontrun{ #' sc <- sparkR.init() -#' sqlCtx <- sparkRSQL.init(sc) +#' sqlContext <- sparkRSQL.init(sc) #' path <- "path/to/file.json" -#' df <- jsonFile(sqlCtx, path) +#' df <- jsonFile(sqlContext, path) #' newDF <- mutate(df, newCol = df$col1 * 5, newCol2 = df$col1 * 2) #' names(newDF) # Will contain newCol, newCol2 #' } @@ -1024,9 +1024,9 @@ setMethod("mutate", #' @examples #'\dontrun{ #' sc <- sparkR.init() -#' sqlCtx <- sparkRSQL.init(sc) +#' sqlContext <- sparkRSQL.init(sc) #' path <- "path/to/file.json" -#' df <- jsonFile(sqlCtx, path) +#' df <- jsonFile(sqlContext, path) #' newDF <- withColumnRenamed(df, "col1", "newCol1") #' } setMethod("withColumnRenamed", @@ -1055,9 +1055,9 @@ setMethod("withColumnRenamed", #' @examples #'\dontrun{ #' sc <- sparkR.init() -#' sqlCtx <- sparkRSQL.init(sc) +#' sqlContext <- sparkRSQL.init(sc) #' path <- "path/to/file.json" -#' df <- jsonFile(sqlCtx, path) +#' df <- jsonFile(sqlContext, path) #' newDF <- rename(df, col1 = df$newCol1) #' } setMethod("rename", @@ -1095,9 +1095,9 @@ setClassUnion("characterOrColumn", c("character", "Column")) #' @examples #'\dontrun{ #' sc <- sparkR.init() -#' sqlCtx <- sparkRSQL.init(sc) +#' sqlContext <- sparkRSQL.init(sc) #' path <- "path/to/file.json" -#' df <- jsonFile(sqlCtx, path) +#' df <- jsonFile(sqlContext, path) #' arrange(df, df$col1) #' arrange(df, "col1") #' arrange(df, asc(df$col1), desc(abs(df$col2))) @@ -1137,9 +1137,9 @@ setMethod("orderBy", #' @examples #'\dontrun{ #' sc <- sparkR.init() -#' sqlCtx <- sparkRSQL.init(sc) +#' sqlContext <- sparkRSQL.init(sc) #' path <- "path/to/file.json" -#' df <- jsonFile(sqlCtx, path) +#' df <- jsonFile(sqlContext, path) #' filter(df, "col1 > 0") #' filter(df, df$col2 != "abcdefg") #' } @@ -1167,7 +1167,7 @@ setMethod("where", #' #' @param x A Spark DataFrame #' @param y A Spark DataFrame -#' @param joinExpr (Optional) The expression used to perform the join. joinExpr must be a +#' @param joinExpr (Optional) The expression used to perform the join. joinExpr must be a #' Column expression. If joinExpr is omitted, join() wil perform a Cartesian join #' @param joinType The type of join to perform. The following join types are available: #' 'inner', 'outer', 'left_outer', 'right_outer', 'semijoin'. The default joinType is "inner". @@ -1177,9 +1177,9 @@ setMethod("where", #' @examples #'\dontrun{ #' sc <- sparkR.init() -#' sqlCtx <- sparkRSQL.init(sc) -#' df1 <- jsonFile(sqlCtx, path) -#' df2 <- jsonFile(sqlCtx, path2) +#' sqlContext <- sparkRSQL.init(sc) +#' df1 <- jsonFile(sqlContext, path) +#' df2 <- jsonFile(sqlContext, path2) #' join(df1, df2) # Performs a Cartesian #' join(df1, df2, df1$col1 == df2$col2) # Performs an inner join based on expression #' join(df1, df2, df1$col1 == df2$col2, "right_outer") @@ -1219,9 +1219,9 @@ setMethod("join", #' @examples #'\dontrun{ #' sc <- sparkR.init() -#' sqlCtx <- sparkRSQL.init(sc) -#' df1 <- jsonFile(sqlCtx, path) -#' df2 <- jsonFile(sqlCtx, path2) +#' sqlContext <- sparkRSQL.init(sc) +#' df1 <- jsonFile(sqlContext, path) +#' df2 <- jsonFile(sqlContext, path2) #' unioned <- unionAll(df, df2) #' } setMethod("unionAll", @@ -1244,9 +1244,9 @@ setMethod("unionAll", #' @examples #'\dontrun{ #' sc <- sparkR.init() -#' sqlCtx <- sparkRSQL.init(sc) -#' df1 <- jsonFile(sqlCtx, path) -#' df2 <- jsonFile(sqlCtx, path2) +#' sqlContext <- sparkRSQL.init(sc) +#' df1 <- jsonFile(sqlContext, path) +#' df2 <- jsonFile(sqlContext, path2) #' intersectDF <- intersect(df, df2) #' } setMethod("intersect", @@ -1269,9 +1269,9 @@ setMethod("intersect", #' @examples #'\dontrun{ #' sc <- sparkR.init() -#' sqlCtx <- sparkRSQL.init(sc) -#' df1 <- jsonFile(sqlCtx, path) -#' df2 <- jsonFile(sqlCtx, path2) +#' sqlContext <- sparkRSQL.init(sc) +#' df1 <- jsonFile(sqlContext, path) +#' df2 <- jsonFile(sqlContext, path2) #' exceptDF <- except(df, df2) #' } #' @rdname except @@ -1303,23 +1303,22 @@ setMethod("except", #' @param source A name for external data source #' @param mode One of 'append', 'overwrite', 'error', 'ignore' #' -#' @rdname write.df +#' @rdname write.df #' @export #' @examples #'\dontrun{ #' sc <- sparkR.init() -#' sqlCtx <- sparkRSQL.init(sc) +#' sqlContext <- sparkRSQL.init(sc) #' path <- "path/to/file.json" -#' df <- jsonFile(sqlCtx, path) +#' df <- jsonFile(sqlContext, path) #' write.df(df, "myfile", "parquet", "overwrite") #' } setMethod("write.df", - signature(df = "DataFrame", path = 'character', source = 'character', - mode = 'character'), - function(df, path = NULL, source = NULL, mode = "append", ...){ + signature(df = "DataFrame", path = 'character'), + function(df, path, source = NULL, mode = "append", ...){ if (is.null(source)) { - sqlCtx <- get(".sparkRSQLsc", envir = .sparkREnv) - source <- callJMethod(sqlCtx, "getConf", "spark.sql.sources.default", + sqlContext <- get(".sparkRSQLsc", envir = .sparkREnv) + source <- callJMethod(sqlContext, "getConf", "spark.sql.sources.default", "org.apache.spark.sql.parquet") } allModes <- c("append", "overwrite", "error", "ignore") @@ -1338,9 +1337,8 @@ setMethod("write.df", #' @aliases saveDF #' @export setMethod("saveDF", - signature(df = "DataFrame", path = 'character', source = 'character', - mode = 'character'), - function(df, path = NULL, source = NULL, mode = "append", ...){ + signature(df = "DataFrame", path = 'character'), + function(df, path, source = NULL, mode = "append", ...){ write.df(df, path, source, mode, ...) }) @@ -1371,9 +1369,9 @@ setMethod("saveDF", #' @examples #'\dontrun{ #' sc <- sparkR.init() -#' sqlCtx <- sparkRSQL.init(sc) +#' sqlContext <- sparkRSQL.init(sc) #' path <- "path/to/file.json" -#' df <- jsonFile(sqlCtx, path) +#' df <- jsonFile(sqlContext, path) #' saveAsTable(df, "myfile") #' } setMethod("saveAsTable", @@ -1381,8 +1379,8 @@ setMethod("saveAsTable", mode = 'character'), function(df, tableName, source = NULL, mode="append", ...){ if (is.null(source)) { - sqlCtx <- get(".sparkRSQLsc", envir = .sparkREnv) - source <- callJMethod(sqlCtx, "getConf", "spark.sql.sources.default", + sqlContext <- get(".sparkRSQLsc", envir = .sparkREnv) + source <- callJMethod(sqlContext, "getConf", "spark.sql.sources.default", "org.apache.spark.sql.parquet") } allModes <- c("append", "overwrite", "error", "ignore") @@ -1403,14 +1401,14 @@ setMethod("saveAsTable", #' @param col A string of name #' @param ... Additional expressions #' @return A DataFrame -#' @rdname describe +#' @rdname describe #' @export #' @examples #'\dontrun{ #' sc <- sparkR.init() -#' sqlCtx <- sparkRSQL.init(sc) +#' sqlContext <- sparkRSQL.init(sc) #' path <- "path/to/file.json" -#' df <- jsonFile(sqlCtx, path) +#' df <- jsonFile(sqlContext, path) #' describe(df) #' describe(df, "col1") #' describe(df, "col1", "col2") @@ -1431,3 +1429,128 @@ setMethod("describe", sdf <- callJMethod(x@sdf, "describe", listToSeq(colList)) dataFrame(sdf) }) + +#' dropna +#' +#' Returns a new DataFrame omitting rows with null values. +#' +#' @param x A SparkSQL DataFrame. +#' @param how "any" or "all". +#' if "any", drop a row if it contains any nulls. +#' if "all", drop a row only if all its values are null. +#' if minNonNulls is specified, how is ignored. +#' @param minNonNulls If specified, drop rows that have less than +#' minNonNulls non-null values. +#' This overwrites the how parameter. +#' @param cols Optional list of column names to consider. +#' @return A DataFrame +#' +#' @rdname nafunctions +#' @export +#' @examples +#'\dontrun{ +#' sc <- sparkR.init() +#' sqlCtx <- sparkRSQL.init(sc) +#' path <- "path/to/file.json" +#' df <- jsonFile(sqlCtx, path) +#' dropna(df) +#' } +setMethod("dropna", + signature(x = "DataFrame"), + function(x, how = c("any", "all"), minNonNulls = NULL, cols = NULL) { + how <- match.arg(how) + if (is.null(cols)) { + cols <- columns(x) + } + if (is.null(minNonNulls)) { + minNonNulls <- if (how == "any") { length(cols) } else { 1 } + } + + naFunctions <- callJMethod(x@sdf, "na") + sdf <- callJMethod(naFunctions, "drop", + as.integer(minNonNulls), listToSeq(as.list(cols))) + dataFrame(sdf) + }) + +#' @aliases dropna +#' @export +setMethod("na.omit", + signature(x = "DataFrame"), + function(x, how = c("any", "all"), minNonNulls = NULL, cols = NULL) { + dropna(x, how, minNonNulls, cols) + }) + +#' fillna +#' +#' Replace null values. +#' +#' @param x A SparkSQL DataFrame. +#' @param value Value to replace null values with. +#' Should be an integer, numeric, character or named list. +#' If the value is a named list, then cols is ignored and +#' value must be a mapping from column name (character) to +#' replacement value. The replacement value must be an +#' integer, numeric or character. +#' @param cols optional list of column names to consider. +#' Columns specified in cols that do not have matching data +#' type are ignored. For example, if value is a character, and +#' subset contains a non-character column, then the non-character +#' column is simply ignored. +#' @return A DataFrame +#' +#' @rdname nafunctions +#' @export +#' @examples +#'\dontrun{ +#' sc <- sparkR.init() +#' sqlCtx <- sparkRSQL.init(sc) +#' path <- "path/to/file.json" +#' df <- jsonFile(sqlCtx, path) +#' fillna(df, 1) +#' fillna(df, list("age" = 20, "name" = "unknown")) +#' } +setMethod("fillna", + signature(x = "DataFrame"), + function(x, value, cols = NULL) { + if (!(class(value) %in% c("integer", "numeric", "character", "list"))) { + stop("value should be an integer, numeric, charactor or named list.") + } + + if (class(value) == "list") { + # Check column names in the named list + colNames <- names(value) + if (length(colNames) == 0 || !all(colNames != "")) { + stop("value should be an a named list with each name being a column name.") + } + + # Convert to the named list to an environment to be passed to JVM + valueMap <- new.env() + for (col in colNames) { + # Check each item in the named list is of valid type + v <- value[[col]] + if (!(class(v) %in% c("integer", "numeric", "character"))) { + stop("Each item in value should be an integer, numeric or charactor.") + } + valueMap[[col]] <- v + } + + # When value is a named list, caller is expected not to pass in cols + if (!is.null(cols)) { + warning("When value is a named list, cols is ignored!") + cols <- NULL + } + + value <- valueMap + } else if (is.integer(value)) { + # Cast an integer to a numeric + value <- as.numeric(value) + } + + naFunctions <- callJMethod(x@sdf, "na") + sdf <- if (length(cols) == 0) { + callJMethod(naFunctions, "fill", value) + } else { + callJMethod(naFunctions, "fill", value, listToSeq(as.list(cols))) + } + dataFrame(sdf) + }) diff --git a/R/pkg/R/RDD.R b/R/pkg/R/RDD.R index d3a68fff780c..89511141d3ef 100644 --- a/R/pkg/R/RDD.R +++ b/R/pkg/R/RDD.R @@ -48,7 +48,7 @@ setMethod("initialize", "RDD", function(.Object, jrdd, serializedMode, # byte: The RDD stores data serialized in R. # string: The RDD stores data as strings. # row: The RDD stores the serialized rows of a DataFrame. - + # We use an environment to store mutable states inside an RDD object. # Note that R's call-by-value semantics makes modifying slots inside an # object (passed as an argument into a function, such as cache()) difficult: @@ -239,7 +239,7 @@ setMethod("cache", # @aliases persist,RDD-method setMethod("persist", signature(x = "RDD", newLevel = "character"), - function(x, newLevel) { + function(x, newLevel = "MEMORY_ONLY") { callJMethod(getJRDD(x), "persist", getStorageLevel(newLevel)) x@env$isCached <- TRUE x @@ -363,7 +363,7 @@ setMethod("collectPartition", # @description # \code{collectAsMap} returns a named list as a map that contains all of the elements -# in a key-value pair RDD. +# in a key-value pair RDD. # @examples #\dontrun{ # sc <- sparkR.init() @@ -666,7 +666,7 @@ setMethod("minimum", # rdd <- parallelize(sc, 1:10) # sumRDD(rdd) # 55 #} -# @rdname sumRDD +# @rdname sumRDD # @aliases sumRDD,RDD setMethod("sumRDD", signature(x = "RDD"), @@ -1090,11 +1090,11 @@ setMethod("sortBy", # Return: # A list of the first N elements from the RDD in the specified order. # -takeOrderedElem <- function(x, num, ascending = TRUE) { +takeOrderedElem <- function(x, num, ascending = TRUE) { if (num <= 0L) { return(list()) } - + partitionFunc <- function(part) { if (num < length(part)) { # R limitation: order works only on primitive types! @@ -1152,7 +1152,7 @@ takeOrderedElem <- function(x, num, ascending = TRUE) { # @aliases takeOrdered,RDD,RDD-method setMethod("takeOrdered", signature(x = "RDD", num = "integer"), - function(x, num) { + function(x, num) { takeOrderedElem(x, num) }) @@ -1173,7 +1173,7 @@ setMethod("takeOrdered", # @aliases top,RDD,RDD-method setMethod("top", signature(x = "RDD", num = "integer"), - function(x, num) { + function(x, num) { takeOrderedElem(x, num, FALSE) }) @@ -1181,7 +1181,7 @@ setMethod("top", # # Aggregate the elements of each partition, and then the results for all the # partitions, using a given associative function and a neutral "zero value". -# +# # @param x An RDD. # @param zeroValue A neutral "zero value". # @param op An associative function for the folding operation. @@ -1207,7 +1207,7 @@ setMethod("fold", # # Aggregate the elements of each partition, and then the results for all the # partitions, using given combine functions and a neutral "zero value". -# +# # @param x An RDD. # @param zeroValue A neutral "zero value". # @param seqOp A function to aggregate the RDD elements. It may return a different @@ -1230,11 +1230,11 @@ setMethod("fold", # @aliases aggregateRDD,RDD,RDD-method setMethod("aggregateRDD", signature(x = "RDD", zeroValue = "ANY", seqOp = "ANY", combOp = "ANY"), - function(x, zeroValue, seqOp, combOp) { + function(x, zeroValue, seqOp, combOp) { partitionFunc <- function(part) { Reduce(seqOp, part, zeroValue) } - + partitionList <- collect(lapplyPartition(x, partitionFunc), flatten = FALSE) Reduce(combOp, partitionList, zeroValue) @@ -1330,7 +1330,7 @@ setMethod("setName", #\dontrun{ # sc <- sparkR.init() # rdd <- parallelize(sc, list("a", "b", "c", "d", "e"), 3L) -# collect(zipWithUniqueId(rdd)) +# collect(zipWithUniqueId(rdd)) # # list(list("a", 0), list("b", 3), list("c", 1), list("d", 4), list("e", 2)) #} # @rdname zipWithUniqueId @@ -1426,7 +1426,7 @@ setMethod("glom", partitionFunc <- function(part) { list(part) } - + lapplyPartition(x, partitionFunc) }) @@ -1498,16 +1498,16 @@ setMethod("zipRDD", # The jrdd's elements are of scala Tuple2 type. The serialized # flag here is used for the elements inside the tuples. rdd <- RDD(jrdd, getSerializedMode(rdds[[1]])) - + mergePartitions(rdd, TRUE) }) # Cartesian product of this RDD and another one. # -# Return the Cartesian product of this RDD and another one, -# that is, the RDD of all pairs of elements (a, b) where a +# Return the Cartesian product of this RDD and another one, +# that is, the RDD of all pairs of elements (a, b) where a # is in this and b is in other. -# +# # @param x An RDD. # @param other An RDD. # @return A new RDD which is the Cartesian product of these two RDDs. @@ -1515,7 +1515,7 @@ setMethod("zipRDD", #\dontrun{ # sc <- sparkR.init() # rdd <- parallelize(sc, 1:2) -# sortByKey(cartesian(rdd, rdd)) +# sortByKey(cartesian(rdd, rdd)) # # list(list(1, 1), list(1, 2), list(2, 1), list(2, 2)) #} # @rdname cartesian @@ -1528,7 +1528,7 @@ setMethod("cartesian", # The jrdd's elements are of scala Tuple2 type. The serialized # flag here is used for the elements inside the tuples. rdd <- RDD(jrdd, getSerializedMode(rdds[[1]])) - + mergePartitions(rdd, FALSE) }) @@ -1598,11 +1598,11 @@ setMethod("intersection", # Zips an RDD's partitions with one (or more) RDD(s). # Same as zipPartitions in Spark. -# +# # @param ... RDDs to be zipped. # @param func A function to transform zipped partitions. -# @return A new RDD by applying a function to the zipped partitions. -# Assumes that all the RDDs have the *same number of partitions*, but +# @return A new RDD by applying a function to the zipped partitions. +# Assumes that all the RDDs have the *same number of partitions*, but # does *not* require them to have the same number of elements in each partition. # @examples #\dontrun{ @@ -1610,7 +1610,7 @@ setMethod("intersection", # rdd1 <- parallelize(sc, 1:2, 2L) # 1, 2 # rdd2 <- parallelize(sc, 1:4, 2L) # 1:2, 3:4 # rdd3 <- parallelize(sc, 1:6, 2L) # 1:3, 4:6 -# collect(zipPartitions(rdd1, rdd2, rdd3, +# collect(zipPartitions(rdd1, rdd2, rdd3, # func = function(x, y, z) { list(list(x, y, z))} )) # # list(list(1, c(1,2), c(1,2,3)), list(2, c(3,4), c(4,5,6))) #} @@ -1627,7 +1627,7 @@ setMethod("zipPartitions", if (length(unique(nPart)) != 1) { stop("Can only zipPartitions RDDs which have the same number of partitions.") } - + rrdds <- lapply(rrdds, function(rdd) { mapPartitionsWithIndex(rdd, function(partIndex, part) { print(length(part)) diff --git a/R/pkg/R/SQLContext.R b/R/pkg/R/SQLContext.R index 531442e8459e..9a743a341153 100644 --- a/R/pkg/R/SQLContext.R +++ b/R/pkg/R/SQLContext.R @@ -69,7 +69,7 @@ infer_type <- function(x) { #' #' Converts an RDD to a DataFrame by infer the types. #' -#' @param sqlCtx A SQLContext +#' @param sqlContext A SQLContext #' @param data An RDD or list or data.frame #' @param schema a list of column names or named list (StructType), optional #' @return an DataFrame @@ -77,13 +77,13 @@ infer_type <- function(x) { #' @examples #'\dontrun{ #' sc <- sparkR.init() -#' sqlCtx <- sparkRSQL.init(sc) +#' sqlContext <- sparkRSQL.init(sc) #' rdd <- lapply(parallelize(sc, 1:10), function(x) list(a=x, b=as.character(x))) -#' df <- createDataFrame(sqlCtx, rdd) +#' df <- createDataFrame(sqlContext, rdd) #' } # TODO(davies): support sampling and infer type from NA -createDataFrame <- function(sqlCtx, data, schema = NULL, samplingRatio = 1.0) { +createDataFrame <- function(sqlContext, data, schema = NULL, samplingRatio = 1.0) { if (is.data.frame(data)) { # get the names of columns, they will be put into RDD schema <- names(data) @@ -102,7 +102,7 @@ createDataFrame <- function(sqlCtx, data, schema = NULL, samplingRatio = 1.0) { }) } if (is.list(data)) { - sc <- callJStatic("org.apache.spark.sql.api.r.SQLUtils", "getJavaSparkContext", sqlCtx) + sc <- callJStatic("org.apache.spark.sql.api.r.SQLUtils", "getJavaSparkContext", sqlContext) rdd <- parallelize(sc, data) } else if (inherits(data, "RDD")) { rdd <- data @@ -146,7 +146,7 @@ createDataFrame <- function(sqlCtx, data, schema = NULL, samplingRatio = 1.0) { jrdd <- getJRDD(lapply(rdd, function(x) x), "row") srdd <- callJMethod(jrdd, "rdd") sdf <- callJStatic("org.apache.spark.sql.api.r.SQLUtils", "createDF", - srdd, schema$jobj, sqlCtx) + srdd, schema$jobj, sqlContext) dataFrame(sdf) } @@ -161,7 +161,7 @@ createDataFrame <- function(sqlCtx, data, schema = NULL, samplingRatio = 1.0) { # @examples #\dontrun{ # sc <- sparkR.init() -# sqlCtx <- sparkRSQL.init(sc) +# sqlContext <- sparkRSQL.init(sc) # rdd <- lapply(parallelize(sc, 1:10), function(x) list(a=x, b=as.character(x))) # df <- toDF(rdd) # } @@ -170,39 +170,39 @@ setGeneric("toDF", function(x, ...) { standardGeneric("toDF") }) setMethod("toDF", signature(x = "RDD"), function(x, ...) { - sqlCtx <- if (exists(".sparkRHivesc", envir = .sparkREnv)) { + sqlContext <- if (exists(".sparkRHivesc", envir = .sparkREnv)) { get(".sparkRHivesc", envir = .sparkREnv) } else if (exists(".sparkRSQLsc", envir = .sparkREnv)) { get(".sparkRSQLsc", envir = .sparkREnv) } else { stop("no SQL context available") } - createDataFrame(sqlCtx, x, ...) + createDataFrame(sqlContext, x, ...) }) #' Create a DataFrame from a JSON file. #' -#' Loads a JSON file (one object per line), returning the result as a DataFrame +#' Loads a JSON file (one object per line), returning the result as a DataFrame #' It goes through the entire dataset once to determine the schema. #' -#' @param sqlCtx SQLContext to use +#' @param sqlContext SQLContext to use #' @param path Path of file to read. A vector of multiple paths is allowed. #' @return DataFrame #' @export #' @examples #'\dontrun{ #' sc <- sparkR.init() -#' sqlCtx <- sparkRSQL.init(sc) +#' sqlContext <- sparkRSQL.init(sc) #' path <- "path/to/file.json" -#' df <- jsonFile(sqlCtx, path) +#' df <- jsonFile(sqlContext, path) #' } -jsonFile <- function(sqlCtx, path) { +jsonFile <- function(sqlContext, path) { # Allow the user to have a more flexible definiton of the text file path path <- normalizePath(path) # Convert a string vector of paths to a string containing comma separated paths path <- paste(path, collapse = ",") - sdf <- callJMethod(sqlCtx, "jsonFile", path) + sdf <- callJMethod(sqlContext, "jsonFile", path) dataFrame(sdf) } @@ -211,7 +211,7 @@ jsonFile <- function(sqlCtx, path) { # # Loads an RDD storing one JSON object per string as a DataFrame. # -# @param sqlCtx SQLContext to use +# @param sqlContext SQLContext to use # @param rdd An RDD of JSON string # @param schema A StructType object to use as schema # @param samplingRatio The ratio of simpling used to infer the schema @@ -220,16 +220,16 @@ jsonFile <- function(sqlCtx, path) { # @examples #\dontrun{ # sc <- sparkR.init() -# sqlCtx <- sparkRSQL.init(sc) +# sqlContext <- sparkRSQL.init(sc) # rdd <- texFile(sc, "path/to/json") -# df <- jsonRDD(sqlCtx, rdd) +# df <- jsonRDD(sqlContext, rdd) # } # TODO: support schema -jsonRDD <- function(sqlCtx, rdd, schema = NULL, samplingRatio = 1.0) { +jsonRDD <- function(sqlContext, rdd, schema = NULL, samplingRatio = 1.0) { rdd <- serializeToString(rdd) if (is.null(schema)) { - sdf <- callJMethod(sqlCtx, "jsonRDD", callJMethod(getJRDD(rdd), "rdd"), samplingRatio) + sdf <- callJMethod(sqlContext, "jsonRDD", callJMethod(getJRDD(rdd), "rdd"), samplingRatio) dataFrame(sdf) } else { stop("not implemented") @@ -238,68 +238,67 @@ jsonRDD <- function(sqlCtx, rdd, schema = NULL, samplingRatio = 1.0) { #' Create a DataFrame from a Parquet file. -#' +#' #' Loads a Parquet file, returning the result as a DataFrame. #' -#' @param sqlCtx SQLContext to use +#' @param sqlContext SQLContext to use #' @param ... Path(s) of parquet file(s) to read. #' @return DataFrame #' @export # TODO: Implement saveasParquetFile and write examples for both -parquetFile <- function(sqlCtx, ...) { +parquetFile <- function(sqlContext, ...) { # Allow the user to have a more flexible definiton of the text file path paths <- lapply(list(...), normalizePath) - sdf <- callJMethod(sqlCtx, "parquetFile", paths) + sdf <- callJMethod(sqlContext, "parquetFile", paths) dataFrame(sdf) } #' SQL Query -#' +#' #' Executes a SQL query using Spark, returning the result as a DataFrame. #' -#' @param sqlCtx SQLContext to use +#' @param sqlContext SQLContext to use #' @param sqlQuery A character vector containing the SQL query #' @return DataFrame #' @export #' @examples #'\dontrun{ #' sc <- sparkR.init() -#' sqlCtx <- sparkRSQL.init(sc) +#' sqlContext <- sparkRSQL.init(sc) #' path <- "path/to/file.json" -#' df <- jsonFile(sqlCtx, path) +#' df <- jsonFile(sqlContext, path) #' registerTempTable(df, "table") -#' new_df <- sql(sqlCtx, "SELECT * FROM table") +#' new_df <- sql(sqlContext, "SELECT * FROM table") #' } -sql <- function(sqlCtx, sqlQuery) { - sdf <- callJMethod(sqlCtx, "sql", sqlQuery) - dataFrame(sdf) +sql <- function(sqlContext, sqlQuery) { + sdf <- callJMethod(sqlContext, "sql", sqlQuery) + dataFrame(sdf) } - #' Create a DataFrame from a SparkSQL Table -#' +#' #' Returns the specified Table as a DataFrame. The Table must have already been registered #' in the SQLContext. #' -#' @param sqlCtx SQLContext to use +#' @param sqlContext SQLContext to use #' @param tableName The SparkSQL Table to convert to a DataFrame. #' @return DataFrame #' @export #' @examples #'\dontrun{ #' sc <- sparkR.init() -#' sqlCtx <- sparkRSQL.init(sc) +#' sqlContext <- sparkRSQL.init(sc) #' path <- "path/to/file.json" -#' df <- jsonFile(sqlCtx, path) +#' df <- jsonFile(sqlContext, path) #' registerTempTable(df, "table") -#' new_df <- table(sqlCtx, "table") +#' new_df <- table(sqlContext, "table") #' } -table <- function(sqlCtx, tableName) { - sdf <- callJMethod(sqlCtx, "table", tableName) - dataFrame(sdf) +table <- function(sqlContext, tableName) { + sdf <- callJMethod(sqlContext, "table", tableName) + dataFrame(sdf) } @@ -307,22 +306,22 @@ table <- function(sqlCtx, tableName) { #' #' Returns a DataFrame containing names of tables in the given database. #' -#' @param sqlCtx SQLContext to use +#' @param sqlContext SQLContext to use #' @param databaseName name of the database #' @return a DataFrame #' @export #' @examples #'\dontrun{ #' sc <- sparkR.init() -#' sqlCtx <- sparkRSQL.init(sc) -#' tables(sqlCtx, "hive") +#' sqlContext <- sparkRSQL.init(sc) +#' tables(sqlContext, "hive") #' } -tables <- function(sqlCtx, databaseName = NULL) { +tables <- function(sqlContext, databaseName = NULL) { jdf <- if (is.null(databaseName)) { - callJMethod(sqlCtx, "tables") + callJMethod(sqlContext, "tables") } else { - callJMethod(sqlCtx, "tables", databaseName) + callJMethod(sqlContext, "tables", databaseName) } dataFrame(jdf) } @@ -332,82 +331,82 @@ tables <- function(sqlCtx, databaseName = NULL) { #' #' Returns the names of tables in the given database as an array. #' -#' @param sqlCtx SQLContext to use +#' @param sqlContext SQLContext to use #' @param databaseName name of the database #' @return a list of table names #' @export #' @examples #'\dontrun{ #' sc <- sparkR.init() -#' sqlCtx <- sparkRSQL.init(sc) -#' tableNames(sqlCtx, "hive") +#' sqlContext <- sparkRSQL.init(sc) +#' tableNames(sqlContext, "hive") #' } -tableNames <- function(sqlCtx, databaseName = NULL) { +tableNames <- function(sqlContext, databaseName = NULL) { if (is.null(databaseName)) { - callJMethod(sqlCtx, "tableNames") + callJMethod(sqlContext, "tableNames") } else { - callJMethod(sqlCtx, "tableNames", databaseName) + callJMethod(sqlContext, "tableNames", databaseName) } } #' Cache Table -#' +#' #' Caches the specified table in-memory. #' -#' @param sqlCtx SQLContext to use +#' @param sqlContext SQLContext to use #' @param tableName The name of the table being cached #' @return DataFrame #' @export #' @examples #'\dontrun{ #' sc <- sparkR.init() -#' sqlCtx <- sparkRSQL.init(sc) +#' sqlContext <- sparkRSQL.init(sc) #' path <- "path/to/file.json" -#' df <- jsonFile(sqlCtx, path) +#' df <- jsonFile(sqlContext, path) #' registerTempTable(df, "table") -#' cacheTable(sqlCtx, "table") +#' cacheTable(sqlContext, "table") #' } -cacheTable <- function(sqlCtx, tableName) { - callJMethod(sqlCtx, "cacheTable", tableName) +cacheTable <- function(sqlContext, tableName) { + callJMethod(sqlContext, "cacheTable", tableName) } #' Uncache Table -#' +#' #' Removes the specified table from the in-memory cache. #' -#' @param sqlCtx SQLContext to use +#' @param sqlContext SQLContext to use #' @param tableName The name of the table being uncached #' @return DataFrame #' @export #' @examples #'\dontrun{ #' sc <- sparkR.init() -#' sqlCtx <- sparkRSQL.init(sc) +#' sqlContext <- sparkRSQL.init(sc) #' path <- "path/to/file.json" -#' df <- jsonFile(sqlCtx, path) +#' df <- jsonFile(sqlContext, path) #' registerTempTable(df, "table") -#' uncacheTable(sqlCtx, "table") +#' uncacheTable(sqlContext, "table") #' } -uncacheTable <- function(sqlCtx, tableName) { - callJMethod(sqlCtx, "uncacheTable", tableName) +uncacheTable <- function(sqlContext, tableName) { + callJMethod(sqlContext, "uncacheTable", tableName) } #' Clear Cache #' #' Removes all cached tables from the in-memory cache. #' -#' @param sqlCtx SQLContext to use +#' @param sqlContext SQLContext to use #' @examples #' \dontrun{ -#' clearCache(sqlCtx) +#' clearCache(sqlContext) #' } -clearCache <- function(sqlCtx) { - callJMethod(sqlCtx, "clearCache") +clearCache <- function(sqlContext) { + callJMethod(sqlContext, "clearCache") } #' Drop Temporary Table @@ -415,22 +414,22 @@ clearCache <- function(sqlCtx) { #' Drops the temporary table with the given table name in the catalog. #' If the table has been cached/persisted before, it's also unpersisted. #' -#' @param sqlCtx SQLContext to use +#' @param sqlContext SQLContext to use #' @param tableName The name of the SparkSQL table to be dropped. #' @examples #' \dontrun{ #' sc <- sparkR.init() -#' sqlCtx <- sparkRSQL.init(sc) -#' df <- read.df(sqlCtx, path, "parquet") +#' sqlContext <- sparkRSQL.init(sc) +#' df <- read.df(sqlContext, path, "parquet") #' registerTempTable(df, "table") -#' dropTempTable(sqlCtx, "table") +#' dropTempTable(sqlContext, "table") #' } -dropTempTable <- function(sqlCtx, tableName) { +dropTempTable <- function(sqlContext, tableName) { if (class(tableName) != "character") { stop("tableName must be a string.") } - callJMethod(sqlCtx, "dropTempTable", tableName) + callJMethod(sqlContext, "dropTempTable", tableName) } #' Load an DataFrame @@ -441,7 +440,7 @@ dropTempTable <- function(sqlCtx, tableName) { #' If `source` is not specified, the default data source configured by #' "spark.sql.sources.default" will be used. #' -#' @param sqlCtx SQLContext to use +#' @param sqlContext SQLContext to use #' @param path The path of files to load #' @param source the name of external data source #' @return DataFrame @@ -449,24 +448,35 @@ dropTempTable <- function(sqlCtx, tableName) { #' @examples #'\dontrun{ #' sc <- sparkR.init() -#' sqlCtx <- sparkRSQL.init(sc) -#' df <- read.df(sqlCtx, "path/to/file.json", source = "json") +#' sqlContext <- sparkRSQL.init(sc) +#' df <- read.df(sqlContext, "path/to/file.json", source = "json") #' } -read.df <- function(sqlCtx, path = NULL, source = NULL, ...) { +read.df <- function(sqlContext, path = NULL, source = NULL, schema = NULL, ...) { options <- varargsToEnv(...) if (!is.null(path)) { options[['path']] <- path } - sdf <- callJMethod(sqlCtx, "load", source, options) + if (is.null(source)) { + sqlContext <- get(".sparkRSQLsc", envir = .sparkREnv) + source <- callJMethod(sqlContext, "getConf", "spark.sql.sources.default", + "org.apache.spark.sql.parquet") + } + if (!is.null(schema)) { + stopifnot(class(schema) == "structType") + sdf <- callJStatic("org.apache.spark.sql.api.r.SQLUtils", "loadDF", sqlContext, source, + schema$jobj, options) + } else { + sdf <- callJStatic("org.apache.spark.sql.api.r.SQLUtils", "loadDF", sqlContext, source, options) + } dataFrame(sdf) } #' @aliases loadDF #' @export -loadDF <- function(sqlCtx, path = NULL, source = NULL, ...) { - read.df(sqlCtx, path, source, ...) +loadDF <- function(sqlContext, path = NULL, source = NULL, schema = NULL, ...) { + read.df(sqlContext, path, source, schema, ...) } #' Create an external table @@ -478,7 +488,7 @@ loadDF <- function(sqlCtx, path = NULL, source = NULL, ...) { #' If `source` is not specified, the default data source configured by #' "spark.sql.sources.default" will be used. #' -#' @param sqlCtx SQLContext to use +#' @param sqlContext SQLContext to use #' @param tableName A name of the table #' @param path The path of files to load #' @param source the name of external data source @@ -487,15 +497,15 @@ loadDF <- function(sqlCtx, path = NULL, source = NULL, ...) { #' @examples #'\dontrun{ #' sc <- sparkR.init() -#' sqlCtx <- sparkRSQL.init(sc) -#' df <- sparkRSQL.createExternalTable(sqlCtx, "myjson", path="path/to/json", source="json") +#' sqlContext <- sparkRSQL.init(sc) +#' df <- sparkRSQL.createExternalTable(sqlContext, "myjson", path="path/to/json", source="json") #' } -createExternalTable <- function(sqlCtx, tableName, path = NULL, source = NULL, ...) { +createExternalTable <- function(sqlContext, tableName, path = NULL, source = NULL, ...) { options <- varargsToEnv(...) if (!is.null(path)) { options[['path']] <- path } - sdf <- callJMethod(sqlCtx, "createExternalTable", tableName, source, options) + sdf <- callJMethod(sqlContext, "createExternalTable", tableName, source, options) dataFrame(sdf) } diff --git a/R/pkg/R/broadcast.R b/R/pkg/R/broadcast.R index 23dc38780716..2403925b267c 100644 --- a/R/pkg/R/broadcast.R +++ b/R/pkg/R/broadcast.R @@ -27,9 +27,9 @@ # @description Broadcast variables can be created using the broadcast # function from a \code{SparkContext}. # @rdname broadcast-class -# @seealso broadcast +# @seealso broadcast # -# @param id Id of the backing Spark broadcast variable +# @param id Id of the backing Spark broadcast variable # @export setClass("Broadcast", slots = list(id = "character")) @@ -68,7 +68,7 @@ setMethod("value", # variable on workers. Not intended for use outside the package. # # @rdname broadcast-internal -# @seealso broadcast, value +# @seealso broadcast, value # @param bcastId The id of broadcast variable to set # @param value The value to be set diff --git a/R/pkg/R/client.R b/R/pkg/R/client.R index 1281c41213e3..78c7a3037ffa 100644 --- a/R/pkg/R/client.R +++ b/R/pkg/R/client.R @@ -34,24 +34,36 @@ connectBackend <- function(hostname, port, timeout = 6000) { con } -launchBackend <- function(args, sparkHome, jars, sparkSubmitOpts) { +determineSparkSubmitBin <- function() { if (.Platform$OS.type == "unix") { sparkSubmitBinName = "spark-submit" } else { sparkSubmitBinName = "spark-submit.cmd" } + sparkSubmitBinName +} + +generateSparkSubmitArgs <- function(args, sparkHome, jars, sparkSubmitOpts, packages) { + if (jars != "") { + jars <- paste("--jars", jars) + } + + if (packages != "") { + packages <- paste("--packages", packages) + } + combinedArgs <- paste(jars, packages, sparkSubmitOpts, args, sep = " ") + combinedArgs +} + +launchBackend <- function(args, sparkHome, jars, sparkSubmitOpts, packages) { + sparkSubmitBinName <- determineSparkSubmitBin() if (sparkHome != "") { sparkSubmitBin <- file.path(sparkHome, "bin", sparkSubmitBinName) } else { sparkSubmitBin <- sparkSubmitBinName } - - if (jars != "") { - jars <- paste("--jars", jars) - } - - combinedArgs <- paste(jars, sparkSubmitOpts, args, sep = " ") + combinedArgs <- generateSparkSubmitArgs(args, sparkHome, jars, sparkSubmitOpts, packages) cat("Launching java with spark-submit command", sparkSubmitBin, combinedArgs, "\n") invisible(system2(sparkSubmitBin, combinedArgs, wait = F)) } diff --git a/R/pkg/R/column.R b/R/pkg/R/column.R index 80e92d3105a3..8e4b0f5bf1c4 100644 --- a/R/pkg/R/column.R +++ b/R/pkg/R/column.R @@ -210,6 +210,22 @@ setMethod("cast", } }) +#' Match a column with given values. +#' +#' @rdname column +#' @return a matched values as a result of comparing with given values. +#' \dontrun{ +#' filter(df, "age in (10, 30)") +#' where(df, df$age %in% c(10, 30)) +#' } +setMethod("%in%", + signature(x = "Column"), + function(x, table) { + table <- listToSeq(as.list(table)) + jc <- callJMethod(x@jc, "in", table) + return(column(jc)) + }) + #' Approx Count Distinct #' #' @rdname column diff --git a/R/pkg/R/deserialize.R b/R/pkg/R/deserialize.R index 257b435607ce..d961bbc38368 100644 --- a/R/pkg/R/deserialize.R +++ b/R/pkg/R/deserialize.R @@ -18,7 +18,7 @@ # Utility functions to deserialize objects from Java. # Type mapping from Java to R -# +# # void -> NULL # Int -> integer # String -> character diff --git a/R/pkg/R/generics.R b/R/pkg/R/generics.R index a23d3b217b2f..fad9d71158c5 100644 --- a/R/pkg/R/generics.R +++ b/R/pkg/R/generics.R @@ -20,7 +20,8 @@ # @rdname aggregateRDD # @seealso reduce # @export -setGeneric("aggregateRDD", function(x, zeroValue, seqOp, combOp) { standardGeneric("aggregateRDD") }) +setGeneric("aggregateRDD", + function(x, zeroValue, seqOp, combOp) { standardGeneric("aggregateRDD") }) # @rdname cache-methods # @export @@ -130,7 +131,7 @@ setGeneric("maximum", function(x) { standardGeneric("maximum") }) # @export setGeneric("minimum", function(x) { standardGeneric("minimum") }) -# @rdname sumRDD +# @rdname sumRDD # @export setGeneric("sumRDD", function(x) { standardGeneric("sumRDD") }) @@ -219,7 +220,7 @@ setGeneric("zipRDD", function(x, other) { standardGeneric("zipRDD") }) # @rdname zipRDD # @export -setGeneric("zipPartitions", function(..., func) { standardGeneric("zipPartitions") }, +setGeneric("zipPartitions", function(..., func) { standardGeneric("zipPartitions") }, signature = "...") # @rdname zipWithIndex @@ -364,7 +365,7 @@ setGeneric("subtract", # @rdname subtractByKey # @export -setGeneric("subtractByKey", +setGeneric("subtractByKey", function(x, other, numPartitions = 1) { standardGeneric("subtractByKey") }) @@ -396,6 +397,20 @@ setGeneric("columns", function(x) {standardGeneric("columns") }) #' @export setGeneric("describe", function(x, col, ...) { standardGeneric("describe") }) +#' @rdname nafunctions +#' @export +setGeneric("dropna", + function(x, how = c("any", "all"), minNonNulls = NULL, cols = NULL) { + standardGeneric("dropna") + }) + +#' @rdname nafunctions +#' @export +setGeneric("na.omit", + function(x, how = c("any", "all"), minNonNulls = NULL, cols = NULL) { + standardGeneric("na.omit") + }) + #' @rdname schema #' @export setGeneric("dtypes", function(x) { standardGeneric("dtypes") }) @@ -408,6 +423,10 @@ setGeneric("explain", function(x, ...) { standardGeneric("explain") }) #' @export setGeneric("except", function(x, y) { standardGeneric("except") }) +#' @rdname nafunctions +#' @export +setGeneric("fillna", function(x, value, cols = NULL) { standardGeneric("fillna") }) + #' @rdname filter #' @export setGeneric("filter", function(x, condition) { standardGeneric("filter") }) @@ -482,11 +501,11 @@ setGeneric("saveAsTable", function(df, tableName, source, mode, ...) { #' @rdname write.df #' @export -setGeneric("write.df", function(df, path, source, mode, ...) { standardGeneric("write.df") }) +setGeneric("write.df", function(df, path, ...) { standardGeneric("write.df") }) #' @rdname write.df #' @export -setGeneric("saveDF", function(df, path, source, mode, ...) { standardGeneric("saveDF") }) +setGeneric("saveDF", function(df, path, ...) { standardGeneric("saveDF") }) #' @rdname schema #' @export @@ -638,4 +657,3 @@ setGeneric("toRadians", function(x) { standardGeneric("toRadians") }) #' @rdname column #' @export setGeneric("upper", function(x) { standardGeneric("upper") }) - diff --git a/R/pkg/R/group.R b/R/pkg/R/group.R index b75848199757..8f1c68f7c4d2 100644 --- a/R/pkg/R/group.R +++ b/R/pkg/R/group.R @@ -136,4 +136,3 @@ createMethods <- function() { } createMethods() - diff --git a/R/pkg/R/jobj.R b/R/pkg/R/jobj.R index a8a25230b636..0838a7bb35e0 100644 --- a/R/pkg/R/jobj.R +++ b/R/pkg/R/jobj.R @@ -16,7 +16,7 @@ # # References to objects that exist on the JVM backend -# are maintained using the jobj. +# are maintained using the jobj. #' @include generics.R NULL diff --git a/R/pkg/R/pairRDD.R b/R/pkg/R/pairRDD.R index 7694652856da..0f1179e0aa51 100644 --- a/R/pkg/R/pairRDD.R +++ b/R/pkg/R/pairRDD.R @@ -329,7 +329,7 @@ setMethod("reduceByKey", convertEnvsToList(keys, vals) } locallyReduced <- lapplyPartition(x, reduceVals) - shuffled <- partitionBy(locallyReduced, numPartitions) + shuffled <- partitionBy(locallyReduced, numToInt(numPartitions)) lapplyPartition(shuffled, reduceVals) }) @@ -436,7 +436,7 @@ setMethod("combineByKey", convertEnvsToList(keys, combiners) } locallyCombined <- lapplyPartition(x, combineLocally) - shuffled <- partitionBy(locallyCombined, numPartitions) + shuffled <- partitionBy(locallyCombined, numToInt(numPartitions)) mergeAfterShuffle <- function(part) { combiners <- new.env() keys <- new.env() @@ -560,8 +560,8 @@ setMethod("join", # Left outer join two RDDs # # @description -# \code{leftouterjoin} This function left-outer-joins two RDDs where every element is of the form list(K, V). -# The key types of the two RDDs should be the same. +# \code{leftouterjoin} This function left-outer-joins two RDDs where every element is of +# the form list(K, V). The key types of the two RDDs should be the same. # # @param x An RDD to be joined. Should be an RDD where each element is # list(K, V). @@ -597,8 +597,8 @@ setMethod("leftOuterJoin", # Right outer join two RDDs # # @description -# \code{rightouterjoin} This function right-outer-joins two RDDs where every element is of the form list(K, V). -# The key types of the two RDDs should be the same. +# \code{rightouterjoin} This function right-outer-joins two RDDs where every element is of +# the form list(K, V). The key types of the two RDDs should be the same. # # @param x An RDD to be joined. Should be an RDD where each element is # list(K, V). @@ -634,8 +634,8 @@ setMethod("rightOuterJoin", # Full outer join two RDDs # # @description -# \code{fullouterjoin} This function full-outer-joins two RDDs where every element is of the form list(K, V). -# The key types of the two RDDs should be the same. +# \code{fullouterjoin} This function full-outer-joins two RDDs where every element is of +# the form list(K, V). The key types of the two RDDs should be the same. # # @param x An RDD to be joined. Should be an RDD where each element is # list(K, V). @@ -784,7 +784,7 @@ setMethod("sortByKey", newRDD <- partitionBy(x, numPartitions, rangePartitionFunc) lapplyPartition(newRDD, partitionFunc) }) - + # Subtract a pair RDD with another pair RDD. # # Return an RDD with the pairs from x whose keys are not in other. @@ -820,7 +820,7 @@ setMethod("subtractByKey", }) # Return a subset of this RDD sampled by key. -# +# # @description # \code{sampleByKey} Create a sample of this RDD using variable sampling rates # for different keys as specified by fractions, a key to sampling rate map. diff --git a/R/pkg/R/schema.R b/R/pkg/R/schema.R index e442119086b1..15e2bdbd55d7 100644 --- a/R/pkg/R/schema.R +++ b/R/pkg/R/schema.R @@ -20,7 +20,7 @@ #' structType #' -#' Create a structType object that contains the metadata for a DataFrame. Intended for +#' Create a structType object that contains the metadata for a DataFrame. Intended for #' use with createDataFrame and toDF. #' #' @param x a structField object (created with the field() function) diff --git a/R/pkg/R/serialize.R b/R/pkg/R/serialize.R index c53d0a961016..78535eff0d2f 100644 --- a/R/pkg/R/serialize.R +++ b/R/pkg/R/serialize.R @@ -37,6 +37,14 @@ writeObject <- function(con, object, writeType = TRUE) { # passing in vectors as arrays and instead require arrays to be passed # as lists. type <- class(object)[[1]] # class of POSIXlt is c("POSIXlt", "POSIXt") + # Checking types is needed here, since ‘is.na’ only handles atomic vectors, + # lists and pairlists + if (type %in% c("integer", "character", "logical", "double", "numeric")) { + if (is.na(object)) { + object <- NULL + type <- "NULL" + } + } if (writeType) { writeType(con, type) } @@ -160,6 +168,14 @@ writeList <- function(con, arr) { } } +# Used to pass arrays where the elements can be of different types +writeGenericList <- function(con, list) { + writeInt(con, length(list)) + for (elem in list) { + writeObject(con, elem) + } +} + # Used to pass in hash maps required on Java side. writeEnv <- function(con, env) { len <- length(env) @@ -168,7 +184,7 @@ writeEnv <- function(con, env) { if (len > 0) { writeList(con, as.list(ls(env))) vals <- lapply(ls(env), function(x) { env[[x]] }) - writeList(con, as.list(vals)) + writeGenericList(con, as.list(vals)) } } diff --git a/R/pkg/R/sparkR.R b/R/pkg/R/sparkR.R index bc82df01f0ff..048eb8ed541e 100644 --- a/R/pkg/R/sparkR.R +++ b/R/pkg/R/sparkR.R @@ -43,7 +43,7 @@ sparkR.stop <- function() { callJMethod(sc, "stop") rm(".sparkRjsc", envir = env) } - + if (exists(".backendLaunched", envir = env)) { callJStatic("SparkRHandler", "stopBackend") } @@ -81,6 +81,7 @@ sparkR.stop <- function() { #' @param sparkExecutorEnv Named list of environment variables to be used when launching executors. #' @param sparkJars Character string vector of jar files to pass to the worker nodes. #' @param sparkRLibDir The path where R is installed on the worker nodes. +#' @param sparkPackages Character string vector of packages from spark-packages.org #' @export #' @examples #'\dontrun{ @@ -100,14 +101,16 @@ sparkR.init <- function( sparkEnvir = list(), sparkExecutorEnv = list(), sparkJars = "", - sparkRLibDir = "") { + sparkRLibDir = "", + sparkPackages = "") { if (exists(".sparkRjsc", envir = .sparkREnv)) { - cat("Re-using existing Spark Context. Please stop SparkR with sparkR.stop() or restart R to create a new Spark Context\n") + cat(paste("Re-using existing Spark Context.", + "Please stop SparkR with sparkR.stop() or restart R to create a new Spark Context\n")) return(get(".sparkRjsc", envir = .sparkREnv)) } - sparkMem <- Sys.getenv("SPARK_MEM", "512m") + sparkMem <- Sys.getenv("SPARK_MEM", "1024m") jars <- suppressWarnings(normalizePath(as.character(sparkJars))) # Classpath separator is ";" on Windows @@ -129,7 +132,8 @@ sparkR.init <- function( args = path, sparkHome = sparkHome, jars = jars, - sparkSubmitOpts = Sys.getenv("SPARKR_SUBMIT_ARGS", "sparkr-shell")) + sparkSubmitOpts = Sys.getenv("SPARKR_SUBMIT_ARGS", "sparkr-shell"), + packages = sparkPackages) # wait atmost 100 seconds for JVM to launch wait <- 0.1 for (i in 1:25) { @@ -174,17 +178,19 @@ sparkR.init <- function( for (varname in names(sparkEnvir)) { sparkEnvirMap[[varname]] <- sparkEnvir[[varname]] } - + sparkExecutorEnvMap <- new.env() if (!any(names(sparkExecutorEnv) == "LD_LIBRARY_PATH")) { - sparkExecutorEnvMap[["LD_LIBRARY_PATH"]] <- paste0("$LD_LIBRARY_PATH:",Sys.getenv("LD_LIBRARY_PATH")) + sparkExecutorEnvMap[["LD_LIBRARY_PATH"]] <- + paste0("$LD_LIBRARY_PATH:",Sys.getenv("LD_LIBRARY_PATH")) } for (varname in names(sparkExecutorEnv)) { sparkExecutorEnvMap[[varname]] <- sparkExecutorEnv[[varname]] } nonEmptyJars <- Filter(function(x) { x != "" }, jars) - localJarPaths <- sapply(nonEmptyJars, function(j) { utils::URLencode(paste("file:", uriSep, j, sep = "")) }) + localJarPaths <- sapply(nonEmptyJars, + function(j) { utils::URLencode(paste("file:", uriSep, j, sep = "")) }) # Set the start time to identify jobjs # Seconds resolution is good enough for this purpose, so use ints @@ -214,7 +220,7 @@ sparkR.init <- function( #' Initialize a new SQLContext. #' -#' This function creates a SparkContext from an existing JavaSparkContext and +#' This function creates a SparkContext from an existing JavaSparkContext and #' then uses it to initialize a new SQLContext #' #' @param jsc The existing JavaSparkContext created with SparkR.init() @@ -222,19 +228,26 @@ sparkR.init <- function( #' @examples #'\dontrun{ #' sc <- sparkR.init() -#' sqlCtx <- sparkRSQL.init(sc) +#' sqlContext <- sparkRSQL.init(sc) #'} -sparkRSQL.init <- function(jsc) { +sparkRSQL.init <- function(jsc = NULL) { if (exists(".sparkRSQLsc", envir = .sparkREnv)) { return(get(".sparkRSQLsc", envir = .sparkREnv)) } - sqlCtx <- callJStatic("org.apache.spark.sql.api.r.SQLUtils", - "createSQLContext", - jsc) - assign(".sparkRSQLsc", sqlCtx, envir = .sparkREnv) - sqlCtx + # If jsc is NULL, create a Spark Context + sc <- if (is.null(jsc)) { + sparkR.init() + } else { + jsc + } + + sqlContext <- callJStatic("org.apache.spark.sql.api.r.SQLUtils", + "createSQLContext", + sc) + assign(".sparkRSQLsc", sqlContext, envir = .sparkREnv) + sqlContext } #' Initialize a new HiveContext. @@ -246,15 +259,22 @@ sparkRSQL.init <- function(jsc) { #' @examples #'\dontrun{ #' sc <- sparkR.init() -#' sqlCtx <- sparkRHive.init(sc) +#' sqlContext <- sparkRHive.init(sc) #'} -sparkRHive.init <- function(jsc) { +sparkRHive.init <- function(jsc = NULL) { if (exists(".sparkRHivesc", envir = .sparkREnv)) { return(get(".sparkRHivesc", envir = .sparkREnv)) } - ssc <- callJMethod(jsc, "sc") + # If jsc is NULL, create a Spark Context + sc <- if (is.null(jsc)) { + sparkR.init() + } else { + jsc + } + + ssc <- callJMethod(sc, "sc") hiveCtx <- tryCatch({ newJObject("org.apache.spark.sql.hive.HiveContext", ssc) }, error = function(err) { @@ -264,3 +284,47 @@ sparkRHive.init <- function(jsc) { assign(".sparkRHivesc", hiveCtx, envir = .sparkREnv) hiveCtx } + +#' Assigns a group ID to all the jobs started by this thread until the group ID is set to a +#' different value or cleared. +#' +#' @param sc existing spark context +#' @param groupid the ID to be assigned to job groups +#' @param description description for the the job group ID +#' @param interruptOnCancel flag to indicate if the job is interrupted on job cancellation +#' @examples +#'\dontrun{ +#' sc <- sparkR.init() +#' setJobGroup(sc, "myJobGroup", "My job group description", TRUE) +#'} + +setJobGroup <- function(sc, groupId, description, interruptOnCancel) { + callJMethod(sc, "setJobGroup", groupId, description, interruptOnCancel) +} + +#' Clear current job group ID and its description +#' +#' @param sc existing spark context +#' @examples +#'\dontrun{ +#' sc <- sparkR.init() +#' clearJobGroup(sc) +#'} + +clearJobGroup <- function(sc) { + callJMethod(sc, "clearJobGroup") +} + +#' Cancel active jobs for the specified group +#' +#' @param sc existing spark context +#' @param groupId the ID of job group to be cancelled +#' @examples +#'\dontrun{ +#' sc <- sparkR.init() +#' cancelJobGroup(sc, "myJobGroup") +#'} + +cancelJobGroup <- function(sc, groupId) { + callJMethod(sc, "cancelJobGroup", groupId) +} diff --git a/R/pkg/R/utils.R b/R/pkg/R/utils.R index 0e7b7bd5a5b3..ea629a64f715 100644 --- a/R/pkg/R/utils.R +++ b/R/pkg/R/utils.R @@ -122,13 +122,49 @@ hashCode <- function(key) { intBits <- packBits(rawToBits(rawVec), "integer") as.integer(bitwXor(intBits[2], intBits[1])) } else if (class(key) == "character") { - .Call("stringHashCode", key) + # TODO: SPARK-7839 means we might not have the native library available + if (is.loaded("stringHashCode")) { + .Call("stringHashCode", key) + } else { + n <- nchar(key) + if (n == 0) { + 0L + } else { + asciiVals <- sapply(charToRaw(key), function(x) { strtoi(x, 16L) }) + hashC <- 0 + for (k in 1:length(asciiVals)) { + hashC <- mult31AndAdd(hashC, asciiVals[k]) + } + as.integer(hashC) + } + } } else { warning(paste("Could not hash object, returning 0", sep = "")) as.integer(0) } } +# Helper function used to wrap a 'numeric' value to integer bounds. +# Useful for implementing C-like integer arithmetic +wrapInt <- function(value) { + if (value > .Machine$integer.max) { + value <- value - 2 * .Machine$integer.max - 2 + } else if (value < -1 * .Machine$integer.max) { + value <- 2 * .Machine$integer.max + value + 2 + } + value +} + +# Multiply `val` by 31 and add `addVal` to the result. Ensures that +# integer-overflows are handled at every step. +mult31AndAdd <- function(val, addVal) { + vec <- c(bitwShiftL(val, c(4,3,2,1,0)), addVal) + Reduce(function(a, b) { + wrapInt(as.numeric(a) + as.numeric(b)) + }, + vec) +} + # Create a new RDD with serializedMode == "byte". # Return itself if already in "byte" format. serializeToBytes <- function(rdd) { @@ -298,18 +334,21 @@ getStorageLevel <- function(newLevel = c("DISK_ONLY", "MEMORY_ONLY_SER_2", "OFF_HEAP")) { match.arg(newLevel) + storageLevelClass <- "org.apache.spark.storage.StorageLevel" storageLevel <- switch(newLevel, - "DISK_ONLY" = callJStatic("org.apache.spark.storage.StorageLevel", "DISK_ONLY"), - "DISK_ONLY_2" = callJStatic("org.apache.spark.storage.StorageLevel", "DISK_ONLY_2"), - "MEMORY_AND_DISK" = callJStatic("org.apache.spark.storage.StorageLevel", "MEMORY_AND_DISK"), - "MEMORY_AND_DISK_2" = callJStatic("org.apache.spark.storage.StorageLevel", "MEMORY_AND_DISK_2"), - "MEMORY_AND_DISK_SER" = callJStatic("org.apache.spark.storage.StorageLevel", "MEMORY_AND_DISK_SER"), - "MEMORY_AND_DISK_SER_2" = callJStatic("org.apache.spark.storage.StorageLevel", "MEMORY_AND_DISK_SER_2"), - "MEMORY_ONLY" = callJStatic("org.apache.spark.storage.StorageLevel", "MEMORY_ONLY"), - "MEMORY_ONLY_2" = callJStatic("org.apache.spark.storage.StorageLevel", "MEMORY_ONLY_2"), - "MEMORY_ONLY_SER" = callJStatic("org.apache.spark.storage.StorageLevel", "MEMORY_ONLY_SER"), - "MEMORY_ONLY_SER_2" = callJStatic("org.apache.spark.storage.StorageLevel", "MEMORY_ONLY_SER_2"), - "OFF_HEAP" = callJStatic("org.apache.spark.storage.StorageLevel", "OFF_HEAP")) + "DISK_ONLY" = callJStatic(storageLevelClass, "DISK_ONLY"), + "DISK_ONLY_2" = callJStatic(storageLevelClass, "DISK_ONLY_2"), + "MEMORY_AND_DISK" = callJStatic(storageLevelClass, "MEMORY_AND_DISK"), + "MEMORY_AND_DISK_2" = callJStatic(storageLevelClass, "MEMORY_AND_DISK_2"), + "MEMORY_AND_DISK_SER" = callJStatic(storageLevelClass, + "MEMORY_AND_DISK_SER"), + "MEMORY_AND_DISK_SER_2" = callJStatic(storageLevelClass, + "MEMORY_AND_DISK_SER_2"), + "MEMORY_ONLY" = callJStatic(storageLevelClass, "MEMORY_ONLY"), + "MEMORY_ONLY_2" = callJStatic(storageLevelClass, "MEMORY_ONLY_2"), + "MEMORY_ONLY_SER" = callJStatic(storageLevelClass, "MEMORY_ONLY_SER"), + "MEMORY_ONLY_SER_2" = callJStatic(storageLevelClass, "MEMORY_ONLY_SER_2"), + "OFF_HEAP" = callJStatic(storageLevelClass, "OFF_HEAP")) } # Utility function for functions where an argument needs to be integer but we want to allow @@ -332,21 +371,21 @@ listToSeq <- function(l) { } # Utility function to recursively traverse the Abstract Syntax Tree (AST) of a -# user defined function (UDF), and to examine variables in the UDF to decide +# user defined function (UDF), and to examine variables in the UDF to decide # if their values should be included in the new function environment. # param # node The current AST node in the traversal. # oldEnv The original function environment. # defVars An Accumulator of variables names defined in the function's calling environment, # including function argument and local variable names. -# checkedFunc An environment of function objects examined during cleanClosure. It can +# checkedFunc An environment of function objects examined during cleanClosure. It can # be considered as a "name"-to-"list of functions" mapping. # newEnv A new function environment to store necessary function dependencies, an output argument. processClosure <- function(node, oldEnv, defVars, checkedFuncs, newEnv) { nodeLen <- length(node) - + if (nodeLen > 1 && typeof(node) == "language") { - # Recursive case: current AST node is an internal node, check for its children. + # Recursive case: current AST node is an internal node, check for its children. if (length(node[[1]]) > 1) { for (i in 1:nodeLen) { processClosure(node[[i]], oldEnv, defVars, checkedFuncs, newEnv) @@ -357,7 +396,7 @@ processClosure <- function(node, oldEnv, defVars, checkedFuncs, newEnv) { for (i in 2:nodeLen) { processClosure(node[[i]], oldEnv, defVars, checkedFuncs, newEnv) } - } else if (nodeChar == "<-" || nodeChar == "=" || + } else if (nodeChar == "<-" || nodeChar == "=" || nodeChar == "<<-") { # Assignment Ops. defVar <- node[[2]] if (length(defVar) == 1 && typeof(defVar) == "symbol") { @@ -386,21 +425,21 @@ processClosure <- function(node, oldEnv, defVars, checkedFuncs, newEnv) { } } } - } else if (nodeLen == 1 && + } else if (nodeLen == 1 && (typeof(node) == "symbol" || typeof(node) == "language")) { # Base case: current AST node is a leaf node and a symbol or a function call. nodeChar <- as.character(node) if (!nodeChar %in% defVars$data) { # Not a function parameter or local variable. func.env <- oldEnv topEnv <- parent.env(.GlobalEnv) - # Search in function environment, and function's enclosing environments + # Search in function environment, and function's enclosing environments # up to global environment. There is no need to look into package environments - # above the global or namespace environment that is not SparkR below the global, + # above the global or namespace environment that is not SparkR below the global, # as they are assumed to be loaded on workers. while (!identical(func.env, topEnv)) { # Namespaces other than "SparkR" will not be searched. - if (!isNamespace(func.env) || - (getNamespaceName(func.env) == "SparkR" && + if (!isNamespace(func.env) || + (getNamespaceName(func.env) == "SparkR" && !(nodeChar %in% getNamespaceExports("SparkR")))) { # Only include SparkR internals. # Set parameter 'inherits' to FALSE since we do not need to search in # attached package environments. @@ -408,7 +447,7 @@ processClosure <- function(node, oldEnv, defVars, checkedFuncs, newEnv) { error = function(e) { FALSE })) { obj <- get(nodeChar, envir = func.env, inherits = FALSE) if (is.function(obj)) { # If the node is a function call. - funcList <- mget(nodeChar, envir = checkedFuncs, inherits = F, + funcList <- mget(nodeChar, envir = checkedFuncs, inherits = F, ifnotfound = list(list(NULL)))[[1]] found <- sapply(funcList, function(func) { ifelse(identical(func, obj), TRUE, FALSE) @@ -417,7 +456,7 @@ processClosure <- function(node, oldEnv, defVars, checkedFuncs, newEnv) { break } # Function has not been examined, record it and recursively clean its closure. - assign(nodeChar, + assign(nodeChar, if (is.null(funcList[[1]])) { list(obj) } else { @@ -430,7 +469,7 @@ processClosure <- function(node, oldEnv, defVars, checkedFuncs, newEnv) { break } } - + # Continue to search in enclosure. func.env <- parent.env(func.env) } @@ -438,8 +477,8 @@ processClosure <- function(node, oldEnv, defVars, checkedFuncs, newEnv) { } } -# Utility function to get user defined function (UDF) dependencies (closure). -# More specifically, this function captures the values of free variables defined +# Utility function to get user defined function (UDF) dependencies (closure). +# More specifically, this function captures the values of free variables defined # outside a UDF, and stores them in the function's environment. # param # func A function whose closure needs to be captured. @@ -452,7 +491,7 @@ cleanClosure <- function(func, checkedFuncs = new.env()) { newEnv <- new.env(parent = .GlobalEnv) func.body <- body(func) oldEnv <- environment(func) - # defVars is an Accumulator of variables names defined in the function's calling + # defVars is an Accumulator of variables names defined in the function's calling # environment. First, function's arguments are added to defVars. defVars <- initAccumulator() argNames <- names(as.list(args(func))) @@ -473,15 +512,15 @@ cleanClosure <- function(func, checkedFuncs = new.env()) { # return value # A list of two result RDDs. appendPartitionLengths <- function(x, other) { - if (getSerializedMode(x) != getSerializedMode(other) || + if (getSerializedMode(x) != getSerializedMode(other) || getSerializedMode(x) == "byte") { # Append the number of elements in each partition to that partition so that we can later # know the boundary of elements from x and other. # - # Note that this appending also serves the purpose of reserialization, because even if + # Note that this appending also serves the purpose of reserialization, because even if # any RDD is serialized, we need to reserialize it to make sure its partitions are encoded # as a single byte array. For example, partitions of an RDD generated from partitionBy() - # may be encoded as multiple byte arrays. + # may be encoded as multiple byte arrays. appendLength <- function(part) { len <- length(part) part[[len + 1]] <- len + 1 @@ -508,23 +547,25 @@ mergePartitions <- function(rdd, zip) { lengthOfValues <- part[[len]] lengthOfKeys <- part[[len - lengthOfValues]] stopifnot(len == lengthOfKeys + lengthOfValues) - - # For zip operation, check if corresponding partitions of both RDDs have the same number of elements. + + # For zip operation, check if corresponding partitions + # of both RDDs have the same number of elements. if (zip && lengthOfKeys != lengthOfValues) { - stop("Can only zip RDDs with same number of elements in each pair of corresponding partitions.") + stop(paste("Can only zip RDDs with same number of elements", + "in each pair of corresponding partitions.")) } - + if (lengthOfKeys > 1) { keys <- part[1 : (lengthOfKeys - 1)] } else { keys <- list() } if (lengthOfValues > 1) { - values <- part[(lengthOfKeys + 1) : (len - 1)] + values <- part[(lengthOfKeys + 1) : (len - 1)] } else { values <- list() } - + if (!zip) { return(mergeCompactLists(keys, values)) } @@ -542,6 +583,6 @@ mergePartitions <- function(rdd, zip) { part } } - + PipelinedRDD(rdd, partitionFunc) } diff --git a/R/pkg/R/zzz.R b/R/pkg/R/zzz.R index 80d796d46794..301feade65fa 100644 --- a/R/pkg/R/zzz.R +++ b/R/pkg/R/zzz.R @@ -18,4 +18,3 @@ .onLoad <- function(libname, pkgname) { sparkR.onLoad(libname, pkgname) } - diff --git a/R/pkg/inst/profile/shell.R b/R/pkg/inst/profile/shell.R index 33478d9e2999..7189f1a26093 100644 --- a/R/pkg/inst/profile/shell.R +++ b/R/pkg/inst/profile/shell.R @@ -24,10 +24,24 @@ old <- getOption("defaultPackages") options(defaultPackages = c(old, "SparkR")) - sc <- SparkR::sparkR.init(Sys.getenv("MASTER", unset = "")) + sc <- SparkR::sparkR.init() assign("sc", sc, envir=.GlobalEnv) - sqlCtx <- SparkR::sparkRSQL.init(sc) - assign("sqlCtx", sqlCtx, envir=.GlobalEnv) - cat("\n Welcome to SparkR!") - cat("\n Spark context is available as sc, SQL context is available as sqlCtx\n") + sqlContext <- SparkR::sparkRSQL.init(sc) + sparkVer <- SparkR:::callJMethod(sc, "version") + assign("sqlContext", sqlContext, envir=.GlobalEnv) + cat("\n Welcome to") + cat("\n") + cat(" ____ __", "\n") + cat(" / __/__ ___ _____/ /__", "\n") + cat(" _\\ \\/ _ \\/ _ `/ __/ '_/", "\n") + cat(" /___/ .__/\\_,_/_/ /_/\\_\\") + if (nchar(sparkVer) == 0) { + cat("\n") + } else { + cat(" version ", sparkVer, "\n") + } + cat(" /_/", "\n") + cat("\n") + + cat("\n Spark context is available as sc, SQL context is available as sqlContext\n") } diff --git a/R/pkg/inst/test_support/sparktestjar_2.10-1.0.jar b/R/pkg/inst/test_support/sparktestjar_2.10-1.0.jar new file mode 100644 index 000000000000..1d5c2af631aa Binary files /dev/null and b/R/pkg/inst/test_support/sparktestjar_2.10-1.0.jar differ diff --git a/R/pkg/inst/tests/jarTest.R b/R/pkg/inst/tests/jarTest.R new file mode 100644 index 000000000000..d68bb20950b0 --- /dev/null +++ b/R/pkg/inst/tests/jarTest.R @@ -0,0 +1,32 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +library(SparkR) + +sc <- sparkR.init() + +helloTest <- SparkR:::callJStatic("sparkR.test.hello", + "helloWorld", + "Dave") + +basicFunction <- SparkR:::callJStatic("sparkR.test.basicFunction", + "addStuff", + 2L, + 2L) + +sparkR.stop() +output <- c(helloTest, basicFunction) +writeLines(output) diff --git a/R/pkg/inst/tests/test_binaryFile.R b/R/pkg/inst/tests/test_binaryFile.R index ca4218f3819f..ccaea18ecab2 100644 --- a/R/pkg/inst/tests/test_binaryFile.R +++ b/R/pkg/inst/tests/test_binaryFile.R @@ -59,15 +59,15 @@ test_that("saveAsObjectFile()/objectFile() following RDD transformations works", wordCount <- lapply(words, function(word) { list(word, 1L) }) counts <- reduceByKey(wordCount, "+", 2L) - + saveAsObjectFile(counts, fileName2) counts <- objectFile(sc, fileName2) - + output <- collect(counts) expected <- list(list("awesome.", 1), list("Spark", 2), list("pretty.", 1), list("is", 2)) expect_equal(sortKeyValueList(output), sortKeyValueList(expected)) - + unlink(fileName1) unlink(fileName2, recursive = TRUE) }) @@ -82,9 +82,8 @@ test_that("saveAsObjectFile()/objectFile() works with multiple paths", { saveAsObjectFile(rdd2, fileName2) rdd <- objectFile(sc, c(fileName1, fileName2)) - expect_true(count(rdd) == 2) + expect_equal(count(rdd), 2) unlink(fileName1, recursive = TRUE) unlink(fileName2, recursive = TRUE) }) - diff --git a/R/pkg/inst/tests/test_binary_function.R b/R/pkg/inst/tests/test_binary_function.R index 6785a7bdae8c..3be8c65a6c1a 100644 --- a/R/pkg/inst/tests/test_binary_function.R +++ b/R/pkg/inst/tests/test_binary_function.R @@ -30,7 +30,7 @@ mockFile <- c("Spark is pretty.", "Spark is awesome.") test_that("union on two RDDs", { actual <- collect(unionRDD(rdd, rdd)) expect_equal(actual, as.list(rep(nums, 2))) - + fileName <- tempfile(pattern="spark-test", fileext=".tmp") writeLines(mockFile, fileName) @@ -38,13 +38,13 @@ test_that("union on two RDDs", { union.rdd <- unionRDD(rdd, text.rdd) actual <- collect(union.rdd) expect_equal(actual, c(as.list(nums), mockFile)) - expect_true(getSerializedMode(union.rdd) == "byte") + expect_equal(getSerializedMode(union.rdd), "byte") rdd<- map(text.rdd, function(x) {x}) union.rdd <- unionRDD(rdd, text.rdd) actual <- collect(union.rdd) expect_equal(actual, as.list(c(mockFile, mockFile))) - expect_true(getSerializedMode(union.rdd) == "byte") + expect_equal(getSerializedMode(union.rdd), "byte") unlink(fileName) }) @@ -52,14 +52,14 @@ test_that("union on two RDDs", { test_that("cogroup on two RDDs", { rdd1 <- parallelize(sc, list(list(1, 1), list(2, 4))) rdd2 <- parallelize(sc, list(list(1, 2), list(1, 3))) - cogroup.rdd <- cogroup(rdd1, rdd2, numPartitions = 2L) + cogroup.rdd <- cogroup(rdd1, rdd2, numPartitions = 2L) actual <- collect(cogroup.rdd) - expect_equal(actual, + expect_equal(actual, list(list(1, list(list(1), list(2, 3))), list(2, list(list(4), list())))) - + rdd1 <- parallelize(sc, list(list("a", 1), list("a", 4))) rdd2 <- parallelize(sc, list(list("b", 2), list("a", 3))) - cogroup.rdd <- cogroup(rdd1, rdd2, numPartitions = 2L) + cogroup.rdd <- cogroup(rdd1, rdd2, numPartitions = 2L) actual <- collect(cogroup.rdd) expected <- list(list("b", list(list(), list(2))), list("a", list(list(1, 4), list(3)))) @@ -71,31 +71,31 @@ test_that("zipPartitions() on RDDs", { rdd1 <- parallelize(sc, 1:2, 2L) # 1, 2 rdd2 <- parallelize(sc, 1:4, 2L) # 1:2, 3:4 rdd3 <- parallelize(sc, 1:6, 2L) # 1:3, 4:6 - actual <- collect(zipPartitions(rdd1, rdd2, rdd3, + actual <- collect(zipPartitions(rdd1, rdd2, rdd3, func = function(x, y, z) { list(list(x, y, z))} )) expect_equal(actual, list(list(1, c(1,2), c(1,2,3)), list(2, c(3,4), c(4,5,6)))) - + mockFile = c("Spark is pretty.", "Spark is awesome.") fileName <- tempfile(pattern="spark-test", fileext=".tmp") writeLines(mockFile, fileName) - + rdd <- textFile(sc, fileName, 1) - actual <- collect(zipPartitions(rdd, rdd, + actual <- collect(zipPartitions(rdd, rdd, func = function(x, y) { list(paste(x, y, sep = "\n")) })) expected <- list(paste(mockFile, mockFile, sep = "\n")) expect_equal(actual, expected) - + rdd1 <- parallelize(sc, 0:1, 1) - actual <- collect(zipPartitions(rdd1, rdd, + actual <- collect(zipPartitions(rdd1, rdd, func = function(x, y) { list(x + nchar(y)) })) expected <- list(0:1 + nchar(mockFile)) expect_equal(actual, expected) - + rdd <- map(rdd, function(x) { x }) - actual <- collect(zipPartitions(rdd, rdd1, + actual <- collect(zipPartitions(rdd, rdd1, func = function(x, y) { list(y + nchar(x)) })) expect_equal(actual, expected) - + unlink(fileName) }) diff --git a/R/pkg/inst/tests/test_client.R b/R/pkg/inst/tests/test_client.R new file mode 100644 index 000000000000..30b05c1a2afc --- /dev/null +++ b/R/pkg/inst/tests/test_client.R @@ -0,0 +1,32 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +context("functions in client.R") + +test_that("adding spark-testing-base as a package works", { + args <- generateSparkSubmitArgs("", "", "", "", + "holdenk:spark-testing-base:1.3.0_0.0.5") + expect_equal(gsub("[[:space:]]", "", args), + gsub("[[:space:]]", "", + "--packages holdenk:spark-testing-base:1.3.0_0.0.5")) +}) + +test_that("no package specified doesn't add packages flag", { + args <- generateSparkSubmitArgs("", "", "", "", "") + expect_equal(gsub("[[:space:]]", "", args), + "") +}) diff --git a/R/pkg/inst/tests/test_context.R b/R/pkg/inst/tests/test_context.R index e4aab37436a7..513bbc8e6205 100644 --- a/R/pkg/inst/tests/test_context.R +++ b/R/pkg/inst/tests/test_context.R @@ -48,3 +48,10 @@ test_that("rdd GC across sparkR.stop", { count(rdd3) count(rdd4) }) + +test_that("job group functions can be called", { + sc <- sparkR.init() + setJobGroup(sc, "groupId", "job description", TRUE) + cancelJobGroup(sc, "groupId") + clearJobGroup(sc) +}) diff --git a/R/pkg/inst/tests/test_includeJAR.R b/R/pkg/inst/tests/test_includeJAR.R new file mode 100644 index 000000000000..cc1faeabffe3 --- /dev/null +++ b/R/pkg/inst/tests/test_includeJAR.R @@ -0,0 +1,37 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +context("include an external JAR in SparkContext") + +runScript <- function() { + sparkHome <- Sys.getenv("SPARK_HOME") + sparkTestJarPath <- "R/lib/SparkR/test_support/sparktestjar_2.10-1.0.jar" + jarPath <- paste("--jars", shQuote(file.path(sparkHome, sparkTestJarPath))) + scriptPath <- file.path(sparkHome, "R/lib/SparkR/tests/jarTest.R") + submitPath <- file.path(sparkHome, "bin/spark-submit") + res <- system2(command = submitPath, + args = c(jarPath, scriptPath), + stdout = TRUE) + tail(res, 2) +} + +test_that("sparkJars tag in SparkContext", { + testOutput <- runScript() + helloTest <- testOutput[1] + expect_equal(helloTest, "Hello, Dave") + basicFunction <- testOutput[2] + expect_equal(basicFunction, "4") +}) diff --git a/R/pkg/inst/tests/test_parallelize_collect.R b/R/pkg/inst/tests/test_parallelize_collect.R index fff028657db3..2552127cc547 100644 --- a/R/pkg/inst/tests/test_parallelize_collect.R +++ b/R/pkg/inst/tests/test_parallelize_collect.R @@ -57,7 +57,7 @@ test_that("parallelize() on simple vectors and lists returns an RDD", { strListRDD2) for (rdd in rdds) { - expect_true(inherits(rdd, "RDD")) + expect_is(rdd, "RDD") expect_true(.hasSlot(rdd, "jrdd") && inherits(rdd@jrdd, "jobj") && isInstanceOf(rdd@jrdd, "org.apache.spark.api.java.JavaRDD")) diff --git a/R/pkg/inst/tests/test_rdd.R b/R/pkg/inst/tests/test_rdd.R index 03207353c31c..b79692873cec 100644 --- a/R/pkg/inst/tests/test_rdd.R +++ b/R/pkg/inst/tests/test_rdd.R @@ -33,9 +33,9 @@ test_that("get number of partitions in RDD", { }) test_that("first on RDD", { - expect_true(first(rdd) == 1) + expect_equal(first(rdd), 1) newrdd <- lapply(rdd, function(x) x + 1) - expect_true(first(newrdd) == 2) + expect_equal(first(newrdd), 2) }) test_that("count and length on RDD", { @@ -477,7 +477,7 @@ test_that("cartesian() on RDDs", { list(1, 1), list(1, 2), list(1, 3), list(2, 1), list(2, 2), list(2, 3), list(3, 1), list(3, 2), list(3, 3))) - + # test case where one RDD is empty emptyRdd <- parallelize(sc, list()) actual <- collect(cartesian(rdd, emptyRdd)) @@ -486,7 +486,7 @@ test_that("cartesian() on RDDs", { mockFile = c("Spark is pretty.", "Spark is awesome.") fileName <- tempfile(pattern="spark-test", fileext=".tmp") writeLines(mockFile, fileName) - + rdd <- textFile(sc, fileName) actual <- collect(cartesian(rdd, rdd)) expected <- list( @@ -495,7 +495,7 @@ test_that("cartesian() on RDDs", { list("Spark is pretty.", "Spark is pretty."), list("Spark is pretty.", "Spark is awesome.")) expect_equal(sortKeyValueList(actual), expected) - + rdd1 <- parallelize(sc, 0:1) actual <- collect(cartesian(rdd1, rdd)) expect_equal(sortKeyValueList(actual), @@ -504,11 +504,11 @@ test_that("cartesian() on RDDs", { list(0, "Spark is awesome."), list(1, "Spark is pretty."), list(1, "Spark is awesome."))) - + rdd1 <- map(rdd, function(x) { x }) actual <- collect(cartesian(rdd, rdd1)) expect_equal(sortKeyValueList(actual), expected) - + unlink(fileName) }) @@ -669,13 +669,15 @@ test_that("fullOuterJoin() on pairwise RDDs", { rdd1 <- parallelize(sc, list(list(1,2), list(1,3), list(3,3))) rdd2 <- parallelize(sc, list(list(1,1), list(2,4))) actual <- collect(fullOuterJoin(rdd1, rdd2, 2L)) - expected <- list(list(1, list(2, 1)), list(1, list(3, 1)), list(2, list(NULL, 4)), list(3, list(3, NULL))) + expected <- list(list(1, list(2, 1)), list(1, list(3, 1)), + list(2, list(NULL, 4)), list(3, list(3, NULL))) expect_equal(sortKeyValueList(actual), sortKeyValueList(expected)) rdd1 <- parallelize(sc, list(list("a",2), list("a",3), list("c", 1))) rdd2 <- parallelize(sc, list(list("a",1), list("b",4))) actual <- collect(fullOuterJoin(rdd1, rdd2, 2L)) - expected <- list(list("b", list(NULL, 4)), list("a", list(2, 1)), list("a", list(3, 1)), list("c", list(1, NULL))) + expected <- list(list("b", list(NULL, 4)), list("a", list(2, 1)), + list("a", list(3, 1)), list("c", list(1, NULL))) expect_equal(sortKeyValueList(actual), sortKeyValueList(expected)) @@ -683,13 +685,15 @@ test_that("fullOuterJoin() on pairwise RDDs", { rdd2 <- parallelize(sc, list(list(3,3), list(4,4))) actual <- collect(fullOuterJoin(rdd1, rdd2, 2L)) expect_equal(sortKeyValueList(actual), - sortKeyValueList(list(list(1, list(1, NULL)), list(2, list(2, NULL)), list(3, list(NULL, 3)), list(4, list(NULL, 4))))) + sortKeyValueList(list(list(1, list(1, NULL)), list(2, list(2, NULL)), + list(3, list(NULL, 3)), list(4, list(NULL, 4))))) rdd1 <- parallelize(sc, list(list("a",1), list("b",2))) rdd2 <- parallelize(sc, list(list("c",3), list("d",4))) actual <- collect(fullOuterJoin(rdd1, rdd2, 2L)) expect_equal(sortKeyValueList(actual), - sortKeyValueList(list(list("a", list(1, NULL)), list("b", list(2, NULL)), list("d", list(NULL, 4)), list("c", list(NULL, 3))))) + sortKeyValueList(list(list("a", list(1, NULL)), list("b", list(2, NULL)), + list("d", list(NULL, 4)), list("c", list(NULL, 3))))) }) test_that("sortByKey() on pairwise RDDs", { @@ -760,7 +764,7 @@ test_that("collectAsMap() on a pairwise RDD", { }) test_that("show()", { - rdd <- parallelize(sc, list(1:10)) + rdd <- parallelize(sc, list(1:10)) expect_output(show(rdd), "ParallelCollectionRDD\\[\\d+\\] at parallelize at RRDD\\.scala:\\d+") }) diff --git a/R/pkg/inst/tests/test_shuffle.R b/R/pkg/inst/tests/test_shuffle.R index d7dedda553c5..adf0b91d25fe 100644 --- a/R/pkg/inst/tests/test_shuffle.R +++ b/R/pkg/inst/tests/test_shuffle.R @@ -106,39 +106,39 @@ test_that("aggregateByKey", { zeroValue <- list(0, 0) seqOp <- function(x, y) { list(x[[1]] + y, x[[2]] + 1) } combOp <- function(x, y) { list(x[[1]] + y[[1]], x[[2]] + y[[2]]) } - aggregatedRDD <- aggregateByKey(rdd, zeroValue, seqOp, combOp, 2L) - + aggregatedRDD <- aggregateByKey(rdd, zeroValue, seqOp, combOp, 2L) + actual <- collect(aggregatedRDD) - + expected <- list(list(1, list(3, 2)), list(2, list(7, 2))) expect_equal(sortKeyValueList(actual), sortKeyValueList(expected)) # test aggregateByKey for string keys rdd <- parallelize(sc, list(list("a", 1), list("a", 2), list("b", 3), list("b", 4))) - + zeroValue <- list(0, 0) seqOp <- function(x, y) { list(x[[1]] + y, x[[2]] + 1) } combOp <- function(x, y) { list(x[[1]] + y[[1]], x[[2]] + y[[2]]) } - aggregatedRDD <- aggregateByKey(rdd, zeroValue, seqOp, combOp, 2L) + aggregatedRDD <- aggregateByKey(rdd, zeroValue, seqOp, combOp, 2L) actual <- collect(aggregatedRDD) - + expected <- list(list("a", list(3, 2)), list("b", list(7, 2))) expect_equal(sortKeyValueList(actual), sortKeyValueList(expected)) }) -test_that("foldByKey", { +test_that("foldByKey", { # test foldByKey for int keys folded <- foldByKey(intRdd, 0, "+", 2L) - + actual <- collect(folded) - + expected <- list(list(2L, 101), list(1L, 199)) expect_equal(sortKeyValueList(actual), sortKeyValueList(expected)) # test foldByKey for double keys folded <- foldByKey(doubleRdd, 0, "+", 2L) - + actual <- collect(folded) expected <- list(list(1.5, 199), list(2.5, 101)) @@ -146,15 +146,15 @@ test_that("foldByKey", { # test foldByKey for string keys stringKeyPairs <- list(list("a", -1), list("b", 100), list("b", 1), list("a", 200)) - + stringKeyRDD <- parallelize(sc, stringKeyPairs) folded <- foldByKey(stringKeyRDD, 0, "+", 2L) - + actual <- collect(folded) - + expected <- list(list("b", 101), list("a", 199)) expect_equal(sortKeyValueList(actual), sortKeyValueList(expected)) - + # test foldByKey for empty pair RDD rdd <- parallelize(sc, list()) folded <- foldByKey(rdd, 0, "+", 2L) diff --git a/R/pkg/inst/tests/test_sparkSQL.R b/R/pkg/inst/tests/test_sparkSQL.R index 3e5658eb5b24..b0ea38854304 100644 --- a/R/pkg/inst/tests/test_sparkSQL.R +++ b/R/pkg/inst/tests/test_sparkSQL.R @@ -19,11 +19,19 @@ library(testthat) context("SparkSQL functions") +# Utility function for easily checking the values of a StructField +checkStructField <- function(actual, expectedName, expectedType, expectedNullable) { + expect_equal(class(actual), "structField") + expect_equal(actual$name(), expectedName) + expect_equal(actual$dataType.toString(), expectedType) + expect_equal(actual$nullable(), expectedNullable) +} + # Tests for SparkSQL functions in SparkR sc <- sparkR.init() -sqlCtx <- sparkRSQL.init(sc) +sqlContext <- sparkRSQL.init(sc) mockLines <- c("{\"name\":\"Michael\"}", "{\"name\":\"Andy\", \"age\":30}", @@ -32,6 +40,15 @@ jsonPath <- tempfile(pattern="sparkr-test", fileext=".tmp") parquetPath <- tempfile(pattern="sparkr-test", fileext=".parquet") writeLines(mockLines, jsonPath) +# For test nafunctions, like dropna(), fillna(),... +mockLinesNa <- c("{\"name\":\"Bob\",\"age\":16,\"height\":176.5}", + "{\"name\":\"Alice\",\"age\":null,\"height\":164.3}", + "{\"name\":\"David\",\"age\":60,\"height\":null}", + "{\"name\":\"Amy\",\"age\":null,\"height\":null}", + "{\"name\":null,\"age\":null,\"height\":null}") +jsonPathNa <- tempfile(pattern="sparkr-test", fileext=".tmp") +writeLines(mockLinesNa, jsonPathNa) + test_that("infer types", { expect_equal(infer_type(1L), "integer") expect_equal(infer_type(1.0), "double") @@ -43,9 +60,10 @@ test_that("infer types", { list(type = 'array', elementType = "integer", containsNull = TRUE)) expect_equal(infer_type(list(1L, 2L)), list(type = 'array', elementType = "integer", containsNull = TRUE)) - expect_equal(infer_type(list(a = 1L, b = "2")), - structType(structField(x = "a", type = "integer", nullable = TRUE), - structField(x = "b", type = "string", nullable = TRUE))) + testStruct <- infer_type(list(a = 1L, b = "2")) + expect_equal(class(testStruct), "structType") + checkStructField(testStruct$fields()[[1]], "a", "IntegerType", TRUE) + checkStructField(testStruct$fields()[[2]], "b", "StringType", TRUE) e <- new.env() assign("a", 1L, envir = e) expect_equal(infer_type(e), @@ -55,83 +73,120 @@ test_that("infer types", { test_that("structType and structField", { testField <- structField("a", "string") - expect_true(inherits(testField, "structField")) - expect_true(testField$name() == "a") + expect_is(testField, "structField") + expect_equal(testField$name(), "a") expect_true(testField$nullable()) - + testSchema <- structType(testField, structField("b", "integer")) - expect_true(inherits(testSchema, "structType")) - expect_true(inherits(testSchema$fields()[[2]], "structField")) - expect_true(testSchema$fields()[[1]]$dataType.toString() == "StringType") + expect_is(testSchema, "structType") + expect_is(testSchema$fields()[[2]], "structField") + expect_equal(testSchema$fields()[[1]]$dataType.toString(), "StringType") }) test_that("create DataFrame from RDD", { rdd <- lapply(parallelize(sc, 1:10), function(x) { list(x, as.character(x)) }) - df <- createDataFrame(sqlCtx, rdd, list("a", "b")) - expect_true(inherits(df, "DataFrame")) - expect_true(count(df) == 10) + df <- createDataFrame(sqlContext, rdd, list("a", "b")) + expect_is(df, "DataFrame") + expect_equal(count(df), 10) expect_equal(columns(df), c("a", "b")) expect_equal(dtypes(df), list(c("a", "int"), c("b", "string"))) - df <- createDataFrame(sqlCtx, rdd) - expect_true(inherits(df, "DataFrame")) + df <- createDataFrame(sqlContext, rdd) + expect_is(df, "DataFrame") expect_equal(columns(df), c("_1", "_2")) schema <- structType(structField(x = "a", type = "integer", nullable = TRUE), structField(x = "b", type = "string", nullable = TRUE)) - df <- createDataFrame(sqlCtx, rdd, schema) - expect_true(inherits(df, "DataFrame")) + df <- createDataFrame(sqlContext, rdd, schema) + expect_is(df, "DataFrame") expect_equal(columns(df), c("a", "b")) expect_equal(dtypes(df), list(c("a", "int"), c("b", "string"))) rdd <- lapply(parallelize(sc, 1:10), function(x) { list(a = x, b = as.character(x)) }) - df <- createDataFrame(sqlCtx, rdd) - expect_true(inherits(df, "DataFrame")) - expect_true(count(df) == 10) + df <- createDataFrame(sqlContext, rdd) + expect_is(df, "DataFrame") + expect_equal(count(df), 10) expect_equal(columns(df), c("a", "b")) expect_equal(dtypes(df), list(c("a", "int"), c("b", "string"))) }) +test_that("convert NAs to null type in DataFrames", { + rdd <- parallelize(sc, list(list(1L, 2L), list(NA, 4L))) + df <- createDataFrame(sqlContext, rdd, list("a", "b")) + expect_true(is.na(collect(df)[2, "a"])) + expect_equal(collect(df)[2, "b"], 4L) + + l <- data.frame(x = 1L, y = c(1L, NA_integer_, 3L)) + df <- createDataFrame(sqlContext, l) + expect_equal(collect(df)[2, "x"], 1L) + expect_true(is.na(collect(df)[2, "y"])) + + rdd <- parallelize(sc, list(list(1, 2), list(NA, 4))) + df <- createDataFrame(sqlContext, rdd, list("a", "b")) + expect_true(is.na(collect(df)[2, "a"])) + expect_equal(collect(df)[2, "b"], 4) + + l <- data.frame(x = 1, y = c(1, NA_real_, 3)) + df <- createDataFrame(sqlContext, l) + expect_equal(collect(df)[2, "x"], 1) + expect_true(is.na(collect(df)[2, "y"])) + + l <- list("a", "b", NA, "d") + df <- createDataFrame(sqlContext, l) + expect_true(is.na(collect(df)[3, "_1"])) + expect_equal(collect(df)[4, "_1"], "d") + + l <- list("a", "b", NA_character_, "d") + df <- createDataFrame(sqlContext, l) + expect_true(is.na(collect(df)[3, "_1"])) + expect_equal(collect(df)[4, "_1"], "d") + + l <- list(TRUE, FALSE, NA, TRUE) + df <- createDataFrame(sqlContext, l) + expect_true(is.na(collect(df)[3, "_1"])) + expect_equal(collect(df)[4, "_1"], TRUE) +}) + test_that("toDF", { rdd <- lapply(parallelize(sc, 1:10), function(x) { list(x, as.character(x)) }) df <- toDF(rdd, list("a", "b")) - expect_true(inherits(df, "DataFrame")) - expect_true(count(df) == 10) + expect_is(df, "DataFrame") + expect_equal(count(df), 10) expect_equal(columns(df), c("a", "b")) expect_equal(dtypes(df), list(c("a", "int"), c("b", "string"))) df <- toDF(rdd) - expect_true(inherits(df, "DataFrame")) + expect_is(df, "DataFrame") expect_equal(columns(df), c("_1", "_2")) schema <- structType(structField(x = "a", type = "integer", nullable = TRUE), structField(x = "b", type = "string", nullable = TRUE)) df <- toDF(rdd, schema) - expect_true(inherits(df, "DataFrame")) + expect_is(df, "DataFrame") expect_equal(columns(df), c("a", "b")) expect_equal(dtypes(df), list(c("a", "int"), c("b", "string"))) rdd <- lapply(parallelize(sc, 1:10), function(x) { list(a = x, b = as.character(x)) }) df <- toDF(rdd) - expect_true(inherits(df, "DataFrame")) - expect_true(count(df) == 10) + expect_is(df, "DataFrame") + expect_equal(count(df), 10) expect_equal(columns(df), c("a", "b")) expect_equal(dtypes(df), list(c("a", "int"), c("b", "string"))) }) test_that("create DataFrame from list or data.frame", { l <- list(list(1, 2), list(3, 4)) - df <- createDataFrame(sqlCtx, l, c("a", "b")) + df <- createDataFrame(sqlContext, l, c("a", "b")) expect_equal(columns(df), c("a", "b")) l <- list(list(a=1, b=2), list(a=3, b=4)) - df <- createDataFrame(sqlCtx, l) + df <- createDataFrame(sqlContext, l) expect_equal(columns(df), c("a", "b")) a <- 1:3 b <- c("a", "b", "c") ldf <- data.frame(a, b) - df <- createDataFrame(sqlCtx, ldf) + df <- createDataFrame(sqlContext, ldf) expect_equal(columns(df), c("a", "b")) expect_equal(dtypes(df), list(c("a", "int"), c("b", "string"))) expect_equal(count(df), 3) @@ -142,7 +197,7 @@ test_that("create DataFrame from list or data.frame", { test_that("create DataFrame with different data types", { l <- list(a = 1L, b = 2, c = TRUE, d = "ss", e = as.Date("2012-12-13"), f = as.POSIXct("2015-03-15 12:13:14.056")) - df <- createDataFrame(sqlCtx, list(l)) + df <- createDataFrame(sqlContext, list(l)) expect_equal(dtypes(df), list(c("a", "int"), c("b", "double"), c("c", "boolean"), c("d", "string"), c("e", "date"), c("f", "timestamp"))) expect_equal(count(df), 1) @@ -154,7 +209,7 @@ test_that("create DataFrame with different data types", { # e <- new.env() # assign("n", 3L, envir = e) # l <- list(1:10, list("a", "b"), e, list(a="aa", b=3L)) -# df <- createDataFrame(sqlCtx, list(l), c("a", "b", "c", "d")) +# df <- createDataFrame(sqlContext, list(l), c("a", "b", "c", "d")) # expect_equal(dtypes(df), list(c("a", "array"), c("b", "array"), # c("c", "map"), c("d", "struct"))) # expect_equal(count(df), 1) @@ -163,102 +218,102 @@ test_that("create DataFrame with different data types", { #}) test_that("jsonFile() on a local file returns a DataFrame", { - df <- jsonFile(sqlCtx, jsonPath) - expect_true(inherits(df, "DataFrame")) - expect_true(count(df) == 3) + df <- jsonFile(sqlContext, jsonPath) + expect_is(df, "DataFrame") + expect_equal(count(df), 3) }) test_that("jsonRDD() on a RDD with json string", { rdd <- parallelize(sc, mockLines) - expect_true(count(rdd) == 3) - df <- jsonRDD(sqlCtx, rdd) - expect_true(inherits(df, "DataFrame")) - expect_true(count(df) == 3) + expect_equal(count(rdd), 3) + df <- jsonRDD(sqlContext, rdd) + expect_is(df, "DataFrame") + expect_equal(count(df), 3) rdd2 <- flatMap(rdd, function(x) c(x, x)) - df <- jsonRDD(sqlCtx, rdd2) - expect_true(inherits(df, "DataFrame")) - expect_true(count(df) == 6) + df <- jsonRDD(sqlContext, rdd2) + expect_is(df, "DataFrame") + expect_equal(count(df), 6) }) test_that("test cache, uncache and clearCache", { - df <- jsonFile(sqlCtx, jsonPath) + df <- jsonFile(sqlContext, jsonPath) registerTempTable(df, "table1") - cacheTable(sqlCtx, "table1") - uncacheTable(sqlCtx, "table1") - clearCache(sqlCtx) - dropTempTable(sqlCtx, "table1") + cacheTable(sqlContext, "table1") + uncacheTable(sqlContext, "table1") + clearCache(sqlContext) + dropTempTable(sqlContext, "table1") }) test_that("test tableNames and tables", { - df <- jsonFile(sqlCtx, jsonPath) + df <- jsonFile(sqlContext, jsonPath) registerTempTable(df, "table1") - expect_true(length(tableNames(sqlCtx)) == 1) - df <- tables(sqlCtx) - expect_true(count(df) == 1) - dropTempTable(sqlCtx, "table1") + expect_equal(length(tableNames(sqlContext)), 1) + df <- tables(sqlContext) + expect_equal(count(df), 1) + dropTempTable(sqlContext, "table1") }) test_that("registerTempTable() results in a queryable table and sql() results in a new DataFrame", { - df <- jsonFile(sqlCtx, jsonPath) + df <- jsonFile(sqlContext, jsonPath) registerTempTable(df, "table1") - newdf <- sql(sqlCtx, "SELECT * FROM table1 where name = 'Michael'") - expect_true(inherits(newdf, "DataFrame")) - expect_true(count(newdf) == 1) - dropTempTable(sqlCtx, "table1") + newdf <- sql(sqlContext, "SELECT * FROM table1 where name = 'Michael'") + expect_is(newdf, "DataFrame") + expect_equal(count(newdf), 1) + dropTempTable(sqlContext, "table1") }) test_that("insertInto() on a registered table", { - df <- read.df(sqlCtx, jsonPath, "json") + df <- read.df(sqlContext, jsonPath, "json") write.df(df, parquetPath, "parquet", "overwrite") - dfParquet <- read.df(sqlCtx, parquetPath, "parquet") + dfParquet <- read.df(sqlContext, parquetPath, "parquet") lines <- c("{\"name\":\"Bob\", \"age\":24}", "{\"name\":\"James\", \"age\":35}") jsonPath2 <- tempfile(pattern="jsonPath2", fileext=".tmp") parquetPath2 <- tempfile(pattern = "parquetPath2", fileext = ".parquet") writeLines(lines, jsonPath2) - df2 <- read.df(sqlCtx, jsonPath2, "json") + df2 <- read.df(sqlContext, jsonPath2, "json") write.df(df2, parquetPath2, "parquet", "overwrite") - dfParquet2 <- read.df(sqlCtx, parquetPath2, "parquet") + dfParquet2 <- read.df(sqlContext, parquetPath2, "parquet") registerTempTable(dfParquet, "table1") insertInto(dfParquet2, "table1") - expect_true(count(sql(sqlCtx, "select * from table1")) == 5) - expect_true(first(sql(sqlCtx, "select * from table1 order by age"))$name == "Michael") - dropTempTable(sqlCtx, "table1") + expect_equal(count(sql(sqlContext, "select * from table1")), 5) + expect_equal(first(sql(sqlContext, "select * from table1 order by age"))$name, "Michael") + dropTempTable(sqlContext, "table1") registerTempTable(dfParquet, "table1") insertInto(dfParquet2, "table1", overwrite = TRUE) - expect_true(count(sql(sqlCtx, "select * from table1")) == 2) - expect_true(first(sql(sqlCtx, "select * from table1 order by age"))$name == "Bob") - dropTempTable(sqlCtx, "table1") + expect_equal(count(sql(sqlContext, "select * from table1")), 2) + expect_equal(first(sql(sqlContext, "select * from table1 order by age"))$name, "Bob") + dropTempTable(sqlContext, "table1") }) test_that("table() returns a new DataFrame", { - df <- jsonFile(sqlCtx, jsonPath) + df <- jsonFile(sqlContext, jsonPath) registerTempTable(df, "table1") - tabledf <- table(sqlCtx, "table1") - expect_true(inherits(tabledf, "DataFrame")) - expect_true(count(tabledf) == 3) - dropTempTable(sqlCtx, "table1") + tabledf <- table(sqlContext, "table1") + expect_is(tabledf, "DataFrame") + expect_equal(count(tabledf), 3) + dropTempTable(sqlContext, "table1") }) test_that("toRDD() returns an RRDD", { - df <- jsonFile(sqlCtx, jsonPath) + df <- jsonFile(sqlContext, jsonPath) testRDD <- toRDD(df) - expect_true(inherits(testRDD, "RDD")) - expect_true(count(testRDD) == 3) + expect_is(testRDD, "RDD") + expect_equal(count(testRDD), 3) }) test_that("union on two RDDs created from DataFrames returns an RRDD", { - df <- jsonFile(sqlCtx, jsonPath) + df <- jsonFile(sqlContext, jsonPath) RDD1 <- toRDD(df) RDD2 <- toRDD(df) unioned <- unionRDD(RDD1, RDD2) - expect_true(inherits(unioned, "RDD")) - expect_true(SparkR:::getSerializedMode(unioned) == "byte") - expect_true(collect(unioned)[[2]]$name == "Andy") + expect_is(unioned, "RDD") + expect_equal(SparkR:::getSerializedMode(unioned), "byte") + expect_equal(collect(unioned)[[2]]$name, "Andy") }) test_that("union on mixed serialization types correctly returns a byte RRDD", { @@ -274,70 +329,70 @@ test_that("union on mixed serialization types correctly returns a byte RRDD", { writeLines(textLines, textPath) textRDD <- textFile(sc, textPath) - df <- jsonFile(sqlCtx, jsonPath) + df <- jsonFile(sqlContext, jsonPath) dfRDD <- toRDD(df) unionByte <- unionRDD(rdd, dfRDD) - expect_true(inherits(unionByte, "RDD")) - expect_true(SparkR:::getSerializedMode(unionByte) == "byte") - expect_true(collect(unionByte)[[1]] == 1) - expect_true(collect(unionByte)[[12]]$name == "Andy") + expect_is(unionByte, "RDD") + expect_equal(SparkR:::getSerializedMode(unionByte), "byte") + expect_equal(collect(unionByte)[[1]], 1) + expect_equal(collect(unionByte)[[12]]$name, "Andy") unionString <- unionRDD(textRDD, dfRDD) - expect_true(inherits(unionString, "RDD")) - expect_true(SparkR:::getSerializedMode(unionString) == "byte") - expect_true(collect(unionString)[[1]] == "Michael") - expect_true(collect(unionString)[[5]]$name == "Andy") + expect_is(unionString, "RDD") + expect_equal(SparkR:::getSerializedMode(unionString), "byte") + expect_equal(collect(unionString)[[1]], "Michael") + expect_equal(collect(unionString)[[5]]$name, "Andy") }) test_that("objectFile() works with row serialization", { objectPath <- tempfile(pattern="spark-test", fileext=".tmp") - df <- jsonFile(sqlCtx, jsonPath) + df <- jsonFile(sqlContext, jsonPath) dfRDD <- toRDD(df) saveAsObjectFile(coalesce(dfRDD, 1L), objectPath) objectIn <- objectFile(sc, objectPath) - expect_true(inherits(objectIn, "RDD")) + expect_is(objectIn, "RDD") expect_equal(SparkR:::getSerializedMode(objectIn), "byte") expect_equal(collect(objectIn)[[2]]$age, 30) }) test_that("lapply() on a DataFrame returns an RDD with the correct columns", { - df <- jsonFile(sqlCtx, jsonPath) + df <- jsonFile(sqlContext, jsonPath) testRDD <- lapply(df, function(row) { row$newCol <- row$age + 5 row }) - expect_true(inherits(testRDD, "RDD")) + expect_is(testRDD, "RDD") collected <- collect(testRDD) - expect_true(collected[[1]]$name == "Michael") - expect_true(collected[[2]]$newCol == "35") + expect_equal(collected[[1]]$name, "Michael") + expect_equal(collected[[2]]$newCol, 35) }) test_that("collect() returns a data.frame", { - df <- jsonFile(sqlCtx, jsonPath) + df <- jsonFile(sqlContext, jsonPath) rdf <- collect(df) expect_true(is.data.frame(rdf)) - expect_true(names(rdf)[1] == "age") - expect_true(nrow(rdf) == 3) - expect_true(ncol(rdf) == 2) + expect_equal(names(rdf)[1], "age") + expect_equal(nrow(rdf), 3) + expect_equal(ncol(rdf), 2) }) test_that("limit() returns DataFrame with the correct number of rows", { - df <- jsonFile(sqlCtx, jsonPath) + df <- jsonFile(sqlContext, jsonPath) dfLimited <- limit(df, 2) - expect_true(inherits(dfLimited, "DataFrame")) - expect_true(count(dfLimited) == 2) + expect_is(dfLimited, "DataFrame") + expect_equal(count(dfLimited), 2) }) test_that("collect() and take() on a DataFrame return the same number of rows and columns", { - df <- jsonFile(sqlCtx, jsonPath) - expect_true(nrow(collect(df)) == nrow(take(df, 10))) - expect_true(ncol(collect(df)) == ncol(take(df, 10))) + df <- jsonFile(sqlContext, jsonPath) + expect_equal(nrow(collect(df)), nrow(take(df, 10))) + expect_equal(ncol(collect(df)), ncol(take(df, 10))) }) -test_that("multiple pipeline transformations starting with a DataFrame result in an RDD with the correct values", { - df <- jsonFile(sqlCtx, jsonPath) +test_that("multiple pipeline transformations result in an RDD with the correct values", { + df <- jsonFile(sqlContext, jsonPath) first <- lapply(df, function(row) { row$age <- row$age + 5 row @@ -346,15 +401,15 @@ test_that("multiple pipeline transformations starting with a DataFrame result in row$testCol <- if (row$age == 35 && !is.na(row$age)) TRUE else FALSE row }) - expect_true(inherits(second, "RDD")) - expect_true(count(second) == 3) - expect_true(collect(second)[[2]]$age == 35) + expect_is(second, "RDD") + expect_equal(count(second), 3) + expect_equal(collect(second)[[2]]$age, 35) expect_true(collect(second)[[2]]$testCol) expect_false(collect(second)[[3]]$testCol) }) test_that("cache(), persist(), and unpersist() on a DataFrame", { - df <- jsonFile(sqlCtx, jsonPath) + df <- jsonFile(sqlContext, jsonPath) expect_false(df@env$isCached) cache(df) expect_true(df@env$isCached) @@ -373,38 +428,38 @@ test_that("cache(), persist(), and unpersist() on a DataFrame", { }) test_that("schema(), dtypes(), columns(), names() return the correct values/format", { - df <- jsonFile(sqlCtx, jsonPath) + df <- jsonFile(sqlContext, jsonPath) testSchema <- schema(df) - expect_true(length(testSchema$fields()) == 2) - expect_true(testSchema$fields()[[1]]$dataType.toString() == "LongType") - expect_true(testSchema$fields()[[2]]$dataType.simpleString() == "string") - expect_true(testSchema$fields()[[1]]$name() == "age") + expect_equal(length(testSchema$fields()), 2) + expect_equal(testSchema$fields()[[1]]$dataType.toString(), "LongType") + expect_equal(testSchema$fields()[[2]]$dataType.simpleString(), "string") + expect_equal(testSchema$fields()[[1]]$name(), "age") testTypes <- dtypes(df) - expect_true(length(testTypes[[1]]) == 2) - expect_true(testTypes[[1]][1] == "age") + expect_equal(length(testTypes[[1]]), 2) + expect_equal(testTypes[[1]][1], "age") testCols <- columns(df) - expect_true(length(testCols) == 2) - expect_true(testCols[2] == "name") + expect_equal(length(testCols), 2) + expect_equal(testCols[2], "name") testNames <- names(df) - expect_true(length(testNames) == 2) - expect_true(testNames[2] == "name") + expect_equal(length(testNames), 2) + expect_equal(testNames[2], "name") }) test_that("head() and first() return the correct data", { - df <- jsonFile(sqlCtx, jsonPath) + df <- jsonFile(sqlContext, jsonPath) testHead <- head(df) - expect_true(nrow(testHead) == 3) - expect_true(ncol(testHead) == 2) + expect_equal(nrow(testHead), 3) + expect_equal(ncol(testHead), 2) testHead2 <- head(df, 2) - expect_true(nrow(testHead2) == 2) - expect_true(ncol(testHead2) == 2) + expect_equal(nrow(testHead2), 2) + expect_equal(ncol(testHead2), 2) testFirst <- first(df) - expect_true(nrow(testFirst) == 1) + expect_equal(nrow(testFirst), 1) }) test_that("distinct() on DataFrames", { @@ -415,17 +470,17 @@ test_that("distinct() on DataFrames", { jsonPathWithDup <- tempfile(pattern="sparkr-test", fileext=".tmp") writeLines(lines, jsonPathWithDup) - df <- jsonFile(sqlCtx, jsonPathWithDup) + df <- jsonFile(sqlContext, jsonPathWithDup) uniques <- distinct(df) - expect_true(inherits(uniques, "DataFrame")) - expect_true(count(uniques) == 3) + expect_is(uniques, "DataFrame") + expect_equal(count(uniques), 3) }) test_that("sample on a DataFrame", { - df <- jsonFile(sqlCtx, jsonPath) + df <- jsonFile(sqlContext, jsonPath) sampled <- sample(df, FALSE, 1.0) expect_equal(nrow(collect(sampled)), count(df)) - expect_true(inherits(sampled, "DataFrame")) + expect_is(sampled, "DataFrame") sampled2 <- sample(df, FALSE, 0.1) expect_true(count(sampled2) < 3) @@ -435,16 +490,16 @@ test_that("sample on a DataFrame", { }) test_that("select operators", { - df <- select(jsonFile(sqlCtx, jsonPath), "name", "age") - expect_true(inherits(df$name, "Column")) - expect_true(inherits(df[[2]], "Column")) - expect_true(inherits(df[["age"]], "Column")) + df <- select(jsonFile(sqlContext, jsonPath), "name", "age") + expect_is(df$name, "Column") + expect_is(df[[2]], "Column") + expect_is(df[["age"]], "Column") - expect_true(inherits(df[,1], "DataFrame")) + expect_is(df[,1], "DataFrame") expect_equal(columns(df[,1]), c("name")) expect_equal(columns(df[,"age"]), c("age")) df2 <- df[,c("age", "name")] - expect_true(inherits(df2, "DataFrame")) + expect_is(df2, "DataFrame") expect_equal(columns(df2), c("age", "name")) df$age2 <- df$age @@ -461,48 +516,61 @@ test_that("select operators", { }) test_that("select with column", { - df <- jsonFile(sqlCtx, jsonPath) + df <- jsonFile(sqlContext, jsonPath) df1 <- select(df, "name") - expect_true(columns(df1) == c("name")) - expect_true(count(df1) == 3) + expect_equal(columns(df1), c("name")) + expect_equal(count(df1), 3) df2 <- select(df, df$age) - expect_true(columns(df2) == c("age")) - expect_true(count(df2) == 3) + expect_equal(columns(df2), c("age")) + expect_equal(count(df2), 3) }) test_that("selectExpr() on a DataFrame", { - df <- jsonFile(sqlCtx, jsonPath) + df <- jsonFile(sqlContext, jsonPath) selected <- selectExpr(df, "age * 2") - expect_true(names(selected) == "(age * 2)") + expect_equal(names(selected), "(age * 2)") expect_equal(collect(selected), collect(select(df, df$age * 2L))) selected2 <- selectExpr(df, "name as newName", "abs(age) as age") expect_equal(names(selected2), c("newName", "age")) - expect_true(count(selected2) == 3) + expect_equal(count(selected2), 3) }) test_that("column calculation", { - df <- jsonFile(sqlCtx, jsonPath) + df <- jsonFile(sqlContext, jsonPath) d <- collect(select(df, alias(df$age + 1, "age2"))) - expect_true(names(d) == c("age2")) + expect_equal(names(d), c("age2")) df2 <- select(df, lower(df$name), abs(df$age)) - expect_true(inherits(df2, "DataFrame")) - expect_true(count(df2) == 3) + expect_is(df2, "DataFrame") + expect_equal(count(df2), 3) }) test_that("read.df() from json file", { - df <- read.df(sqlCtx, jsonPath, "json") - expect_true(inherits(df, "DataFrame")) - expect_true(count(df) == 3) + df <- read.df(sqlContext, jsonPath, "json") + expect_is(df, "DataFrame") + expect_equal(count(df), 3) + + # Check if we can apply a user defined schema + schema <- structType(structField("name", type = "string"), + structField("age", type = "double")) + + df1 <- read.df(sqlContext, jsonPath, "json", schema) + expect_is(df1, "DataFrame") + expect_equal(dtypes(df1), list(c("name", "string"), c("age", "double"))) + + # Run the same with loadDF + df2 <- loadDF(sqlContext, jsonPath, "json", schema) + expect_is(df2, "DataFrame") + expect_equal(dtypes(df2), list(c("name", "string"), c("age", "double"))) }) test_that("write.df() as parquet file", { - df <- read.df(sqlCtx, jsonPath, "json") + df <- read.df(sqlContext, jsonPath, "json") write.df(df, parquetPath, "parquet", mode="overwrite") - df2 <- read.df(sqlCtx, parquetPath, "parquet") - expect_true(inherits(df2, "DataFrame")) - expect_true(count(df2) == 3) + df2 <- read.df(sqlContext, parquetPath, "parquet") + expect_is(df2, "DataFrame") + expect_equal(count(df2), 3) }) test_that("test HiveContext", { @@ -512,17 +580,17 @@ test_that("test HiveContext", { skip("Hive is not build with SparkSQL, skipped") }) df <- createExternalTable(hiveCtx, "json", jsonPath, "json") - expect_true(inherits(df, "DataFrame")) - expect_true(count(df) == 3) + expect_is(df, "DataFrame") + expect_equal(count(df), 3) df2 <- sql(hiveCtx, "select * from json") - expect_true(inherits(df2, "DataFrame")) - expect_true(count(df2) == 3) + expect_is(df2, "DataFrame") + expect_equal(count(df2), 3) jsonPath2 <- tempfile(pattern="sparkr-test", fileext=".tmp") saveAsTable(df, "json", "json", "append", path = jsonPath2) df3 <- sql(hiveCtx, "select * from json") - expect_true(inherits(df3, "DataFrame")) - expect_true(count(df3) == 6) + expect_is(df3, "DataFrame") + expect_equal(count(df3), 6) }) test_that("column operators", { @@ -539,7 +607,7 @@ test_that("column functions", { c3 <- lower(c) + upper(c) + first(c) + last(c) c4 <- approxCountDistinct(c) + countDistinct(c) + cast(c, "string") c5 <- n(c) + n_distinct(c) - c5 <- acos(c) + asin(c) + atan(c) + cbrt(c) + c5 <- acos(c) + asin(c) + atan(c) + cbrt(c) c6 <- ceiling(c) + cos(c) + cosh(c) + exp(c) + expm1(c) c7 <- floor(c) + log(c) + log10(c) + log1p(c) + rint(c) c8 <- sign(c) + sin(c) + sinh(c) + tan(c) + tanh(c) @@ -553,7 +621,7 @@ test_that("column binary mathfunctions", { "{\"a\":4, \"b\":8}") jsonPathWithDup <- tempfile(pattern="sparkr-test", fileext=".tmp") writeLines(lines, jsonPathWithDup) - df <- jsonFile(sqlCtx, jsonPathWithDup) + df <- jsonFile(sqlContext, jsonPathWithDup) expect_equal(collect(select(df, atan2(df$a, df$b)))[1, "ATAN2(a, b)"], atan2(1, 5)) expect_equal(collect(select(df, atan2(df$a, df$b)))[2, "ATAN2(a, b)"], atan2(2, 6)) expect_equal(collect(select(df, atan2(df$a, df$b)))[3, "ATAN2(a, b)"], atan2(3, 7)) @@ -565,7 +633,7 @@ test_that("column binary mathfunctions", { }) test_that("string operators", { - df <- jsonFile(sqlCtx, jsonPath) + df <- jsonFile(sqlContext, jsonPath) expect_equal(count(where(df, like(df$name, "A%"))), 1) expect_equal(count(where(df, startsWith(df$name, "A"))), 1) expect_equal(first(select(df, substr(df$name, 1, 2)))[[1]], "Mi") @@ -573,71 +641,81 @@ test_that("string operators", { }) test_that("group by", { - df <- jsonFile(sqlCtx, jsonPath) + df <- jsonFile(sqlContext, jsonPath) df1 <- agg(df, name = "max", age = "sum") - expect_true(1 == count(df1)) + expect_equal(1, count(df1)) df1 <- agg(df, age2 = max(df$age)) - expect_true(1 == count(df1)) + expect_equal(1, count(df1)) expect_equal(columns(df1), c("age2")) gd <- groupBy(df, "name") - expect_true(inherits(gd, "GroupedData")) + expect_is(gd, "GroupedData") df2 <- count(gd) - expect_true(inherits(df2, "DataFrame")) - expect_true(3 == count(df2)) + expect_is(df2, "DataFrame") + expect_equal(3, count(df2)) # Also test group_by, summarize, mean gd1 <- group_by(df, "name") - expect_true(inherits(gd1, "GroupedData")) + expect_is(gd1, "GroupedData") df_summarized <- summarize(gd, mean_age = mean(df$age)) - expect_true(inherits(df_summarized, "DataFrame")) - expect_true(3 == count(df_summarized)) + expect_is(df_summarized, "DataFrame") + expect_equal(3, count(df_summarized)) df3 <- agg(gd, age = "sum") - expect_true(inherits(df3, "DataFrame")) - expect_true(3 == count(df3)) + expect_is(df3, "DataFrame") + expect_equal(3, count(df3)) df3 <- agg(gd, age = sum(df$age)) - expect_true(inherits(df3, "DataFrame")) - expect_true(3 == count(df3)) + expect_is(df3, "DataFrame") + expect_equal(3, count(df3)) expect_equal(columns(df3), c("name", "age")) df4 <- sum(gd, "age") - expect_true(inherits(df4, "DataFrame")) - expect_true(3 == count(df4)) - expect_true(3 == count(mean(gd, "age"))) - expect_true(3 == count(max(gd, "age"))) + expect_is(df4, "DataFrame") + expect_equal(3, count(df4)) + expect_equal(3, count(mean(gd, "age"))) + expect_equal(3, count(max(gd, "age"))) }) test_that("arrange() and orderBy() on a DataFrame", { - df <- jsonFile(sqlCtx, jsonPath) + df <- jsonFile(sqlContext, jsonPath) sorted <- arrange(df, df$age) - expect_true(collect(sorted)[1,2] == "Michael") + expect_equal(collect(sorted)[1,2], "Michael") sorted2 <- arrange(df, "name") - expect_true(collect(sorted2)[2,"age"] == 19) + expect_equal(collect(sorted2)[2,"age"], 19) sorted3 <- orderBy(df, asc(df$age)) expect_true(is.na(first(sorted3)$age)) - expect_true(collect(sorted3)[2, "age"] == 19) + expect_equal(collect(sorted3)[2, "age"], 19) sorted4 <- orderBy(df, desc(df$name)) - expect_true(first(sorted4)$name == "Michael") - expect_true(collect(sorted4)[3,"name"] == "Andy") + expect_equal(first(sorted4)$name, "Michael") + expect_equal(collect(sorted4)[3,"name"], "Andy") }) test_that("filter() on a DataFrame", { - df <- jsonFile(sqlCtx, jsonPath) + df <- jsonFile(sqlContext, jsonPath) filtered <- filter(df, "age > 20") - expect_true(count(filtered) == 1) - expect_true(collect(filtered)$name == "Andy") + expect_equal(count(filtered), 1) + expect_equal(collect(filtered)$name, "Andy") filtered2 <- where(df, df$name != "Michael") - expect_true(count(filtered2) == 2) - expect_true(collect(filtered2)$age[2] == 19) + expect_equal(count(filtered2), 2) + expect_equal(collect(filtered2)$age[2], 19) + + # test suites for %in% + filtered3 <- filter(df, "age in (19)") + expect_equal(count(filtered3), 1) + filtered4 <- filter(df, "age in (19, 30)") + expect_equal(count(filtered4), 2) + filtered5 <- where(df, df$age %in% c(19)) + expect_equal(count(filtered5), 1) + filtered6 <- where(df, df$age %in% c(19, 30)) + expect_equal(count(filtered6), 2) }) test_that("join() on a DataFrame", { - df <- jsonFile(sqlCtx, jsonPath) + df <- jsonFile(sqlContext, jsonPath) mockLines2 <- c("{\"name\":\"Michael\", \"test\": \"yes\"}", "{\"name\":\"Andy\", \"test\": \"no\"}", @@ -645,125 +723,232 @@ test_that("join() on a DataFrame", { "{\"name\":\"Bob\", \"test\": \"yes\"}") jsonPath2 <- tempfile(pattern="sparkr-test", fileext=".tmp") writeLines(mockLines2, jsonPath2) - df2 <- jsonFile(sqlCtx, jsonPath2) + df2 <- jsonFile(sqlContext, jsonPath2) joined <- join(df, df2) expect_equal(names(joined), c("age", "name", "name", "test")) - expect_true(count(joined) == 12) + expect_equal(count(joined), 12) joined2 <- join(df, df2, df$name == df2$name) expect_equal(names(joined2), c("age", "name", "name", "test")) - expect_true(count(joined2) == 3) + expect_equal(count(joined2), 3) joined3 <- join(df, df2, df$name == df2$name, "right_outer") expect_equal(names(joined3), c("age", "name", "name", "test")) - expect_true(count(joined3) == 4) + expect_equal(count(joined3), 4) expect_true(is.na(collect(orderBy(joined3, joined3$age))$age[2])) joined4 <- select(join(df, df2, df$name == df2$name, "outer"), alias(df$age + 5, "newAge"), df$name, df2$test) expect_equal(names(joined4), c("newAge", "name", "test")) - expect_true(count(joined4) == 4) + expect_equal(count(joined4), 4) expect_equal(collect(orderBy(joined4, joined4$name))$newAge[3], 24) }) test_that("toJSON() returns an RDD of the correct values", { - df <- jsonFile(sqlCtx, jsonPath) + df <- jsonFile(sqlContext, jsonPath) testRDD <- toJSON(df) - expect_true(inherits(testRDD, "RDD")) - expect_true(SparkR:::getSerializedMode(testRDD) == "string") + expect_is(testRDD, "RDD") + expect_equal(SparkR:::getSerializedMode(testRDD), "string") expect_equal(collect(testRDD)[[1]], mockLines[1]) }) test_that("showDF()", { - df <- jsonFile(sqlCtx, jsonPath) + df <- jsonFile(sqlContext, jsonPath) s <- capture.output(showDF(df)) - expect_output(s , "+----+-------+\n| age| name|\n+----+-------+\n|null|Michael|\n| 30| Andy|\n| 19| Justin|\n+----+-------+\n") + expected <- paste("+----+-------+\n", + "| age| name|\n", + "+----+-------+\n", + "|null|Michael|\n", + "| 30| Andy|\n", + "| 19| Justin|\n", + "+----+-------+\n", sep="") + expect_output(s , expected) }) test_that("isLocal()", { - df <- jsonFile(sqlCtx, jsonPath) + df <- jsonFile(sqlContext, jsonPath) expect_false(isLocal(df)) }) test_that("unionAll(), except(), and intersect() on a DataFrame", { - df <- jsonFile(sqlCtx, jsonPath) + df <- jsonFile(sqlContext, jsonPath) lines <- c("{\"name\":\"Bob\", \"age\":24}", "{\"name\":\"Andy\", \"age\":30}", "{\"name\":\"James\", \"age\":35}") jsonPath2 <- tempfile(pattern="sparkr-test", fileext=".tmp") writeLines(lines, jsonPath2) - df2 <- read.df(sqlCtx, jsonPath2, "json") + df2 <- read.df(sqlContext, jsonPath2, "json") unioned <- arrange(unionAll(df, df2), df$age) - expect_true(inherits(unioned, "DataFrame")) - expect_true(count(unioned) == 6) - expect_true(first(unioned)$name == "Michael") + expect_is(unioned, "DataFrame") + expect_equal(count(unioned), 6) + expect_equal(first(unioned)$name, "Michael") excepted <- arrange(except(df, df2), desc(df$age)) - expect_true(inherits(unioned, "DataFrame")) - expect_true(count(excepted) == 2) - expect_true(first(excepted)$name == "Justin") + expect_is(unioned, "DataFrame") + expect_equal(count(excepted), 2) + expect_equal(first(excepted)$name, "Justin") intersected <- arrange(intersect(df, df2), df$age) - expect_true(inherits(unioned, "DataFrame")) - expect_true(count(intersected) == 1) - expect_true(first(intersected)$name == "Andy") + expect_is(unioned, "DataFrame") + expect_equal(count(intersected), 1) + expect_equal(first(intersected)$name, "Andy") }) test_that("withColumn() and withColumnRenamed()", { - df <- jsonFile(sqlCtx, jsonPath) + df <- jsonFile(sqlContext, jsonPath) newDF <- withColumn(df, "newAge", df$age + 2) - expect_true(length(columns(newDF)) == 3) - expect_true(columns(newDF)[3] == "newAge") - expect_true(first(filter(newDF, df$name != "Michael"))$newAge == 32) + expect_equal(length(columns(newDF)), 3) + expect_equal(columns(newDF)[3], "newAge") + expect_equal(first(filter(newDF, df$name != "Michael"))$newAge, 32) newDF2 <- withColumnRenamed(df, "age", "newerAge") - expect_true(length(columns(newDF2)) == 2) - expect_true(columns(newDF2)[1] == "newerAge") + expect_equal(length(columns(newDF2)), 2) + expect_equal(columns(newDF2)[1], "newerAge") }) test_that("mutate() and rename()", { - df <- jsonFile(sqlCtx, jsonPath) + df <- jsonFile(sqlContext, jsonPath) newDF <- mutate(df, newAge = df$age + 2) - expect_true(length(columns(newDF)) == 3) - expect_true(columns(newDF)[3] == "newAge") - expect_true(first(filter(newDF, df$name != "Michael"))$newAge == 32) + expect_equal(length(columns(newDF)), 3) + expect_equal(columns(newDF)[3], "newAge") + expect_equal(first(filter(newDF, df$name != "Michael"))$newAge, 32) newDF2 <- rename(df, newerAge = df$age) - expect_true(length(columns(newDF2)) == 2) - expect_true(columns(newDF2)[1] == "newerAge") + expect_equal(length(columns(newDF2)), 2) + expect_equal(columns(newDF2)[1], "newerAge") }) test_that("write.df() on DataFrame and works with parquetFile", { - df <- jsonFile(sqlCtx, jsonPath) + df <- jsonFile(sqlContext, jsonPath) write.df(df, parquetPath, "parquet", mode="overwrite") - parquetDF <- parquetFile(sqlCtx, parquetPath) - expect_true(inherits(parquetDF, "DataFrame")) + parquetDF <- parquetFile(sqlContext, parquetPath) + expect_is(parquetDF, "DataFrame") expect_equal(count(df), count(parquetDF)) }) test_that("parquetFile works with multiple input paths", { - df <- jsonFile(sqlCtx, jsonPath) + df <- jsonFile(sqlContext, jsonPath) write.df(df, parquetPath, "parquet", mode="overwrite") parquetPath2 <- tempfile(pattern = "parquetPath2", fileext = ".parquet") write.df(df, parquetPath2, "parquet", mode="overwrite") - parquetDF <- parquetFile(sqlCtx, parquetPath, parquetPath2) - expect_true(inherits(parquetDF, "DataFrame")) - expect_true(count(parquetDF) == count(df)*2) + parquetDF <- parquetFile(sqlContext, parquetPath, parquetPath2) + expect_is(parquetDF, "DataFrame") + expect_equal(count(parquetDF), count(df)*2) }) test_that("describe() on a DataFrame", { - df <- jsonFile(sqlCtx, jsonPath) + df <- jsonFile(sqlContext, jsonPath) stats <- describe(df, "age") - expect_true(collect(stats)[1, "summary"] == "count") - expect_true(collect(stats)[2, "age"] == 24.5) - expect_true(collect(stats)[3, "age"] == 5.5) + expect_equal(collect(stats)[1, "summary"], "count") + expect_equal(collect(stats)[2, "age"], "24.5") + expect_equal(collect(stats)[3, "age"], "5.5") stats <- describe(df) - expect_true(collect(stats)[4, "name"] == "Andy") - expect_true(collect(stats)[5, "age"] == 30.0) + expect_equal(collect(stats)[4, "name"], "Andy") + expect_equal(collect(stats)[5, "age"], "30") +}) + +test_that("dropna() on a DataFrame", { + df <- jsonFile(sqlContext, jsonPathNa) + rows <- collect(df) + + # drop with columns + + expected <- rows[!is.na(rows$name),] + actual <- collect(dropna(df, cols = "name")) + expect_identical(expected, actual) + + expected <- rows[!is.na(rows$age),] + actual <- collect(dropna(df, cols = "age")) + row.names(expected) <- row.names(actual) + # identical on two dataframes does not work here. Don't know why. + # use identical on all columns as a workaround. + expect_identical(expected$age, actual$age) + expect_identical(expected$height, actual$height) + expect_identical(expected$name, actual$name) + + expected <- rows[!is.na(rows$age) & !is.na(rows$height),] + actual <- collect(dropna(df, cols = c("age", "height"))) + expect_identical(expected, actual) + + expected <- rows[!is.na(rows$age) & !is.na(rows$height) & !is.na(rows$name),] + actual <- collect(dropna(df)) + expect_identical(expected, actual) + + # drop with how + + expected <- rows[!is.na(rows$age) & !is.na(rows$height) & !is.na(rows$name),] + actual <- collect(dropna(df)) + expect_identical(expected, actual) + + expected <- rows[!is.na(rows$age) | !is.na(rows$height) | !is.na(rows$name),] + actual <- collect(dropna(df, "all")) + expect_identical(expected, actual) + + expected <- rows[!is.na(rows$age) & !is.na(rows$height) & !is.na(rows$name),] + actual <- collect(dropna(df, "any")) + expect_identical(expected, actual) + + expected <- rows[!is.na(rows$age) & !is.na(rows$height),] + actual <- collect(dropna(df, "any", cols = c("age", "height"))) + expect_identical(expected, actual) + + expected <- rows[!is.na(rows$age) | !is.na(rows$height),] + actual <- collect(dropna(df, "all", cols = c("age", "height"))) + expect_identical(expected, actual) + + # drop with threshold + + expected <- rows[as.integer(!is.na(rows$age)) + as.integer(!is.na(rows$height)) >= 2,] + actual <- collect(dropna(df, minNonNulls = 2, cols = c("age", "height"))) + expect_identical(expected, actual) + + expected <- rows[as.integer(!is.na(rows$age)) + + as.integer(!is.na(rows$height)) + + as.integer(!is.na(rows$name)) >= 3,] + actual <- collect(dropna(df, minNonNulls = 3, cols = c("name", "age", "height"))) + expect_identical(expected, actual) +}) + +test_that("fillna() on a DataFrame", { + df <- jsonFile(sqlContext, jsonPathNa) + rows <- collect(df) + + # fill with value + + expected <- rows + expected$age[is.na(expected$age)] <- 50 + expected$height[is.na(expected$height)] <- 50.6 + actual <- collect(fillna(df, 50.6)) + expect_identical(expected, actual) + + expected <- rows + expected$name[is.na(expected$name)] <- "unknown" + actual <- collect(fillna(df, "unknown")) + expect_identical(expected, actual) + + expected <- rows + expected$age[is.na(expected$age)] <- 50 + actual <- collect(fillna(df, 50.6, "age")) + expect_identical(expected, actual) + + expected <- rows + expected$name[is.na(expected$name)] <- "unknown" + actual <- collect(fillna(df, "unknown", c("age", "name"))) + expect_identical(expected, actual) + + # fill with named list + + expected <- rows + expected$age[is.na(expected$age)] <- 50 + expected$height[is.na(expected$height)] <- 50.6 + expected$name[is.na(expected$name)] <- "unknown" + actual <- collect(fillna(df, list("age" = 50, "height" = 50.6, "name" = "unknown"))) + expect_identical(expected, actual) }) unlink(parquetPath) unlink(jsonPath) +unlink(jsonPathNa) diff --git a/R/pkg/inst/tests/test_take.R b/R/pkg/inst/tests/test_take.R index 7f4c7c315d78..c2c724cdc762 100644 --- a/R/pkg/inst/tests/test_take.R +++ b/R/pkg/inst/tests/test_take.R @@ -59,9 +59,8 @@ test_that("take() gives back the original elements in correct count and order", expect_equal(take(strListRDD, 3), as.list(head(strList, n = 3))) expect_equal(take(strListRDD2, 1), as.list(head(strList, n = 1))) - expect_true(length(take(strListRDD, 0)) == 0) - expect_true(length(take(strVectorRDD, 0)) == 0) - expect_true(length(take(numListRDD, 0)) == 0) - expect_true(length(take(numVectorRDD, 0)) == 0) + expect_equal(length(take(strListRDD, 0)), 0) + expect_equal(length(take(strVectorRDD, 0)), 0) + expect_equal(length(take(numListRDD, 0)), 0) + expect_equal(length(take(numVectorRDD, 0)), 0) }) - diff --git a/R/pkg/inst/tests/test_textFile.R b/R/pkg/inst/tests/test_textFile.R index 6b87b4b3e0b0..58318dfef71a 100644 --- a/R/pkg/inst/tests/test_textFile.R +++ b/R/pkg/inst/tests/test_textFile.R @@ -27,9 +27,9 @@ test_that("textFile() on a local file returns an RDD", { writeLines(mockFile, fileName) rdd <- textFile(sc, fileName) - expect_true(inherits(rdd, "RDD")) + expect_is(rdd, "RDD") expect_true(count(rdd) > 0) - expect_true(count(rdd) == 2) + expect_equal(count(rdd), 2) unlink(fileName) }) @@ -58,7 +58,7 @@ test_that("textFile() word count works as expected", { expected <- list(list("pretty.", 1), list("is", 2), list("awesome.", 1), list("Spark", 2)) expect_equal(sortKeyValueList(output), sortKeyValueList(expected)) - + unlink(fileName) }) @@ -115,13 +115,13 @@ test_that("textFile() and saveAsTextFile() word count works as expected", { saveAsTextFile(counts, fileName2) rdd <- textFile(sc, fileName2) - + output <- collect(rdd) expected <- list(list("awesome.", 1), list("Spark", 2), list("pretty.", 1), list("is", 2)) expectedStr <- lapply(expected, function(x) { toString(x) }) expect_equal(sortKeyValueList(output), sortKeyValueList(expectedStr)) - + unlink(fileName1) unlink(fileName2) }) @@ -133,7 +133,7 @@ test_that("textFile() on multiple paths", { writeLines("Spark is awesome.", fileName2) rdd <- textFile(sc, c(fileName1, fileName2)) - expect_true(count(rdd) == 2) + expect_equal(count(rdd), 2) unlink(fileName1) unlink(fileName2) @@ -159,4 +159,3 @@ test_that("Pipelined operations on RDDs created using textFile", { unlink(fileName) }) - diff --git a/R/pkg/inst/tests/test_utils.R b/R/pkg/inst/tests/test_utils.R index 539e3a3c19df..aa0d2a66b908 100644 --- a/R/pkg/inst/tests/test_utils.R +++ b/R/pkg/inst/tests/test_utils.R @@ -43,13 +43,13 @@ test_that("serializeToBytes on RDD", { mockFile <- c("Spark is pretty.", "Spark is awesome.") fileName <- tempfile(pattern="spark-test", fileext=".tmp") writeLines(mockFile, fileName) - + text.rdd <- textFile(sc, fileName) - expect_true(getSerializedMode(text.rdd) == "string") + expect_equal(getSerializedMode(text.rdd), "string") ser.rdd <- serializeToBytes(text.rdd) expect_equal(collect(ser.rdd), as.list(mockFile)) - expect_true(getSerializedMode(ser.rdd) == "byte") - + expect_equal(getSerializedMode(ser.rdd), "byte") + unlink(fileName) }) @@ -64,7 +64,7 @@ test_that("cleanClosure on R functions", { expect_equal(actual, y) actual <- get("g", envir = env, inherits = FALSE) expect_equal(actual, g) - + # Test for nested enclosures and package variables. env2 <- new.env() funcEnv <- new.env(parent = env2) @@ -106,7 +106,7 @@ test_that("cleanClosure on R functions", { expect_equal(length(ls(env)), 1) actual <- get("y", envir = env, inherits = FALSE) expect_equal(actual, y) - + # Test for function (and variable) definitions. f <- function(x) { g <- function(y) { y * 2 } @@ -115,7 +115,7 @@ test_that("cleanClosure on R functions", { newF <- cleanClosure(f) env <- environment(newF) expect_equal(length(ls(env)), 0) # "y" and "g" should not be included. - + # Test for overriding variables in base namespace (Issue: SparkR-196). nums <- as.list(1:10) rdd <- parallelize(sc, nums, 2L) @@ -128,7 +128,7 @@ test_that("cleanClosure on R functions", { actual <- collect(lapply(rdd, f)) expected <- as.list(c(rep(FALSE, 4), rep(TRUE, 6))) expect_equal(actual, expected) - + # Test for broadcast variables. a <- matrix(nrow=10, ncol=10, data=rnorm(100)) aBroadcast <- broadcast(sc, a) diff --git a/R/pkg/src/Makefile b/R/pkg/src-native/Makefile similarity index 100% rename from R/pkg/src/Makefile rename to R/pkg/src-native/Makefile diff --git a/R/pkg/src/Makefile.win b/R/pkg/src-native/Makefile.win similarity index 100% rename from R/pkg/src/Makefile.win rename to R/pkg/src-native/Makefile.win diff --git a/R/pkg/src/string_hash_code.c b/R/pkg/src-native/string_hash_code.c similarity index 100% rename from R/pkg/src/string_hash_code.c rename to R/pkg/src-native/string_hash_code.c diff --git a/README.md b/README.md index 9c09d40e2bda..380422ca00db 100644 --- a/README.md +++ b/README.md @@ -3,8 +3,8 @@ Spark is a fast and general cluster computing system for Big Data. It provides high-level APIs in Scala, Java, and Python, and an optimized engine that supports general computation graphs for data analysis. It also supports a -rich set of higher-level tools including Spark SQL for SQL and structured -data processing, MLlib for machine learning, GraphX for graph processing, +rich set of higher-level tools including Spark SQL for SQL and DataFrames, +MLlib for machine learning, GraphX for graph processing, and Spark Streaming for stream processing. @@ -22,7 +22,7 @@ This README file only contains basic setup instructions. Spark is built using [Apache Maven](http://maven.apache.org/). To build Spark and its example programs, run: - mvn -DskipTests clean package + build/mvn -DskipTests clean package (You do not need to do this if you downloaded a pre-built package.) More detailed documentation is available from the project site, at @@ -43,7 +43,7 @@ Try the following command, which should return 1000: Alternatively, if you prefer Python, you can use the Python shell: ./bin/pyspark - + And run the following command, which should also return 1000: >>> sc.parallelize(range(1000)).count() @@ -58,9 +58,9 @@ To run one of them, use `./bin/run-example [params]`. For example: will run the Pi example locally. You can set the MASTER environment variable when running examples to submit -examples to a cluster. This can be a mesos:// or spark:// URL, -"yarn-cluster" or "yarn-client" to run on YARN, and "local" to run -locally with one thread, or "local[N]" to run locally with N threads. You +examples to a cluster. This can be a mesos:// or spark:// URL, +"yarn-cluster" or "yarn-client" to run on YARN, and "local" to run +locally with one thread, or "local[N]" to run locally with N threads. You can also use an abbreviated class name if the class is in the `examples` package. For instance: @@ -75,7 +75,7 @@ can be run using: ./dev/run-tests -Please see the guidance on how to +Please see the guidance on how to [run tests for a module, or individual tests](https://cwiki.apache.org/confluence/display/SPARK/Useful+Developer+Tools). ## A Note About Hadoop Versions diff --git a/assembly/pom.xml b/assembly/pom.xml index 626c8577e31f..e9c6d26ccddc 100644 --- a/assembly/pom.xml +++ b/assembly/pom.xml @@ -21,7 +21,7 @@ org.apache.spark spark-parent_2.10 - 1.4.0-SNAPSHOT + 1.5.0-SNAPSHOT ../pom.xml diff --git a/bagel/pom.xml b/bagel/pom.xml index 1f3dec91314f..ed5c37e595a9 100644 --- a/bagel/pom.xml +++ b/bagel/pom.xml @@ -21,7 +21,7 @@ org.apache.spark spark-parent_2.10 - 1.4.0-SNAPSHOT + 1.5.0-SNAPSHOT ../pom.xml @@ -40,6 +40,13 @@ spark-core_${scala.binary.version} ${project.version} + + org.apache.spark + spark-core_${scala.binary.version} + ${project.version} + test-jar + test + org.scalacheck scalacheck_${scala.binary.version} diff --git a/bagel/src/test/scala/org/apache/spark/bagel/BagelSuite.scala b/bagel/src/test/scala/org/apache/spark/bagel/BagelSuite.scala index ccb262a4ee02..fb10d734ac74 100644 --- a/bagel/src/test/scala/org/apache/spark/bagel/BagelSuite.scala +++ b/bagel/src/test/scala/org/apache/spark/bagel/BagelSuite.scala @@ -17,7 +17,7 @@ package org.apache.spark.bagel -import org.scalatest.{BeforeAndAfter, FunSuite, Assertions} +import org.scalatest.{BeforeAndAfter, Assertions} import org.scalatest.concurrent.Timeouts import org.scalatest.time.SpanSugar._ @@ -27,7 +27,7 @@ import org.apache.spark.storage.StorageLevel class TestVertex(val active: Boolean, val age: Int) extends Vertex with Serializable class TestMessage(val targetId: String) extends Message[String] with Serializable -class BagelSuite extends FunSuite with Assertions with BeforeAndAfter with Timeouts { +class BagelSuite extends SparkFunSuite with Assertions with BeforeAndAfter with Timeouts { var sc: SparkContext = _ diff --git a/bin/pyspark b/bin/pyspark index 8acad6113797..f9dbddfa5356 100755 --- a/bin/pyspark +++ b/bin/pyspark @@ -17,24 +17,10 @@ # limitations under the License. # -# Figure out where Spark is installed export SPARK_HOME="$(cd "`dirname "$0"`"/..; pwd)" source "$SPARK_HOME"/bin/load-spark-env.sh - -function usage() { - if [ -n "$1" ]; then - echo $1 - fi - echo "Usage: ./bin/pyspark [options]" 1>&2 - "$SPARK_HOME"/bin/spark-submit --help 2>&1 | grep -v Usage 1>&2 - exit $2 -} -export -f usage - -if [[ "$@" = *--help ]] || [[ "$@" = *-h ]]; then - usage -fi +export _SPARK_CMD_USAGE="Usage: ./bin/pyspark [options]" # In Spark <= 1.1, setting IPYTHON=1 would cause the driver to be launched using the `ipython` # executable, while the worker would still be launched using PYSPARK_PYTHON. @@ -90,11 +76,7 @@ if [[ -n "$SPARK_TESTING" ]]; then unset YARN_CONF_DIR unset HADOOP_CONF_DIR export PYTHONHASHSEED=0 - if [[ -n "$PYSPARK_DOC_TEST" ]]; then - exec "$PYSPARK_DRIVER_PYTHON" -m doctest $1 - else - exec "$PYSPARK_DRIVER_PYTHON" $1 - fi + exec "$PYSPARK_DRIVER_PYTHON" -m $1 exit fi diff --git a/bin/pyspark2.cmd b/bin/pyspark2.cmd index 09b4149c2a43..45e9e3def512 100644 --- a/bin/pyspark2.cmd +++ b/bin/pyspark2.cmd @@ -21,6 +21,7 @@ rem Figure out where the Spark framework is installed set SPARK_HOME=%~dp0.. call %SPARK_HOME%\bin\load-spark-env.cmd +set _SPARK_CMD_USAGE=Usage: bin\pyspark.cmd [options] rem Figure out which Python to use. if "x%PYSPARK_DRIVER_PYTHON%"=="x" ( diff --git a/bin/spark-class b/bin/spark-class index c49d97ce5cf2..2b59e5df5736 100755 --- a/bin/spark-class +++ b/bin/spark-class @@ -16,18 +16,12 @@ # See the License for the specific language governing permissions and # limitations under the License. # -set -e # Figure out where Spark is installed export SPARK_HOME="$(cd "`dirname "$0"`"/..; pwd)" . "$SPARK_HOME"/bin/load-spark-env.sh -if [ -z "$1" ]; then - echo "Usage: spark-class []" 1>&2 - exit 1 -fi - # Find the java binary if [ -n "${JAVA_HOME}" ]; then RUNNER="${JAVA_HOME}/bin/java" @@ -64,24 +58,6 @@ fi SPARK_ASSEMBLY_JAR="${ASSEMBLY_DIR}/${ASSEMBLY_JARS}" -# Verify that versions of java used to build the jars and run Spark are compatible -if [ -n "$JAVA_HOME" ]; then - JAR_CMD="$JAVA_HOME/bin/jar" -else - JAR_CMD="jar" -fi - -if [ $(command -v "$JAR_CMD") ] ; then - jar_error_check=$("$JAR_CMD" -tf "$SPARK_ASSEMBLY_JAR" nonexistent/class/path 2>&1) - if [[ "$jar_error_check" =~ "invalid CEN header" ]]; then - echo "Loading Spark jar with '$JAR_CMD' failed. " 1>&2 - echo "This is likely because Spark was compiled with Java 7 and run " 1>&2 - echo "with Java 6. (see SPARK-1703). Please use Java 7 to run Spark " 1>&2 - echo "or build Spark with Java 6." 1>&2 - exit 1 - fi -fi - LAUNCH_CLASSPATH="$SPARK_ASSEMBLY_JAR" # Add the launcher build dir to the classpath if requested. @@ -98,9 +74,4 @@ CMD=() while IFS= read -d '' -r ARG; do CMD+=("$ARG") done < <("$RUNNER" -cp "$LAUNCH_CLASSPATH" org.apache.spark.launcher.Main "$@") - -if [ "${CMD[0]}" = "usage" ]; then - "${CMD[@]}" -else - exec "${CMD[@]}" -fi +exec "${CMD[@]}" diff --git a/bin/spark-shell b/bin/spark-shell index b3761b5e1375..a6dc863d83fc 100755 --- a/bin/spark-shell +++ b/bin/spark-shell @@ -29,20 +29,7 @@ esac set -o posix export FWDIR="$(cd "`dirname "$0"`"/..; pwd)" - -usage() { - if [ -n "$1" ]; then - echo "$1" - fi - echo "Usage: ./bin/spark-shell [options]" - "$FWDIR"/bin/spark-submit --help 2>&1 | grep -v Usage 1>&2 - exit "$2" -} -export -f usage - -if [[ "$@" = *--help ]] || [[ "$@" = *-h ]]; then - usage "" 0 -fi +export _SPARK_CMD_USAGE="Usage: ./bin/spark-shell [options]" # SPARK-4161: scala does not assume use of the java classpath, # so we need to add the "-Dscala.usejavacp=true" flag manually. We diff --git a/bin/spark-shell2.cmd b/bin/spark-shell2.cmd index 00fd30fa38d3..251309d67f86 100644 --- a/bin/spark-shell2.cmd +++ b/bin/spark-shell2.cmd @@ -18,12 +18,7 @@ rem limitations under the License. rem set SPARK_HOME=%~dp0.. - -echo "%*" | findstr " \<--help\> \<-h\>" >nul -if %ERRORLEVEL% equ 0 ( - call :usage - exit /b 0 -) +set _SPARK_CMD_USAGE=Usage: .\bin\spark-shell.cmd [options] rem SPARK-4161: scala does not assume use of the java classpath, rem so we need to add the "-Dscala.usejavacp=true" flag manually. We @@ -37,16 +32,4 @@ if "x%SPARK_SUBMIT_OPTS%"=="x" ( set SPARK_SUBMIT_OPTS="%SPARK_SUBMIT_OPTS% -Dscala.usejavacp=true" :run_shell -call %SPARK_HOME%\bin\spark-submit2.cmd --class org.apache.spark.repl.Main %* -set SPARK_ERROR_LEVEL=%ERRORLEVEL% -if not "x%SPARK_LAUNCHER_USAGE_ERROR%"=="x" ( - call :usage - exit /b 1 -) -exit /b %SPARK_ERROR_LEVEL% - -:usage -echo %SPARK_LAUNCHER_USAGE_ERROR% -echo "Usage: .\bin\spark-shell.cmd [options]" >&2 -call %SPARK_HOME%\bin\spark-submit2.cmd --help 2>&1 | findstr /V "Usage" 1>&2 -goto :eof +%SPARK_HOME%\bin\spark-submit2.cmd --class org.apache.spark.repl.Main %* diff --git a/bin/spark-sql b/bin/spark-sql index ca1729f4cfcb..4ea7bc6e39c0 100755 --- a/bin/spark-sql +++ b/bin/spark-sql @@ -17,41 +17,6 @@ # limitations under the License. # -# -# Shell script for starting the Spark SQL CLI - -# Enter posix mode for bash -set -o posix - -# NOTE: This exact class name is matched downstream by SparkSubmit. -# Any changes need to be reflected there. -export CLASS="org.apache.spark.sql.hive.thriftserver.SparkSQLCLIDriver" - -# Figure out where Spark is installed export FWDIR="$(cd "`dirname "$0"`"/..; pwd)" - -function usage { - if [ -n "$1" ]; then - echo "$1" - fi - echo "Usage: ./bin/spark-sql [options] [cli option]" - pattern="usage" - pattern+="\|Spark assembly has been built with Hive" - pattern+="\|NOTE: SPARK_PREPEND_CLASSES is set" - pattern+="\|Spark Command: " - pattern+="\|--help" - pattern+="\|=======" - - "$FWDIR"/bin/spark-submit --help 2>&1 | grep -v Usage 1>&2 - echo - echo "CLI options:" - "$FWDIR"/bin/spark-class "$CLASS" --help 2>&1 | grep -v "$pattern" 1>&2 - exit "$2" -} -export -f usage - -if [[ "$@" = *--help ]] || [[ "$@" = *-h ]]; then - usage "" 0 -fi - -exec "$FWDIR"/bin/spark-submit --class "$CLASS" "$@" +export _SPARK_CMD_USAGE="Usage: ./bin/spark-sql [options] [cli option]" +exec "$FWDIR"/bin/spark-submit --class org.apache.spark.sql.hive.thriftserver.SparkSQLCLIDriver "$@" diff --git a/bin/spark-submit b/bin/spark-submit index 0e0afe71a0f0..255378b0f077 100755 --- a/bin/spark-submit +++ b/bin/spark-submit @@ -22,16 +22,4 @@ SPARK_HOME="$(cd "`dirname "$0"`"/..; pwd)" # disable randomized hash for string in Python 3.3+ export PYTHONHASHSEED=0 -# Only define a usage function if an upstream script hasn't done so. -if ! type -t usage >/dev/null 2>&1; then - usage() { - if [ -n "$1" ]; then - echo "$1" - fi - "$SPARK_HOME"/bin/spark-class org.apache.spark.deploy.SparkSubmit --help - exit "$2" - } - export -f usage -fi - exec "$SPARK_HOME"/bin/spark-class org.apache.spark.deploy.SparkSubmit "$@" diff --git a/bin/spark-submit2.cmd b/bin/spark-submit2.cmd index d3fc4a5cc3f6..651376e52692 100644 --- a/bin/spark-submit2.cmd +++ b/bin/spark-submit2.cmd @@ -24,15 +24,4 @@ rem disable randomized hash for string in Python 3.3+ set PYTHONHASHSEED=0 set CLASS=org.apache.spark.deploy.SparkSubmit -call %~dp0spark-class2.cmd %CLASS% %* -set SPARK_ERROR_LEVEL=%ERRORLEVEL% -if not "x%SPARK_LAUNCHER_USAGE_ERROR%"=="x" ( - call :usage - exit /b 1 -) -exit /b %SPARK_ERROR_LEVEL% - -:usage -echo %SPARK_LAUNCHER_USAGE_ERROR% -call %SPARK_HOME%\bin\spark-class2.cmd %CLASS% --help -goto :eof +%~dp0spark-class2.cmd %CLASS% %* diff --git a/bin/sparkR b/bin/sparkR index 8c918e2b09ae..464c29f36942 100755 --- a/bin/sparkR +++ b/bin/sparkR @@ -17,23 +17,7 @@ # limitations under the License. # -# Figure out where Spark is installed export SPARK_HOME="$(cd "`dirname "$0"`"/..; pwd)" - source "$SPARK_HOME"/bin/load-spark-env.sh - -function usage() { - if [ -n "$1" ]; then - echo $1 - fi - echo "Usage: ./bin/sparkR [options]" 1>&2 - "$SPARK_HOME"/bin/spark-submit --help 2>&1 | grep -v Usage 1>&2 - exit $2 -} -export -f usage - -if [[ "$@" = *--help ]] || [[ "$@" = *-h ]]; then - usage -fi - +export _SPARK_CMD_USAGE="Usage: ./bin/sparkR [options]" exec "$SPARK_HOME"/bin/spark-submit sparkr-shell-main "$@" diff --git a/build/mvn b/build/mvn index 3561110a4c01..e8364181e823 100755 --- a/build/mvn +++ b/build/mvn @@ -69,11 +69,14 @@ install_app() { # Install maven under the build/ folder install_mvn() { + local MVN_VERSION="3.3.3" + install_app \ - "http://archive.apache.org/dist/maven/maven-3/3.2.5/binaries" \ - "apache-maven-3.2.5-bin.tar.gz" \ - "apache-maven-3.2.5/bin/mvn" - MVN_BIN="${_DIR}/apache-maven-3.2.5/bin/mvn" + "http://archive.apache.org/dist/maven/maven-3/${MVN_VERSION}/binaries" \ + "apache-maven-${MVN_VERSION}-bin.tar.gz" \ + "apache-maven-${MVN_VERSION}/bin/mvn" + + MVN_BIN="${_DIR}/apache-maven-${MVN_VERSION}/bin/mvn" } # Install zinc under the build/ folder @@ -105,28 +108,16 @@ install_scala() { SCALA_LIBRARY="$(cd "$(dirname ${scala_bin})/../lib" && pwd)/scala-library.jar" } -# Determines if a given application is already installed. If not, will attempt -# to install -## Arg1 - application name -## Arg2 - Alternate path to local install under build/ dir -check_and_install_app() { - # create the local environment variable in uppercase - local app_bin="`echo $1 | awk '{print toupper(\$0)}'`_BIN" - # some black magic to set the generated app variable (i.e. MVN_BIN) into the - # environment - eval "${app_bin}=`which $1 2>/dev/null`" - - if [ -z "`which $1 2>/dev/null`" ]; then - install_$1 - fi -} - # Setup healthy defaults for the Zinc port if none were provided from # the environment ZINC_PORT=${ZINC_PORT:-"3030"} -# Check and install all applications necessary to build Spark -check_and_install_app "mvn" +# Install Maven if necessary +MVN_BIN="$(command -v mvn)" + +if [ ! "$MVN_BIN" ]; then + install_mvn +fi # Install the proper version of Scala and Zinc for the build install_zinc diff --git a/conf/metrics.properties.template b/conf/metrics.properties.template index 2e0cb5db170a..7f17bc7eea4f 100644 --- a/conf/metrics.properties.template +++ b/conf/metrics.properties.template @@ -4,7 +4,7 @@ # divided into instances which correspond to internal components. # Each instance can be configured to report its metrics to one or more sinks. # Accepted values for [instance] are "master", "worker", "executor", "driver", -# and "applications". A wild card "*" can be used as an instance name, in +# and "applications". A wildcard "*" can be used as an instance name, in # which case all instances will inherit the supplied property. # # Within an instance, a "source" specifies a particular set of grouped metrics. @@ -32,7 +32,7 @@ # name (see examples below). # 2. Some sinks involve a polling period. The minimum allowed polling period # is 1 second. -# 3. Wild card properties can be overridden by more specific properties. +# 3. Wildcard properties can be overridden by more specific properties. # For example, master.sink.console.period takes precedence over # *.sink.console.period. # 4. A metrics specific configuration @@ -47,6 +47,13 @@ # instance master and applications. MetricsServlet may not be configured by self. # +## List of available common sources and their properties. + +# org.apache.spark.metrics.source.JvmSource +# Note: Currently, JvmSource is the only available common source +# to add additionaly to an instance, to enable this, +# set the "class" option to its fully qulified class name (see examples below) + ## List of available sinks and their properties. # org.apache.spark.metrics.sink.ConsoleSink @@ -126,9 +133,9 @@ #*.sink.slf4j.class=org.apache.spark.metrics.sink.Slf4jSink # Polling period for Slf4JSink -#*.sink.sl4j.period=1 +#*.sink.slf4j.period=1 -#*.sink.sl4j.unit=minutes +#*.sink.slf4j.unit=minutes # Enable jvm source for instance master, worker, driver and executor diff --git a/conf/spark-env.sh.template b/conf/spark-env.sh.template index 43c4288912b1..192d3ae09113 100755 --- a/conf/spark-env.sh.template +++ b/conf/spark-env.sh.template @@ -22,7 +22,7 @@ # - SPARK_EXECUTOR_INSTANCES, Number of workers to start (Default: 2) # - SPARK_EXECUTOR_CORES, Number of cores for the workers (Default: 1). # - SPARK_EXECUTOR_MEMORY, Memory per Worker (e.g. 1000M, 2G) (Default: 1G) -# - SPARK_DRIVER_MEMORY, Memory for Master (e.g. 1000M, 2G) (Default: 512 Mb) +# - SPARK_DRIVER_MEMORY, Memory for Master (e.g. 1000M, 2G) (Default: 1G) # - SPARK_YARN_APP_NAME, The name of your application (Default: Spark) # - SPARK_YARN_QUEUE, The hadoop queue to use for allocation requests (Default: ‘default’) # - SPARK_YARN_DIST_FILES, Comma separated list of files to be distributed with the job. diff --git a/core/pom.xml b/core/pom.xml index bfa49d0d6dc2..aee0d9262060 100644 --- a/core/pom.xml +++ b/core/pom.xml @@ -21,7 +21,7 @@ org.apache.spark spark-parent_2.10 - 1.4.0-SNAPSHOT + 1.5.0-SNAPSHOT ../pom.xml @@ -69,16 +69,6 @@ org.apache.hadoop hadoop-client - - - javax.servlet - servlet-api - - - org.codehaus.jackson - jackson-mapper-asl - - org.apache.spark @@ -338,6 +328,12 @@ org.seleniumhq.selenium selenium-java + + + com.google.guava + guava + + test @@ -348,7 +344,7 @@ org.mockito - mockito-all + mockito-core test @@ -377,9 +373,15 @@ test - org.spark-project + net.razorvine pyrolite 4.4 + + + net.razorvine + serpent + + net.sf.py4j diff --git a/core/src/main/java/org/apache/spark/shuffle/sort/BypassMergeSortShuffleWriter.java b/core/src/main/java/org/apache/spark/shuffle/sort/BypassMergeSortShuffleWriter.java new file mode 100644 index 000000000000..d3d6280284be --- /dev/null +++ b/core/src/main/java/org/apache/spark/shuffle/sort/BypassMergeSortShuffleWriter.java @@ -0,0 +1,184 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.shuffle.sort; + +import java.io.File; +import java.io.FileInputStream; +import java.io.FileOutputStream; +import java.io.IOException; + +import scala.Product2; +import scala.Tuple2; +import scala.collection.Iterator; + +import com.google.common.io.Closeables; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import org.apache.spark.Partitioner; +import org.apache.spark.SparkConf; +import org.apache.spark.TaskContext; +import org.apache.spark.executor.ShuffleWriteMetrics; +import org.apache.spark.serializer.Serializer; +import org.apache.spark.serializer.SerializerInstance; +import org.apache.spark.storage.*; +import org.apache.spark.util.Utils; + +/** + * This class implements sort-based shuffle's hash-style shuffle fallback path. This write path + * writes incoming records to separate files, one file per reduce partition, then concatenates these + * per-partition files to form a single output file, regions of which are served to reducers. + * Records are not buffered in memory. This is essentially identical to + * {@link org.apache.spark.shuffle.hash.HashShuffleWriter}, except that it writes output in a format + * that can be served / consumed via {@link org.apache.spark.shuffle.IndexShuffleBlockResolver}. + *

+ * This write path is inefficient for shuffles with large numbers of reduce partitions because it + * simultaneously opens separate serializers and file streams for all partitions. As a result, + * {@link SortShuffleManager} only selects this write path when + *

    + *
  • no Ordering is specified,
  • + *
  • no Aggregator is specific, and
  • + *
  • the number of partitions is less than + * spark.shuffle.sort.bypassMergeThreshold.
  • + *
+ * + * This code used to be part of {@link org.apache.spark.util.collection.ExternalSorter} but was + * refactored into its own class in order to reduce code complexity; see SPARK-7855 for details. + *

+ * There have been proposals to completely remove this code path; see SPARK-6026 for details. + */ +final class BypassMergeSortShuffleWriter implements SortShuffleFileWriter { + + private final Logger logger = LoggerFactory.getLogger(BypassMergeSortShuffleWriter.class); + + private final int fileBufferSize; + private final boolean transferToEnabled; + private final int numPartitions; + private final BlockManager blockManager; + private final Partitioner partitioner; + private final ShuffleWriteMetrics writeMetrics; + private final Serializer serializer; + + /** Array of file writers, one for each partition */ + private BlockObjectWriter[] partitionWriters; + + public BypassMergeSortShuffleWriter( + SparkConf conf, + BlockManager blockManager, + Partitioner partitioner, + ShuffleWriteMetrics writeMetrics, + Serializer serializer) { + // Use getSizeAsKb (not bytes) to maintain backwards compatibility if no units are provided + this.fileBufferSize = (int) conf.getSizeAsKb("spark.shuffle.file.buffer", "32k") * 1024; + this.transferToEnabled = conf.getBoolean("spark.file.transferTo", true); + this.numPartitions = partitioner.numPartitions(); + this.blockManager = blockManager; + this.partitioner = partitioner; + this.writeMetrics = writeMetrics; + this.serializer = serializer; + } + + @Override + public void insertAll(Iterator> records) throws IOException { + assert (partitionWriters == null); + if (!records.hasNext()) { + return; + } + final SerializerInstance serInstance = serializer.newInstance(); + final long openStartTime = System.nanoTime(); + partitionWriters = new BlockObjectWriter[numPartitions]; + for (int i = 0; i < numPartitions; i++) { + final Tuple2 tempShuffleBlockIdPlusFile = + blockManager.diskBlockManager().createTempShuffleBlock(); + final File file = tempShuffleBlockIdPlusFile._2(); + final BlockId blockId = tempShuffleBlockIdPlusFile._1(); + partitionWriters[i] = + blockManager.getDiskWriter(blockId, file, serInstance, fileBufferSize, writeMetrics).open(); + } + // Creating the file to write to and creating a disk writer both involve interacting with + // the disk, and can take a long time in aggregate when we open many files, so should be + // included in the shuffle write time. + writeMetrics.incShuffleWriteTime(System.nanoTime() - openStartTime); + + while (records.hasNext()) { + final Product2 record = records.next(); + final K key = record._1(); + partitionWriters[partitioner.getPartition(key)].write(key, record._2()); + } + + for (BlockObjectWriter writer : partitionWriters) { + writer.commitAndClose(); + } + } + + @Override + public long[] writePartitionedFile( + BlockId blockId, + TaskContext context, + File outputFile) throws IOException { + // Track location of the partition starts in the output file + final long[] lengths = new long[numPartitions]; + if (partitionWriters == null) { + // We were passed an empty iterator + return lengths; + } + + final FileOutputStream out = new FileOutputStream(outputFile, true); + final long writeStartTime = System.nanoTime(); + boolean threwException = true; + try { + for (int i = 0; i < numPartitions; i++) { + final FileInputStream in = new FileInputStream(partitionWriters[i].fileSegment().file()); + boolean copyThrewException = true; + try { + lengths[i] = Utils.copyStream(in, out, false, transferToEnabled); + copyThrewException = false; + } finally { + Closeables.close(in, copyThrewException); + } + if (!blockManager.diskBlockManager().getFile(partitionWriters[i].blockId()).delete()) { + logger.error("Unable to delete file for partition {}", i); + } + } + threwException = false; + } finally { + Closeables.close(out, threwException); + writeMetrics.incShuffleWriteTime(System.nanoTime() - writeStartTime); + } + partitionWriters = null; + return lengths; + } + + @Override + public void stop() throws IOException { + if (partitionWriters != null) { + try { + final DiskBlockManager diskBlockManager = blockManager.diskBlockManager(); + for (BlockObjectWriter writer : partitionWriters) { + // This method explicitly does _not_ throw exceptions: + writer.revertPartialWritesAndClose(); + if (!diskBlockManager.getFile(writer.blockId()).delete()) { + logger.error("Error while deleting file for block {}", writer.blockId()); + } + } + } finally { + partitionWriters = null; + } + } + } +} diff --git a/core/src/main/java/org/apache/spark/shuffle/sort/SortShuffleFileWriter.java b/core/src/main/java/org/apache/spark/shuffle/sort/SortShuffleFileWriter.java new file mode 100644 index 000000000000..656ea0401a14 --- /dev/null +++ b/core/src/main/java/org/apache/spark/shuffle/sort/SortShuffleFileWriter.java @@ -0,0 +1,53 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.shuffle.sort; + +import java.io.File; +import java.io.IOException; + +import scala.Product2; +import scala.collection.Iterator; + +import org.apache.spark.annotation.Private; +import org.apache.spark.TaskContext; +import org.apache.spark.storage.BlockId; + +/** + * Interface for objects that {@link SortShuffleWriter} uses to write its output files. + */ +@Private +public interface SortShuffleFileWriter { + + void insertAll(Iterator> records) throws IOException; + + /** + * Write all the data added into this shuffle sorter into a file in the disk store. This is + * called by the SortShuffleWriter and can go through an efficient path of just concatenating + * binary files if we decided to avoid merge-sorting. + * + * @param blockId block ID to write to. The index file will be blockId.name + ".index". + * @param context a TaskContext for a running Spark task, for us to update shuffle metrics. + * @return array of lengths, in bytes, of each partition of the file (used by map output tracker) + */ + long[] writePartitionedFile( + BlockId blockId, + TaskContext context, + File outputFile) throws IOException; + + void stop() throws IOException; +} diff --git a/core/src/main/java/org/apache/spark/shuffle/unsafe/UnsafeShuffleWriter.java b/core/src/main/java/org/apache/spark/shuffle/unsafe/UnsafeShuffleWriter.java index ad7eb04afcd8..764578b18142 100644 --- a/core/src/main/java/org/apache/spark/shuffle/unsafe/UnsafeShuffleWriter.java +++ b/core/src/main/java/org/apache/spark/shuffle/unsafe/UnsafeShuffleWriter.java @@ -139,6 +139,9 @@ public void write(Iterator> records) throws IOException { @Override public void write(scala.collection.Iterator> records) throws IOException { + // Keep track of success so we know if we ecountered an exception + // We do this rather than a standard try/catch/re-throw to handle + // generic throwables. boolean success = false; try { while (records.hasNext()) { @@ -147,8 +150,19 @@ public void write(scala.collection.Iterator> records) throws IOEx closeAndWriteOutput(); success = true; } finally { - if (!success) { - sorter.cleanupAfterError(); + if (sorter != null) { + try { + sorter.cleanupAfterError(); + } catch (Exception e) { + // Only throw this error if we won't be masking another + // error. + if (success) { + throw e; + } else { + logger.error("In addition to a failure during writing, we failed during " + + "cleanup.", e); + } + } } } } diff --git a/core/src/main/resources/org/apache/spark/log4j-defaults-repl.properties b/core/src/main/resources/org/apache/spark/log4j-defaults-repl.properties new file mode 100644 index 000000000000..b146f8a78412 --- /dev/null +++ b/core/src/main/resources/org/apache/spark/log4j-defaults-repl.properties @@ -0,0 +1,12 @@ +# Set everything to be logged to the console +log4j.rootCategory=WARN, console +log4j.appender.console=org.apache.log4j.ConsoleAppender +log4j.appender.console.target=System.err +log4j.appender.console.layout=org.apache.log4j.PatternLayout +log4j.appender.console.layout.ConversionPattern=%d{yy/MM/dd HH:mm:ss} %p %c{1}: %m%n + +# Settings to quiet third party logs that are too verbose +log4j.logger.org.spark-project.jetty=WARN +log4j.logger.org.spark-project.jetty.util.component.AbstractLifeCycle=ERROR +log4j.logger.org.apache.spark.repl.SparkIMain$exprTyper=INFO +log4j.logger.org.apache.spark.repl.SparkILoop$SparkILoopInterpreter=INFO diff --git a/core/src/main/resources/org/apache/spark/ui/static/additional-metrics.js b/core/src/main/resources/org/apache/spark/ui/static/additional-metrics.js index 013db8df9b36..0b450dc76bc3 100644 --- a/core/src/main/resources/org/apache/spark/ui/static/additional-metrics.js +++ b/core/src/main/resources/org/apache/spark/ui/static/additional-metrics.js @@ -50,4 +50,9 @@ $(function() { $("span.additional-metric-title").click(function() { $(this).parent().find('input[type="checkbox"]').trigger('click'); }); + + // Trigger a double click on the span to show full job description. + $(".description-input").dblclick(function() { + $(this).removeClass("description-input").addClass("description-input-full"); + }); }); diff --git a/core/src/main/resources/org/apache/spark/ui/static/dagre-d3.min.js b/core/src/main/resources/org/apache/spark/ui/static/dagre-d3.min.js index c55f752620df..2d9262b972a5 100644 --- a/core/src/main/resources/org/apache/spark/ui/static/dagre-d3.min.js +++ b/core/src/main/resources/org/apache/spark/ui/static/dagre-d3.min.js @@ -20,7 +20,7 @@ * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN * THE SOFTWARE. */ -module.exports={graphlib:require("./lib/graphlib"),dagre:require("./lib/dagre"),intersect:require("./lib/intersect"),render:require("./lib/render"),util:require("./lib/util"),version:require("./lib/version")}},{"./lib/dagre":8,"./lib/graphlib":9,"./lib/intersect":10,"./lib/render":23,"./lib/util":25,"./lib/version":26}],2:[function(require,module,exports){var util=require("./util");module.exports={"default":normal,normal:normal,vee:vee,undirected:undirected};function normal(parent,id,edge,type){var marker=parent.append("marker").attr("id",id).attr("viewBox","0 0 10 10").attr("refX",9).attr("refY",5).attr("markerUnits","strokeWidth").attr("markerWidth",8).attr("markerHeight",6).attr("orient","auto");var path=marker.append("path").attr("d","M 0 0 L 10 5 L 0 10 z").style("stroke-width",1).style("stroke-dasharray","1,0");util.applyStyle(path,edge[type+"Style"])}function vee(parent,id,edge,type){var marker=parent.append("marker").attr("id",id).attr("viewBox","0 0 10 10").attr("refX",9).attr("refY",5).attr("markerUnits","strokeWidth").attr("markerWidth",8).attr("markerHeight",6).attr("orient","auto");var path=marker.append("path").attr("d","M 0 0 L 10 5 L 0 10 L 4 5 z").style("stroke-width",1).style("stroke-dasharray","1,0");util.applyStyle(path,edge[type+"Style"])}function undirected(parent,id,edge,type){var marker=parent.append("marker").attr("id",id).attr("viewBox","0 0 10 10").attr("refX",9).attr("refY",5).attr("markerUnits","strokeWidth").attr("markerWidth",8).attr("markerHeight",6).attr("orient","auto");var path=marker.append("path").attr("d","M 0 5 L 10 5").style("stroke-width",1).style("stroke-dasharray","1,0");util.applyStyle(path,edge[type+"Style"])}},{"./util":25}],3:[function(require,module,exports){var _=require("./lodash"),addLabel=require("./label/add-label"),util=require("./util");module.exports=createClusters;function createClusters(selection,g){var clusters=g.nodes().filter(function(v){return util.isSubgraph(g,v)}),svgClusters=selection.selectAll("g.cluster").data(clusters,function(v){return v});var makeClusterIdentifier=function(v){return"cluster_"+v.replace(/^cluster/,"")};svgClusters.enter().append("g").attr("class",makeClusterIdentifier).attr("name",function(v){return g.node(v).label}).classed("cluster",true).style("opacity",0).append("rect");var sortedClusters=util.orderByRank(g,svgClusters.data());for(var i=0;i0}},{}],14:[function(require,module,exports){module.exports=intersectNode;function intersectNode(node,point){return node.intersect(point)}},{}],15:[function(require,module,exports){var intersectLine=require("./intersect-line");module.exports=intersectPolygon;function intersectPolygon(node,polyPoints,point){var x1=node.x;var y1=node.y;var intersections=[];var minX=Number.POSITIVE_INFINITY,minY=Number.POSITIVE_INFINITY;polyPoints.forEach(function(entry){minX=Math.min(minX,entry.x);minY=Math.min(minY,entry.y)});var left=x1-node.width/2-minX;var top=y1-node.height/2-minY;for(var i=0;i1){intersections.sort(function(p,q){var pdx=p.x-point.x,pdy=p.y-point.y,distp=Math.sqrt(pdx*pdx+pdy*pdy),qdx=q.x-point.x,qdy=q.y-point.y,distq=Math.sqrt(qdx*qdx+qdy*qdy);return distpMath.abs(dx)*h){if(dy<0){h=-h}sx=dy===0?0:h*dx/dy;sy=h}else{if(dx<0){w=-w}sx=w;sy=dx===0?0:w*dy/dx}return{x:x+sx,y:y+sy}}},{}],17:[function(require,module,exports){var util=require("../util");module.exports=addHtmlLabel;function addHtmlLabel(root,node){var fo=root.append("foreignObject").attr("width","100000");var div=fo.append("xhtml:div");var label=node.label;switch(typeof label){case"function":div.insert(label);break;case"object":div.insert(function(){return label});break;default:div.html(label)}util.applyStyle(div,node.labelStyle);div.style("display","inline-block");div.style("white-space","nowrap");var w,h;div.each(function(){w=this.clientWidth;h=this.clientHeight});fo.attr("width",w).attr("height",h);return fo}},{"../util":25}],18:[function(require,module,exports){var addTextLabel=require("./add-text-label"),addHtmlLabel=require("./add-html-label");module.exports=addLabel;function addLabel(root,node){var label=node.label;var labelSvg=root.append("g");if(typeof label!=="string"||node.labelType==="html"){addHtmlLabel(labelSvg,node)}else{addTextLabel(labelSvg,node)}var labelBBox=labelSvg.node().getBBox();labelSvg.attr("transform","translate("+-labelBBox.width/2+","+-labelBBox.height/2+")");return labelSvg}},{"./add-html-label":17,"./add-text-label":19}],19:[function(require,module,exports){var util=require("../util");module.exports=addTextLabel;function addTextLabel(root,node){var domNode=root.append("text");var lines=processEscapeSequences(node.label).split("\n");for(var i=0;imaxPadding){maxPadding=child.paddingTop}}return maxPadding}function getRank(g,v){var maxRank=0;var children=g.children(v);for(var i=0;imaxRank){maxRank=thisRank}}return maxRank}function orderByRank(g,nodes){return nodes.sort(function(x,y){return getRank(g,x)-getRank(g,y)})}function edgeToId(e){return escapeId(e.v)+":"+escapeId(e.w)+":"+escapeId(e.name)}var ID_DELIM=/:/g;function escapeId(str){return str?String(str).replace(ID_DELIM,"\\:"):""}function applyStyle(dom,styleFn){if(styleFn){dom.attr("style",styleFn)}}function applyClass(dom,classFn,otherClasses){if(classFn){dom.attr("class",classFn).attr("class",otherClasses+" "+dom.attr("class"))}}function applyTransition(selection,g){var graph=g.graph();if(_.isPlainObject(graph)){var transition=graph.transition;if(_.isFunction(transition)){return transition(selection)}}return selection}},{"./lodash":20}],26:[function(require,module,exports){module.exports="0.4.4-pre"},{}],27:[function(require,module,exports){module.exports={graphlib:require("./lib/graphlib"),layout:require("./lib/layout"),debug:require("./lib/debug"),util:{time:require("./lib/util").time,notime:require("./lib/util").notime},version:require("./lib/version")}},{"./lib/debug":32,"./lib/graphlib":33,"./lib/layout":35,"./lib/util":55,"./lib/version":56}],28:[function(require,module,exports){"use strict";var _=require("./lodash"),greedyFAS=require("./greedy-fas");module.exports={run:run,undo:undo};function run(g){var fas=g.graph().acyclicer==="greedy"?greedyFAS(g,weightFn(g)):dfsFAS(g);_.each(fas,function(e){var label=g.edge(e);g.removeEdge(e);label.forwardName=e.name;label.reversed=true;g.setEdge(e.w,e.v,label,_.uniqueId("rev"))});function weightFn(g){return function(e){return g.edge(e).weight}}}function dfsFAS(g){var fas=[],stack={},visited={};function dfs(v){if(_.has(visited,v)){return}visited[v]=true;stack[v]=true;_.each(g.outEdges(v),function(e){if(_.has(stack,e.w)){fas.push(e)}else{dfs(e.w)}});delete stack[v]}_.each(g.nodes(),dfs);return fas}function undo(g){_.each(g.edges(),function(e){var label=g.edge(e);if(label.reversed){g.removeEdge(e);var forwardName=label.forwardName;delete label.reversed;delete label.forwardName;g.setEdge(e.w,e.v,label,forwardName)}})}},{"./greedy-fas":34,"./lodash":36}],29:[function(require,module,exports){var _=require("./lodash"),util=require("./util");module.exports=addBorderSegments;function addBorderSegments(g){function dfs(v){var children=g.children(v),node=g.node(v);if(children.length){_.each(children,dfs)}if(_.has(node,"minRank")){node.borderLeft=[];node.borderRight=[];for(var rank=node.minRank,maxRank=node.maxRank+1;rank0;--i){entry=buckets[i].dequeue();if(entry){results=results.concat(removeNode(g,buckets,zeroIdx,entry,true));break}}}}return results}function removeNode(g,buckets,zeroIdx,entry,collectPredecessors){var results=collectPredecessors?[]:undefined;_.each(g.inEdges(entry.v),function(edge){var weight=g.edge(edge),uEntry=g.node(edge.v);if(collectPredecessors){results.push({v:edge.v,w:edge.w})}uEntry.out-=weight;assignBucket(buckets,zeroIdx,uEntry)});_.each(g.outEdges(entry.v),function(edge){var weight=g.edge(edge),w=edge.w,wEntry=g.node(w);wEntry["in"]-=weight;assignBucket(buckets,zeroIdx,wEntry)});g.removeNode(entry.v);return results}function buildState(g,weightFn){var fasGraph=new Graph,maxIn=0,maxOut=0;_.each(g.nodes(),function(v){fasGraph.setNode(v,{v:v,"in":0,out:0})});_.each(g.edges(),function(e){var prevWeight=fasGraph.edge(e.v,e.w)||0,weight=weightFn(e),edgeWeight=prevWeight+weight;fasGraph.setEdge(e.v,e.w,edgeWeight);maxOut=Math.max(maxOut,fasGraph.node(e.v).out+=weight);maxIn=Math.max(maxIn,fasGraph.node(e.w)["in"]+=weight)});var buckets=_.range(maxOut+maxIn+3).map(function(){return new List});var zeroIdx=maxIn+1;_.each(fasGraph.nodes(),function(v){assignBucket(buckets,zeroIdx,fasGraph.node(v))});return{graph:fasGraph,buckets:buckets,zeroIdx:zeroIdx}}function assignBucket(buckets,zeroIdx,entry){if(!entry.out){buckets[0].enqueue(entry)}else if(!entry["in"]){buckets[buckets.length-1].enqueue(entry)}else{buckets[entry.out-entry["in"]+zeroIdx].enqueue(entry)}}},{"./data/list":31,"./graphlib":33,"./lodash":36}],35:[function(require,module,exports){"use strict";var _=require("./lodash"),acyclic=require("./acyclic"),normalize=require("./normalize"),rank=require("./rank"),normalizeRanks=require("./util").normalizeRanks,parentDummyChains=require("./parent-dummy-chains"),removeEmptyRanks=require("./util").removeEmptyRanks,nestingGraph=require("./nesting-graph"),addBorderSegments=require("./add-border-segments"),coordinateSystem=require("./coordinate-system"),order=require("./order"),position=require("./position"),util=require("./util"),Graph=require("./graphlib").Graph;module.exports=layout;function layout(g,opts){var time=opts&&opts.debugTiming?util.time:util.notime;time("layout",function(){var layoutGraph=time(" buildLayoutGraph",function(){return buildLayoutGraph(g)});time(" runLayout",function(){runLayout(layoutGraph,time)});time(" updateInputGraph",function(){updateInputGraph(g,layoutGraph)})})}function runLayout(g,time){time(" makeSpaceForEdgeLabels",function(){makeSpaceForEdgeLabels(g)});time(" removeSelfEdges",function(){removeSelfEdges(g)});time(" acyclic",function(){acyclic.run(g)});time(" nestingGraph.run",function(){nestingGraph.run(g)});time(" rank",function(){rank(util.asNonCompoundGraph(g))});time(" injectEdgeLabelProxies",function(){injectEdgeLabelProxies(g)});time(" removeEmptyRanks",function(){removeEmptyRanks(g)});time(" nestingGraph.cleanup",function(){nestingGraph.cleanup(g)});time(" normalizeRanks",function(){normalizeRanks(g)});time(" assignRankMinMax",function(){assignRankMinMax(g)});time(" removeEdgeLabelProxies",function(){removeEdgeLabelProxies(g)});time(" normalize.run",function(){normalize.run(g)});time(" parentDummyChains",function(){ +module.exports={graphlib:require("./lib/graphlib"),dagre:require("./lib/dagre"),intersect:require("./lib/intersect"),render:require("./lib/render"),util:require("./lib/util"),version:require("./lib/version")}},{"./lib/dagre":8,"./lib/graphlib":9,"./lib/intersect":10,"./lib/render":23,"./lib/util":25,"./lib/version":26}],2:[function(require,module,exports){var util=require("./util");module.exports={"default":normal,normal:normal,vee:vee,undirected:undirected};function normal(parent,id,edge,type){var marker=parent.append("marker").attr("id",id).attr("viewBox","0 0 10 10").attr("refX",9).attr("refY",5).attr("markerUnits","strokeWidth").attr("markerWidth",8).attr("markerHeight",6).attr("orient","auto");var path=marker.append("path").attr("d","M 0 0 L 10 5 L 0 10 z").style("stroke-width",1).style("stroke-dasharray","1,0");util.applyStyle(path,edge[type+"Style"])}function vee(parent,id,edge,type){var marker=parent.append("marker").attr("id",id).attr("viewBox","0 0 10 10").attr("refX",9).attr("refY",5).attr("markerUnits","strokeWidth").attr("markerWidth",8).attr("markerHeight",6).attr("orient","auto");var path=marker.append("path").attr("d","M 0 0 L 10 5 L 0 10 L 4 5 z").style("stroke-width",1).style("stroke-dasharray","1,0");util.applyStyle(path,edge[type+"Style"])}function undirected(parent,id,edge,type){var marker=parent.append("marker").attr("id",id).attr("viewBox","0 0 10 10").attr("refX",9).attr("refY",5).attr("markerUnits","strokeWidth").attr("markerWidth",8).attr("markerHeight",6).attr("orient","auto");var path=marker.append("path").attr("d","M 0 5 L 10 5").style("stroke-width",1).style("stroke-dasharray","1,0");util.applyStyle(path,edge[type+"Style"])}},{"./util":25}],3:[function(require,module,exports){var _=require("./lodash"),addLabel=require("./label/add-label"),util=require("./util");module.exports=createClusters;function createClusters(selection,g){var clusters=g.nodes().filter(function(v){return util.isSubgraph(g,v)}),svgClusters=selection.selectAll("g.cluster").data(clusters,function(v){return v});var makeClusterIdentifier=function(v){return"cluster_"+v.replace(/^cluster/,"")};svgClusters.enter().append("g").attr("class",makeClusterIdentifier).attr("name",function(v){return g.node(v).label}).classed("cluster",true).style("opacity",0).append("rect");var sortedClusters=util.orderByRank(g,svgClusters.data());for(var i=0;i0}},{}],14:[function(require,module,exports){module.exports=intersectNode;function intersectNode(node,point){return node.intersect(point)}},{}],15:[function(require,module,exports){var intersectLine=require("./intersect-line");module.exports=intersectPolygon;function intersectPolygon(node,polyPoints,point){var x1=node.x;var y1=node.y;var intersections=[];var minX=Number.POSITIVE_INFINITY,minY=Number.POSITIVE_INFINITY;polyPoints.forEach(function(entry){minX=Math.min(minX,entry.x);minY=Math.min(minY,entry.y)});var left=x1-node.width/2-minX;var top=y1-node.height/2-minY;for(var i=0;i1){intersections.sort(function(p,q){var pdx=p.x-point.x,pdy=p.y-point.y,distp=Math.sqrt(pdx*pdx+pdy*pdy),qdx=q.x-point.x,qdy=q.y-point.y,distq=Math.sqrt(qdx*qdx+qdy*qdy);return distpMath.abs(dx)*h){if(dy<0){h=-h}sx=dy===0?0:h*dx/dy;sy=h}else{if(dx<0){w=-w}sx=w;sy=dx===0?0:w*dy/dx}return{x:x+sx,y:y+sy}}},{}],17:[function(require,module,exports){var util=require("../util");module.exports=addHtmlLabel;function addHtmlLabel(root,node){var fo=root.append("foreignObject").attr("width","100000");var div=fo.append("xhtml:div");var label=node.label;switch(typeof label){case"function":div.insert(label);break;case"object":div.insert(function(){return label});break;default:div.html(label)}util.applyStyle(div,node.labelStyle);div.style("display","inline-block");div.style("white-space","nowrap");var w,h;div.each(function(){w=this.clientWidth;h=this.clientHeight});fo.attr("width",w).attr("height",h);return fo}},{"../util":25}],18:[function(require,module,exports){var addTextLabel=require("./add-text-label"),addHtmlLabel=require("./add-html-label");module.exports=addLabel;function addLabel(root,node){var label=node.label;var labelSvg=root.append("g");if(typeof label!=="string"||node.labelType==="html"){addHtmlLabel(labelSvg,node)}else{addTextLabel(labelSvg,node)}var labelBBox=labelSvg.node().getBBox();labelSvg.attr("transform","translate("+-labelBBox.width/2+","+-labelBBox.height/2+")");return labelSvg}},{"./add-html-label":17,"./add-text-label":19}],19:[function(require,module,exports){var util=require("../util");module.exports=addTextLabel;function addTextLabel(root,node){var domNode=root.append("text");var lines=processEscapeSequences(node.label).split("\n");for(var i=0;imaxPadding){maxPadding=child.paddingTop}}return maxPadding}function getRank(g,v){var maxRank=0;var children=g.children(v);for(var i=0;imaxRank){maxRank=thisRank}}return maxRank}function orderByRank(g,nodes){return nodes.sort(function(x,y){return getRank(g,x)-getRank(g,y)})}function edgeToId(e){return escapeId(e.v)+":"+escapeId(e.w)+":"+escapeId(e.name)}var ID_DELIM=/:/g;function escapeId(str){return str?String(str).replace(ID_DELIM,"\\:"):""}function applyStyle(dom,styleFn){if(styleFn){dom.attr("style",styleFn)}}function applyClass(dom,classFn,otherClasses){if(classFn){dom.attr("class",classFn).attr("class",otherClasses+" "+dom.attr("class"))}}function applyTransition(selection,g){var graph=g.graph();if(_.isPlainObject(graph)){var transition=graph.transition;if(_.isFunction(transition)){return transition(selection)}}return selection}},{"./lodash":20}],26:[function(require,module,exports){module.exports="0.4.4-pre"},{}],27:[function(require,module,exports){module.exports={graphlib:require("./lib/graphlib"),layout:require("./lib/layout"),debug:require("./lib/debug"),util:{time:require("./lib/util").time,notime:require("./lib/util").notime},version:require("./lib/version")}},{"./lib/debug":32,"./lib/graphlib":33,"./lib/layout":35,"./lib/util":55,"./lib/version":56}],28:[function(require,module,exports){"use strict";var _=require("./lodash"),greedyFAS=require("./greedy-fas");module.exports={run:run,undo:undo};function run(g){var fas=g.graph().acyclicer==="greedy"?greedyFAS(g,weightFn(g)):dfsFAS(g);_.each(fas,function(e){var label=g.edge(e);g.removeEdge(e);label.forwardName=e.name;label.reversed=true;g.setEdge(e.w,e.v,label,_.uniqueId("rev"))});function weightFn(g){return function(e){return g.edge(e).weight}}}function dfsFAS(g){var fas=[],stack={},visited={};function dfs(v){if(_.has(visited,v)){return}visited[v]=true;stack[v]=true;_.each(g.outEdges(v),function(e){if(_.has(stack,e.w)){fas.push(e)}else{dfs(e.w)}});delete stack[v]}_.each(g.nodes(),dfs);return fas}function undo(g){_.each(g.edges(),function(e){var label=g.edge(e);if(label.reversed){g.removeEdge(e);var forwardName=label.forwardName;delete label.reversed;delete label.forwardName;g.setEdge(e.w,e.v,label,forwardName)}})}},{"./greedy-fas":34,"./lodash":36}],29:[function(require,module,exports){var _=require("./lodash"),util=require("./util");module.exports=addBorderSegments;function addBorderSegments(g){function dfs(v){var children=g.children(v),node=g.node(v);if(children.length){_.each(children,dfs)}if(_.has(node,"minRank")){node.borderLeft=[];node.borderRight=[];for(var rank=node.minRank,maxRank=node.maxRank+1;rank0;--i){entry=buckets[i].dequeue();if(entry){results=results.concat(removeNode(g,buckets,zeroIdx,entry,true));break}}}}return results}function removeNode(g,buckets,zeroIdx,entry,collectPredecessors){var results=collectPredecessors?[]:undefined;_.each(g.inEdges(entry.v),function(edge){var weight=g.edge(edge),uEntry=g.node(edge.v);if(collectPredecessors){results.push({v:edge.v,w:edge.w})}uEntry.out-=weight;assignBucket(buckets,zeroIdx,uEntry)});_.each(g.outEdges(entry.v),function(edge){var weight=g.edge(edge),w=edge.w,wEntry=g.node(w);wEntry["in"]-=weight;assignBucket(buckets,zeroIdx,wEntry)});g.removeNode(entry.v);return results}function buildState(g,weightFn){var fasGraph=new Graph,maxIn=0,maxOut=0;_.each(g.nodes(),function(v){fasGraph.setNode(v,{v:v,"in":0,out:0})});_.each(g.edges(),function(e){var prevWeight=fasGraph.edge(e.v,e.w)||0,weight=weightFn(e),edgeWeight=prevWeight+weight;fasGraph.setEdge(e.v,e.w,edgeWeight);maxOut=Math.max(maxOut,fasGraph.node(e.v).out+=weight);maxIn=Math.max(maxIn,fasGraph.node(e.w)["in"]+=weight)});var buckets=_.range(maxOut+maxIn+3).map(function(){return new List});var zeroIdx=maxIn+1;_.each(fasGraph.nodes(),function(v){assignBucket(buckets,zeroIdx,fasGraph.node(v))});return{graph:fasGraph,buckets:buckets,zeroIdx:zeroIdx}}function assignBucket(buckets,zeroIdx,entry){if(!entry.out){buckets[0].enqueue(entry)}else if(!entry["in"]){buckets[buckets.length-1].enqueue(entry)}else{buckets[entry.out-entry["in"]+zeroIdx].enqueue(entry)}}},{"./data/list":31,"./graphlib":33,"./lodash":36}],35:[function(require,module,exports){"use strict";var _=require("./lodash"),acyclic=require("./acyclic"),normalize=require("./normalize"),rank=require("./rank"),normalizeRanks=require("./util").normalizeRanks,parentDummyChains=require("./parent-dummy-chains"),removeEmptyRanks=require("./util").removeEmptyRanks,nestingGraph=require("./nesting-graph"),addBorderSegments=require("./add-border-segments"),coordinateSystem=require("./coordinate-system"),order=require("./order"),position=require("./position"),util=require("./util"),Graph=require("./graphlib").Graph;module.exports=layout;function layout(g,opts){var time=opts&&opts.debugTiming?util.time:util.notime;time("layout",function(){var layoutGraph=time(" buildLayoutGraph",function(){return buildLayoutGraph(g)});time(" runLayout",function(){runLayout(layoutGraph,time)});time(" updateInputGraph",function(){updateInputGraph(g,layoutGraph)})})}function runLayout(g,time){time(" makeSpaceForEdgeLabels",function(){makeSpaceForEdgeLabels(g)});time(" removeSelfEdges",function(){removeSelfEdges(g)});time(" acyclic",function(){acyclic.run(g)});time(" nestingGraph.run",function(){nestingGraph.run(g)});time(" rank",function(){rank(util.asNonCompoundGraph(g))});time(" injectEdgeLabelProxies",function(){injectEdgeLabelProxies(g)});time(" removeEmptyRanks",function(){removeEmptyRanks(g)});time(" nestingGraph.cleanup",function(){nestingGraph.cleanup(g)});time(" normalizeRanks",function(){normalizeRanks(g)});time(" assignRankMinMax",function(){assignRankMinMax(g)});time(" removeEdgeLabelProxies",function(){removeEdgeLabelProxies(g)});time(" normalize.run",function(){normalize.run(g)});time(" parentDummyChains",function(){ parentDummyChains(g)});time(" addBorderSegments",function(){addBorderSegments(g)});time(" order",function(){order(g)});time(" insertSelfEdges",function(){insertSelfEdges(g)});time(" adjustCoordinateSystem",function(){coordinateSystem.adjust(g)});time(" position",function(){position(g)});time(" positionSelfEdges",function(){positionSelfEdges(g)});time(" removeBorderNodes",function(){removeBorderNodes(g)});time(" normalize.undo",function(){normalize.undo(g)});time(" fixupEdgeLabelCoords",function(){fixupEdgeLabelCoords(g)});time(" undoCoordinateSystem",function(){coordinateSystem.undo(g)});time(" translateGraph",function(){translateGraph(g)});time(" assignNodeIntersects",function(){assignNodeIntersects(g)});time(" reversePoints",function(){reversePointsForReversedEdges(g)});time(" acyclic.undo",function(){acyclic.undo(g)})}function updateInputGraph(inputGraph,layoutGraph){_.each(inputGraph.nodes(),function(v){var inputLabel=inputGraph.node(v),layoutLabel=layoutGraph.node(v);if(inputLabel){inputLabel.x=layoutLabel.x;inputLabel.y=layoutLabel.y;if(layoutGraph.children(v).length){inputLabel.width=layoutLabel.width;inputLabel.height=layoutLabel.height}}});_.each(inputGraph.edges(),function(e){var inputLabel=inputGraph.edge(e),layoutLabel=layoutGraph.edge(e);inputLabel.points=layoutLabel.points;if(_.has(layoutLabel,"x")){inputLabel.x=layoutLabel.x;inputLabel.y=layoutLabel.y}});inputGraph.graph().width=layoutGraph.graph().width;inputGraph.graph().height=layoutGraph.graph().height}var graphNumAttrs=["nodesep","edgesep","ranksep","marginx","marginy"],graphDefaults={ranksep:50,edgesep:20,nodesep:50,rankdir:"tb"},graphAttrs=["acyclicer","ranker","rankdir","align"],nodeNumAttrs=["width","height"],nodeDefaults={width:0,height:0},edgeNumAttrs=["minlen","weight","width","height","labeloffset"],edgeDefaults={minlen:1,weight:1,width:0,height:0,labeloffset:10,labelpos:"r"},edgeAttrs=["labelpos"];function buildLayoutGraph(inputGraph){var g=new Graph({multigraph:true,compound:true}),graph=canonicalize(inputGraph.graph());g.setGraph(_.merge({},graphDefaults,selectNumberAttrs(graph,graphNumAttrs),_.pick(graph,graphAttrs)));_.each(inputGraph.nodes(),function(v){var node=canonicalize(inputGraph.node(v));g.setNode(v,_.defaults(selectNumberAttrs(node,nodeNumAttrs),nodeDefaults));g.setParent(v,inputGraph.parent(v))});_.each(inputGraph.edges(),function(e){var edge=canonicalize(inputGraph.edge(e));g.setEdge(e,_.merge({},edgeDefaults,selectNumberAttrs(edge,edgeNumAttrs),_.pick(edge,edgeAttrs)))});return g}function makeSpaceForEdgeLabels(g){var graph=g.graph();graph.ranksep/=2;_.each(g.edges(),function(e){var edge=g.edge(e);edge.minlen*=2;if(edge.labelpos.toLowerCase()!=="c"){if(graph.rankdir==="TB"||graph.rankdir==="BT"){edge.width+=edge.labeloffset}else{edge.height+=edge.labeloffset}}})}function injectEdgeLabelProxies(g){_.each(g.edges(),function(e){var edge=g.edge(e);if(edge.width&&edge.height){var v=g.node(e.v),w=g.node(e.w),label={rank:(w.rank-v.rank)/2+v.rank,e:e};util.addDummyNode(g,"edge-proxy",label,"_ep")}})}function assignRankMinMax(g){var maxRank=0;_.each(g.nodes(),function(v){var node=g.node(v);if(node.borderTop){node.minRank=g.node(node.borderTop).rank;node.maxRank=g.node(node.borderBottom).rank;maxRank=_.max(maxRank,node.maxRank)}});g.graph().maxRank=maxRank}function removeEdgeLabelProxies(g){_.each(g.nodes(),function(v){var node=g.node(v);if(node.dummy==="edge-proxy"){g.edge(node.e).labelRank=node.rank;g.removeNode(v)}})}function translateGraph(g){var minX=Number.POSITIVE_INFINITY,maxX=0,minY=Number.POSITIVE_INFINITY,maxY=0,graphLabel=g.graph(),marginX=graphLabel.marginx||0,marginY=graphLabel.marginy||0;function getExtremes(attrs){var x=attrs.x,y=attrs.y,w=attrs.width,h=attrs.height;minX=Math.min(minX,x-w/2);maxX=Math.max(maxX,x+w/2);minY=Math.min(minY,y-h/2);maxY=Math.max(maxY,y+h/2)}_.each(g.nodes(),function(v){getExtremes(g.node(v))});_.each(g.edges(),function(e){var edge=g.edge(e);if(_.has(edge,"x")){getExtremes(edge)}});minX-=marginX;minY-=marginY;_.each(g.nodes(),function(v){var node=g.node(v);node.x-=minX;node.y-=minY});_.each(g.edges(),function(e){var edge=g.edge(e);_.each(edge.points,function(p){p.x-=minX;p.y-=minY});if(_.has(edge,"x")){edge.x-=minX}if(_.has(edge,"y")){edge.y-=minY}});graphLabel.width=maxX-minX+marginX;graphLabel.height=maxY-minY+marginY}function assignNodeIntersects(g){_.each(g.edges(),function(e){var edge=g.edge(e),nodeV=g.node(e.v),nodeW=g.node(e.w),p1,p2;if(!edge.points){edge.points=[];p1=nodeW;p2=nodeV}else{p1=edge.points[0];p2=edge.points[edge.points.length-1]}edge.points.unshift(util.intersectRect(nodeV,p1));edge.points.push(util.intersectRect(nodeW,p2))})}function fixupEdgeLabelCoords(g){_.each(g.edges(),function(e){var edge=g.edge(e);if(_.has(edge,"x")){if(edge.labelpos==="l"||edge.labelpos==="r"){edge.width-=edge.labeloffset}switch(edge.labelpos){case"l":edge.x-=edge.width/2+edge.labeloffset;break;case"r":edge.x+=edge.width/2+edge.labeloffset;break}}})}function reversePointsForReversedEdges(g){_.each(g.edges(),function(e){var edge=g.edge(e);if(edge.reversed){edge.points.reverse()}})}function removeBorderNodes(g){_.each(g.nodes(),function(v){if(g.children(v).length){var node=g.node(v),t=g.node(node.borderTop),b=g.node(node.borderBottom),l=g.node(_.last(node.borderLeft)),r=g.node(_.last(node.borderRight));node.width=Math.abs(r.x-l.x);node.height=Math.abs(b.y-t.y);node.x=l.x+node.width/2;node.y=t.y+node.height/2}});_.each(g.nodes(),function(v){if(g.node(v).dummy==="border"){g.removeNode(v)}})}function removeSelfEdges(g){_.each(g.edges(),function(e){if(e.v===e.w){var node=g.node(e.v);if(!node.selfEdges){node.selfEdges=[]}node.selfEdges.push({e:e,label:g.edge(e)});g.removeEdge(e)}})}function insertSelfEdges(g){var layers=util.buildLayerMatrix(g);_.each(layers,function(layer){var orderShift=0;_.each(layer,function(v,i){var node=g.node(v);node.order=i+orderShift;_.each(node.selfEdges,function(selfEdge){util.addDummyNode(g,"selfedge",{width:selfEdge.label.width,height:selfEdge.label.height,rank:node.rank,order:i+ ++orderShift,e:selfEdge.e,label:selfEdge.label},"_se")});delete node.selfEdges})})}function positionSelfEdges(g){_.each(g.nodes(),function(v){var node=g.node(v);if(node.dummy==="selfedge"){var selfNode=g.node(node.e.v),x=selfNode.x+selfNode.width/2,y=selfNode.y,dx=node.x-x,dy=selfNode.height/2;g.setEdge(node.e,node.label);g.removeNode(v);node.label.points=[{x:x+2*dx/3,y:y-dy},{x:x+5*dx/6,y:y-dy},{x:x+dx,y:y},{x:x+5*dx/6,y:y+dy},{x:x+2*dx/3,y:y+dy}];node.label.x=node.x;node.label.y=node.y}})}function selectNumberAttrs(obj,attrs){return _.mapValues(_.pick(obj,attrs),Number)}function canonicalize(attrs){var newAttrs={};_.each(attrs,function(v,k){newAttrs[k.toLowerCase()]=v});return newAttrs}},{"./acyclic":28,"./add-border-segments":29,"./coordinate-system":30,"./graphlib":33,"./lodash":36,"./nesting-graph":37,"./normalize":38,"./order":43,"./parent-dummy-chains":48,"./position":50,"./rank":52,"./util":55}],36:[function(require,module,exports){module.exports=require(20)},{"/Users/andrew/Documents/dev/dagre-d3/lib/lodash.js":20,lodash:77}],37:[function(require,module,exports){var _=require("./lodash"),util=require("./util");module.exports={run:run,cleanup:cleanup};function run(g){var root=util.addDummyNode(g,"root",{},"_root"),depths=treeDepths(g),height=_.max(depths)-1,nodeSep=2*height+1;g.graph().nestingRoot=root;_.each(g.edges(),function(e){g.edge(e).minlen*=nodeSep});var weight=sumWeights(g)+1;_.each(g.children(),function(child){dfs(g,root,nodeSep,weight,height,depths,child)});g.graph().nodeRankFactor=nodeSep}function dfs(g,root,nodeSep,weight,height,depths,v){var children=g.children(v);if(!children.length){if(v!==root){g.setEdge(root,v,{weight:0,minlen:nodeSep})}return}var top=util.addBorderNode(g,"_bt"),bottom=util.addBorderNode(g,"_bb"),label=g.node(v);g.setParent(top,v);label.borderTop=top;g.setParent(bottom,v);label.borderBottom=bottom;_.each(children,function(child){dfs(g,root,nodeSep,weight,height,depths,child);var childNode=g.node(child),childTop=childNode.borderTop?childNode.borderTop:child,childBottom=childNode.borderBottom?childNode.borderBottom:child,thisWeight=childNode.borderTop?weight:2*weight,minlen=childTop!==childBottom?1:height-depths[v]+1;g.setEdge(top,childTop,{weight:thisWeight,minlen:minlen,nestingEdge:true});g.setEdge(childBottom,bottom,{weight:thisWeight,minlen:minlen,nestingEdge:true})});if(!g.parent(v)){g.setEdge(root,top,{weight:0,minlen:height+depths[v]})}}function treeDepths(g){var depths={};function dfs(v,depth){var children=g.children(v);if(children&&children.length){_.each(children,function(child){dfs(child,depth+1)})}depths[v]=depth}_.each(g.children(),function(v){dfs(v,1)});return depths}function sumWeights(g){return _.reduce(g.edges(),function(acc,e){return acc+g.edge(e).weight},0)}function cleanup(g){var graphLabel=g.graph();g.removeNode(graphLabel.nestingRoot);delete graphLabel.nestingRoot;_.each(g.edges(),function(e){var edge=g.edge(e);if(edge.nestingEdge){g.removeEdge(e)}})}},{"./lodash":36,"./util":55}],38:[function(require,module,exports){"use strict";var _=require("./lodash"),util=require("./util");module.exports={run:run,undo:undo};function run(g){g.graph().dummyChains=[];_.each(g.edges(),function(edge){normalizeEdge(g,edge)})}function normalizeEdge(g,e){var v=e.v,vRank=g.node(v).rank,w=e.w,wRank=g.node(w).rank,name=e.name,edgeLabel=g.edge(e),labelRank=edgeLabel.labelRank;if(wRank===vRank+1)return;g.removeEdge(e);var dummy,attrs,i;for(i=0,++vRank;vRank0){if(index%2){weightSum+=tree[index+1]}index=index-1>>1;tree[index]+=entry.weight}cc+=entry.weight*weightSum}));return cc}},{"../lodash":36}],43:[function(require,module,exports){"use strict";var _=require("../lodash"),initOrder=require("./init-order"),crossCount=require("./cross-count"),sortSubgraph=require("./sort-subgraph"),buildLayerGraph=require("./build-layer-graph"),addSubgraphConstraints=require("./add-subgraph-constraints"),Graph=require("../graphlib").Graph,util=require("../util");module.exports=order;function order(g){var maxRank=util.maxRank(g),downLayerGraphs=buildLayerGraphs(g,_.range(1,maxRank+1),"inEdges"),upLayerGraphs=buildLayerGraphs(g,_.range(maxRank-1,-1,-1),"outEdges");var layering=initOrder(g);assignOrder(g,layering);var bestCC=Number.POSITIVE_INFINITY,best;for(var i=0,lastBest=0;lastBest<4;++i,++lastBest){sweepLayerGraphs(i%2?downLayerGraphs:upLayerGraphs,i%4>=2);layering=util.buildLayerMatrix(g);var cc=crossCount(g,layering);if(cc=vEntry.barycenter){mergeEntries(vEntry,uEntry)}}}function handleOut(vEntry){return function(wEntry){wEntry["in"].push(vEntry);if(--wEntry.indegree===0){sourceSet.push(wEntry)}}}while(sourceSet.length){var entry=sourceSet.pop();entries.push(entry);_.each(entry["in"].reverse(),handleIn(entry));_.each(entry.out,handleOut(entry))}return _.chain(entries).filter(function(entry){return!entry.merged}).map(function(entry){return _.pick(entry,["vs","i","barycenter","weight"])}).value()}function mergeEntries(target,source){var sum=0,weight=0;if(target.weight){sum+=target.barycenter*target.weight;weight+=target.weight}if(source.weight){sum+=source.barycenter*source.weight;weight+=source.weight}target.vs=source.vs.concat(target.vs);target.barycenter=sum/weight;target.weight=weight;target.i=Math.min(source.i,target.i);source.merged=true}},{"../lodash":36}],46:[function(require,module,exports){var _=require("../lodash"),barycenter=require("./barycenter"),resolveConflicts=require("./resolve-conflicts"),sort=require("./sort");module.exports=sortSubgraph;function sortSubgraph(g,v,cg,biasRight){var movable=g.children(v),node=g.node(v),bl=node?node.borderLeft:undefined,br=node?node.borderRight:undefined,subgraphs={};if(bl){movable=_.filter(movable,function(w){return w!==bl&&w!==br})}var barycenters=barycenter(g,movable);_.each(barycenters,function(entry){if(g.children(entry.v).length){var subgraphResult=sortSubgraph(g,entry.v,cg,biasRight);subgraphs[entry.v]=subgraphResult;if(_.has(subgraphResult,"barycenter")){mergeBarycenters(entry,subgraphResult)}}});var entries=resolveConflicts(barycenters,cg);expandSubgraphs(entries,subgraphs);var result=sort(entries,biasRight);if(bl){result.vs=_.flatten([bl,result.vs,br],true);if(g.predecessors(bl).length){var blPred=g.node(g.predecessors(bl)[0]),brPred=g.node(g.predecessors(br)[0]);if(!_.has(result,"barycenter")){result.barycenter=0;result.weight=0}result.barycenter=(result.barycenter*result.weight+blPred.order+brPred.order)/(result.weight+2);result.weight+=2}}return result}function expandSubgraphs(entries,subgraphs){_.each(entries,function(entry){entry.vs=_.flatten(entry.vs.map(function(v){if(subgraphs[v]){return subgraphs[v].vs}return v}),true)})}function mergeBarycenters(target,other){if(!_.isUndefined(target.barycenter)){target.barycenter=(target.barycenter*target.weight+other.barycenter*other.weight)/(target.weight+other.weight);target.weight+=other.weight}else{target.barycenter=other.barycenter;target.weight=other.weight}}},{"../lodash":36,"./barycenter":40,"./resolve-conflicts":45,"./sort":47}],47:[function(require,module,exports){var _=require("../lodash"),util=require("../util");module.exports=sort;function sort(entries,biasRight){var parts=util.partition(entries,function(entry){return _.has(entry,"barycenter")});var sortable=parts.lhs,unsortable=_.sortBy(parts.rhs,function(entry){return-entry.i}),vs=[],sum=0,weight=0,vsIndex=0;sortable.sort(compareWithBias(!!biasRight));vsIndex=consumeUnsortable(vs,unsortable,vsIndex);_.each(sortable,function(entry){vsIndex+=entry.vs.length;vs.push(entry.vs);sum+=entry.barycenter*entry.weight;weight+=entry.weight;vsIndex=consumeUnsortable(vs,unsortable,vsIndex)});var result={vs:_.flatten(vs,true)};if(weight){result.barycenter=sum/weight;result.weight=weight}return result}function consumeUnsortable(vs,unsortable,index){var last;while(unsortable.length&&(last=_.last(unsortable)).i<=index){unsortable.pop();vs.push(last.vs);index++}return index}function compareWithBias(bias){return function(entryV,entryW){if(entryV.barycenterentryW.barycenter){return 1}return!bias?entryV.i-entryW.i:entryW.i-entryV.i}}},{"../lodash":36,"../util":55}],48:[function(require,module,exports){var _=require("./lodash");module.exports=parentDummyChains;function parentDummyChains(g){var postorderNums=postorder(g);_.each(g.graph().dummyChains,function(v){var node=g.node(v),edgeObj=node.edgeObj,pathData=findPath(g,postorderNums,edgeObj.v,edgeObj.w),path=pathData.path,lca=pathData.lca,pathIdx=0,pathV=path[pathIdx],ascending=true;while(v!==edgeObj.w){node=g.node(v);if(ascending){while((pathV=path[pathIdx])!==lca&&g.node(pathV).maxRanklow||lim>postorderNums[parent].lim));lca=parent;parent=w;while((parent=g.parent(parent))!==lca){wPath.push(parent)}return{path:vPath.concat(wPath.reverse()),lca:lca}}function postorder(g){var result={},lim=0;function dfs(v){var low=lim;_.each(g.children(v),dfs);result[v]={low:low,lim:lim++}}_.each(g.children(),dfs);return result}},{"./lodash":36}],49:[function(require,module,exports){"use strict";var _=require("../lodash"),Graph=require("../graphlib").Graph,util=require("../util");module.exports={positionX:positionX,findType1Conflicts:findType1Conflicts,findType2Conflicts:findType2Conflicts,addConflict:addConflict,hasConflict:hasConflict,verticalAlignment:verticalAlignment,horizontalCompaction:horizontalCompaction,alignCoordinates:alignCoordinates,findSmallestWidthAlignment:findSmallestWidthAlignment,balance:balance};function findType1Conflicts(g,layering){var conflicts={};function visitLayer(prevLayer,layer){var k0=0,scanPos=0,prevLayerLength=prevLayer.length,lastNode=_.last(layer);_.each(layer,function(v,i){var w=findOtherInnerSegmentNode(g,v),k1=w?g.node(w).order:prevLayerLength;if(w||v===lastNode){_.each(layer.slice(scanPos,i+1),function(scanNode){_.each(g.predecessors(scanNode),function(u){var uLabel=g.node(u),uPos=uLabel.order;if((uPosnextNorthBorder)){addConflict(conflicts,u,v)}})}})}function visitLayer(north,south){var prevNorthPos=-1,nextNorthPos,southPos=0;_.each(south,function(v,southLookahead){if(g.node(v).dummy==="border"){var predecessors=g.predecessors(v);if(predecessors.length){nextNorthPos=g.node(predecessors[0]).order;scan(south,southPos,southLookahead,prevNorthPos,nextNorthPos);southPos=southLookahead;prevNorthPos=nextNorthPos}}scan(south,southPos,south.length,nextNorthPos,north.length)});return south}_.reduce(layering,visitLayer);return conflicts}function findOtherInnerSegmentNode(g,v){if(g.node(v).dummy){return _.find(g.predecessors(v),function(u){return g.node(u).dummy})}}function addConflict(conflicts,v,w){if(v>w){var tmp=v;v=w;w=tmp}var conflictsV=conflicts[v];if(!conflictsV){conflicts[v]=conflictsV={}}conflictsV[w]=true}function hasConflict(conflicts,v,w){if(v>w){var tmp=v;v=w;w=tmp}return _.has(conflicts[v],w)}function verticalAlignment(g,layering,conflicts,neighborFn){var root={},align={},pos={};_.each(layering,function(layer){_.each(layer,function(v,order){root[v]=v;align[v]=v;pos[v]=order})});_.each(layering,function(layer){var prevIdx=-1;_.each(layer,function(v){var ws=neighborFn(v);if(ws.length){ws=_.sortBy(ws,function(w){return pos[w]});var mp=(ws.length-1)/2;for(var i=Math.floor(mp),il=Math.ceil(mp);i<=il;++i){var w=ws[i];if(align[v]===v&&prevIdxwLabel.lim){tailLabel=wLabel;flip=true}var candidates=_.filter(g.edges(),function(edge){return flip===isDescendant(t,t.node(edge.v),tailLabel)&&flip!==isDescendant(t,t.node(edge.w),tailLabel)});return _.min(candidates,function(edge){return slack(g,edge)})}function exchangeEdges(t,g,e,f){var v=e.v,w=e.w;t.removeEdge(v,w);t.setEdge(f.v,f.w,{});initLowLimValues(t);initCutValues(t,g);updateRanks(t,g)}function updateRanks(t,g){var root=_.find(t.nodes(),function(v){return!g.node(v).parent}),vs=preorder(t,root);vs=vs.slice(1);_.each(vs,function(v){var parent=t.node(v).parent,edge=g.edge(v,parent),flipped=false;if(!edge){edge=g.edge(parent,v);flipped=true}g.node(v).rank=g.node(parent).rank+(flipped?edge.minlen:-edge.minlen)})}function isTreeEdge(tree,u,v){return tree.hasEdge(u,v)}function isDescendant(tree,vLabel,rootLabel){return rootLabel.low<=vLabel.lim&&vLabel.lim<=rootLabel.lim}},{"../graphlib":33,"../lodash":36,"../util":55,"./feasible-tree":51,"./util":54}],54:[function(require,module,exports){"use strict";var _=require("../lodash");module.exports={longestPath:longestPath,slack:slack};function longestPath(g){var visited={};function dfs(v){var label=g.node(v);if(_.has(visited,v)){return label.rank}visited[v]=true;var rank=_.min(_.map(g.outEdges(v),function(e){return dfs(e.w)-g.edge(e).minlen}));if(rank===Number.POSITIVE_INFINITY){rank=0}return label.rank=rank}_.each(g.sources(),dfs)}function slack(g,e){return g.node(e.w).rank-g.node(e.v).rank-g.edge(e).minlen}},{"../lodash":36}],55:[function(require,module,exports){"use strict";var _=require("./lodash"),Graph=require("./graphlib").Graph;module.exports={addDummyNode:addDummyNode,simplify:simplify,asNonCompoundGraph:asNonCompoundGraph,successorWeights:successorWeights,predecessorWeights:predecessorWeights,intersectRect:intersectRect,buildLayerMatrix:buildLayerMatrix,normalizeRanks:normalizeRanks,removeEmptyRanks:removeEmptyRanks,addBorderNode:addBorderNode,maxRank:maxRank,partition:partition,time:time,notime:notime};function addDummyNode(g,type,attrs,name){var v;do{v=_.uniqueId(name)}while(g.hasNode(v));attrs.dummy=type;g.setNode(v,attrs);return v}function simplify(g){var simplified=(new Graph).setGraph(g.graph());_.each(g.nodes(),function(v){simplified.setNode(v,g.node(v))});_.each(g.edges(),function(e){var simpleLabel=simplified.edge(e.v,e.w)||{weight:0,minlen:1},label=g.edge(e);simplified.setEdge(e.v,e.w,{weight:simpleLabel.weight+label.weight,minlen:Math.max(simpleLabel.minlen,label.minlen)})});return simplified}function asNonCompoundGraph(g){var simplified=new Graph({multigraph:g.isMultigraph()}).setGraph(g.graph());_.each(g.nodes(),function(v){if(!g.children(v).length){simplified.setNode(v,g.node(v))}});_.each(g.edges(),function(e){simplified.setEdge(e,g.edge(e))});return simplified}function successorWeights(g){var weightMap=_.map(g.nodes(),function(v){var sucs={};_.each(g.outEdges(v),function(e){sucs[e.w]=(sucs[e.w]||0)+g.edge(e).weight});return sucs});return _.zipObject(g.nodes(),weightMap)}function predecessorWeights(g){var weightMap=_.map(g.nodes(),function(v){var preds={};_.each(g.inEdges(v),function(e){preds[e.v]=(preds[e.v]||0)+g.edge(e).weight});return preds});return _.zipObject(g.nodes(),weightMap)}function intersectRect(rect,point){var x=rect.x;var y=rect.y;var dx=point.x-x;var dy=point.y-y;var w=rect.width/2;var h=rect.height/2;if(!dx&&!dy){throw new Error("Not possible to find intersection inside of the rectangle")}var sx,sy;if(Math.abs(dy)*w>Math.abs(dx)*h){if(dy<0){h=-h}sx=h*dx/dy;sy=h}else{if(dx<0){w=-w}sx=w;sy=w*dy/dx}return{x:x+sx,y:y+sy}}function buildLayerMatrix(g){var layering=_.map(_.range(maxRank(g)+1),function(){return[]});_.each(g.nodes(),function(v){var node=g.node(v),rank=node.rank;if(!_.isUndefined(rank)){layering[rank][node.order]=v}});return layering}function normalizeRanks(g){var min=_.min(_.map(g.nodes(),function(v){return g.node(v).rank}));_.each(g.nodes(),function(v){var node=g.node(v);if(_.has(node,"rank")){node.rank-=min}})}function removeEmptyRanks(g){var offset=_.min(_.map(g.nodes(),function(v){return g.node(v).rank}));var layers=[];_.each(g.nodes(),function(v){var rank=g.node(v).rank-offset;if(!_.has(layers,rank)){layers[rank]=[]}layers[rank].push(v)});var delta=0,nodeRankFactor=g.graph().nodeRankFactor;_.each(layers,function(vs,i){if(_.isUndefined(vs)&&i%nodeRankFactor!==0){--delta}else if(delta){_.each(vs,function(v){g.node(v).rank+=delta})}})}function addBorderNode(g,prefix,rank,order){var node={width:0,height:0};if(arguments.length>=4){node.rank=rank;node.order=order}return addDummyNode(g,"border",node,prefix)}function maxRank(g){return _.max(_.map(g.nodes(),function(v){var rank=g.node(v).rank;if(!_.isUndefined(rank)){return rank}}))}function partition(collection,fn){var result={lhs:[],rhs:[]};_.each(collection,function(value){if(fn(value)){result.lhs.push(value)}else{result.rhs.push(value)}});return result}function time(name,fn){var start=_.now();try{return fn()}finally{console.log(name+" time: "+(_.now()-start)+"ms")}}function notime(name,fn){return fn()}},{"./graphlib":33,"./lodash":36}],56:[function(require,module,exports){module.exports="0.7.1"},{}],57:[function(require,module,exports){var lib=require("./lib");module.exports={Graph:lib.Graph,json:require("./lib/json"),alg:require("./lib/alg"),version:lib.version}},{"./lib":73,"./lib/alg":64,"./lib/json":74}],58:[function(require,module,exports){var _=require("../lodash");module.exports=components;function components(g){var visited={},cmpts=[],cmpt;function dfs(v){if(_.has(visited,v))return;visited[v]=true;cmpt.push(v);_.each(g.successors(v),dfs);_.each(g.predecessors(v),dfs)}_.each(g.nodes(),function(v){cmpt=[];dfs(v);if(cmpt.length){cmpts.push(cmpt)}});return cmpts}},{"../lodash":75}],59:[function(require,module,exports){var _=require("../lodash");module.exports=dfs;function dfs(g,vs,order){if(!_.isArray(vs)){vs=[vs]}var acc=[],visited={};_.each(vs,function(v){if(!g.hasNode(v)){throw new Error("Graph does not have node: "+v)}doDfs(g,v,order==="post",visited,acc)});return acc}function doDfs(g,v,postorder,visited,acc){if(!_.has(visited,v)){visited[v]=true;if(!postorder){acc.push(v)}_.each(g.neighbors(v),function(w){doDfs(g,w,postorder,visited,acc)});if(postorder){acc.push(v)}}}},{"../lodash":75}],60:[function(require,module,exports){var dijkstra=require("./dijkstra"),_=require("../lodash");module.exports=dijkstraAll;function dijkstraAll(g,weightFunc,edgeFunc){return _.transform(g.nodes(),function(acc,v){acc[v]=dijkstra(g,v,weightFunc,edgeFunc)},{})}},{"../lodash":75,"./dijkstra":61}],61:[function(require,module,exports){var _=require("../lodash"),PriorityQueue=require("../data/priority-queue");module.exports=dijkstra;var DEFAULT_WEIGHT_FUNC=_.constant(1);function dijkstra(g,source,weightFn,edgeFn){return runDijkstra(g,String(source),weightFn||DEFAULT_WEIGHT_FUNC,edgeFn||function(v){return g.outEdges(v)})}function runDijkstra(g,source,weightFn,edgeFn){var results={},pq=new PriorityQueue,v,vEntry;var updateNeighbors=function(edge){var w=edge.v!==v?edge.v:edge.w,wEntry=results[w],weight=weightFn(edge),distance=vEntry.distance+weight;if(weight<0){throw new Error("dijkstra does not allow negative edge weights. "+"Bad edge: "+edge+" Weight: "+weight)}if(distance0){v=pq.removeMin();vEntry=results[v];if(vEntry.distance===Number.POSITIVE_INFINITY){break}edgeFn(v).forEach(updateNeighbors)}return results}},{"../data/priority-queue":71,"../lodash":75}],62:[function(require,module,exports){var _=require("../lodash"),tarjan=require("./tarjan");module.exports=findCycles;function findCycles(g){return _.filter(tarjan(g),function(cmpt){return cmpt.length>1})}},{"../lodash":75,"./tarjan":69}],63:[function(require,module,exports){var _=require("../lodash");module.exports=floydWarshall;var DEFAULT_WEIGHT_FUNC=_.constant(1);function floydWarshall(g,weightFn,edgeFn){return runFloydWarshall(g,weightFn||DEFAULT_WEIGHT_FUNC,edgeFn||function(v){return g.outEdges(v)})}function runFloydWarshall(g,weightFn,edgeFn){var results={},nodes=g.nodes();nodes.forEach(function(v){results[v]={};results[v][v]={distance:0};nodes.forEach(function(w){if(v!==w){results[v][w]={distance:Number.POSITIVE_INFINITY}}});edgeFn(v).forEach(function(edge){var w=edge.v===v?edge.w:edge.v,d=weightFn(edge);results[v][w]={distance:d,predecessor:v}})});nodes.forEach(function(k){var rowK=results[k];nodes.forEach(function(i){var rowI=results[i];nodes.forEach(function(j){var ik=rowI[k];var kj=rowK[j];var ij=rowI[j];var altDistance=ik.distance+kj.distance;if(altDistance0){v=pq.removeMin();if(_.has(parents,v)){result.setEdge(v,parents[v])}else if(init){throw new Error("Input graph is not connected: "+g)}else{init=true}g.nodeEdges(v).forEach(updateNeighbors)}return result}},{"../data/priority-queue":71,"../graph":72,"../lodash":75}],69:[function(require,module,exports){var _=require("../lodash");module.exports=tarjan;function tarjan(g){var index=0,stack=[],visited={},results=[];function dfs(v){var entry=visited[v]={onStack:true,lowlink:index,index:index++};stack.push(v);g.successors(v).forEach(function(w){if(!_.has(visited,w)){dfs(w);entry.lowlink=Math.min(entry.lowlink,visited[w].lowlink)}else if(visited[w].onStack){entry.lowlink=Math.min(entry.lowlink,visited[w].index)}});if(entry.lowlink===entry.index){var cmpt=[],w;do{w=stack.pop();visited[w].onStack=false;cmpt.push(w)}while(v!==w);results.push(cmpt)}}g.nodes().forEach(function(v){if(!_.has(visited,v)){dfs(v)}});return results}},{"../lodash":75}],70:[function(require,module,exports){var _=require("../lodash");module.exports=topsort;topsort.CycleException=CycleException;function topsort(g){var visited={},stack={},results=[];function visit(node){if(_.has(stack,node)){throw new CycleException}if(!_.has(visited,node)){stack[node]=true;visited[node]=true;_.each(g.predecessors(node),visit);delete stack[node];results.push(node)}}_.each(g.sinks(),visit);if(_.size(visited)!==g.nodeCount()){throw new CycleException}return results}function CycleException(){}},{"../lodash":75}],71:[function(require,module,exports){var _=require("../lodash");module.exports=PriorityQueue;function PriorityQueue(){this._arr=[];this._keyIndices={}}PriorityQueue.prototype.size=function(){return this._arr.length};PriorityQueue.prototype.keys=function(){return this._arr.map(function(x){return x.key})};PriorityQueue.prototype.has=function(key){return _.has(this._keyIndices,key)};PriorityQueue.prototype.priority=function(key){var index=this._keyIndices[key];if(index!==undefined){return this._arr[index].priority}};PriorityQueue.prototype.min=function(){if(this.size()===0){throw new Error("Queue underflow")}return this._arr[0].key};PriorityQueue.prototype.add=function(key,priority){var keyIndices=this._keyIndices;key=String(key);if(!_.has(keyIndices,key)){var arr=this._arr;var index=arr.length;keyIndices[key]=index;arr.push({key:key,priority:priority});this._decrease(index);return true}return false};PriorityQueue.prototype.removeMin=function(){this._swap(0,this._arr.length-1);var min=this._arr.pop();delete this._keyIndices[min.key];this._heapify(0);return min.key};PriorityQueue.prototype.decrease=function(key,priority){var index=this._keyIndices[key];if(priority>this._arr[index].priority){throw new Error("New priority is greater than current priority. "+"Key: "+key+" Old: "+this._arr[index].priority+" New: "+priority)}this._arr[index].priority=priority;this._decrease(index)};PriorityQueue.prototype._heapify=function(i){var arr=this._arr;var l=2*i,r=l+1,largest=i;if(l>1;if(arr[parent].priority1){this.setNode(v,value)}else{this.setNode(v)}},this);return this};Graph.prototype.setNode=function(v,value){if(_.has(this._nodes,v)){if(arguments.length>1){this._nodes[v]=value}return this}this._nodes[v]=arguments.length>1?value:this._defaultNodeLabelFn(v);if(this._isCompound){this._parent[v]=GRAPH_NODE;this._children[v]={};this._children[GRAPH_NODE][v]=true}this._in[v]={};this._preds[v]={};this._out[v]={};this._sucs[v]={};++this._nodeCount;return this};Graph.prototype.node=function(v){return this._nodes[v]};Graph.prototype.hasNode=function(v){return _.has(this._nodes,v)};Graph.prototype.removeNode=function(v){var self=this;if(_.has(this._nodes,v)){var removeEdge=function(e){self.removeEdge(self._edgeObjs[e])};delete this._nodes[v];if(this._isCompound){this._removeFromParentsChildList(v);delete this._parent[v];_.each(this.children(v),function(child){this.setParent(child)},this);delete this._children[v]}_.each(_.keys(this._in[v]),removeEdge);delete this._in[v];delete this._preds[v];_.each(_.keys(this._out[v]),removeEdge);delete this._out[v];delete this._sucs[v];--this._nodeCount}return this};Graph.prototype.setParent=function(v,parent){if(!this._isCompound){throw new Error("Cannot set parent in a non-compound graph")}if(_.isUndefined(parent)){parent=GRAPH_NODE}else{for(var ancestor=parent;!_.isUndefined(ancestor);ancestor=this.parent(ancestor)){if(ancestor===v){throw new Error("Setting "+parent+" as parent of "+v+" would create create a cycle")}}this.setNode(parent)}this.setNode(v);this._removeFromParentsChildList(v);this._parent[v]=parent;this._children[parent][v]=true;return this};Graph.prototype._removeFromParentsChildList=function(v){delete this._children[this._parent[v]][v]};Graph.prototype.parent=function(v){if(this._isCompound){var parent=this._parent[v];if(parent!==GRAPH_NODE){return parent}}};Graph.prototype.children=function(v){if(_.isUndefined(v)){v=GRAPH_NODE}if(this._isCompound){var children=this._children[v];if(children){return _.keys(children)}}else if(v===GRAPH_NODE){return this.nodes()}else if(this.hasNode(v)){return[]}};Graph.prototype.predecessors=function(v){var predsV=this._preds[v];if(predsV){return _.keys(predsV)}};Graph.prototype.successors=function(v){var sucsV=this._sucs[v];if(sucsV){return _.keys(sucsV)}};Graph.prototype.neighbors=function(v){var preds=this.predecessors(v);if(preds){return _.union(preds,this.successors(v))}};Graph.prototype.setDefaultEdgeLabel=function(newDefault){if(!_.isFunction(newDefault)){newDefault=_.constant(newDefault)}this._defaultEdgeLabelFn=newDefault;return this};Graph.prototype.edgeCount=function(){return this._edgeCount};Graph.prototype.edges=function(){return _.values(this._edgeObjs)};Graph.prototype.setPath=function(vs,value){var self=this,args=arguments;_.reduce(vs,function(v,w){if(args.length>1){self.setEdge(v,w,value)}else{self.setEdge(v,w)}return w});return this};Graph.prototype.setEdge=function(){var v,w,name,value,valueSpecified=false;if(_.isPlainObject(arguments[0])){v=arguments[0].v;w=arguments[0].w;name=arguments[0].name;if(arguments.length===2){value=arguments[1];valueSpecified=true}}else{v=arguments[0];w=arguments[1];name=arguments[3];if(arguments.length>2){value=arguments[2];valueSpecified=true}}v=""+v;w=""+w;if(!_.isUndefined(name)){name=""+name}var e=edgeArgsToId(this._isDirected,v,w,name);if(_.has(this._edgeLabels,e)){if(valueSpecified){this._edgeLabels[e]=value}return this}if(!_.isUndefined(name)&&!this._isMultigraph){throw new Error("Cannot set a named edge when isMultigraph = false")}this.setNode(v);this.setNode(w);this._edgeLabels[e]=valueSpecified?value:this._defaultEdgeLabelFn(v,w,name);var edgeObj=edgeArgsToObj(this._isDirected,v,w,name);v=edgeObj.v;w=edgeObj.w;Object.freeze(edgeObj);this._edgeObjs[e]=edgeObj;incrementOrInitEntry(this._preds[w],v);incrementOrInitEntry(this._sucs[v],w);this._in[w][e]=edgeObj;this._out[v][e]=edgeObj;this._edgeCount++;return this};Graph.prototype.edge=function(v,w,name){var e=arguments.length===1?edgeObjToId(this._isDirected,arguments[0]):edgeArgsToId(this._isDirected,v,w,name);return this._edgeLabels[e]};Graph.prototype.hasEdge=function(v,w,name){var e=arguments.length===1?edgeObjToId(this._isDirected,arguments[0]):edgeArgsToId(this._isDirected,v,w,name);return _.has(this._edgeLabels,e)};Graph.prototype.removeEdge=function(v,w,name){var e=arguments.length===1?edgeObjToId(this._isDirected,arguments[0]):edgeArgsToId(this._isDirected,v,w,name),edge=this._edgeObjs[e];if(edge){v=edge.v;w=edge.w;delete this._edgeLabels[e];delete this._edgeObjs[e];decrementOrRemoveEntry(this._preds[w],v);decrementOrRemoveEntry(this._sucs[v],w);delete this._in[w][e];delete this._out[v][e];this._edgeCount--}return this};Graph.prototype.inEdges=function(v,u){var inV=this._in[v];if(inV){var edges=_.values(inV);if(!u){return edges}return _.filter(edges,function(edge){return edge.v===u})}};Graph.prototype.outEdges=function(v,w){var outV=this._out[v];if(outV){var edges=_.values(outV);if(!w){return edges}return _.filter(edges,function(edge){return edge.w===w})}};Graph.prototype.nodeEdges=function(v,w){var inEdges=this.inEdges(v,w);if(inEdges){return inEdges.concat(this.outEdges(v,w))}};function incrementOrInitEntry(map,k){if(_.has(map,k)){map[k]++}else{map[k]=1}}function decrementOrRemoveEntry(map,k){if(!--map[k]){delete map[k]}}function edgeArgsToId(isDirected,v,w,name){if(!isDirected&&v>w){var tmp=v;v=w;w=tmp}return v+EDGE_KEY_DELIM+w+EDGE_KEY_DELIM+(_.isUndefined(name)?DEFAULT_EDGE_NAME:name)}function edgeArgsToObj(isDirected,v,w,name){if(!isDirected&&v>w){var tmp=v;v=w;w=tmp}var edgeObj={v:v,w:w};if(name){edgeObj.name=name}return edgeObj}function edgeObjToId(isDirected,edgeObj){return edgeArgsToId(isDirected,edgeObj.v,edgeObj.w,edgeObj.name)}},{"./lodash":75}],73:[function(require,module,exports){module.exports={Graph:require("./graph"),version:require("./version")}},{"./graph":72,"./version":76}],74:[function(require,module,exports){var _=require("./lodash"),Graph=require("./graph");module.exports={write:write,read:read};function write(g){var json={options:{directed:g.isDirected(),multigraph:g.isMultigraph(),compound:g.isCompound()},nodes:writeNodes(g),edges:writeEdges(g)};if(!_.isUndefined(g.graph())){json.value=_.clone(g.graph())}return json}function writeNodes(g){return _.map(g.nodes(),function(v){var nodeValue=g.node(v),parent=g.parent(v),node={v:v};if(!_.isUndefined(nodeValue)){node.value=nodeValue}if(!_.isUndefined(parent)){node.parent=parent}return node})}function writeEdges(g){return _.map(g.edges(),function(e){var edgeValue=g.edge(e),edge={v:e.v,w:e.w};if(!_.isUndefined(e.name)){edge.name=e.name}if(!_.isUndefined(edgeValue)){edge.value=edgeValue}return edge})}function read(json){var g=new Graph(json.options).setGraph(json.value);_.each(json.nodes,function(entry){g.setNode(entry.v,entry.value);if(entry.parent){g.setParent(entry.v,entry.parent)}});_.each(json.edges,function(entry){g.setEdge({v:entry.v,w:entry.w,name:entry.name},entry.value)});return g}},{"./graph":72,"./lodash":75}],75:[function(require,module,exports){module.exports=require(20)},{"/Users/andrew/Documents/dev/dagre-d3/lib/lodash.js":20,lodash:77}],76:[function(require,module,exports){module.exports="1.0.1"},{}],77:[function(require,module,exports){(function(global){(function(){var undefined;var arrayPool=[],objectPool=[];var idCounter=0;var keyPrefix=+new Date+"";var largeArraySize=75;var maxPoolSize=40;var whitespace=" \f \ufeff"+"\n\r\u2028\u2029"+" ᠎              ";var reEmptyStringLeading=/\b__p \+= '';/g,reEmptyStringMiddle=/\b(__p \+=) '' \+/g,reEmptyStringTrailing=/(__e\(.*?\)|\b__t\)) \+\n'';/g;var reEsTemplate=/\$\{([^\\}]*(?:\\.[^\\}]*)*)\}/g;var reFlags=/\w*$/;var reFuncName=/^\s*function[ \n\r\t]+\w/;var reInterpolate=/<%=([\s\S]+?)%>/g;var reLeadingSpacesAndZeros=RegExp("^["+whitespace+"]*0+(?=.$)");var reNoMatch=/($^)/;var reThis=/\bthis\b/;var reUnescapedString=/['\n\r\t\u2028\u2029\\]/g;var contextProps=["Array","Boolean","Date","Function","Math","Number","Object","RegExp","String","_","attachEvent","clearTimeout","isFinite","isNaN","parseInt","setTimeout"];var templateCounter=0;var argsClass="[object Arguments]",arrayClass="[object Array]",boolClass="[object Boolean]",dateClass="[object Date]",funcClass="[object Function]",numberClass="[object Number]",objectClass="[object Object]",regexpClass="[object RegExp]",stringClass="[object String]";var cloneableClasses={};cloneableClasses[funcClass]=false;cloneableClasses[argsClass]=cloneableClasses[arrayClass]=cloneableClasses[boolClass]=cloneableClasses[dateClass]=cloneableClasses[numberClass]=cloneableClasses[objectClass]=cloneableClasses[regexpClass]=cloneableClasses[stringClass]=true;var debounceOptions={leading:false,maxWait:0,trailing:false};var descriptor={configurable:false,enumerable:false,value:null,writable:false};var objectTypes={"boolean":false,"function":true,object:true,number:false,string:false,undefined:false};var stringEscapes={"\\":"\\","'":"'","\n":"n","\r":"r"," ":"t","\u2028":"u2028","\u2029":"u2029"};var root=objectTypes[typeof window]&&window||this;var freeExports=objectTypes[typeof exports]&&exports&&!exports.nodeType&&exports;var freeModule=objectTypes[typeof module]&&module&&!module.nodeType&&module;var moduleExports=freeModule&&freeModule.exports===freeExports&&freeExports;var freeGlobal=objectTypes[typeof global]&&global;if(freeGlobal&&(freeGlobal.global===freeGlobal||freeGlobal.window===freeGlobal)){root=freeGlobal}function baseIndexOf(array,value,fromIndex){var index=(fromIndex||0)-1,length=array?array.length:0;while(++index-1?0:-1:cache?0:-1}function cachePush(value){var cache=this.cache,type=typeof value;if(type=="boolean"||value==null){cache[value]=true}else{if(type!="number"&&type!="string"){type="object"}var key=type=="number"?value:keyPrefix+value,typeCache=cache[type]||(cache[type]={});if(type=="object"){(typeCache[key]||(typeCache[key]=[])).push(value)}else{typeCache[key]=true}}}function charAtCallback(value){return value.charCodeAt(0)}function compareAscending(a,b){var ac=a.criteria,bc=b.criteria,index=-1,length=ac.length;while(++indexother||typeof value=="undefined"){return 1}if(value/g,evaluate:/<%([\s\S]+?)%>/g,interpolate:reInterpolate,variable:"",imports:{_:lodash}};function baseBind(bindData){var func=bindData[0],partialArgs=bindData[2],thisArg=bindData[4];function bound(){if(partialArgs){var args=slice(partialArgs);push.apply(args,arguments)}if(this instanceof bound){var thisBinding=baseCreate(func.prototype),result=func.apply(thisBinding,args||arguments);return isObject(result)?result:thisBinding}return func.apply(thisArg,args||arguments)}setBindData(bound,bindData);return bound}function baseClone(value,isDeep,callback,stackA,stackB){if(callback){var result=callback(value);if(typeof result!="undefined"){return result}}var isObj=isObject(value);if(isObj){var className=toString.call(value);if(!cloneableClasses[className]){return value}var ctor=ctorByClass[className];switch(className){case boolClass:case dateClass:return new ctor(+value);case numberClass:case stringClass:return new ctor(value);case regexpClass:result=ctor(value.source,reFlags.exec(value));result.lastIndex=value.lastIndex;return result}}else{return value}var isArr=isArray(value);if(isDeep){var initedStack=!stackA;stackA||(stackA=getArray());stackB||(stackB=getArray());var length=stackA.length;while(length--){if(stackA[length]==value){return stackB[length]}}result=isArr?ctor(value.length):{}}else{result=isArr?slice(value):assign({},value)}if(isArr){if(hasOwnProperty.call(value,"index")){result.index=value.index}if(hasOwnProperty.call(value,"input")){result.input=value.input}}if(!isDeep){return result}stackA.push(value);stackB.push(result);(isArr?forEach:forOwn)(value,function(objValue,key){result[key]=baseClone(objValue,isDeep,callback,stackA,stackB)});if(initedStack){releaseArray(stackA);releaseArray(stackB)}return result}function baseCreate(prototype,properties){return isObject(prototype)?nativeCreate(prototype):{}; diff --git a/core/src/main/resources/org/apache/spark/ui/static/sorttable.js b/core/src/main/resources/org/apache/spark/ui/static/sorttable.js index dbacbf19beee..dde6069000bc 100644 --- a/core/src/main/resources/org/apache/spark/ui/static/sorttable.js +++ b/core/src/main/resources/org/apache/spark/ui/static/sorttable.js @@ -100,7 +100,7 @@ sorttable = { this.removeChild(document.getElementById('sorttable_sortfwdind')); sortrevind = document.createElement('span'); sortrevind.id = "sorttable_sortrevind"; - sortrevind.innerHTML = stIsIE ? ' 5' : ' ▴'; + sortrevind.innerHTML = stIsIE ? ' 5' : ' ▾'; this.appendChild(sortrevind); return; } @@ -113,7 +113,7 @@ sorttable = { this.removeChild(document.getElementById('sorttable_sortrevind')); sortfwdind = document.createElement('span'); sortfwdind.id = "sorttable_sortfwdind"; - sortfwdind.innerHTML = stIsIE ? ' 6' : ' ▾'; + sortfwdind.innerHTML = stIsIE ? ' 6' : ' ▴'; this.appendChild(sortfwdind); return; } @@ -134,7 +134,7 @@ sorttable = { this.className += ' sorttable_sorted'; sortfwdind = document.createElement('span'); sortfwdind.id = "sorttable_sortfwdind"; - sortfwdind.innerHTML = stIsIE ? ' 6' : ' ▾'; + sortfwdind.innerHTML = stIsIE ? ' 6' : ' ▴'; this.appendChild(sortfwdind); // build an array to sort. This is a Schwartzian transform thing, diff --git a/core/src/main/resources/org/apache/spark/ui/static/spark-dag-viz.css b/core/src/main/resources/org/apache/spark/ui/static/spark-dag-viz.css index eedefb44b96f..3b4ae2ed354b 100644 --- a/core/src/main/resources/org/apache/spark/ui/static/spark-dag-viz.css +++ b/core/src/main/resources/org/apache/spark/ui/static/spark-dag-viz.css @@ -15,32 +15,21 @@ * limitations under the License. */ -#dag-viz-graph svg path { - stroke: #444; - stroke-width: 1.5px; -} - -#dag-viz-graph svg g.cluster rect { - stroke-width: 1px; -} - -#dag-viz-graph svg g.node circle { - fill: #444; +#dag-viz-graph a, #dag-viz-graph a:hover { + text-decoration: none; } -#dag-viz-graph svg g.node rect { - fill: #C3EBFF; - stroke: #3EC0FF; - stroke-width: 1px; +#dag-viz-graph .label { + font-weight: normal; + text-shadow: none; } -#dag-viz-graph svg g.node.cached circle { - fill: #444; +#dag-viz-graph svg path { + stroke: #444; + stroke-width: 1.5px; } -#dag-viz-graph svg g.node.cached rect { - fill: #B3F5C5; - stroke: #56F578; +#dag-viz-graph svg g.cluster rect { stroke-width: 1px; } @@ -61,12 +50,23 @@ stroke-width: 1px; } -#dag-viz-graph svg.job g.cluster[class*="stage"] rect { +#dag-viz-graph svg.job g.cluster.skipped rect { + fill: #D6D6D6; + stroke: #B7B7B7; + stroke-width: 1px; +} + +#dag-viz-graph svg.job g.cluster.stage rect { fill: #FFFFFF; stroke: #FF99AC; stroke-width: 1px; } +#dag-viz-graph svg.job g.cluster.stage.skipped rect { + stroke: #ADADAD; + stroke-width: 1px; +} + #dag-viz-graph svg.job g#cross-stage-edges path { fill: none; } @@ -75,6 +75,20 @@ fill: #333; } +#dag-viz-graph svg.job g.cluster.skipped text { + fill: #666; +} + +#dag-viz-graph svg.job g.node circle { + fill: #444; +} + +#dag-viz-graph svg.job g.node.cached circle { + fill: #A3F545; + stroke: #52C366; + stroke-width: 2px; +} + /* Stage page specific styles */ #dag-viz-graph svg.stage g.cluster rect { @@ -83,7 +97,7 @@ stroke-width: 1px; } -#dag-viz-graph svg.stage g.cluster[class*="stage"] rect { +#dag-viz-graph svg.stage g.cluster.stage rect { fill: #FFFFFF; stroke: #FFA6B6; stroke-width: 1px; @@ -97,11 +111,14 @@ fill: #333; } -#dag-viz-graph a, #dag-viz-graph a:hover { - text-decoration: none; +#dag-viz-graph svg.stage g.node rect { + fill: #C3EBFF; + stroke: #3EC0FF; + stroke-width: 1px; } -#dag-viz-graph .label { - font-weight: normal; - text-shadow: none; +#dag-viz-graph svg.stage g.node.cached rect { + fill: #B3F5C5; + stroke: #52C366; + stroke-width: 2px; } diff --git a/core/src/main/resources/org/apache/spark/ui/static/spark-dag-viz.js b/core/src/main/resources/org/apache/spark/ui/static/spark-dag-viz.js index ee48fd29a643..9fa53baaf421 100644 --- a/core/src/main/resources/org/apache/spark/ui/static/spark-dag-viz.js +++ b/core/src/main/resources/org/apache/spark/ui/static/spark-dag-viz.js @@ -57,9 +57,7 @@ var VizConstants = { stageSep: 40, graphPrefix: "graph_", nodePrefix: "node_", - stagePrefix: "stage_", - clusterPrefix: "cluster_", - stageClusterPrefix: "cluster_stage_" + clusterPrefix: "cluster_" }; var JobPageVizConstants = { @@ -133,9 +131,7 @@ function renderDagViz(forJob) { } // Render - var svg = graphContainer() - .append("svg") - .attr("class", jobOrStage); + var svg = graphContainer().append("svg").attr("class", jobOrStage); if (forJob) { renderDagVizForJob(svg); } else { @@ -144,7 +140,8 @@ function renderDagViz(forJob) { // Find cached RDDs and mark them as such metadataContainer().selectAll(".cached-rdd").each(function(v) { - var nodeId = VizConstants.nodePrefix + d3.select(this).text(); + var rddId = d3.select(this).text().trim(); + var nodeId = VizConstants.nodePrefix + rddId; svg.selectAll("g." + nodeId).classed("cached", true); }); @@ -154,7 +151,7 @@ function renderDagViz(forJob) { /* Render the RDD DAG visualization on the stage page. */ function renderDagVizForStage(svgContainer) { var metadata = metadataContainer().select(".stage-metadata"); - var dot = metadata.select(".dot-file").text(); + var dot = metadata.select(".dot-file").text().trim(); var containerId = VizConstants.graphPrefix + metadata.attr("stage-id"); var container = svgContainer.append("g").attr("id", containerId); renderDot(dot, container, false); @@ -185,23 +182,32 @@ function renderDagVizForJob(svgContainer) { var dot = metadata.select(".dot-file").text(); var stageId = metadata.attr("stage-id"); var containerId = VizConstants.graphPrefix + stageId; - // Link each graph to the corresponding stage page (TODO: handle stage attempts) - var stageLink = $("#stage-" + stageId.replace(VizConstants.stagePrefix, "") + "-0") - .find("a") - .attr("href") + "&expandDagViz=true"; - var container = svgContainer - .append("a") - .attr("xlink:href", stageLink) - .append("g") - .attr("id", containerId); + var isSkipped = metadata.attr("skipped") == "true"; + var container; + if (isSkipped) { + container = svgContainer + .append("g") + .attr("id", containerId) + .attr("skipped", "true"); + } else { + // Link each graph to the corresponding stage page (TODO: handle stage attempts) + // Use the link from the stage table so it also works for the history server + var attemptId = 0 + var stageLink = d3.select("#stage-" + stageId + "-" + attemptId) + .select("a.name-link") + .attr("href") + "&expandDagViz=true"; + container = svgContainer + .append("a") + .attr("xlink:href", stageLink) + .append("g") + .attr("id", containerId); + } // Now we need to shift the container for this stage so it doesn't overlap with // existing ones, taking into account the position and width of the last stage's // container. We do not need to do this for the first stage of this job. if (i > 0) { - var existingStages = svgContainer - .selectAll("g.cluster") - .filter("[class*=\"" + VizConstants.stageClusterPrefix + "\"]"); + var existingStages = svgContainer.selectAll("g.cluster.stage") if (!existingStages.empty()) { var lastStage = d3.select(existingStages[0].pop()); var lastStageWidth = toFloat(lastStage.select("rect").attr("width")); @@ -214,6 +220,12 @@ function renderDagVizForJob(svgContainer) { // Actually render the stage renderDot(dot, container, true); + // Mark elements as skipped if appropriate. Unfortunately we need to mark all + // elements instead of the parent container because of CSS override rules. + if (isSkipped) { + container.selectAll("g").classed("skipped", true); + } + // Round corners on rectangles container .selectAll("rect") @@ -224,7 +236,7 @@ function renderDagVizForJob(svgContainer) { // them separately later. Note that we cannot draw them now because we need to // put these edges in a separate container that is on top of all stage graphs. metadata.selectAll(".incoming-edge").each(function(v) { - var edge = d3.select(this).text().split(","); // e.g. 3,4 => [3, 4] + var edge = d3.select(this).text().trim().split(","); // e.g. 3,4 => [3, 4] crossStageEdges.push(edge); }); }); @@ -243,6 +255,9 @@ function renderDot(dot, container, forJob) { var renderer = new dagreD3.render(); preprocessGraphLayout(g, forJob); renderer(container, g); + + // Find the stage cluster and mark it for styling and post-processing + container.selectAll("g.cluster[name*=\"Stage\"]").classed("stage", true); } /* -------------------- * diff --git a/core/src/main/resources/org/apache/spark/ui/static/timeline-view.js b/core/src/main/resources/org/apache/spark/ui/static/timeline-view.js index 604c29994145..ca74ef9d7e94 100644 --- a/core/src/main/resources/org/apache/spark/ui/static/timeline-view.js +++ b/core/src/main/resources/org/apache/spark/ui/static/timeline-view.js @@ -46,7 +46,7 @@ function drawApplicationTimeline(groupArray, eventObjArray, startTime) { }; $(this).click(function() { - var jobPagePath = $(getSelectorForJobEntry(this)).find("a").attr("href") + var jobPagePath = $(getSelectorForJobEntry(this)).find("a.name-link").attr("href") window.location.href = jobPagePath }); @@ -105,7 +105,7 @@ function drawJobTimeline(groupArray, eventObjArray, startTime) { }; $(this).click(function() { - var stagePagePath = $(getSelectorForStageEntry(this)).find("a").attr("href") + var stagePagePath = $(getSelectorForStageEntry(this)).find("a.name-link").attr("href") window.location.href = stagePagePath }); diff --git a/core/src/main/resources/org/apache/spark/ui/static/webui.css b/core/src/main/resources/org/apache/spark/ui/static/webui.css index e7c1d475d4e5..b1cef4704224 100644 --- a/core/src/main/resources/org/apache/spark/ui/static/webui.css +++ b/core/src/main/resources/org/apache/spark/ui/static/webui.css @@ -135,6 +135,14 @@ pre { display: block; } +.description-input-full { + overflow: hidden; + text-overflow: ellipsis; + width: 100%; + white-space: normal; + display: block; +} + .stacktrace-details { max-height: 300px; overflow-y: auto; diff --git a/core/src/main/scala/org/apache/spark/Accumulators.scala b/core/src/main/scala/org/apache/spark/Accumulators.scala index 330df1d59a9b..5a8d17bd9993 100644 --- a/core/src/main/scala/org/apache/spark/Accumulators.scala +++ b/core/src/main/scala/org/apache/spark/Accumulators.scala @@ -228,7 +228,7 @@ GrowableAccumulableParam[R <% Growable[T] with TraversableOnce[T] with Serializa * @tparam T result type */ class Accumulator[T](@transient initialValue: T, param: AccumulatorParam[T], name: Option[String]) - extends Accumulable[T,T](initialValue, param, name) { + extends Accumulable[T, T](initialValue, param, name) { def this(initialValue: T, param: AccumulatorParam[T]) = this(initialValue, param, None) } diff --git a/core/src/main/scala/org/apache/spark/Aggregator.scala b/core/src/main/scala/org/apache/spark/Aggregator.scala index af9765d313e9..ceeb58075d34 100644 --- a/core/src/main/scala/org/apache/spark/Aggregator.scala +++ b/core/src/main/scala/org/apache/spark/Aggregator.scala @@ -34,8 +34,8 @@ case class Aggregator[K, V, C] ( mergeValue: (C, V) => C, mergeCombiners: (C, C) => C) { - // When spilling is enabled sorting will happen externally, but not necessarily with an - // ExternalSorter. + // When spilling is enabled sorting will happen externally, but not necessarily with an + // ExternalSorter. private val isSpillEnabled = SparkEnv.get.conf.getBoolean("spark.shuffle.spill", true) @deprecated("use combineValuesByKey with TaskContext argument", "0.9.0") @@ -45,7 +45,7 @@ case class Aggregator[K, V, C] ( def combineValuesByKey(iter: Iterator[_ <: Product2[K, V]], context: TaskContext): Iterator[(K, C)] = { if (!isSpillEnabled) { - val combiners = new AppendOnlyMap[K,C] + val combiners = new AppendOnlyMap[K, C] var kv: Product2[K, V] = null val update = (hadValue: Boolean, oldValue: C) => { if (hadValue) mergeValue(oldValue, kv._2) else createCombiner(kv._2) @@ -76,7 +76,7 @@ case class Aggregator[K, V, C] ( : Iterator[(K, C)] = { if (!isSpillEnabled) { - val combiners = new AppendOnlyMap[K,C] + val combiners = new AppendOnlyMap[K, C] var kc: Product2[K, C] = null val update = (hadValue: Boolean, oldValue: C) => { if (hadValue) mergeCombiners(oldValue, kc._2) else kc._2 diff --git a/core/src/main/scala/org/apache/spark/ExecutorAllocationManager.scala b/core/src/main/scala/org/apache/spark/ExecutorAllocationManager.scala index 66bda6808850..49329423dca7 100644 --- a/core/src/main/scala/org/apache/spark/ExecutorAllocationManager.scala +++ b/core/src/main/scala/org/apache/spark/ExecutorAllocationManager.scala @@ -91,7 +91,7 @@ private[spark] class ExecutorAllocationManager( // How long there must be backlogged tasks for before an addition is triggered (seconds) private val schedulerBacklogTimeoutS = conf.getTimeAsSeconds( - "spark.dynamicAllocation.schedulerBacklogTimeout", "5s") + "spark.dynamicAllocation.schedulerBacklogTimeout", "1s") // Same as above, but used only after `schedulerBacklogTimeoutS` is exceeded private val sustainedSchedulerBacklogTimeoutS = conf.getTimeAsSeconds( @@ -99,7 +99,10 @@ private[spark] class ExecutorAllocationManager( // How long an executor must be idle for before it is removed (seconds) private val executorIdleTimeoutS = conf.getTimeAsSeconds( - "spark.dynamicAllocation.executorIdleTimeout", "600s") + "spark.dynamicAllocation.executorIdleTimeout", "60s") + + private val cachedExecutorIdleTimeoutS = conf.getTimeAsSeconds( + "spark.dynamicAllocation.cachedExecutorIdleTimeout", s"${2 * executorIdleTimeoutS}s") // During testing, the methods to actually kill and add executors are mocked out private val testing = conf.getBoolean("spark.dynamicAllocation.testing", false) @@ -150,6 +153,13 @@ private[spark] class ExecutorAllocationManager( // Metric source for ExecutorAllocationManager to expose internal status to MetricsSystem. val executorAllocationManagerSource = new ExecutorAllocationManagerSource + // Whether we are still waiting for the initial set of executors to be allocated. + // While this is true, we will not cancel outstanding executor requests. This is + // set to false when: + // (1) a stage is submitted, or + // (2) an executor idle timeout has elapsed. + @volatile private var initializing: Boolean = true + /** * Verify that the settings specified through the config are valid. * If not, throw an appropriate exception. @@ -240,6 +250,7 @@ private[spark] class ExecutorAllocationManager( removeTimes.retain { case (executorId, expireTime) => val expired = now >= expireTime if (expired) { + initializing = false removeExecutor(executorId) } !expired @@ -261,13 +272,23 @@ private[spark] class ExecutorAllocationManager( private def updateAndSyncNumExecutorsTarget(now: Long): Int = synchronized { val maxNeeded = maxNumExecutorsNeeded - if (maxNeeded < numExecutorsTarget) { + if (initializing) { + // Do not change our target while we are still initializing, + // Otherwise the first job may have to ramp up unnecessarily + 0 + } else if (maxNeeded < numExecutorsTarget) { // The target number exceeds the number we actually need, so stop adding new // executors and inform the cluster manager to cancel the extra pending requests val oldNumExecutorsTarget = numExecutorsTarget numExecutorsTarget = math.max(maxNeeded, minNumExecutors) - client.requestTotalExecutors(numExecutorsTarget) numExecutorsToAdd = 1 + + // If the new target has not changed, avoid sending a message to the cluster manager + if (numExecutorsTarget < oldNumExecutorsTarget) { + client.requestTotalExecutors(numExecutorsTarget) + logDebug(s"Lowering target number of executors to $numExecutorsTarget (previously " + + s"$oldNumExecutorsTarget) because not all requested executors are actually needed") + } numExecutorsTarget - oldNumExecutorsTarget } else if (addTime != NOT_SET && now >= addTime) { val delta = addExecutors(maxNeeded) @@ -292,9 +313,8 @@ private[spark] class ExecutorAllocationManager( private def addExecutors(maxNumExecutorsNeeded: Int): Int = { // Do not request more executors if it would put our target over the upper bound if (numExecutorsTarget >= maxNumExecutors) { - val numExecutorsPending = numExecutorsTarget - executorIds.size - logDebug(s"Not adding executors because there are already ${executorIds.size} registered " + - s"and ${numExecutorsPending} pending executor(s) (limit $maxNumExecutors)") + logDebug(s"Not adding executors because our current target total " + + s"is already $numExecutorsTarget (limit $maxNumExecutors)") numExecutorsToAdd = 1 return 0 } @@ -310,10 +330,19 @@ private[spark] class ExecutorAllocationManager( // Ensure that our target fits within configured bounds: numExecutorsTarget = math.max(math.min(numExecutorsTarget, maxNumExecutors), minNumExecutors) + val delta = numExecutorsTarget - oldNumExecutorsTarget + + // If our target has not changed, do not send a message + // to the cluster manager and reset our exponential growth + if (delta == 0) { + numExecutorsToAdd = 1 + return 0 + } + val addRequestAcknowledged = testing || client.requestTotalExecutors(numExecutorsTarget) if (addRequestAcknowledged) { - val delta = numExecutorsTarget - oldNumExecutorsTarget - logInfo(s"Requesting $delta new executor(s) because tasks are backlogged" + + val executorsString = "executor" + { if (delta > 1) "s" else "" } + logInfo(s"Requesting $delta new $executorsString because tasks are backlogged" + s" (new desired total will be $numExecutorsTarget)") numExecutorsToAdd = if (delta == numExecutorsToAdd) { numExecutorsToAdd * 2 @@ -420,7 +449,7 @@ private[spark] class ExecutorAllocationManager( * This resets all variables used for adding executors. */ private def onSchedulerQueueEmpty(): Unit = synchronized { - logDebug(s"Clearing timer to add executors because there are no more pending tasks") + logDebug("Clearing timer to add executors because there are no more pending tasks") addTime = NOT_SET numExecutorsToAdd = 1 } @@ -433,9 +462,23 @@ private[spark] class ExecutorAllocationManager( private def onExecutorIdle(executorId: String): Unit = synchronized { if (executorIds.contains(executorId)) { if (!removeTimes.contains(executorId) && !executorsPendingToRemove.contains(executorId)) { + // Note that it is not necessary to query the executors since all the cached + // blocks we are concerned with are reported to the driver. Note that this + // does not include broadcast blocks. + val hasCachedBlocks = SparkEnv.get.blockManager.master.hasCachedBlocks(executorId) + val now = clock.getTimeMillis() + val timeout = { + if (hasCachedBlocks) { + // Use a different timeout if the executor has cached blocks. + now + cachedExecutorIdleTimeoutS * 1000 + } else { + now + executorIdleTimeoutS * 1000 + } + } + val realTimeout = if (timeout <= 0) Long.MaxValue else timeout // overflow + removeTimes(executorId) = realTimeout logDebug(s"Starting idle timer for $executorId because there are no more tasks " + - s"scheduled to run on the executor (to expire in $executorIdleTimeoutS seconds)") - removeTimes(executorId) = clock.getTimeMillis + executorIdleTimeoutS * 1000 + s"scheduled to run on the executor (to expire in ${(realTimeout - now)/1000} seconds)") } } else { logWarning(s"Attempted to mark unknown executor $executorId idle") @@ -467,6 +510,7 @@ private[spark] class ExecutorAllocationManager( private var numRunningTasks: Int = _ override def onStageSubmitted(stageSubmitted: SparkListenerStageSubmitted): Unit = { + initializing = false val stageId = stageSubmitted.stageInfo.stageId val numTasks = stageSubmitted.stageInfo.numTasks allocationManager.synchronized { diff --git a/core/src/main/scala/org/apache/spark/FutureAction.scala b/core/src/main/scala/org/apache/spark/FutureAction.scala index 91f9ef8ce718..48792a958130 100644 --- a/core/src/main/scala/org/apache/spark/FutureAction.scala +++ b/core/src/main/scala/org/apache/spark/FutureAction.scala @@ -150,7 +150,7 @@ class SimpleFutureAction[T] private[spark](jobWaiter: JobWaiter[_], resultFunc: } override def isCompleted: Boolean = jobWaiter.jobFinished - + override def isCancelled: Boolean = _cancelled override def value: Option[Try[T]] = { diff --git a/core/src/main/scala/org/apache/spark/HeartbeatReceiver.scala b/core/src/main/scala/org/apache/spark/HeartbeatReceiver.scala index f2b024ff6cb6..221b1dab4327 100644 --- a/core/src/main/scala/org/apache/spark/HeartbeatReceiver.scala +++ b/core/src/main/scala/org/apache/spark/HeartbeatReceiver.scala @@ -24,12 +24,12 @@ import scala.collection.mutable import org.apache.spark.executor.TaskMetrics import org.apache.spark.rpc.{ThreadSafeRpcEndpoint, RpcEnv, RpcCallContext} import org.apache.spark.storage.BlockManagerId -import org.apache.spark.scheduler.{SlaveLost, TaskScheduler} -import org.apache.spark.util.{ThreadUtils, Utils} +import org.apache.spark.scheduler._ +import org.apache.spark.util.{Clock, SystemClock, ThreadUtils, Utils} /** * A heartbeat from executors to the driver. This is a shared message used by several internal - * components to convey liveness or execution information for in-progress tasks. It will also + * components to convey liveness or execution information for in-progress tasks. It will also * expire the hosts that have not heartbeated for more than spark.network.timeout. */ private[spark] case class Heartbeat( @@ -43,15 +43,25 @@ private[spark] case class Heartbeat( */ private[spark] case object TaskSchedulerIsSet -private[spark] case object ExpireDeadHosts - +private[spark] case object ExpireDeadHosts + +private case class ExecutorRegistered(executorId: String) + +private case class ExecutorRemoved(executorId: String) + private[spark] case class HeartbeatResponse(reregisterBlockManager: Boolean) /** * Lives in the driver to receive heartbeats from executors.. */ -private[spark] class HeartbeatReceiver(sc: SparkContext) - extends ThreadSafeRpcEndpoint with Logging { +private[spark] class HeartbeatReceiver(sc: SparkContext, clock: Clock) + extends ThreadSafeRpcEndpoint with SparkListener with Logging { + + def this(sc: SparkContext) { + this(sc, new SystemClock) + } + + sc.addSparkListener(this) override val rpcEnv: RpcEnv = sc.env.rpcEnv @@ -62,18 +72,18 @@ private[spark] class HeartbeatReceiver(sc: SparkContext) // "spark.network.timeout" uses "seconds", while `spark.storage.blockManagerSlaveTimeoutMs` uses // "milliseconds" - private val slaveTimeoutMs = + private val slaveTimeoutMs = sc.conf.getTimeAsMs("spark.storage.blockManagerSlaveTimeoutMs", "120s") - private val executorTimeoutMs = + private val executorTimeoutMs = sc.conf.getTimeAsSeconds("spark.network.timeout", s"${slaveTimeoutMs}ms") * 1000 - + // "spark.network.timeoutInterval" uses "seconds", while // "spark.storage.blockManagerTimeoutIntervalMs" uses "milliseconds" - private val timeoutIntervalMs = + private val timeoutIntervalMs = sc.conf.getTimeAsMs("spark.storage.blockManagerTimeoutIntervalMs", "60s") - private val checkTimeoutIntervalMs = + private val checkTimeoutIntervalMs = sc.conf.getTimeAsSeconds("spark.network.timeoutInterval", s"${timeoutIntervalMs}ms") * 1000 - + private var timeoutCheckingTask: ScheduledFuture[_] = null // "eventLoopThread" is used to run some pretty fast actions. The actions running in it should not @@ -86,30 +96,48 @@ private[spark] class HeartbeatReceiver(sc: SparkContext) override def onStart(): Unit = { timeoutCheckingTask = eventLoopThread.scheduleAtFixedRate(new Runnable { override def run(): Unit = Utils.tryLogNonFatalError { - Option(self).foreach(_.send(ExpireDeadHosts)) + Option(self).foreach(_.ask[Boolean](ExpireDeadHosts)) } }, 0, checkTimeoutIntervalMs, TimeUnit.MILLISECONDS) } - override def receive: PartialFunction[Any, Unit] = { - case ExpireDeadHosts => - expireDeadHosts() + override def receiveAndReply(context: RpcCallContext): PartialFunction[Any, Unit] = { + + // Messages sent and received locally + case ExecutorRegistered(executorId) => + executorLastSeen(executorId) = clock.getTimeMillis() + context.reply(true) + case ExecutorRemoved(executorId) => + executorLastSeen.remove(executorId) + context.reply(true) case TaskSchedulerIsSet => scheduler = sc.taskScheduler - } + context.reply(true) + case ExpireDeadHosts => + expireDeadHosts() + context.reply(true) - override def receiveAndReply(context: RpcCallContext): PartialFunction[Any, Unit] = { + // Messages received from executors case heartbeat @ Heartbeat(executorId, taskMetrics, blockManagerId) => if (scheduler != null) { - executorLastSeen(executorId) = System.currentTimeMillis() - eventLoopThread.submit(new Runnable { - override def run(): Unit = Utils.tryLogNonFatalError { - val unknownExecutor = !scheduler.executorHeartbeatReceived( - executorId, taskMetrics, blockManagerId) - val response = HeartbeatResponse(reregisterBlockManager = unknownExecutor) - context.reply(response) - } - }) + if (executorLastSeen.contains(executorId)) { + executorLastSeen(executorId) = clock.getTimeMillis() + eventLoopThread.submit(new Runnable { + override def run(): Unit = Utils.tryLogNonFatalError { + val unknownExecutor = !scheduler.executorHeartbeatReceived( + executorId, taskMetrics, blockManagerId) + val response = HeartbeatResponse(reregisterBlockManager = unknownExecutor) + context.reply(response) + } + }) + } else { + // This may happen if we get an executor's in-flight heartbeat immediately + // after we just removed it. It's not really an error condition so we should + // not log warning here. Otherwise there may be a lot of noise especially if + // we explicitly remove executors (SPARK-4134). + logDebug(s"Received heartbeat from unknown executor $executorId") + context.reply(HeartbeatResponse(reregisterBlockManager = true)) + } } else { // Because Executor will sleep several seconds before sending the first "Heartbeat", this // case rarely happens. However, if it really happens, log it and ask the executor to @@ -119,9 +147,30 @@ private[spark] class HeartbeatReceiver(sc: SparkContext) } } + /** + * If the heartbeat receiver is not stopped, notify it of executor registrations. + */ + override def onExecutorAdded(executorAdded: SparkListenerExecutorAdded): Unit = { + Option(self).foreach(_.ask[Boolean](ExecutorRegistered(executorAdded.executorId))) + } + + /** + * If the heartbeat receiver is not stopped, notify it of executor removals so it doesn't + * log superfluous errors. + * + * Note that we must do this after the executor is actually removed to guard against the + * following race condition: if we remove an executor's metadata from our data structure + * prematurely, we may get an in-flight heartbeat from the executor before the executor is + * actually removed, in which case we will still mark the executor as a dead host later + * and expire it with loud error messages. + */ + override def onExecutorRemoved(executorRemoved: SparkListenerExecutorRemoved): Unit = { + Option(self).foreach(_.ask[Boolean](ExecutorRemoved(executorRemoved.executorId))) + } + private def expireDeadHosts(): Unit = { logTrace("Checking for hosts with no recent heartbeats in HeartbeatReceiver.") - val now = System.currentTimeMillis() + val now = clock.getTimeMillis() for ((executorId, lastSeenMs) <- executorLastSeen) { if (now - lastSeenMs > executorTimeoutMs) { logWarning(s"Removing executor $executorId with no recent heartbeats: " + @@ -140,7 +189,7 @@ private[spark] class HeartbeatReceiver(sc: SparkContext) } } } - + override def onStop(): Unit = { if (timeoutCheckingTask != null) { timeoutCheckingTask.cancel(true) diff --git a/core/src/main/scala/org/apache/spark/HttpFileServer.scala b/core/src/main/scala/org/apache/spark/HttpFileServer.scala index 7e706bcc42f0..7cf7bc0dc681 100644 --- a/core/src/main/scala/org/apache/spark/HttpFileServer.scala +++ b/core/src/main/scala/org/apache/spark/HttpFileServer.scala @@ -50,8 +50,8 @@ private[spark] class HttpFileServer( def stop() { httpServer.stop() - - // If we only stop sc, but the driver process still run as a services then we need to delete + + // If we only stop sc, but the driver process still run as a services then we need to delete // the tmp dir, if not, it will create too many tmp dirs try { Utils.deleteRecursively(baseDir) diff --git a/core/src/main/scala/org/apache/spark/Logging.scala b/core/src/main/scala/org/apache/spark/Logging.scala index 419d093d5564..7fcb7830e7b0 100644 --- a/core/src/main/scala/org/apache/spark/Logging.scala +++ b/core/src/main/scala/org/apache/spark/Logging.scala @@ -121,13 +121,25 @@ trait Logging { if (usingLog4j12) { val log4j12Initialized = LogManager.getRootLogger.getAllAppenders.hasMoreElements if (!log4j12Initialized) { - val defaultLogProps = "org/apache/spark/log4j-defaults.properties" - Option(Utils.getSparkClassLoader.getResource(defaultLogProps)) match { - case Some(url) => - PropertyConfigurator.configure(url) - System.err.println(s"Using Spark's default log4j profile: $defaultLogProps") - case None => - System.err.println(s"Spark was unable to load $defaultLogProps") + if (Utils.isInInterpreter) { + val replDefaultLogProps = "org/apache/spark/log4j-defaults-repl.properties" + Option(Utils.getSparkClassLoader.getResource(replDefaultLogProps)) match { + case Some(url) => + PropertyConfigurator.configure(url) + System.err.println(s"Using Spark's repl log4j profile: $replDefaultLogProps") + System.err.println("To adjust logging level use sc.setLogLevel(\"INFO\")") + case None => + System.err.println(s"Spark was unable to load $replDefaultLogProps") + } + } else { + val defaultLogProps = "org/apache/spark/log4j-defaults.properties" + Option(Utils.getSparkClassLoader.getResource(defaultLogProps)) match { + case Some(url) => + PropertyConfigurator.configure(url) + System.err.println(s"Using Spark's default log4j profile: $defaultLogProps") + case None => + System.err.println(s"Spark was unable to load $defaultLogProps") + } } } } diff --git a/core/src/main/scala/org/apache/spark/MapOutputTracker.scala b/core/src/main/scala/org/apache/spark/MapOutputTracker.scala index 018422827e1c..862ffe868f58 100644 --- a/core/src/main/scala/org/apache/spark/MapOutputTracker.scala +++ b/core/src/main/scala/org/apache/spark/MapOutputTracker.scala @@ -21,7 +21,7 @@ import java.io._ import java.util.concurrent.ConcurrentHashMap import java.util.zip.{GZIPInputStream, GZIPOutputStream} -import scala.collection.mutable.{HashSet, Map} +import scala.collection.mutable.{HashMap, HashSet, Map} import scala.collection.JavaConversions._ import scala.reflect.ClassTag @@ -284,6 +284,53 @@ private[spark] class MapOutputTrackerMaster(conf: SparkConf) cachedSerializedStatuses.contains(shuffleId) || mapStatuses.contains(shuffleId) } + /** + * Return a list of locations that each have fraction of map output greater than the specified + * threshold. + * + * @param shuffleId id of the shuffle + * @param reducerId id of the reduce task + * @param numReducers total number of reducers in the shuffle + * @param fractionThreshold fraction of total map output size that a location must have + * for it to be considered large. + * + * This method is not thread-safe. + */ + def getLocationsWithLargestOutputs( + shuffleId: Int, + reducerId: Int, + numReducers: Int, + fractionThreshold: Double) + : Option[Array[BlockManagerId]] = { + + if (mapStatuses.contains(shuffleId)) { + val statuses = mapStatuses(shuffleId) + if (statuses.nonEmpty) { + // HashMap to add up sizes of all blocks at the same location + val locs = new HashMap[BlockManagerId, Long] + var totalOutputSize = 0L + var mapIdx = 0 + while (mapIdx < statuses.length) { + val status = statuses(mapIdx) + val blockSize = status.getSizeForBlock(reducerId) + if (blockSize > 0) { + locs(status.location) = locs.getOrElse(status.location, 0L) + blockSize + totalOutputSize += blockSize + } + mapIdx = mapIdx + 1 + } + val topLocs = locs.filter { case (loc, size) => + size.toDouble / totalOutputSize >= fractionThreshold + } + // Return if we have any locations which satisfy the required threshold + if (topLocs.nonEmpty) { + return Some(topLocs.map(_._1).toArray) + } + } + } + None + } + def incrementEpoch() { epochLock.synchronized { epoch += 1 diff --git a/core/src/main/scala/org/apache/spark/Partitioner.scala b/core/src/main/scala/org/apache/spark/Partitioner.scala index b8d244408bc5..82889bcd3098 100644 --- a/core/src/main/scala/org/apache/spark/Partitioner.scala +++ b/core/src/main/scala/org/apache/spark/Partitioner.scala @@ -103,7 +103,7 @@ class HashPartitioner(partitions: Int) extends Partitioner { */ class RangePartitioner[K : Ordering : ClassTag, V]( @transient partitions: Int, - @transient rdd: RDD[_ <: Product2[K,V]], + @transient rdd: RDD[_ <: Product2[K, V]], private var ascending: Boolean = true) extends Partitioner { @@ -185,7 +185,7 @@ class RangePartitioner[K : Ordering : ClassTag, V]( } override def equals(other: Any): Boolean = other match { - case r: RangePartitioner[_,_] => + case r: RangePartitioner[_, _] => r.rangeBounds.sameElements(rangeBounds) && r.ascending == ascending case _ => false @@ -249,7 +249,7 @@ private[spark] object RangePartitioner { * @param sampleSizePerPartition max sample size per partition * @return (total number of items, an array of (partitionId, number of items, sample)) */ - def sketch[K:ClassTag]( + def sketch[K : ClassTag]( rdd: RDD[K], sampleSizePerPartition: Int): (Long, Array[(Int, Int, Array[K])]) = { val shift = rdd.id @@ -272,7 +272,7 @@ private[spark] object RangePartitioner { * @param partitions number of partitions * @return selected bounds */ - def determineBounds[K:Ordering:ClassTag]( + def determineBounds[K : Ordering : ClassTag]( candidates: ArrayBuffer[(K, Float)], partitions: Int): Array[K] = { val ordering = implicitly[Ordering[K]] diff --git a/core/src/main/scala/org/apache/spark/SSLOptions.scala b/core/src/main/scala/org/apache/spark/SSLOptions.scala index 2cdc167f85af..32df42d57dbd 100644 --- a/core/src/main/scala/org/apache/spark/SSLOptions.scala +++ b/core/src/main/scala/org/apache/spark/SSLOptions.scala @@ -17,7 +17,9 @@ package org.apache.spark -import java.io.File +import java.io.{File, FileInputStream} +import java.security.{KeyStore, NoSuchAlgorithmException} +import javax.net.ssl.{KeyManager, KeyManagerFactory, SSLContext, TrustManager, TrustManagerFactory} import com.typesafe.config.{Config, ConfigFactory, ConfigValueFactory} import org.eclipse.jetty.util.ssl.SslContextFactory @@ -38,7 +40,7 @@ import org.eclipse.jetty.util.ssl.SslContextFactory * @param trustStore a path to the trust-store file * @param trustStorePassword a password to access the trust-store file * @param protocol SSL protocol (remember that SSLv3 was compromised) supported by Java - * @param enabledAlgorithms a set of encryption algorithms to use + * @param enabledAlgorithms a set of encryption algorithms that may be used */ private[spark] case class SSLOptions( enabled: Boolean = false, @@ -48,7 +50,8 @@ private[spark] case class SSLOptions( trustStore: Option[File] = None, trustStorePassword: Option[String] = None, protocol: Option[String] = None, - enabledAlgorithms: Set[String] = Set.empty) { + enabledAlgorithms: Set[String] = Set.empty) + extends Logging { /** * Creates a Jetty SSL context factory according to the SSL settings represented by this object. @@ -63,7 +66,7 @@ private[spark] case class SSLOptions( trustStorePassword.foreach(sslContextFactory.setTrustStorePassword) keyPassword.foreach(sslContextFactory.setKeyManagerPassword) protocol.foreach(sslContextFactory.setProtocol) - sslContextFactory.setIncludeCipherSuites(enabledAlgorithms.toSeq: _*) + sslContextFactory.setIncludeCipherSuites(supportedAlgorithms.toSeq: _*) Some(sslContextFactory) } else { @@ -94,7 +97,7 @@ private[spark] case class SSLOptions( .withValue("akka.remote.netty.tcp.security.protocol", ConfigValueFactory.fromAnyRef(protocol.getOrElse(""))) .withValue("akka.remote.netty.tcp.security.enabled-algorithms", - ConfigValueFactory.fromIterable(enabledAlgorithms.toSeq)) + ConfigValueFactory.fromIterable(supportedAlgorithms.toSeq)) .withValue("akka.remote.netty.tcp.enable-ssl", ConfigValueFactory.fromAnyRef(true))) } else { @@ -102,6 +105,36 @@ private[spark] case class SSLOptions( } } + /* + * The supportedAlgorithms set is a subset of the enabledAlgorithms that + * are supported by the current Java security provider for this protocol. + */ + private val supportedAlgorithms: Set[String] = { + var context: SSLContext = null + try { + context = SSLContext.getInstance(protocol.orNull) + /* The set of supported algorithms does not depend upon the keys, trust, or + rng, although they will influence which algorithms are eventually used. */ + context.init(null, null, null) + } catch { + case npe: NullPointerException => + logDebug("No SSL protocol specified") + context = SSLContext.getDefault + case nsa: NoSuchAlgorithmException => + logDebug(s"No support for requested SSL protocol ${protocol.get}") + context = SSLContext.getDefault + } + + val providerAlgorithms = context.getServerSocketFactory.getSupportedCipherSuites.toSet + + // Log which algorithms we are discarding + (enabledAlgorithms &~ providerAlgorithms).foreach { cipher => + logDebug(s"Discarding unsupported cipher $cipher") + } + + enabledAlgorithms & providerAlgorithms + } + /** Returns a string representation of this SSLOptions with all the passwords masked. */ override def toString: String = s"SSLOptions{enabled=$enabled, " + s"keyStore=$keyStore, keyStorePassword=${keyStorePassword.map(_ => "xxx")}, " + diff --git a/core/src/main/scala/org/apache/spark/SecurityManager.scala b/core/src/main/scala/org/apache/spark/SecurityManager.scala index 8aed1e20e068..673ef49e7c1c 100644 --- a/core/src/main/scala/org/apache/spark/SecurityManager.scala +++ b/core/src/main/scala/org/apache/spark/SecurityManager.scala @@ -192,7 +192,7 @@ private[spark] class SecurityManager(sparkConf: SparkConf) // key used to store the spark secret in the Hadoop UGI private val sparkSecretLookupKey = "sparkCookie" - private val authOn = sparkConf.getBoolean("spark.authenticate", false) + private val authOn = sparkConf.getBoolean(SecurityManager.SPARK_AUTH_CONF, false) // keep spark.ui.acls.enable for backwards compatibility with 1.0 private var aclsOn = sparkConf.getBoolean("spark.acls.enable", sparkConf.getBoolean("spark.ui.acls.enable", false)) @@ -365,10 +365,12 @@ private[spark] class SecurityManager(sparkConf: SparkConf) cookie } else { // user must have set spark.authenticate.secret config - sparkConf.getOption("spark.authenticate.secret") match { + // For Master/Worker, auth secret is in conf; for Executors, it is in env variable + sys.env.get(SecurityManager.ENV_AUTH_SECRET) + .orElse(sparkConf.getOption(SecurityManager.SPARK_AUTH_SECRET_CONF)) match { case Some(value) => value case None => throw new Exception("Error: a secret key must be specified via the " + - "spark.authenticate.secret config") + SecurityManager.SPARK_AUTH_SECRET_CONF + " config") } } sCookie @@ -449,3 +451,12 @@ private[spark] class SecurityManager(sparkConf: SparkConf) override def getSaslUser(appId: String): String = getSaslUser() override def getSecretKey(appId: String): String = getSecretKey() } + +private[spark] object SecurityManager { + + val SPARK_AUTH_CONF: String = "spark.authenticate" + val SPARK_AUTH_SECRET_CONF: String = "spark.authenticate.secret" + // This is used to set auth secret to an executor's env variable. It should have the same + // value as SPARK_AUTH_SECERET_CONF set in SparkConf + val ENV_AUTH_SECRET = "_SPARK_AUTH_SECRET" +} diff --git a/core/src/main/scala/org/apache/spark/SerializableWritable.scala b/core/src/main/scala/org/apache/spark/SerializableWritable.scala index cb2cae185256..beb2e2725472 100644 --- a/core/src/main/scala/org/apache/spark/SerializableWritable.scala +++ b/core/src/main/scala/org/apache/spark/SerializableWritable.scala @@ -41,7 +41,7 @@ class SerializableWritable[T <: Writable](@transient var t: T) extends Serializa private def readObject(in: ObjectInputStream): Unit = Utils.tryOrIOException { in.defaultReadObject() val ow = new ObjectWritable() - ow.setConf(new Configuration()) + ow.setConf(new Configuration(false)) ow.readFields(in) t = ow.get().asInstanceOf[T] } diff --git a/core/src/main/scala/org/apache/spark/SizeEstimator.scala b/core/src/main/scala/org/apache/spark/SizeEstimator.scala deleted file mode 100644 index 54fc3a856adf..000000000000 --- a/core/src/main/scala/org/apache/spark/SizeEstimator.scala +++ /dev/null @@ -1,44 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark - -import org.apache.spark.annotation.DeveloperApi - -/** - * Estimates the sizes of Java objects (number of bytes of memory they occupy), for use in - * memory-aware caches. - * - * Based on the following JavaWorld article: - * http://www.javaworld.com/javaworld/javaqa/2003-12/02-qa-1226-sizeof.html - */ -@DeveloperApi -object SizeEstimator { - /** - * :: DeveloperApi :: - * Estimate the number of bytes that the given object takes up on the JVM heap. The estimate - * includes space taken up by objects referenced by the given object, their references, and so on - * and so forth. - * - * This is useful for determining the amount of heap space a broadcast variable will occupy on - * each executor or the amount of space each object will take when caching objects in - * deserialized form. This is not the same as the serialized size of the object, which will - * typically be much smaller. - */ - @DeveloperApi - def estimate(obj: AnyRef): Long = org.apache.spark.util.SizeEstimator.estimate(obj) -} diff --git a/core/src/main/scala/org/apache/spark/SparkConf.scala b/core/src/main/scala/org/apache/spark/SparkConf.scala index a8fc90ad2050..6cf36fbbd625 100644 --- a/core/src/main/scala/org/apache/spark/SparkConf.scala +++ b/core/src/main/scala/org/apache/spark/SparkConf.scala @@ -227,7 +227,7 @@ class SparkConf(loadDefaults: Boolean) extends Cloneable with Logging { def getSizeAsBytes(key: String, defaultValue: String): Long = { Utils.byteStringAsBytes(get(key, defaultValue)) } - + /** * Get a size parameter as Kibibytes; throws a NoSuchElementException if it's not set. If no * suffix is provided then Kibibytes are assumed. @@ -244,7 +244,7 @@ class SparkConf(loadDefaults: Boolean) extends Cloneable with Logging { def getSizeAsKb(key: String, defaultValue: String): Long = { Utils.byteStringAsKb(get(key, defaultValue)) } - + /** * Get a size parameter as Mebibytes; throws a NoSuchElementException if it's not set. If no * suffix is provided then Mebibytes are assumed. @@ -261,7 +261,7 @@ class SparkConf(loadDefaults: Boolean) extends Cloneable with Logging { def getSizeAsMb(key: String, defaultValue: String): Long = { Utils.byteStringAsMb(get(key, defaultValue)) } - + /** * Get a size parameter as Gibibytes; throws a NoSuchElementException if it's not set. If no * suffix is provided then Gibibytes are assumed. @@ -278,7 +278,7 @@ class SparkConf(loadDefaults: Boolean) extends Cloneable with Logging { def getSizeAsGb(key: String, defaultValue: String): Long = { Utils.byteStringAsGb(get(key, defaultValue)) } - + /** Get a parameter as an Option */ def getOption(key: String): Option[String] = { Option(settings.get(key)).orElse(getDeprecatedConfig(key, this)) @@ -480,8 +480,8 @@ private[spark] object SparkConf extends Logging { "spark.kryoserializer.buffer.mb was previously specified as '0.064'. Fractional values " + "are no longer accepted. To specify the equivalent now, one may use '64k'.") ) - - Map(configs.map { cfg => (cfg.key -> cfg) }:_*) + + Map(configs.map { cfg => (cfg.key -> cfg) } : _*) } /** @@ -508,8 +508,8 @@ private[spark] object SparkConf extends Logging { "spark.reducer.maxSizeInFlight" -> Seq( AlternateConfig("spark.reducer.maxMbInFlight", "1.4")), "spark.kryoserializer.buffer" -> - Seq(AlternateConfig("spark.kryoserializer.buffer.mb", "1.4", - translation = s => s"${s.toDouble * 1000}k")), + Seq(AlternateConfig("spark.kryoserializer.buffer.mb", "1.4", + translation = s => s"${(s.toDouble * 1000).toInt}k")), "spark.kryoserializer.buffer.max" -> Seq( AlternateConfig("spark.kryoserializer.buffer.max.mb", "1.4")), "spark.shuffle.file.buffer" -> Seq( @@ -557,7 +557,7 @@ private[spark] object SparkConf extends Logging { def isExecutorStartupConf(name: String): Boolean = { isAkkaConf(name) || name.startsWith("spark.akka") || - name.startsWith("spark.auth") || + (name.startsWith("spark.auth") && name != SecurityManager.SPARK_AUTH_SECRET_CONF) || name.startsWith("spark.ssl") || isSparkPortConf(name) } diff --git a/core/src/main/scala/org/apache/spark/SparkContext.scala b/core/src/main/scala/org/apache/spark/SparkContext.scala index af276e7b8d40..d2547eeff2b4 100644 --- a/core/src/main/scala/org/apache/spark/SparkContext.scala +++ b/core/src/main/scala/org/apache/spark/SparkContext.scala @@ -315,6 +315,14 @@ class SparkContext(config: SparkConf) extends Logging with ExecutorAllocationCli _dagScheduler = ds } + /** + * A unique identifier for the Spark application. + * Its format depends on the scheduler implementation. + * (i.e. + * in case of local spark app something like 'local-1433865536131' + * in case of YARN something like 'application_1433865536131_34483' + * ) + */ def applicationId: String = _applicationId def applicationAttemptId: Option[String] = _applicationAttemptId @@ -389,7 +397,7 @@ class SparkContext(config: SparkConf) extends Logging with ExecutorAllocationCli _conf.set("spark.executor.id", SparkContext.DRIVER_IDENTIFIER) - _jars =_conf.getOption("spark.jars").map(_.split(",")).map(_.filter(_.size != 0)).toSeq.flatten + _jars = _conf.getOption("spark.jars").map(_.split(",")).map(_.filter(_.size != 0)).toSeq.flatten _files = _conf.getOption("spark.files").map(_.split(",")).map(_.filter(_.size != 0)) .toSeq.flatten @@ -438,7 +446,7 @@ class SparkContext(config: SparkConf) extends Logging with ExecutorAllocationCli _ui = if (conf.getBoolean("spark.ui.enabled", true)) { Some(SparkUI.createLiveUI(this, _conf, listenerBus, _jobProgressListener, - _env.securityManager,appName, startTime = startTime)) + _env.securityManager, appName, startTime = startTime)) } else { // For tests, do not enable the UI None @@ -490,7 +498,7 @@ class SparkContext(config: SparkConf) extends Logging with ExecutorAllocationCli _schedulerBackend = sched _taskScheduler = ts _dagScheduler = new DAGScheduler(this) - _heartbeatReceiver.send(TaskSchedulerIsSet) + _heartbeatReceiver.ask[Boolean](TaskSchedulerIsSet) // start TaskScheduler after taskScheduler sets DAGScheduler reference in DAGScheduler's // constructor @@ -545,7 +553,6 @@ class SparkContext(config: SparkConf) extends Logging with ExecutorAllocationCli // Post init _taskScheduler.postStartHook() - _env.metricsSystem.registerSource(new DAGSchedulerSource(dagScheduler)) _env.metricsSystem.registerSource(new BlockManagerSource(_env.blockManager)) _executorAllocationManager.foreach { e => _env.metricsSystem.registerSource(e.executorAllocationManagerSource) @@ -678,7 +685,7 @@ class SparkContext(config: SparkConf) extends Logging with ExecutorAllocationCli * * Note: Return statements are NOT allowed in the given body. */ - private def withScope[U](body: => U): U = RDDOperationScope.withScope[U](this)(body) + private[spark] def withScope[U](body: => U): U = RDDOperationScope.withScope[U](this)(body) // Methods for creating RDDs @@ -697,6 +704,78 @@ class SparkContext(config: SparkConf) extends Logging with ExecutorAllocationCli new ParallelCollectionRDD[T](this, seq, numSlices, Map[Int, Seq[String]]()) } + /** + * Creates a new RDD[Long] containing elements from `start` to `end`(exclusive), increased by + * `step` every element. + * + * @note if we need to cache this RDD, we should make sure each partition does not exceed limit. + * + * @param start the start value. + * @param end the end value. + * @param step the incremental step + * @param numSlices the partition number of the new RDD. + * @return + */ + def range( + start: Long, + end: Long, + step: Long = 1, + numSlices: Int = defaultParallelism): RDD[Long] = withScope { + assertNotStopped() + // when step is 0, range will run infinitely + require(step != 0, "step cannot be 0") + val numElements: BigInt = { + val safeStart = BigInt(start) + val safeEnd = BigInt(end) + if ((safeEnd - safeStart) % step == 0 || safeEnd > safeStart ^ step > 0) { + (safeEnd - safeStart) / step + } else { + // the remainder has the same sign with range, could add 1 more + (safeEnd - safeStart) / step + 1 + } + } + parallelize(0 until numSlices, numSlices).mapPartitionsWithIndex((i, _) => { + val partitionStart = (i * numElements) / numSlices * step + start + val partitionEnd = (((i + 1) * numElements) / numSlices) * step + start + def getSafeMargin(bi: BigInt): Long = + if (bi.isValidLong) { + bi.toLong + } else if (bi > 0) { + Long.MaxValue + } else { + Long.MinValue + } + val safePartitionStart = getSafeMargin(partitionStart) + val safePartitionEnd = getSafeMargin(partitionEnd) + + new Iterator[Long] { + private[this] var number: Long = safePartitionStart + private[this] var overflow: Boolean = false + + override def hasNext = + if (!overflow) { + if (step > 0) { + number < safePartitionEnd + } else { + number > safePartitionEnd + } + } else false + + override def next() = { + val ret = number + number += step + if (number < ret ^ step < 0) { + // we have Long.MaxValue + Long.MaxValue < Long.MaxValue + // and Long.MinValue + Long.MinValue > Long.MinValue, so iff the step causes a step + // back, we are pretty sure that we have an overflow. + overflow = true + } + ret + } + } + }) + } + /** Distribute a local Scala collection to form an RDD. * * This method is identical to `parallelize`. @@ -752,7 +831,8 @@ class SparkContext(config: SparkConf) extends Logging with ExecutorAllocationCli * }}} * * @note Small files are preferred, large file is also allowable, but may cause bad performance. - * + * @note On some filesystems, `.../path/*` can be a more efficient way to read all files + * in a directory rather than `.../path/` or `.../path` * @param minPartitions A suggestion value of the minimal splitting number for input data. */ def wholeTextFiles( @@ -799,9 +879,10 @@ class SparkContext(config: SparkConf) extends Logging with ExecutorAllocationCli * (a-hdfs-path/part-nnnnn, its content) * }}} * - * @param minPartitions A suggestion value of the minimal splitting number for input data. - * * @note Small files are preferred; very large files may cause bad performance. + * @note On some filesystems, `.../path/*` can be a more efficient way to read all files + * in a directory rather than `.../path/` or `.../path` + * @param minPartitions A suggestion value of the minimal splitting number for input data. */ @Experimental def binaryFiles( @@ -845,7 +926,7 @@ class SparkContext(config: SparkConf) extends Logging with ExecutorAllocationCli classOf[FixedLengthBinaryInputFormat], classOf[LongWritable], classOf[BytesWritable], - conf=conf) + conf = conf) val data = br.map { case (k, v) => val bytes = v.getBytes assert(bytes.length == recordLength, "Byte array does not have correct length") @@ -902,7 +983,7 @@ class SparkContext(config: SparkConf) extends Logging with ExecutorAllocationCli minPartitions: Int = defaultMinPartitions): RDD[(K, V)] = withScope { assertNotStopped() // A Hadoop configuration can be about 10 KB, which is pretty big, so broadcast it. - val confBroadcast = broadcast(new SerializableWritable(hadoopConfiguration)) + val confBroadcast = broadcast(new SerializableConfiguration(hadoopConfiguration)) val setInputPathsFunc = (jobConf: JobConf) => FileInputFormat.setInputPaths(jobConf, path) new HadoopRDD( this, @@ -1087,8 +1168,8 @@ class SparkContext(config: SparkConf) extends Logging with ExecutorAllocationCli kcf: () => WritableConverter[K], vcf: () => WritableConverter[V]): RDD[(K, V)] = { withScope { assertNotStopped() - val kc = kcf() - val vc = vcf() + val kc = clean(kcf)() + val vc = clean(vcf)() val format = classOf[SequenceFileInputFormat[Writable, Writable]] val writables = hadoopFile(path, format, kc.writableClass(km).asInstanceOf[Class[Writable]], @@ -1195,7 +1276,7 @@ class SparkContext(config: SparkConf) extends Logging with ExecutorAllocationCli */ def accumulableCollection[R <% Growable[T] with TraversableOnce[T] with Serializable: ClassTag, T] (initialValue: R): Accumulable[R, T] = { - val param = new GrowableAccumulableParam[R,T] + val param = new GrowableAccumulableParam[R, T] val acc = new Accumulable(initialValue, param) cleaner.foreach(_.registerAccumulatorForCleanup(acc)) acc @@ -1244,7 +1325,7 @@ class SparkContext(config: SparkConf) extends Logging with ExecutorAllocationCli val uri = new URI(path) val schemeCorrectedPath = uri.getScheme match { case null | "local" => new File(path).getCanonicalFile.toURI.toString - case _ => path + case _ => path } val hadoopPath = new Path(schemeCorrectedPath) @@ -1812,7 +1893,7 @@ class SparkContext(config: SparkConf) extends Logging with ExecutorAllocationCli * * @param f the closure to clean * @param checkSerializable whether or not to immediately check f for serializability - * @throws SparkException if checkSerializable is set but f is not + * @throws SparkException if checkSerializable is set but f is not * serializable */ private[spark] def clean[F <: AnyRef](f: F, checkSerializable: Boolean = true): F = { @@ -1825,6 +1906,16 @@ class SparkContext(config: SparkConf) extends Logging with ExecutorAllocationCli * be a HDFS path if running on a cluster. */ def setCheckpointDir(directory: String) { + + // If we are running on a cluster, log a warning if the directory is local. + // Otherwise, the driver may attempt to reconstruct the checkpointed RDD from + // its own local file system, which is incorrect because the checkpoint files + // are actually on the executor machines. + if (!isLocal && Utils.nonLocalPaths(directory).isEmpty) { + logWarning("Checkpoint directory must be non-local " + + "if Spark is running on a cluster: " + directory) + } + checkpointDir = Option(directory).map { dir => val path = new Path(dir, UUID.randomUUID().toString) val fs = path.getFileSystem(hadoopConfiguration) @@ -1919,7 +2010,7 @@ class SparkContext(config: SparkConf) extends Logging with ExecutorAllocationCli // Note: this code assumes that the task scheduler has been initialized and has contacted // the cluster manager to get an application ID (in case the cluster manager provides one). listenerBus.post(SparkListenerApplicationStart(appName, Some(applicationId), - startTime, sparkUser, applicationAttemptId)) + startTime, sparkUser, applicationAttemptId, schedulerBackend.getDriverLogUrls)) } /** Post the application end event */ diff --git a/core/src/main/scala/org/apache/spark/SparkEnv.scala b/core/src/main/scala/org/apache/spark/SparkEnv.scala index 327114542880..d18fc599e989 100644 --- a/core/src/main/scala/org/apache/spark/SparkEnv.scala +++ b/core/src/main/scala/org/apache/spark/SparkEnv.scala @@ -20,7 +20,8 @@ package org.apache.spark import java.io.File import java.net.Socket -import scala.collection.JavaConversions._ +import akka.actor.ActorSystem + import scala.collection.mutable import scala.util.Properties @@ -75,7 +76,8 @@ class SparkEnv ( val conf: SparkConf) extends Logging { // TODO Remove actorSystem - val actorSystem = rpcEnv.asInstanceOf[AkkaRpcEnv].actorSystem + @deprecated("Actor system is no longer supported as of 1.4.0", "1.4.0") + val actorSystem: ActorSystem = rpcEnv.asInstanceOf[AkkaRpcEnv].actorSystem private[spark] var isStopped = false private val pythonWorkers = mutable.HashMap[(String, Map[String, String]), PythonWorkerFactory]() @@ -87,39 +89,42 @@ class SparkEnv ( private var driverTmpDirToDelete: Option[String] = None private[spark] def stop() { - isStopped = true - pythonWorkers.foreach { case(key, worker) => worker.stop() } - Option(httpFileServer).foreach(_.stop()) - mapOutputTracker.stop() - shuffleManager.stop() - broadcastManager.stop() - blockManager.stop() - blockManager.master.stop() - metricsSystem.stop() - outputCommitCoordinator.stop() - rpcEnv.shutdown() - - // Unfortunately Akka's awaitTermination doesn't actually wait for the Netty server to shut - // down, but let's call it anyway in case it gets fixed in a later release - // UPDATE: In Akka 2.1.x, this hangs if there are remote actors, so we can't call it. - // actorSystem.awaitTermination() - - // Note that blockTransferService is stopped by BlockManager since it is started by it. - - // If we only stop sc, but the driver process still run as a services then we need to delete - // the tmp dir, if not, it will create too many tmp dirs. - // We only need to delete the tmp dir create by driver, because sparkFilesDir is point to the - // current working dir in executor which we do not need to delete. - driverTmpDirToDelete match { - case Some(path) => { - try { - Utils.deleteRecursively(new File(path)) - } catch { - case e: Exception => - logWarning(s"Exception while deleting Spark temp dir: $path", e) + + if (!isStopped) { + isStopped = true + pythonWorkers.values.foreach(_.stop()) + Option(httpFileServer).foreach(_.stop()) + mapOutputTracker.stop() + shuffleManager.stop() + broadcastManager.stop() + blockManager.stop() + blockManager.master.stop() + metricsSystem.stop() + outputCommitCoordinator.stop() + rpcEnv.shutdown() + + // Unfortunately Akka's awaitTermination doesn't actually wait for the Netty server to shut + // down, but let's call it anyway in case it gets fixed in a later release + // UPDATE: In Akka 2.1.x, this hangs if there are remote actors, so we can't call it. + // actorSystem.awaitTermination() + + // Note that blockTransferService is stopped by BlockManager since it is started by it. + + // If we only stop sc, but the driver process still run as a services then we need to delete + // the tmp dir, if not, it will create too many tmp dirs. + // We only need to delete the tmp dir create by driver, because sparkFilesDir is point to the + // current working dir in executor which we do not need to delete. + driverTmpDirToDelete match { + case Some(path) => { + try { + Utils.deleteRecursively(new File(path)) + } catch { + case e: Exception => + logWarning(s"Exception while deleting Spark temp dir: $path", e) + } } + case None => // We just need to delete tmp dir created by driver, so do nothing on executor } - case None => // We just need to delete tmp dir created by driver, so do nothing on executor } } @@ -168,7 +173,7 @@ object SparkEnv extends Logging { /** * Returns the ThreadLocal SparkEnv. */ - @deprecated("Use SparkEnv.get instead", "1.2") + @deprecated("Use SparkEnv.get instead", "1.2.0") def getThreadLocal: SparkEnv = { env } @@ -298,7 +303,7 @@ object SparkEnv extends Logging { } } - val mapOutputTracker = if (isDriver) { + val mapOutputTracker = if (isDriver) { new MapOutputTrackerMaster(conf) } else { new MapOutputTrackerWorker(conf) @@ -348,7 +353,7 @@ object SparkEnv extends Logging { val fileServerPort = conf.getInt("spark.fileserver.port", 0) val server = new HttpFileServer(conf, securityManager, fileServerPort) server.initialize() - conf.set("spark.fileserver.uri", server.serverUri) + conf.set("spark.fileserver.uri", server.serverUri) server } else { null diff --git a/core/src/main/scala/org/apache/spark/SparkHadoopWriter.scala b/core/src/main/scala/org/apache/spark/SparkHadoopWriter.scala index 2ec42d3aea16..f5dd36cbcfe6 100644 --- a/core/src/main/scala/org/apache/spark/SparkHadoopWriter.scala +++ b/core/src/main/scala/org/apache/spark/SparkHadoopWriter.scala @@ -28,6 +28,7 @@ import org.apache.hadoop.fs.Path import org.apache.spark.mapred.SparkHadoopMapRedUtil import org.apache.spark.rdd.HadoopRDD +import org.apache.spark.util.SerializableJobConf /** * Internal helper class that saves an RDD using a Hadoop OutputFormat. @@ -42,7 +43,7 @@ class SparkHadoopWriter(@transient jobConf: JobConf) with Serializable { private val now = new Date() - private val conf = new SerializableWritable(jobConf) + private val conf = new SerializableJobConf(jobConf) private var jobID = 0 private var splitID = 0 @@ -50,8 +51,8 @@ class SparkHadoopWriter(@transient jobConf: JobConf) private var jID: SerializableWritable[JobID] = null private var taID: SerializableWritable[TaskAttemptID] = null - @transient private var writer: RecordWriter[AnyRef,AnyRef] = null - @transient private var format: OutputFormat[AnyRef,AnyRef] = null + @transient private var writer: RecordWriter[AnyRef, AnyRef] = null + @transient private var format: OutputFormat[AnyRef, AnyRef] = null @transient private var committer: OutputCommitter = null @transient private var jobContext: JobContext = null @transient private var taskContext: TaskAttemptContext = null @@ -114,10 +115,10 @@ class SparkHadoopWriter(@transient jobConf: JobConf) // ********* Private Functions ********* - private def getOutputFormat(): OutputFormat[AnyRef,AnyRef] = { + private def getOutputFormat(): OutputFormat[AnyRef, AnyRef] = { if (format == null) { format = conf.value.getOutputFormat() - .asInstanceOf[OutputFormat[AnyRef,AnyRef]] + .asInstanceOf[OutputFormat[AnyRef, AnyRef]] } format } @@ -138,7 +139,7 @@ class SparkHadoopWriter(@transient jobConf: JobConf) private def getTaskContext(): TaskAttemptContext = { if (taskContext == null) { - taskContext = newTaskAttemptContext(conf.value, taID.value) + taskContext = newTaskAttemptContext(conf.value, taID.value) } taskContext } diff --git a/core/src/main/scala/org/apache/spark/TestUtils.scala b/core/src/main/scala/org/apache/spark/TestUtils.scala index 398ca41e1615..a1ebbecf93b7 100644 --- a/core/src/main/scala/org/apache/spark/TestUtils.scala +++ b/core/src/main/scala/org/apache/spark/TestUtils.scala @@ -51,7 +51,7 @@ private[spark] object TestUtils { classpathUrls: Seq[URL] = Seq()): URL = { val tempDir = Utils.createTempDir() val files1 = for (name <- classNames) yield { - createCompiledClass(name, tempDir, toStringValue, classpathUrls = classpathUrls) + createCompiledClass(name, tempDir, toStringValue, classpathUrls = classpathUrls) } val files2 = for ((childName, baseName) <- classNamesWithBase) yield { createCompiledClass(childName, tempDir, toStringValue, baseName, classpathUrls) @@ -105,23 +105,18 @@ private[spark] object TestUtils { URI.create(s"string:///${name.replace(".", "/")}${SOURCE.extension}") } - private class JavaSourceFromString(val name: String, val code: String) + private[spark] class JavaSourceFromString(val name: String, val code: String) extends SimpleJavaFileObject(createURI(name), SOURCE) { override def getCharContent(ignoreEncodingErrors: Boolean): String = code } - /** Creates a compiled class with the given name. Class file will be placed in destDir. */ + /** Creates a compiled class with the source file. Class file will be placed in destDir. */ def createCompiledClass( className: String, destDir: File, - toStringValue: String = "", - baseClass: String = null, - classpathUrls: Seq[URL] = Seq()): File = { + sourceFile: JavaSourceFromString, + classpathUrls: Seq[URL]): File = { val compiler = ToolProvider.getSystemJavaCompiler - val extendsText = Option(baseClass).map { c => s" extends ${c}" }.getOrElse("") - val sourceFile = new JavaSourceFromString(className, - "public class " + className + extendsText + " implements java.io.Serializable {" + - " @Override public String toString() { return \"" + toStringValue + "\"; }}") // Calling this outputs a class file in pwd. It's easier to just rename the file than // build a custom FileManager that controls the output location. @@ -144,4 +139,18 @@ private[spark] object TestUtils { assert(out.exists(), "Destination file not moved: " + out.getAbsolutePath()) out } + + /** Creates a compiled class with the given name. Class file will be placed in destDir. */ + def createCompiledClass( + className: String, + destDir: File, + toStringValue: String = "", + baseClass: String = null, + classpathUrls: Seq[URL] = Seq()): File = { + val extendsText = Option(baseClass).map { c => s" extends ${c}" }.getOrElse("") + val sourceFile = new JavaSourceFromString(className, + "public class " + className + extendsText + " implements java.io.Serializable {" + + " @Override public String toString() { return \"" + toStringValue + "\"; }}") + createCompiledClass(className, destDir, sourceFile, classpathUrls) + } } diff --git a/core/src/main/scala/org/apache/spark/api/java/JavaDoubleRDD.scala b/core/src/main/scala/org/apache/spark/api/java/JavaDoubleRDD.scala index 61af867b11b9..a650df605b92 100644 --- a/core/src/main/scala/org/apache/spark/api/java/JavaDoubleRDD.scala +++ b/core/src/main/scala/org/apache/spark/api/java/JavaDoubleRDD.scala @@ -137,7 +137,7 @@ class JavaDoubleRDD(val srdd: RDD[scala.Double]) */ def sample(withReplacement: Boolean, fraction: JDouble): JavaDoubleRDD = sample(withReplacement, fraction, Utils.random.nextLong) - + /** * Return a sampled subset of this RDD. */ diff --git a/core/src/main/scala/org/apache/spark/api/java/JavaRDD.scala b/core/src/main/scala/org/apache/spark/api/java/JavaRDD.scala index db4e996feb31..ed312770ee13 100644 --- a/core/src/main/scala/org/apache/spark/api/java/JavaRDD.scala +++ b/core/src/main/scala/org/apache/spark/api/java/JavaRDD.scala @@ -101,7 +101,7 @@ class JavaRDD[T](val rdd: RDD[T])(implicit val classTag: ClassTag[T]) /** * Return a sampled subset of this RDD. - * + * * @param withReplacement can elements be sampled multiple times (replaced when sampled out) * @param fraction expected size of the sample as a fraction of this RDD's size * without replacement: probability that each element is chosen; fraction must be [0, 1] @@ -109,10 +109,10 @@ class JavaRDD[T](val rdd: RDD[T])(implicit val classTag: ClassTag[T]) */ def sample(withReplacement: Boolean, fraction: Double): JavaRDD[T] = sample(withReplacement, fraction, Utils.random.nextLong) - + /** * Return a sampled subset of this RDD. - * + * * @param withReplacement can elements be sampled multiple times (replaced when sampled out) * @param fraction expected size of the sample as a fraction of this RDD's size * without replacement: probability that each element is chosen; fraction must be [0, 1] diff --git a/core/src/main/scala/org/apache/spark/api/java/JavaRDDLike.scala b/core/src/main/scala/org/apache/spark/api/java/JavaRDDLike.scala index 8bf0627fc420..c95615a5a930 100644 --- a/core/src/main/scala/org/apache/spark/api/java/JavaRDDLike.scala +++ b/core/src/main/scala/org/apache/spark/api/java/JavaRDDLike.scala @@ -60,10 +60,13 @@ trait JavaRDDLike[T, This <: JavaRDDLike[T, This]] extends Serializable { @deprecated("Use partitions() instead.", "1.1.0") def splits: JList[Partition] = new java.util.ArrayList(rdd.partitions.toSeq) - + /** Set of partitions in this RDD. */ def partitions: JList[Partition] = new java.util.ArrayList(rdd.partitions.toSeq) + /** The partitioner of this RDD. */ + def partitioner: Optional[Partitioner] = JavaUtils.optionToOptional(rdd.partitioner) + /** The [[org.apache.spark.SparkContext]] that this RDD was created on. */ def context: SparkContext = rdd.context @@ -96,7 +99,7 @@ trait JavaRDDLike[T, This <: JavaRDDLike[T, This]] extends Serializable { def mapPartitionsWithIndex[R]( f: JFunction2[jl.Integer, java.util.Iterator[T], java.util.Iterator[R]], preservesPartitioning: Boolean = false): JavaRDD[R] = - new JavaRDD(rdd.mapPartitionsWithIndex(((a,b) => f(a,asJavaIterator(b))), + new JavaRDD(rdd.mapPartitionsWithIndex(((a, b) => f(a, asJavaIterator(b))), preservesPartitioning)(fakeClassTag))(fakeClassTag) /** @@ -386,9 +389,16 @@ trait JavaRDDLike[T, This <: JavaRDDLike[T, This]] extends Serializable { /** * Aggregate the elements of each partition, and then the results for all the partitions, using a - * given associative function and a neutral "zero value". The function op(t1, t2) is allowed to - * modify t1 and return it as its result value to avoid object allocation; however, it should not - * modify t2. + * given associative and commutative function and a neutral "zero value". The function + * op(t1, t2) is allowed to modify t1 and return it as its result value to avoid object + * allocation; however, it should not modify t2. + * + * This behaves somewhat differently from fold operations implemented for non-distributed + * collections in functional languages like Scala. This fold operation may be applied to + * partitions individually, and then fold those results into the final result, rather than + * apply the fold to each element sequentially in some defined ordering. For functions + * that are not commutative, the result may differ from that of a fold applied to a + * non-distributed collection. */ def fold(zeroValue: T)(f: JFunction2[T, T, T]): T = rdd.fold(zeroValue)(f) @@ -485,9 +495,9 @@ trait JavaRDDLike[T, This <: JavaRDDLike[T, This]] extends Serializable { new java.util.ArrayList(arr) } - def takeSample(withReplacement: Boolean, num: Int): JList[T] = + def takeSample(withReplacement: Boolean, num: Int): JList[T] = takeSample(withReplacement, num, Utils.random.nextLong) - + def takeSample(withReplacement: Boolean, num: Int, seed: Long): JList[T] = { import scala.collection.JavaConversions._ val arr: java.util.Collection[T] = rdd.takeSample(withReplacement, num, seed).toSeq diff --git a/core/src/main/scala/org/apache/spark/api/python/PythonHadoopUtil.scala b/core/src/main/scala/org/apache/spark/api/python/PythonHadoopUtil.scala index c9181a29d475..b959b683d167 100644 --- a/core/src/main/scala/org/apache/spark/api/python/PythonHadoopUtil.scala +++ b/core/src/main/scala/org/apache/spark/api/python/PythonHadoopUtil.scala @@ -19,8 +19,8 @@ package org.apache.spark.api.python import org.apache.spark.broadcast.Broadcast import org.apache.spark.rdd.RDD -import org.apache.spark.util.Utils -import org.apache.spark.{Logging, SerializableWritable, SparkException} +import org.apache.spark.util.{SerializableConfiguration, Utils} +import org.apache.spark.{Logging, SparkException} import org.apache.hadoop.conf.Configuration import org.apache.hadoop.io._ import scala.util.{Failure, Success, Try} @@ -61,7 +61,7 @@ private[python] object Converter extends Logging { * Other objects are passed through without conversion. */ private[python] class WritableToJavaConverter( - conf: Broadcast[SerializableWritable[Configuration]]) extends Converter[Any, Any] { + conf: Broadcast[SerializableConfiguration]) extends Converter[Any, Any] { /** * Converts a [[org.apache.hadoop.io.Writable]] to the underlying primitive, String or diff --git a/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala b/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala index 7409dc2d866f..dc9f62f39e6d 100644 --- a/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala +++ b/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala @@ -36,7 +36,7 @@ import org.apache.spark.api.java.{JavaPairRDD, JavaRDD, JavaSparkContext} import org.apache.spark.broadcast.Broadcast import org.apache.spark.input.PortableDataStream import org.apache.spark.rdd.RDD -import org.apache.spark.util.Utils +import org.apache.spark.util.{SerializableConfiguration, Utils} import scala.util.control.NonFatal @@ -47,6 +47,7 @@ private[spark] class PythonRDD( pythonIncludes: JList[String], preservePartitoning: Boolean, pythonExec: String, + pythonVer: String, broadcastVars: JList[Broadcast[PythonBroadcast]], accumulator: Accumulator[JList[Array[Byte]]]) extends RDD[Array[Byte]](parent) { @@ -210,6 +211,8 @@ private[spark] class PythonRDD( val dataOut = new DataOutputStream(stream) // Partition index dataOut.writeInt(split.index) + // Python version of driver + PythonRDD.writeUTF(pythonVer, dataOut) // sparkFilesDir PythonRDD.writeUTF(SparkFiles.getRootDirectory, dataOut) // Python includes (*.zip and *.egg files) @@ -442,7 +445,7 @@ private[spark] object PythonRDD extends Logging { val kc = Utils.classForName(keyClass).asInstanceOf[Class[K]] val vc = Utils.classForName(valueClass).asInstanceOf[Class[V]] val rdd = sc.sc.sequenceFile[K, V](path, kc, vc, minSplits) - val confBroadcasted = sc.sc.broadcast(new SerializableWritable(sc.hadoopConfiguration())) + val confBroadcasted = sc.sc.broadcast(new SerializableConfiguration(sc.hadoopConfiguration())) val converted = convertRDD(rdd, keyConverterClass, valueConverterClass, new WritableToJavaConverter(confBroadcasted)) JavaRDD.fromRDD(SerDeUtil.pairRDDToPython(converted, batchSize)) @@ -468,7 +471,7 @@ private[spark] object PythonRDD extends Logging { val rdd = newAPIHadoopRDDFromClassNames[K, V, F](sc, Some(path), inputFormatClass, keyClass, valueClass, mergedConf) - val confBroadcasted = sc.sc.broadcast(new SerializableWritable(mergedConf)) + val confBroadcasted = sc.sc.broadcast(new SerializableConfiguration(mergedConf)) val converted = convertRDD(rdd, keyConverterClass, valueConverterClass, new WritableToJavaConverter(confBroadcasted)) JavaRDD.fromRDD(SerDeUtil.pairRDDToPython(converted, batchSize)) @@ -494,7 +497,7 @@ private[spark] object PythonRDD extends Logging { val rdd = newAPIHadoopRDDFromClassNames[K, V, F](sc, None, inputFormatClass, keyClass, valueClass, conf) - val confBroadcasted = sc.sc.broadcast(new SerializableWritable(conf)) + val confBroadcasted = sc.sc.broadcast(new SerializableConfiguration(conf)) val converted = convertRDD(rdd, keyConverterClass, valueConverterClass, new WritableToJavaConverter(confBroadcasted)) JavaRDD.fromRDD(SerDeUtil.pairRDDToPython(converted, batchSize)) @@ -537,7 +540,7 @@ private[spark] object PythonRDD extends Logging { val rdd = hadoopRDDFromClassNames[K, V, F](sc, Some(path), inputFormatClass, keyClass, valueClass, mergedConf) - val confBroadcasted = sc.sc.broadcast(new SerializableWritable(mergedConf)) + val confBroadcasted = sc.sc.broadcast(new SerializableConfiguration(mergedConf)) val converted = convertRDD(rdd, keyConverterClass, valueConverterClass, new WritableToJavaConverter(confBroadcasted)) JavaRDD.fromRDD(SerDeUtil.pairRDDToPython(converted, batchSize)) @@ -563,7 +566,7 @@ private[spark] object PythonRDD extends Logging { val rdd = hadoopRDDFromClassNames[K, V, F](sc, None, inputFormatClass, keyClass, valueClass, conf) - val confBroadcasted = sc.sc.broadcast(new SerializableWritable(conf)) + val confBroadcasted = sc.sc.broadcast(new SerializableConfiguration(conf)) val converted = convertRDD(rdd, keyConverterClass, valueConverterClass, new WritableToJavaConverter(confBroadcasted)) JavaRDD.fromRDD(SerDeUtil.pairRDDToPython(converted, batchSize)) @@ -720,7 +723,7 @@ private[spark] object PythonRDD extends Logging { val converted = convertRDD(rdd, keyConverterClass, valueConverterClass, new JavaToWritableConverter) val fc = Utils.classForName(outputFormatClass).asInstanceOf[Class[F]] - converted.saveAsHadoopFile(path, kc, vc, fc, new JobConf(mergedConf), codec=codec) + converted.saveAsHadoopFile(path, kc, vc, fc, new JobConf(mergedConf), codec = codec) } /** @@ -794,10 +797,10 @@ private class PythonAccumulatorParam(@transient serverHost: String, serverPort: val bufferSize = SparkEnv.get.conf.getInt("spark.buffer.size", 65536) - /** + /** * We try to reuse a single Socket to transfer accumulator updates, as they are all added * by the DAGScheduler's single-threaded actor anyway. - */ + */ @transient var socket: Socket = _ def openSocket(): Socket = synchronized { @@ -840,6 +843,7 @@ private class PythonAccumulatorParam(@transient serverHost: String, serverPort: * An Wrapper for Python Broadcast, which is written into disk by Python. It also will * write the data into disk after deserialization, then Python can read it from disks. */ +// scalastyle:off no.finalize private[spark] class PythonBroadcast(@transient var path: String) extends Serializable { /** @@ -881,3 +885,4 @@ private[spark] class PythonBroadcast(@transient var path: String) extends Serial } } } +// scalastyle:on no.finalize diff --git a/core/src/main/scala/org/apache/spark/api/python/PythonUtils.scala b/core/src/main/scala/org/apache/spark/api/python/PythonUtils.scala index efb6b93cfc35..90dacaeb9342 100644 --- a/core/src/main/scala/org/apache/spark/api/python/PythonUtils.scala +++ b/core/src/main/scala/org/apache/spark/api/python/PythonUtils.scala @@ -50,8 +50,15 @@ private[spark] object PythonUtils { /** * Convert list of T into seq of T (for calling API with varargs) */ - def toSeq[T](cols: JList[T]): Seq[T] = { - cols.toList.toSeq + def toSeq[T](vs: JList[T]): Seq[T] = { + vs.toList.toSeq + } + + /** + * Convert list of T into array of T (for calling API with array) + */ + def toArray[T](vs: JList[T]): Array[T] = { + vs.toArray().asInstanceOf[Array[T]] } /** diff --git a/core/src/main/scala/org/apache/spark/api/r/RBackend.scala b/core/src/main/scala/org/apache/spark/api/r/RBackend.scala index 0a91977928ce..1a5f2bca26c2 100644 --- a/core/src/main/scala/org/apache/spark/api/r/RBackend.scala +++ b/core/src/main/scala/org/apache/spark/api/r/RBackend.scala @@ -29,7 +29,7 @@ import io.netty.channel.socket.nio.NioServerSocketChannel import io.netty.handler.codec.LengthFieldBasedFrameDecoder import io.netty.handler.codec.bytes.{ByteArrayDecoder, ByteArrayEncoder} -import org.apache.spark.Logging +import org.apache.spark.{Logging, SparkConf} /** * Netty-based backend server that is used to communicate between R and Java. @@ -41,14 +41,15 @@ private[spark] class RBackend { private[this] var bossGroup: EventLoopGroup = null def init(): Int = { - bossGroup = new NioEventLoopGroup(2) + val conf = new SparkConf() + bossGroup = new NioEventLoopGroup(conf.getInt("spark.r.numRBackendThreads", 2)) val workerGroup = bossGroup val handler = new RBackendHandler(this) - + bootstrap = new ServerBootstrap() .group(bossGroup, workerGroup) .channel(classOf[NioServerSocketChannel]) - + bootstrap.childHandler(new ChannelInitializer[SocketChannel]() { def initChannel(ch: SocketChannel): Unit = { ch.pipeline() diff --git a/core/src/main/scala/org/apache/spark/api/r/RBackendHandler.scala b/core/src/main/scala/org/apache/spark/api/r/RBackendHandler.scala index 0075d963711f..4b8f7fe9242e 100644 --- a/core/src/main/scala/org/apache/spark/api/r/RBackendHandler.scala +++ b/core/src/main/scala/org/apache/spark/api/r/RBackendHandler.scala @@ -77,7 +77,7 @@ private[r] class RBackendHandler(server: RBackend) val reply = bos.toByteArray ctx.write(reply) } - + override def channelReadComplete(ctx: ChannelHandlerContext): Unit = { ctx.flush() } @@ -88,6 +88,21 @@ private[r] class RBackendHandler(server: RBackend) ctx.close() } + // Looks up a class given a class name. This function first checks the + // current class loader and if a class is not found, it looks up the class + // in the context class loader. Address [SPARK-5185] + def getStaticClass(objId: String): Class[_] = { + try { + val clsCurrent = Class.forName(objId) + clsCurrent + } catch { + // Use contextLoader if we can't find the JAR in the system class loader + case e: ClassNotFoundException => + val clsContext = Class.forName(objId, true, Thread.currentThread().getContextClassLoader) + clsContext + } + } + def handleMethodCall( isStatic: Boolean, objId: String, @@ -98,7 +113,7 @@ private[r] class RBackendHandler(server: RBackend) var obj: Object = null try { val cls = if (isStatic) { - Class.forName(objId) + getStaticClass(objId) } else { JVMObjectTracker.get(objId) match { case None => throw new IllegalArgumentException("Object not found " + objId) @@ -124,7 +139,7 @@ private[r] class RBackendHandler(server: RBackend) } throw new Exception(s"No matched method found for $cls.$methodName") } - val ret = methods.head.invoke(obj, args:_*) + val ret = methods.head.invoke(obj, args : _*) // Write status bit writeInt(dos, 0) @@ -135,7 +150,7 @@ private[r] class RBackendHandler(server: RBackend) matchMethod(numArgs, args, x.getParameterTypes) }.head - val obj = ctor.newInstance(args:_*) + val obj = ctor.newInstance(args : _*) writeInt(dos, 0) writeObject(dos, obj.asInstanceOf[AnyRef]) diff --git a/core/src/main/scala/org/apache/spark/api/r/RRDD.scala b/core/src/main/scala/org/apache/spark/api/r/RRDD.scala index 06247f7e8b78..524676544d6f 100644 --- a/core/src/main/scala/org/apache/spark/api/r/RRDD.scala +++ b/core/src/main/scala/org/apache/spark/api/r/RRDD.scala @@ -309,7 +309,7 @@ private class StringRRDD[T: ClassTag]( } private object SpecialLengths { - val TIMING_DATA = -1 + val TIMING_DATA = -1 } private[r] class BufferedStreamThread( @@ -355,7 +355,6 @@ private[r] object RRDD { val sparkConf = new SparkConf().setAppName(appName) .setSparkHome(sparkHome) - .setJars(jars) // Override `master` if we have a user-specified value if (master != "") { @@ -373,7 +372,11 @@ private[r] object RRDD { sparkConf.setExecutorEnv(name.asInstanceOf[String], value.asInstanceOf[String]) } - new JavaSparkContext(sparkConf) + val jsc = new JavaSparkContext(sparkConf) + jars.foreach { jar => + jsc.addJar(jar) + } + jsc } /** @@ -388,7 +391,7 @@ private[r] object RRDD { } private def createRProcess(rLibDir: String, port: Int, script: String): BufferedStreamThread = { - val rCommand = "Rscript" + val rCommand = SparkEnv.get.conf.get("spark.sparkr.r.command", "Rscript") val rOptions = "--vanilla" val rExecScript = rLibDir + "/SparkR/worker/" + script val pb = new ProcessBuilder(List(rCommand, rOptions, rExecScript)) diff --git a/core/src/main/scala/org/apache/spark/api/r/SerDe.scala b/core/src/main/scala/org/apache/spark/api/r/SerDe.scala index 371dfe454d1a..56adc857d4ce 100644 --- a/core/src/main/scala/org/apache/spark/api/r/SerDe.scala +++ b/core/src/main/scala/org/apache/spark/api/r/SerDe.scala @@ -18,7 +18,7 @@ package org.apache.spark.api.r import java.io.{DataInputStream, DataOutputStream} -import java.sql.{Date, Time} +import java.sql.{Timestamp, Date, Time} import scala.collection.JavaConversions._ @@ -107,9 +107,12 @@ private[spark] object SerDe { Date.valueOf(readString(in)) } - def readTime(in: DataInputStream): Time = { - val t = in.readDouble() - new Time((t * 1000L).toLong) + def readTime(in: DataInputStream): Timestamp = { + val seconds = in.readDouble() + val sec = Math.floor(seconds).toLong + val t = new Timestamp(sec * 1000L) + t.setNanos(((seconds - sec) * 1e9).toInt) + t } def readBytesArr(in: DataInputStream): Array[Array[Byte]] = { @@ -157,9 +160,11 @@ private[spark] object SerDe { val keysLen = readInt(in) val keys = (0 until keysLen).map(_ => readTypedObject(in, keysType)) - val valuesType = readObjectType(in) val valuesLen = readInt(in) - val values = (0 until valuesLen).map(_ => readTypedObject(in, valuesType)) + val values = (0 until valuesLen).map(_ => { + val valueType = readObjectType(in) + readTypedObject(in, valueType) + }) mapAsJavaMap(keys.zip(values).toMap) } else { new java.util.HashMap[Object, Object]() @@ -225,6 +230,9 @@ private[spark] object SerDe { case "java.sql.Time" => writeType(dos, "time") writeTime(dos, value.asInstanceOf[Time]) + case "java.sql.Timestamp" => + writeType(dos, "time") + writeTime(dos, value.asInstanceOf[Timestamp]) case "[B" => writeType(dos, "raw") writeBytes(dos, value.asInstanceOf[Array[Byte]]) @@ -287,6 +295,9 @@ private[spark] object SerDe { out.writeDouble(value.getTime.toDouble / 1000.0) } + def writeTime(out: DataOutputStream, value: Timestamp): Unit = { + out.writeDouble((value.getTime / 1000).toDouble + value.getNanos.toDouble / 1e9) + } // NOTE: Only works for ASCII right now def writeString(out: DataOutputStream, value: String): Unit = { diff --git a/core/src/main/scala/org/apache/spark/broadcast/HttpBroadcast.scala b/core/src/main/scala/org/apache/spark/broadcast/HttpBroadcast.scala index 4457c75e8b0f..b69af639f786 100644 --- a/core/src/main/scala/org/apache/spark/broadcast/HttpBroadcast.scala +++ b/core/src/main/scala/org/apache/spark/broadcast/HttpBroadcast.scala @@ -125,7 +125,7 @@ private[broadcast] object HttpBroadcast extends Logging { securityManager = securityMgr if (isDriver) { createServer(conf) - conf.set("spark.httpBroadcast.uri", serverUri) + conf.set("spark.httpBroadcast.uri", serverUri) } serverUri = conf.get("spark.httpBroadcast.uri") cleaner = new MetadataCleaner(MetadataCleanerType.HTTP_BROADCAST, cleanup, conf) @@ -187,7 +187,7 @@ private[broadcast] object HttpBroadcast extends Logging { } private def read[T: ClassTag](id: Long): T = { - logDebug("broadcast read server: " + serverUri + " id: broadcast-" + id) + logDebug("broadcast read server: " + serverUri + " id: broadcast-" + id) val url = serverUri + "/" + BroadcastBlockId(id).name var uc: URLConnection = null diff --git a/core/src/main/scala/org/apache/spark/deploy/Client.scala b/core/src/main/scala/org/apache/spark/deploy/Client.scala index 848b62f9de71..71f7e2129116 100644 --- a/core/src/main/scala/org/apache/spark/deploy/Client.scala +++ b/core/src/main/scala/org/apache/spark/deploy/Client.scala @@ -18,17 +18,17 @@ package org.apache.spark.deploy import scala.collection.mutable.HashSet -import scala.concurrent._ +import scala.concurrent.ExecutionContext +import scala.reflect.ClassTag +import scala.util.{Failure, Success} -import akka.actor._ -import akka.pattern.ask -import akka.remote.{AssociationErrorEvent, DisassociatedEvent, RemotingLifecycleEvent} import org.apache.log4j.{Level, Logger} +import org.apache.spark.rpc.{RpcEndpointRef, RpcAddress, RpcEnv, ThreadSafeRpcEndpoint} import org.apache.spark.{Logging, SecurityManager, SparkConf} import org.apache.spark.deploy.DeployMessages._ import org.apache.spark.deploy.master.{DriverState, Master} -import org.apache.spark.util.{ActorLogReceive, AkkaUtils, RpcUtils, Utils} +import org.apache.spark.util.{ThreadUtils, SparkExitCode, Utils} /** * Proxy that relays messages to the driver. @@ -36,20 +36,30 @@ import org.apache.spark.util.{ActorLogReceive, AkkaUtils, RpcUtils, Utils} * We currently don't support retry if submission fails. In HA mode, client will submit request to * all masters and see which one could handle it. */ -private class ClientActor(driverArgs: ClientArguments, conf: SparkConf) - extends Actor with ActorLogReceive with Logging { - - private val masterActors = driverArgs.masters.map { m => - context.actorSelection(Master.toAkkaUrl(m, AkkaUtils.protocol(context.system))) - } - private val lostMasters = new HashSet[Address] - private var activeMasterActor: ActorSelection = null - - val timeout = RpcUtils.askTimeout(conf) - - override def preStart(): Unit = { - context.system.eventStream.subscribe(self, classOf[RemotingLifecycleEvent]) - +private class ClientEndpoint( + override val rpcEnv: RpcEnv, + driverArgs: ClientArguments, + masterEndpoints: Seq[RpcEndpointRef], + conf: SparkConf) + extends ThreadSafeRpcEndpoint with Logging { + + // A scheduled executor used to send messages at the specified time. + private val forwardMessageThread = + ThreadUtils.newDaemonSingleThreadScheduledExecutor("client-forward-message") + // Used to provide the implicit parameter of `Future` methods. + private val forwardMessageExecutionContext = + ExecutionContext.fromExecutor(forwardMessageThread, + t => t match { + case ie: InterruptedException => // Exit normally + case e: Throwable => + logError(e.getMessage, e) + System.exit(SparkExitCode.UNCAUGHT_EXCEPTION) + }) + + private val lostMasters = new HashSet[RpcAddress] + private var activeMasterEndpoint: RpcEndpointRef = null + + override def onStart(): Unit = { driverArgs.cmd match { case "launch" => // TODO: We could add an env variable here and intercept it in `sc.addJar` that would @@ -82,29 +92,37 @@ private class ClientActor(driverArgs: ClientArguments, conf: SparkConf) driverArgs.cores, driverArgs.supervise, command) - - // This assumes only one Master is active at a time - for (masterActor <- masterActors) { - masterActor ! RequestSubmitDriver(driverDescription) - } + ayncSendToMasterAndForwardReply[SubmitDriverResponse]( + RequestSubmitDriver(driverDescription)) case "kill" => val driverId = driverArgs.driverId - // This assumes only one Master is active at a time - for (masterActor <- masterActors) { - masterActor ! RequestKillDriver(driverId) - } + ayncSendToMasterAndForwardReply[KillDriverResponse](RequestKillDriver(driverId)) + } + } + + /** + * Send the message to master and forward the reply to self asynchronously. + */ + private def ayncSendToMasterAndForwardReply[T: ClassTag](message: Any): Unit = { + for (masterEndpoint <- masterEndpoints) { + masterEndpoint.ask[T](message).onComplete { + case Success(v) => self.send(v) + case Failure(e) => + logWarning(s"Error sending messages to master $masterEndpoint", e) + }(forwardMessageExecutionContext) } } /* Find out driver status then exit the JVM */ def pollAndReportStatus(driverId: String) { + // Since ClientEndpoint is the only RpcEndpoint in the process, blocking the event loop thread + // is fine. println("... waiting before polling master for driver state") Thread.sleep(5000) println("... polling master for driver state") - val statusFuture = (activeMasterActor ? RequestDriverStatus(driverId))(timeout) - .mapTo[DriverStatusResponse] - val statusResponse = Await.result(statusFuture, timeout) + val statusResponse = + activeMasterEndpoint.askWithRetry[DriverStatusResponse](RequestDriverStatus(driverId)) statusResponse.found match { case false => println(s"ERROR: Cluster master did not recognize $driverId") @@ -127,50 +145,62 @@ private class ClientActor(driverArgs: ClientArguments, conf: SparkConf) } } - override def receiveWithLogging: PartialFunction[Any, Unit] = { + override def receive: PartialFunction[Any, Unit] = { - case SubmitDriverResponse(success, driverId, message) => + case SubmitDriverResponse(master, success, driverId, message) => println(message) if (success) { - activeMasterActor = context.actorSelection(sender.path) + activeMasterEndpoint = master pollAndReportStatus(driverId.get) } else if (!Utils.responseFromBackup(message)) { System.exit(-1) } - case KillDriverResponse(driverId, success, message) => + case KillDriverResponse(master, driverId, success, message) => println(message) if (success) { - activeMasterActor = context.actorSelection(sender.path) + activeMasterEndpoint = master pollAndReportStatus(driverId) } else if (!Utils.responseFromBackup(message)) { System.exit(-1) } + } - case DisassociatedEvent(_, remoteAddress, _) => - if (!lostMasters.contains(remoteAddress)) { - println(s"Error connecting to master $remoteAddress.") - lostMasters += remoteAddress - // Note that this heuristic does not account for the fact that a Master can recover within - // the lifetime of this client. Thus, once a Master is lost it is lost to us forever. This - // is not currently a concern, however, because this client does not retry submissions. - if (lostMasters.size >= masterActors.size) { - println("No master is available, exiting.") - System.exit(-1) - } + override def onDisconnected(remoteAddress: RpcAddress): Unit = { + if (!lostMasters.contains(remoteAddress)) { + println(s"Error connecting to master $remoteAddress.") + lostMasters += remoteAddress + // Note that this heuristic does not account for the fact that a Master can recover within + // the lifetime of this client. Thus, once a Master is lost it is lost to us forever. This + // is not currently a concern, however, because this client does not retry submissions. + if (lostMasters.size >= masterEndpoints.size) { + println("No master is available, exiting.") + System.exit(-1) } + } + } - case AssociationErrorEvent(cause, _, remoteAddress, _, _) => - if (!lostMasters.contains(remoteAddress)) { - println(s"Error connecting to master ($remoteAddress).") - println(s"Cause was: $cause") - lostMasters += remoteAddress - if (lostMasters.size >= masterActors.size) { - println("No master is available, exiting.") - System.exit(-1) - } + override def onNetworkError(cause: Throwable, remoteAddress: RpcAddress): Unit = { + if (!lostMasters.contains(remoteAddress)) { + println(s"Error connecting to master ($remoteAddress).") + println(s"Cause was: $cause") + lostMasters += remoteAddress + if (lostMasters.size >= masterEndpoints.size) { + println("No master is available, exiting.") + System.exit(-1) } + } + } + + override def onError(cause: Throwable): Unit = { + println(s"Error processing messages, exiting.") + cause.printStackTrace() + System.exit(-1) + } + + override def onStop(): Unit = { + forwardMessageThread.shutdownNow() } } @@ -194,15 +224,13 @@ object Client { conf.set("akka.loglevel", driverArgs.logLevel.toString.replace("WARN", "WARNING")) Logger.getRootLogger.setLevel(driverArgs.logLevel) - val (actorSystem, _) = AkkaUtils.createActorSystem( - "driverClient", Utils.localHostName(), 0, conf, new SecurityManager(conf)) + val rpcEnv = + RpcEnv.create("driverClient", Utils.localHostName(), 0, conf, new SecurityManager(conf)) - // Verify driverArgs.master is a valid url so that we can use it in ClientActor safely - for (m <- driverArgs.masters) { - Master.toAkkaUrl(m, AkkaUtils.protocol(actorSystem)) - } - actorSystem.actorOf(Props(classOf[ClientActor], driverArgs, conf)) + val masterEndpoints = driverArgs.masters.map(RpcAddress.fromSparkURL). + map(rpcEnv.setupEndpointRef(Master.SYSTEM_NAME, _, Master.ENDPOINT_NAME)) + rpcEnv.setupEndpoint("client", new ClientEndpoint(rpcEnv, driverArgs, masterEndpoints, conf)) - actorSystem.awaitTermination() + rpcEnv.awaitTermination() } } diff --git a/core/src/main/scala/org/apache/spark/deploy/ClientArguments.scala b/core/src/main/scala/org/apache/spark/deploy/ClientArguments.scala index 316e2d59f01b..42d3296062e6 100644 --- a/core/src/main/scala/org/apache/spark/deploy/ClientArguments.scala +++ b/core/src/main/scala/org/apache/spark/deploy/ClientArguments.scala @@ -117,7 +117,7 @@ private[deploy] class ClientArguments(args: Array[String]) { private[deploy] object ClientArguments { val DEFAULT_CORES = 1 - val DEFAULT_MEMORY = 512 // MB + val DEFAULT_MEMORY = Utils.DEFAULT_DRIVER_MEM_MB // MB val DEFAULT_SUPERVISE = false def isValidJarUrl(s: String): Boolean = { diff --git a/core/src/main/scala/org/apache/spark/deploy/DeployMessage.scala b/core/src/main/scala/org/apache/spark/deploy/DeployMessage.scala index 9db6fd1ac4db..12727de9b4cf 100644 --- a/core/src/main/scala/org/apache/spark/deploy/DeployMessage.scala +++ b/core/src/main/scala/org/apache/spark/deploy/DeployMessage.scala @@ -24,11 +24,12 @@ import org.apache.spark.deploy.master.{ApplicationInfo, DriverInfo, WorkerInfo} import org.apache.spark.deploy.master.DriverState.DriverState import org.apache.spark.deploy.master.RecoveryState.MasterState import org.apache.spark.deploy.worker.{DriverRunner, ExecutorRunner} +import org.apache.spark.rpc.RpcEndpointRef import org.apache.spark.util.Utils private[deploy] sealed trait DeployMessage extends Serializable -/** Contains messages sent between Scheduler actor nodes. */ +/** Contains messages sent between Scheduler endpoint nodes. */ private[deploy] object DeployMessages { // Worker to Master @@ -37,6 +38,7 @@ private[deploy] object DeployMessages { id: String, host: String, port: Int, + worker: RpcEndpointRef, cores: Int, memory: Int, webUiPort: Int, @@ -63,11 +65,11 @@ private[deploy] object DeployMessages { case class WorkerSchedulerStateResponse(id: String, executors: List[ExecutorDescription], driverIds: Seq[String]) - case class Heartbeat(workerId: String) extends DeployMessage + case class Heartbeat(workerId: String, worker: RpcEndpointRef) extends DeployMessage // Master to Worker - case class RegisteredWorker(masterUrl: String, masterWebUiUrl: String) extends DeployMessage + case class RegisteredWorker(master: RpcEndpointRef, masterWebUiUrl: String) extends DeployMessage case class RegisterWorkerFailed(message: String) extends DeployMessage @@ -92,13 +94,13 @@ private[deploy] object DeployMessages { // Worker internal - case object WorkDirCleanup // Sent to Worker actor periodically for cleaning up app folders + case object WorkDirCleanup // Sent to Worker endpoint periodically for cleaning up app folders case object ReregisterWithMaster // used when a worker attempts to reconnect to a master // AppClient to Master - case class RegisterApplication(appDescription: ApplicationDescription) + case class RegisterApplication(appDescription: ApplicationDescription, driver: RpcEndpointRef) extends DeployMessage case class UnregisterApplication(appId: String) @@ -107,7 +109,7 @@ private[deploy] object DeployMessages { // Master to AppClient - case class RegisteredApplication(appId: String, masterUrl: String) extends DeployMessage + case class RegisteredApplication(appId: String, master: RpcEndpointRef) extends DeployMessage // TODO(matei): replace hostPort with host case class ExecutorAdded(id: Int, workerId: String, hostPort: String, cores: Int, memory: Int) { @@ -123,12 +125,14 @@ private[deploy] object DeployMessages { case class RequestSubmitDriver(driverDescription: DriverDescription) extends DeployMessage - case class SubmitDriverResponse(success: Boolean, driverId: Option[String], message: String) + case class SubmitDriverResponse( + master: RpcEndpointRef, success: Boolean, driverId: Option[String], message: String) extends DeployMessage case class RequestKillDriver(driverId: String) extends DeployMessage - case class KillDriverResponse(driverId: String, success: Boolean, message: String) + case class KillDriverResponse( + master: RpcEndpointRef, driverId: String, success: Boolean, message: String) extends DeployMessage case class RequestDriverStatus(driverId: String) extends DeployMessage @@ -142,7 +146,7 @@ private[deploy] object DeployMessages { // Master to Worker & AppClient - case class MasterChanged(masterUrl: String, masterWebUiUrl: String) + case class MasterChanged(master: RpcEndpointRef, masterWebUiUrl: String) // MasterWebUI To Master diff --git a/core/src/main/scala/org/apache/spark/deploy/FaultToleranceTest.scala b/core/src/main/scala/org/apache/spark/deploy/FaultToleranceTest.scala index c048b78910f3..b4edb6109e83 100644 --- a/core/src/main/scala/org/apache/spark/deploy/FaultToleranceTest.scala +++ b/core/src/main/scala/org/apache/spark/deploy/FaultToleranceTest.scala @@ -65,7 +65,7 @@ private object FaultToleranceTest extends App with Logging { private val workers = ListBuffer[TestWorkerInfo]() private var sc: SparkContext = _ - private val zk = SparkCuratorUtil.newClient(conf) + private val zk = SparkCuratorUtil.newClient(conf) private var numPassed = 0 private var numFailed = 0 diff --git a/core/src/main/scala/org/apache/spark/deploy/JsonProtocol.scala b/core/src/main/scala/org/apache/spark/deploy/JsonProtocol.scala index 2954f932b4f4..ccffb3665298 100644 --- a/core/src/main/scala/org/apache/spark/deploy/JsonProtocol.scala +++ b/core/src/main/scala/org/apache/spark/deploy/JsonProtocol.scala @@ -76,12 +76,13 @@ private[deploy] object JsonProtocol { } def writeMasterState(obj: MasterStateResponse): JObject = { + val aliveWorkers = obj.workers.filter(_.isAlive()) ("url" -> obj.uri) ~ ("workers" -> obj.workers.toList.map(writeWorkerInfo)) ~ - ("cores" -> obj.workers.map(_.cores).sum) ~ - ("coresused" -> obj.workers.map(_.coresUsed).sum) ~ - ("memory" -> obj.workers.map(_.memory).sum) ~ - ("memoryused" -> obj.workers.map(_.memoryUsed).sum) ~ + ("cores" -> aliveWorkers.map(_.cores).sum) ~ + ("coresused" -> aliveWorkers.map(_.coresUsed).sum) ~ + ("memory" -> aliveWorkers.map(_.memory).sum) ~ + ("memoryused" -> aliveWorkers.map(_.memoryUsed).sum) ~ ("activeapps" -> obj.activeApps.toList.map(writeApplicationInfo)) ~ ("completedapps" -> obj.completedApps.toList.map(writeApplicationInfo)) ~ ("activedrivers" -> obj.activeDrivers.toList.map(writeDriverInfo)) ~ diff --git a/core/src/main/scala/org/apache/spark/deploy/LocalSparkCluster.scala b/core/src/main/scala/org/apache/spark/deploy/LocalSparkCluster.scala index 860e1a24901b..53356addf6ed 100644 --- a/core/src/main/scala/org/apache/spark/deploy/LocalSparkCluster.scala +++ b/core/src/main/scala/org/apache/spark/deploy/LocalSparkCluster.scala @@ -19,8 +19,7 @@ package org.apache.spark.deploy import scala.collection.mutable.ArrayBuffer -import akka.actor.ActorSystem - +import org.apache.spark.rpc.RpcEnv import org.apache.spark.{Logging, SparkConf} import org.apache.spark.deploy.worker.Worker import org.apache.spark.deploy.master.Master @@ -41,8 +40,10 @@ class LocalSparkCluster( extends Logging { private val localHostname = Utils.localHostName() - private val masterActorSystems = ArrayBuffer[ActorSystem]() - private val workerActorSystems = ArrayBuffer[ActorSystem]() + private val masterRpcEnvs = ArrayBuffer[RpcEnv]() + private val workerRpcEnvs = ArrayBuffer[RpcEnv]() + // exposed for testing + var masterWebUIPort = -1 def start(): Array[String] = { logInfo("Starting a local Spark cluster with " + numWorkers + " workers.") @@ -53,16 +54,17 @@ class LocalSparkCluster( .set("spark.shuffle.service.enabled", "false") /* Start the Master */ - val (masterSystem, masterPort, _, _) = Master.startSystemAndActor(localHostname, 0, 0, _conf) - masterActorSystems += masterSystem - val masterUrl = "spark://" + Utils.localHostNameForURI() + ":" + masterPort + val (rpcEnv, webUiPort, _) = Master.startRpcEnvAndEndpoint(localHostname, 0, 0, _conf) + masterWebUIPort = webUiPort + masterRpcEnvs += rpcEnv + val masterUrl = "spark://" + Utils.localHostNameForURI() + ":" + rpcEnv.address.port val masters = Array(masterUrl) /* Start the Workers */ for (workerNum <- 1 to numWorkers) { - val (workerSystem, _) = Worker.startSystemAndActor(localHostname, 0, 0, coresPerWorker, + val workerEnv = Worker.startRpcEnvAndEndpoint(localHostname, 0, 0, coresPerWorker, memoryPerWorker, masters, null, Some(workerNum), _conf) - workerActorSystems += workerSystem + workerRpcEnvs += workerEnv } masters @@ -73,11 +75,11 @@ class LocalSparkCluster( // Stop the workers before the master so they don't get upset that it disconnected // TODO: In Akka 2.1.x, ActorSystem.awaitTermination hangs when you have remote actors! // This is unfortunate, but for now we just comment it out. - workerActorSystems.foreach(_.shutdown()) + workerRpcEnvs.foreach(_.shutdown()) // workerActorSystems.foreach(_.awaitTermination()) - masterActorSystems.foreach(_.shutdown()) + masterRpcEnvs.foreach(_.shutdown()) // masterActorSystems.foreach(_.awaitTermination()) - masterActorSystems.clear() - workerActorSystems.clear() + masterRpcEnvs.clear() + workerRpcEnvs.clear() } } diff --git a/core/src/main/scala/org/apache/spark/deploy/SparkHadoopUtil.scala b/core/src/main/scala/org/apache/spark/deploy/SparkHadoopUtil.scala index 7fa75ac8c2b5..6d14590a1d19 100644 --- a/core/src/main/scala/org/apache/spark/deploy/SparkHadoopUtil.scala +++ b/core/src/main/scala/org/apache/spark/deploy/SparkHadoopUtil.scala @@ -334,6 +334,19 @@ class SparkHadoopUtil extends Logging { * Stop the thread that does the delegation token updates. */ private[spark] def stopExecutorDelegationTokenRenewer() {} + + /** + * Return a fresh Hadoop configuration, bypassing the HDFS cache mechanism. + * This is to prevent the DFSClient from using an old cached token to connect to the NameNode. + */ + private[spark] def getConfBypassingFSCache( + hadoopConf: Configuration, + scheme: String): Configuration = { + val newConf = new Configuration(hadoopConf) + val confKey = s"fs.${scheme}.impl.disable.cache" + newConf.setBoolean(confKey, true) + newConf + } } object SparkHadoopUtil { diff --git a/core/src/main/scala/org/apache/spark/deploy/SparkSubmit.scala b/core/src/main/scala/org/apache/spark/deploy/SparkSubmit.scala index 329fa06ba8ba..b1d6ec209d62 100644 --- a/core/src/main/scala/org/apache/spark/deploy/SparkSubmit.scala +++ b/core/src/main/scala/org/apache/spark/deploy/SparkSubmit.scala @@ -35,7 +35,8 @@ import org.apache.ivy.core.resolve.ResolveOptions import org.apache.ivy.core.retrieve.RetrieveOptions import org.apache.ivy.core.settings.IvySettings import org.apache.ivy.plugins.matcher.GlobPatternMatcher -import org.apache.ivy.plugins.resolver.{ChainResolver, IBiblioResolver} +import org.apache.ivy.plugins.repository.file.FileRepository +import org.apache.ivy.plugins.resolver.{FileSystemResolver, ChainResolver, IBiblioResolver} import org.apache.spark.SPARK_VERSION import org.apache.spark.deploy.rest._ import org.apache.spark.util.{ChildFirstURLClassLoader, MutableURLClassLoader, Utils} @@ -82,13 +83,13 @@ object SparkSubmit { private val CLASS_NOT_FOUND_EXIT_STATUS = 101 // Exposed for testing - private[spark] var exitFn: () => Unit = () => System.exit(1) + private[spark] var exitFn: Int => Unit = (exitCode: Int) => System.exit(exitCode) private[spark] var printStream: PrintStream = System.err private[spark] def printWarning(str: String): Unit = printStream.println("Warning: " + str) private[spark] def printErrorAndExit(str: String): Unit = { printStream.println("Error: " + str) printStream.println("Run with --help for usage help or --verbose for debug output") - exitFn() + exitFn(1) } private[spark] def printVersionAndExit(): Unit = { printStream.println("""Welcome to @@ -99,7 +100,7 @@ object SparkSubmit { /_/ """.format(SPARK_VERSION)) printStream.println("Type --help for more information.") - exitFn() + exitFn(0) } def main(args: Array[String]): Unit = { @@ -160,7 +161,7 @@ object SparkSubmit { // detect exceptions with empty stack traces here, and treat them differently. if (e.getStackTrace().length == 0) { printStream.println(s"ERROR: ${e.getClass().getName()}: ${e.getMessage()}") - exitFn() + exitFn(1) } else { throw e } @@ -324,55 +325,20 @@ object SparkSubmit { // Usage: PythonAppRunner

[app arguments] args.mainClass = "org.apache.spark.deploy.PythonRunner" args.childArgs = ArrayBuffer(args.primaryResource, args.pyFiles) ++ args.childArgs - args.files = mergeFileLists(args.files, args.primaryResource) + if (clusterManager != YARN) { + // The YARN backend distributes the primary file differently, so don't merge it. + args.files = mergeFileLists(args.files, args.primaryResource) + } + } + if (clusterManager != YARN) { + // The YARN backend handles python files differently, so don't merge the lists. + args.files = mergeFileLists(args.files, args.pyFiles) } - args.files = mergeFileLists(args.files, args.pyFiles) if (args.pyFiles != null) { sysProps("spark.submit.pyFiles") = args.pyFiles } } - // In yarn mode for a python app, add pyspark archives to files - // that can be distributed with the job - if (args.isPython && clusterManager == YARN) { - var pyArchives: String = null - val pyArchivesEnvOpt = sys.env.get("PYSPARK_ARCHIVES_PATH") - if (pyArchivesEnvOpt.isDefined) { - pyArchives = pyArchivesEnvOpt.get - } else { - if (!sys.env.contains("SPARK_HOME")) { - printErrorAndExit("SPARK_HOME does not exist for python application in yarn mode.") - } - val pythonPath = new ArrayBuffer[String] - for (sparkHome <- sys.env.get("SPARK_HOME")) { - val pyLibPath = Seq(sparkHome, "python", "lib").mkString(File.separator) - val pyArchivesFile = new File(pyLibPath, "pyspark.zip") - if (!pyArchivesFile.exists()) { - printErrorAndExit("pyspark.zip does not exist for python application in yarn mode.") - } - val py4jFile = new File(pyLibPath, "py4j-0.8.2.1-src.zip") - if (!py4jFile.exists()) { - printErrorAndExit("py4j-0.8.2.1-src.zip does not exist for python application " + - "in yarn mode.") - } - pythonPath += pyArchivesFile.getAbsolutePath() - pythonPath += py4jFile.getAbsolutePath() - } - pyArchives = pythonPath.mkString(",") - } - - pyArchives = pyArchives.split(",").map { localPath=> - val localURI = Utils.resolveURI(localPath) - if (localURI.getScheme != "local") { - args.files = mergeFileLists(args.files, localURI.toString) - new Path(localPath).getName - } else { - localURI.getPath - } - }.mkString(File.pathSeparator) - sysProps("spark.submit.pyArchives") = pyArchives - } - // If we're running a R app, set the main class to our specific R runner if (args.isR && deployMode == CLIENT) { if (args.primaryResource == SPARKR_SHELL) { @@ -386,19 +352,10 @@ object SparkSubmit { } } - if (isYarnCluster) { - // In yarn-cluster mode for a python app, add primary resource and pyFiles to files - // that can be distributed with the job - if (args.isPython) { - args.files = mergeFileLists(args.files, args.primaryResource) - args.files = mergeFileLists(args.files, args.pyFiles) - } - + if (isYarnCluster && args.isR) { // In yarn-cluster mode for a R app, add primary resource to files // that can be distributed with the job - if (args.isR) { - args.files = mergeFileLists(args.files, args.primaryResource) - } + args.files = mergeFileLists(args.files, args.primaryResource) } // Special flag to avoid deprecation warnings at the client @@ -425,9 +382,10 @@ object SparkSubmit { // Yarn client only OptionAssigner(args.queue, YARN, CLIENT, sysProp = "spark.yarn.queue"), OptionAssigner(args.numExecutors, YARN, CLIENT, sysProp = "spark.executor.instances"), - OptionAssigner(args.executorCores, YARN, CLIENT, sysProp = "spark.executor.cores"), OptionAssigner(args.files, YARN, CLIENT, sysProp = "spark.yarn.dist.files"), OptionAssigner(args.archives, YARN, CLIENT, sysProp = "spark.yarn.dist.archives"), + OptionAssigner(args.principal, YARN, CLIENT, sysProp = "spark.yarn.principal"), + OptionAssigner(args.keytab, YARN, CLIENT, sysProp = "spark.yarn.keytab"), // Yarn cluster only OptionAssigner(args.name, YARN, CLUSTER, clOption = "--name"), @@ -440,13 +398,11 @@ object SparkSubmit { OptionAssigner(args.files, YARN, CLUSTER, clOption = "--files"), OptionAssigner(args.archives, YARN, CLUSTER, clOption = "--archives"), OptionAssigner(args.jars, YARN, CLUSTER, clOption = "--addJars"), - - // Yarn client or cluster - OptionAssigner(args.principal, YARN, ALL_DEPLOY_MODES, clOption = "--principal"), - OptionAssigner(args.keytab, YARN, ALL_DEPLOY_MODES, clOption = "--keytab"), + OptionAssigner(args.principal, YARN, CLUSTER, clOption = "--principal"), + OptionAssigner(args.keytab, YARN, CLUSTER, clOption = "--keytab"), // Other options - OptionAssigner(args.executorCores, STANDALONE, ALL_DEPLOY_MODES, + OptionAssigner(args.executorCores, STANDALONE | YARN, ALL_DEPLOY_MODES, sysProp = "spark.executor.cores"), OptionAssigner(args.executorMemory, STANDALONE | MESOS | YARN, ALL_DEPLOY_MODES, sysProp = "spark.executor.memory"), @@ -516,17 +472,18 @@ object SparkSubmit { } } + // Let YARN know it's a pyspark app, so it distributes needed libraries. + if (clusterManager == YARN && args.isPython) { + sysProps.put("spark.yarn.isPython", "true") + } + // In yarn-cluster mode, use yarn.Client as a wrapper around the user class if (isYarnCluster) { childMainClass = "org.apache.spark.deploy.yarn.Client" if (args.isPython) { - val mainPyFile = new Path(args.primaryResource).getName - childArgs += ("--primary-py-file", mainPyFile) + childArgs += ("--primary-py-file", args.primaryResource) if (args.pyFiles != null) { - // These files will be distributed to each machine's working directory, so strip the - // path prefix - val pyFilesNames = args.pyFiles.split(",").map(p => (new Path(p)).getName).mkString(",") - childArgs += ("--py-files", pyFilesNames) + childArgs += ("--py-files", args.pyFiles) } childArgs += ("--class", "org.apache.spark.deploy.PythonRunner") } else if (args.isR) { @@ -700,7 +657,7 @@ object SparkSubmit { /** * Return whether the given main class represents a sql shell. */ - private def isSqlShell(mainClass: String): Boolean = { + private[deploy] def isSqlShell(mainClass: String): Boolean = { mainClass == "org.apache.spark.sql.hive.thriftserver.SparkSQLCLIDriver" } @@ -753,7 +710,9 @@ private[spark] object SparkSubmitUtils { * @param artifactId the artifactId of the coordinate * @param version the version of the coordinate */ - private[deploy] case class MavenCoordinate(groupId: String, artifactId: String, version: String) + private[deploy] case class MavenCoordinate(groupId: String, artifactId: String, version: String) { + override def toString: String = s"$groupId:$artifactId:$version" + } /** * Extracts maven coordinates from a comma-delimited string. Coordinates should be provided @@ -776,6 +735,16 @@ private[spark] object SparkSubmitUtils { } } + /** Path of the local Maven cache. */ + private[spark] def m2Path: File = { + if (Utils.isTesting) { + // test builds delete the maven cache, and this can cause flakiness + new File("dummy", ".m2" + File.separator + "repository") + } else { + new File(System.getProperty("user.home"), ".m2" + File.separator + "repository") + } + } + /** * Extracts maven coordinates from a comma-delimited string * @param remoteRepos Comma-delimited string of remote repositories @@ -787,20 +756,34 @@ private[spark] object SparkSubmitUtils { val cr = new ChainResolver cr.setName("list") + val repositoryList = remoteRepos.getOrElse("") + // add any other remote repositories other than maven central + if (repositoryList.trim.nonEmpty) { + repositoryList.split(",").zipWithIndex.foreach { case (repo, i) => + val brr: IBiblioResolver = new IBiblioResolver + brr.setM2compatible(true) + brr.setUsepoms(true) + brr.setRoot(repo) + brr.setName(s"repo-${i + 1}") + cr.add(brr) + printStream.println(s"$repo added as a remote repository with the name: ${brr.getName}") + } + } + val localM2 = new IBiblioResolver localM2.setM2compatible(true) - val m2Path = ".m2" + File.separator + "repository" + File.separator - localM2.setRoot(new File(System.getProperty("user.home"), m2Path).toURI.toString) + localM2.setRoot(m2Path.toURI.toString) localM2.setUsepoms(true) localM2.setName("local-m2-cache") cr.add(localM2) - val localIvy = new IBiblioResolver - localIvy.setRoot(new File(ivySettings.getDefaultIvyUserDir, - "local" + File.separator).toURI.toString) + val localIvy = new FileSystemResolver + val localIvyRoot = new File(ivySettings.getDefaultIvyUserDir, "local") + localIvy.setLocal(true) + localIvy.setRepository(new FileRepository(localIvyRoot)) val ivyPattern = Seq("[organisation]", "[module]", "[revision]", "[type]s", "[artifact](-[classifier]).[ext]").mkString(File.separator) - localIvy.setPattern(ivyPattern) + localIvy.addIvyPattern(localIvyRoot.getAbsolutePath + File.separator + ivyPattern) localIvy.setName("local-ivy-cache") cr.add(localIvy) @@ -817,20 +800,6 @@ private[spark] object SparkSubmitUtils { sp.setRoot("http://dl.bintray.com/spark-packages/maven") sp.setName("spark-packages") cr.add(sp) - - val repositoryList = remoteRepos.getOrElse("") - // add any other remote repositories other than maven central - if (repositoryList.trim.nonEmpty) { - repositoryList.split(",").zipWithIndex.foreach { case (repo, i) => - val brr: IBiblioResolver = new IBiblioResolver - brr.setM2compatible(true) - brr.setUsepoms(true) - brr.setRoot(repo) - brr.setName(s"repo-${i + 1}") - cr.add(brr) - printStream.println(s"$repo added as a remote repository with the name: ${brr.getName}") - } - } cr } @@ -864,18 +833,14 @@ private[spark] object SparkSubmitUtils { md.addDependency(dd) } } - + /** Add exclusion rules for dependencies already included in the spark-assembly */ def addExclusionRules( ivySettings: IvySettings, ivyConfName: String, md: DefaultModuleDescriptor): Unit = { // Add scala exclusion rule - val scalaArtifacts = new ArtifactId(new ModuleId("*", "scala-library"), "*", "*", "*") - val scalaDependencyExcludeRule = - new DefaultExcludeRule(scalaArtifacts, ivySettings.getMatcher("glob"), null) - scalaDependencyExcludeRule.addConfiguration(ivyConfName) - md.addExcludeRule(scalaDependencyExcludeRule) + md.addExcludeRule(createExclusion("*:scala-library:*", ivySettings, ivyConfName)) // We need to specify each component explicitly, otherwise we miss spark-streaming-kafka and // other spark-streaming utility components. Underscore is there to differentiate between @@ -884,13 +849,8 @@ private[spark] object SparkSubmitUtils { "sql_", "streaming_", "yarn_", "network-common_", "network-shuffle_", "network-yarn_") components.foreach { comp => - val sparkArtifacts = - new ArtifactId(new ModuleId("org.apache.spark", s"spark-$comp*"), "*", "*", "*") - val sparkDependencyExcludeRule = - new DefaultExcludeRule(sparkArtifacts, ivySettings.getMatcher("glob"), null) - sparkDependencyExcludeRule.addConfiguration(ivyConfName) - - md.addExcludeRule(sparkDependencyExcludeRule) + md.addExcludeRule(createExclusion(s"org.apache.spark:spark-$comp*:*", ivySettings, + ivyConfName)) } } @@ -903,6 +863,7 @@ private[spark] object SparkSubmitUtils { * @param coordinates Comma-delimited string of maven coordinates * @param remoteRepos Comma-delimited string of remote repositories other than maven central * @param ivyPath The path to the local ivy repository + * @param exclusions Exclusions to apply when resolving transitive dependencies * @return The comma-delimited path to the jars of the given maven artifacts including their * transitive dependencies */ @@ -910,76 +871,105 @@ private[spark] object SparkSubmitUtils { coordinates: String, remoteRepos: Option[String], ivyPath: Option[String], + exclusions: Seq[String] = Nil, isTest: Boolean = false): String = { if (coordinates == null || coordinates.trim.isEmpty) { "" } else { val sysOut = System.out - // To prevent ivy from logging to system out - System.setOut(printStream) - val artifacts = extractMavenCoordinates(coordinates) - // Default configuration name for ivy - val ivyConfName = "default" - // set ivy settings for location of cache - val ivySettings: IvySettings = new IvySettings - // Directories for caching downloads through ivy and storing the jars when maven coordinates - // are supplied to spark-submit - val alternateIvyCache = ivyPath.getOrElse("") - val packagesDirectory: File = - if (alternateIvyCache.trim.isEmpty) { - new File(ivySettings.getDefaultIvyUserDir, "jars") + try { + // To prevent ivy from logging to system out + System.setOut(printStream) + val artifacts = extractMavenCoordinates(coordinates) + // Default configuration name for ivy + val ivyConfName = "default" + // set ivy settings for location of cache + val ivySettings: IvySettings = new IvySettings + // Directories for caching downloads through ivy and storing the jars when maven coordinates + // are supplied to spark-submit + val alternateIvyCache = ivyPath.getOrElse("") + val packagesDirectory: File = + if (alternateIvyCache.trim.isEmpty) { + new File(ivySettings.getDefaultIvyUserDir, "jars") + } else { + ivySettings.setDefaultIvyUserDir(new File(alternateIvyCache)) + ivySettings.setDefaultCache(new File(alternateIvyCache, "cache")) + new File(alternateIvyCache, "jars") + } + printStream.println( + s"Ivy Default Cache set to: ${ivySettings.getDefaultCache.getAbsolutePath}") + printStream.println(s"The jars for the packages stored in: $packagesDirectory") + // create a pattern matcher + ivySettings.addMatcher(new GlobPatternMatcher) + // create the dependency resolvers + val repoResolver = createRepoResolvers(remoteRepos, ivySettings) + ivySettings.addResolver(repoResolver) + ivySettings.setDefaultResolver(repoResolver.getName) + + val ivy = Ivy.newInstance(ivySettings) + // Set resolve options to download transitive dependencies as well + val resolveOptions = new ResolveOptions + resolveOptions.setTransitive(true) + val retrieveOptions = new RetrieveOptions + // Turn downloading and logging off for testing + if (isTest) { + resolveOptions.setDownload(false) + resolveOptions.setLog(LogOptions.LOG_QUIET) + retrieveOptions.setLog(LogOptions.LOG_QUIET) } else { - ivySettings.setDefaultIvyUserDir(new File(alternateIvyCache)) - ivySettings.setDefaultCache(new File(alternateIvyCache, "cache")) - new File(alternateIvyCache, "jars") + resolveOptions.setDownload(true) } - printStream.println( - s"Ivy Default Cache set to: ${ivySettings.getDefaultCache.getAbsolutePath}") - printStream.println(s"The jars for the packages stored in: $packagesDirectory") - // create a pattern matcher - ivySettings.addMatcher(new GlobPatternMatcher) - // create the dependency resolvers - val repoResolver = createRepoResolvers(remoteRepos, ivySettings) - ivySettings.addResolver(repoResolver) - ivySettings.setDefaultResolver(repoResolver.getName) - - val ivy = Ivy.newInstance(ivySettings) - // Set resolve options to download transitive dependencies as well - val resolveOptions = new ResolveOptions - resolveOptions.setTransitive(true) - val retrieveOptions = new RetrieveOptions - // Turn downloading and logging off for testing - if (isTest) { - resolveOptions.setDownload(false) - resolveOptions.setLog(LogOptions.LOG_QUIET) - retrieveOptions.setLog(LogOptions.LOG_QUIET) - } else { - resolveOptions.setDownload(true) - } - - // A Module descriptor must be specified. Entries are dummy strings - val md = getModuleDescriptor - md.setDefaultConf(ivyConfName) - // Add exclusion rules for Spark and Scala Library - addExclusionRules(ivySettings, ivyConfName, md) - // add all supplied maven artifacts as dependencies - addDependenciesToIvy(md, artifacts, ivyConfName) + // A Module descriptor must be specified. Entries are dummy strings + val md = getModuleDescriptor + // clear ivy resolution from previous launches. The resolution file is usually at + // ~/.ivy2/org.apache.spark-spark-submit-parent-default.xml. In between runs, this file + // leads to confusion with Ivy when the files can no longer be found at the repository + // declared in that file/ + val mdId = md.getModuleRevisionId + val previousResolution = new File(ivySettings.getDefaultCache, + s"${mdId.getOrganisation}-${mdId.getName}-$ivyConfName.xml") + if (previousResolution.exists) previousResolution.delete + + md.setDefaultConf(ivyConfName) + + // Add exclusion rules for Spark and Scala Library + addExclusionRules(ivySettings, ivyConfName, md) + // add all supplied maven artifacts as dependencies + addDependenciesToIvy(md, artifacts, ivyConfName) + + exclusions.foreach { e => + md.addExcludeRule(createExclusion(e + ":*", ivySettings, ivyConfName)) + } - // resolve dependencies - val rr: ResolveReport = ivy.resolve(md, resolveOptions) - if (rr.hasError) { - throw new RuntimeException(rr.getAllProblemMessages.toString) + // resolve dependencies + val rr: ResolveReport = ivy.resolve(md, resolveOptions) + if (rr.hasError) { + throw new RuntimeException(rr.getAllProblemMessages.toString) + } + // retrieve all resolved dependencies + ivy.retrieve(rr.getModuleDescriptor.getModuleRevisionId, + packagesDirectory.getAbsolutePath + File.separator + + "[organization]_[artifact]-[revision].[ext]", + retrieveOptions.setConfs(Array(ivyConfName))) + resolveDependencyPaths(rr.getArtifacts.toArray, packagesDirectory) + } finally { + System.setOut(sysOut) } - // retrieve all resolved dependencies - ivy.retrieve(rr.getModuleDescriptor.getModuleRevisionId, - packagesDirectory.getAbsolutePath + File.separator + - "[organization]_[artifact]-[revision].[ext]", - retrieveOptions.setConfs(Array(ivyConfName))) - System.setOut(sysOut) - resolveDependencyPaths(rr.getArtifacts.toArray, packagesDirectory) } } + + private def createExclusion( + coords: String, + ivySettings: IvySettings, + ivyConfName: String): ExcludeRule = { + val c = extractMavenCoordinates(coords)(0) + val id = new ArtifactId(new ModuleId(c.groupId, c.artifactId), "*", "*", "*") + val rule = new DefaultExcludeRule(id, ivySettings.getMatcher("glob"), null) + rule.addConfiguration(ivyConfName) + rule + } + } /** diff --git a/core/src/main/scala/org/apache/spark/deploy/SparkSubmitArguments.scala b/core/src/main/scala/org/apache/spark/deploy/SparkSubmitArguments.scala index c0e4c771908b..73ab18332feb 100644 --- a/core/src/main/scala/org/apache/spark/deploy/SparkSubmitArguments.scala +++ b/core/src/main/scala/org/apache/spark/deploy/SparkSubmitArguments.scala @@ -17,12 +17,15 @@ package org.apache.spark.deploy +import java.io.{ByteArrayOutputStream, PrintStream} +import java.lang.reflect.InvocationTargetException import java.net.URI import java.util.{List => JList} import java.util.jar.JarFile import scala.collection.JavaConversions._ import scala.collection.mutable.{ArrayBuffer, HashMap} +import scala.io.Source import org.apache.spark.deploy.SparkSubmitAction._ import org.apache.spark.launcher.SparkSubmitArgumentsParser @@ -169,6 +172,8 @@ private[deploy] class SparkSubmitArguments(args: Seq[String], env: Map[String, S deployMode = Option(deployMode).orElse(env.get("DEPLOY_MODE")).orNull numExecutors = Option(numExecutors) .getOrElse(sparkProperties.get("spark.executor.instances").orNull) + keytab = Option(keytab).orElse(sparkProperties.get("spark.yarn.keytab")).orNull + principal = Option(principal).orElse(sparkProperties.get("spark.yarn.principal")).orNull // Try to set main class from JAR if no --class argument is given if (mainClass == null && !isPython && !isR && primaryResource != null) { @@ -410,6 +415,9 @@ private[deploy] class SparkSubmitArguments(args: Seq[String], env: Map[String, S case VERSION => SparkSubmit.printVersionAndExit() + case USAGE_ERROR => + printUsageAndExit(1) + case _ => throw new IllegalArgumentException(s"Unexpected argument '$opt'.") } @@ -447,11 +455,15 @@ private[deploy] class SparkSubmitArguments(args: Seq[String], env: Map[String, S if (unknownParam != null) { outStream.println("Unknown/unsupported param " + unknownParam) } - outStream.println( + val command = sys.env.get("_SPARK_CMD_USAGE").getOrElse( """Usage: spark-submit [options] [app arguments] |Usage: spark-submit --kill [submission ID] --master [spark://...] - |Usage: spark-submit --status [submission ID] --master [spark://...] - | + |Usage: spark-submit --status [submission ID] --master [spark://...]""".stripMargin) + outStream.println(command) + + val mem_mb = Utils.DEFAULT_DRIVER_MEM_MB + outStream.println( + s""" |Options: | --master MASTER_URL spark://host:port, mesos://host:port, yarn, or local. | --deploy-mode DEPLOY_MODE Whether to launch the driver program locally ("client") or @@ -477,7 +489,7 @@ private[deploy] class SparkSubmitArguments(args: Seq[String], env: Map[String, S | --properties-file FILE Path to a file from which to load extra properties. If not | specified, this will look for conf/spark-defaults.conf. | - | --driver-memory MEM Memory for driver (e.g. 1000M, 2G) (Default: 512M). + | --driver-memory MEM Memory for driver (e.g. 1000M, 2G) (Default: ${mem_mb}M). | --driver-java-options Extra Java options to pass to the driver. | --driver-library-path Extra library path entries to pass to the driver. | --driver-class-path Extra class path entries to pass to the driver. Note that @@ -523,6 +535,65 @@ private[deploy] class SparkSubmitArguments(args: Seq[String], env: Map[String, S | delegation tokens periodically. """.stripMargin ) - SparkSubmit.exitFn() + + if (SparkSubmit.isSqlShell(mainClass)) { + outStream.println("CLI options:") + outStream.println(getSqlShellOptions()) + } + + SparkSubmit.exitFn(exitCode) } + + /** + * Run the Spark SQL CLI main class with the "--help" option and catch its output. Then filter + * the results to remove unwanted lines. + * + * Since the CLI will call `System.exit()`, we install a security manager to prevent that call + * from working, and restore the original one afterwards. + */ + private def getSqlShellOptions(): String = { + val currentOut = System.out + val currentErr = System.err + val currentSm = System.getSecurityManager() + try { + val out = new ByteArrayOutputStream() + val stream = new PrintStream(out) + System.setOut(stream) + System.setErr(stream) + + val sm = new SecurityManager() { + override def checkExit(status: Int): Unit = { + throw new SecurityException() + } + + override def checkPermission(perm: java.security.Permission): Unit = {} + } + System.setSecurityManager(sm) + + try { + Class.forName(mainClass).getMethod("main", classOf[Array[String]]) + .invoke(null, Array(HELP)) + } catch { + case e: InvocationTargetException => + // Ignore SecurityException, since we throw it above. + if (!e.getCause().isInstanceOf[SecurityException]) { + throw e + } + } + + stream.flush() + + // Get the output and discard any unnecessary lines from it. + Source.fromString(new String(out.toByteArray())).getLines + .filter { line => + !line.startsWith("log4j") && !line.startsWith("usage") + } + .mkString("\n") + } finally { + System.setSecurityManager(currentSm) + System.setOut(currentOut) + System.setErr(currentErr) + } + } + } diff --git a/core/src/main/scala/org/apache/spark/deploy/client/AppClient.scala b/core/src/main/scala/org/apache/spark/deploy/client/AppClient.scala index 43c8a934c311..79b251e7e62f 100644 --- a/core/src/main/scala/org/apache/spark/deploy/client/AppClient.scala +++ b/core/src/main/scala/org/apache/spark/deploy/client/AppClient.scala @@ -17,20 +17,17 @@ package org.apache.spark.deploy.client -import java.util.concurrent.TimeoutException +import java.util.concurrent._ +import java.util.concurrent.{Future => JFuture, ScheduledFuture => JScheduledFuture} -import scala.concurrent.Await -import scala.concurrent.duration._ - -import akka.actor._ -import akka.pattern.ask -import akka.remote.{AssociationErrorEvent, DisassociatedEvent, RemotingLifecycleEvent} +import scala.util.control.NonFatal import org.apache.spark.{Logging, SparkConf} import org.apache.spark.deploy.{ApplicationDescription, ExecutorState} import org.apache.spark.deploy.DeployMessages._ import org.apache.spark.deploy.master.Master -import org.apache.spark.util.{ActorLogReceive, RpcUtils, Utils, AkkaUtils} +import org.apache.spark.rpc._ +import org.apache.spark.util.{ThreadUtils, Utils} /** * Interface allowing applications to speak with a Spark deploy cluster. Takes a master URL, @@ -40,98 +37,143 @@ import org.apache.spark.util.{ActorLogReceive, RpcUtils, Utils, AkkaUtils} * @param masterUrls Each url should look like spark://host:port. */ private[spark] class AppClient( - actorSystem: ActorSystem, + rpcEnv: RpcEnv, masterUrls: Array[String], appDescription: ApplicationDescription, listener: AppClientListener, conf: SparkConf) extends Logging { - private val masterAkkaUrls = masterUrls.map(Master.toAkkaUrl(_, AkkaUtils.protocol(actorSystem))) + private val masterRpcAddresses = masterUrls.map(RpcAddress.fromSparkURL(_)) - private val REGISTRATION_TIMEOUT = 20.seconds + private val REGISTRATION_TIMEOUT_SECONDS = 20 private val REGISTRATION_RETRIES = 3 - private var masterAddress: Address = null - private var actor: ActorRef = null + private var endpoint: RpcEndpointRef = null private var appId: String = null - private var registered = false - private var activeMasterUrl: String = null + @volatile private var registered = false + + private class ClientEndpoint(override val rpcEnv: RpcEnv) extends ThreadSafeRpcEndpoint + with Logging { + + private var master: Option[RpcEndpointRef] = None + // To avoid calling listener.disconnected() multiple times + private var alreadyDisconnected = false + @volatile private var alreadyDead = false // To avoid calling listener.dead() multiple times + @volatile private var registerMasterFutures: Array[JFuture[_]] = null + @volatile private var registrationRetryTimer: JScheduledFuture[_] = null + + // A thread pool for registering with masters. Because registering with a master is a blocking + // action, this thread pool must be able to create "masterRpcAddresses.size" threads at the same + // time so that we can register with all masters. + private val registerMasterThreadPool = new ThreadPoolExecutor( + 0, + masterRpcAddresses.size, // Make sure we can register with all masters at the same time + 60L, TimeUnit.SECONDS, + new SynchronousQueue[Runnable](), + ThreadUtils.namedThreadFactory("appclient-register-master-threadpool")) - private class ClientActor extends Actor with ActorLogReceive with Logging { - var master: ActorSelection = null - var alreadyDisconnected = false // To avoid calling listener.disconnected() multiple times - var alreadyDead = false // To avoid calling listener.dead() multiple times - var registrationRetryTimer: Option[Cancellable] = None + // A scheduled executor for scheduling the registration actions + private val registrationRetryThread = + ThreadUtils.newDaemonSingleThreadScheduledExecutor("appclient-registration-retry-thread") - override def preStart() { - context.system.eventStream.subscribe(self, classOf[RemotingLifecycleEvent]) + override def onStart(): Unit = { try { - registerWithMaster() + registerWithMaster(1) } catch { case e: Exception => logWarning("Failed to connect to master", e) markDisconnected() - context.stop(self) + stop() } } - def tryRegisterAllMasters() { - for (masterAkkaUrl <- masterAkkaUrls) { - logInfo("Connecting to master " + masterAkkaUrl + "...") - val actor = context.actorSelection(masterAkkaUrl) - actor ! RegisterApplication(appDescription) + /** + * Register with all masters asynchronously and returns an array `Future`s for cancellation. + */ + private def tryRegisterAllMasters(): Array[JFuture[_]] = { + for (masterAddress <- masterRpcAddresses) yield { + registerMasterThreadPool.submit(new Runnable { + override def run(): Unit = try { + if (registered) { + return + } + logInfo("Connecting to master " + masterAddress.toSparkURL + "...") + val masterRef = + rpcEnv.setupEndpointRef(Master.SYSTEM_NAME, masterAddress, Master.ENDPOINT_NAME) + masterRef.send(RegisterApplication(appDescription, self)) + } catch { + case ie: InterruptedException => // Cancelled + case NonFatal(e) => logWarning(s"Failed to connect to master $masterAddress", e) + } + }) } } - def registerWithMaster() { - tryRegisterAllMasters() - import context.dispatcher - var retries = 0 - registrationRetryTimer = Some { - context.system.scheduler.schedule(REGISTRATION_TIMEOUT, REGISTRATION_TIMEOUT) { + /** + * Register with all masters asynchronously. It will call `registerWithMaster` every + * REGISTRATION_TIMEOUT_SECONDS seconds until exceeding REGISTRATION_RETRIES times. + * Once we connect to a master successfully, all scheduling work and Futures will be cancelled. + * + * nthRetry means this is the nth attempt to register with master. + */ + private def registerWithMaster(nthRetry: Int) { + registerMasterFutures = tryRegisterAllMasters() + registrationRetryTimer = registrationRetryThread.scheduleAtFixedRate(new Runnable { + override def run(): Unit = { Utils.tryOrExit { - retries += 1 if (registered) { - registrationRetryTimer.foreach(_.cancel()) - } else if (retries >= REGISTRATION_RETRIES) { + registerMasterFutures.foreach(_.cancel(true)) + registerMasterThreadPool.shutdownNow() + } else if (nthRetry >= REGISTRATION_RETRIES) { markDead("All masters are unresponsive! Giving up.") } else { - tryRegisterAllMasters() + registerMasterFutures.foreach(_.cancel(true)) + registerWithMaster(nthRetry + 1) } } } - } + }, REGISTRATION_TIMEOUT_SECONDS, REGISTRATION_TIMEOUT_SECONDS, TimeUnit.SECONDS) } - def changeMaster(url: String) { - // activeMasterUrl is a valid Spark url since we receive it from master. - activeMasterUrl = url - master = context.actorSelection( - Master.toAkkaUrl(activeMasterUrl, AkkaUtils.protocol(actorSystem))) - masterAddress = Master.toAkkaAddress(activeMasterUrl, AkkaUtils.protocol(actorSystem)) + /** + * Send a message to the current master. If we have not yet registered successfully with any + * master, the message will be dropped. + */ + private def sendToMaster(message: Any): Unit = { + master match { + case Some(masterRef) => masterRef.send(message) + case None => logWarning(s"Drop $message because has not yet connected to master") + } } - private def isPossibleMaster(remoteUrl: Address) = { - masterAkkaUrls.map(AddressFromURIString(_).hostPort).contains(remoteUrl.hostPort) + private def isPossibleMaster(remoteAddress: RpcAddress): Boolean = { + masterRpcAddresses.contains(remoteAddress) } - override def receiveWithLogging: PartialFunction[Any, Unit] = { - case RegisteredApplication(appId_, masterUrl) => + override def receive: PartialFunction[Any, Unit] = { + case RegisteredApplication(appId_, masterRef) => + // FIXME How to handle the following cases? + // 1. A master receives multiple registrations and sends back multiple + // RegisteredApplications due to an unstable network. + // 2. Receive multiple RegisteredApplication from different masters because the master is + // changing. appId = appId_ registered = true - changeMaster(masterUrl) + master = Some(masterRef) listener.connected(appId) case ApplicationRemoved(message) => markDead("Master removed our application: %s".format(message)) - context.stop(self) + stop() case ExecutorAdded(id: Int, workerId: String, hostPort: String, cores: Int, memory: Int) => val fullId = appId + "/" + id logInfo("Executor added: %s on %s (%s) with %d cores".format(fullId, workerId, hostPort, cores)) - master ! ExecutorStateChanged(appId, id, ExecutorState.RUNNING, None, None) + // FIXME if changing master and `ExecutorAdded` happen at the same time (the order is not + // guaranteed), `ExecutorStateChanged` may be sent to a dead master. + sendToMaster(ExecutorStateChanged(appId, id, ExecutorState.RUNNING, None, None)) listener.executorAdded(fullId, workerId, hostPort, cores, memory) case ExecutorUpdated(id, state, message, exitStatus) => @@ -142,24 +184,32 @@ private[spark] class AppClient( listener.executorRemoved(fullId, message.getOrElse(""), exitStatus) } - case MasterChanged(masterUrl, masterWebUiUrl) => - logInfo("Master has changed, new master is at " + masterUrl) - changeMaster(masterUrl) + case MasterChanged(masterRef, masterWebUiUrl) => + logInfo("Master has changed, new master is at " + masterRef.address.toSparkURL) + master = Some(masterRef) alreadyDisconnected = false - sender ! MasterChangeAcknowledged(appId) + masterRef.send(MasterChangeAcknowledged(appId)) + } - case DisassociatedEvent(_, address, _) if address == masterAddress => + override def receiveAndReply(context: RpcCallContext): PartialFunction[Any, Unit] = { + case StopAppClient => + markDead("Application has been stopped.") + sendToMaster(UnregisterApplication(appId)) + context.reply(true) + stop() + } + + override def onDisconnected(address: RpcAddress): Unit = { + if (master.exists(_.address == address)) { logWarning(s"Connection to $address failed; waiting for master to reconnect...") markDisconnected() + } + } - case AssociationErrorEvent(cause, _, address, _, _) if isPossibleMaster(address) => + override def onNetworkError(cause: Throwable, address: RpcAddress): Unit = { + if (isPossibleMaster(address)) { logWarning(s"Could not connect to $address: $cause") - - case StopAppClient => - markDead("Application has been stopped.") - master ! UnregisterApplication(appId) - sender ! true - context.stop(self) + } } /** @@ -179,28 +229,31 @@ private[spark] class AppClient( } } - override def postStop() { - registrationRetryTimer.foreach(_.cancel()) + override def onStop(): Unit = { + if (registrationRetryTimer != null) { + registrationRetryTimer.cancel(true) + } + registrationRetryThread.shutdownNow() + registerMasterFutures.foreach(_.cancel(true)) + registerMasterThreadPool.shutdownNow() } } def start() { // Just launch an actor; it will call back into the listener. - actor = actorSystem.actorOf(Props(new ClientActor)) + endpoint = rpcEnv.setupEndpoint("AppClient", new ClientEndpoint(rpcEnv)) } def stop() { - if (actor != null) { + if (endpoint != null) { try { - val timeout = RpcUtils.askTimeout(conf) - val future = actor.ask(StopAppClient)(timeout) - Await.result(future, timeout) + endpoint.askWithRetry[Boolean](StopAppClient) } catch { case e: TimeoutException => logInfo("Stop request to Master timed out; it may already be shut down.") } - actor = null + endpoint = null } } } diff --git a/core/src/main/scala/org/apache/spark/deploy/client/TestClient.scala b/core/src/main/scala/org/apache/spark/deploy/client/TestClient.scala index 40835b955058..1c79089303e3 100644 --- a/core/src/main/scala/org/apache/spark/deploy/client/TestClient.scala +++ b/core/src/main/scala/org/apache/spark/deploy/client/TestClient.scala @@ -17,9 +17,10 @@ package org.apache.spark.deploy.client +import org.apache.spark.rpc.RpcEnv import org.apache.spark.{SecurityManager, SparkConf, Logging} import org.apache.spark.deploy.{ApplicationDescription, Command} -import org.apache.spark.util.{AkkaUtils, Utils} +import org.apache.spark.util.Utils private[spark] object TestClient { @@ -46,13 +47,12 @@ private[spark] object TestClient { def main(args: Array[String]) { val url = args(0) val conf = new SparkConf - val (actorSystem, _) = AkkaUtils.createActorSystem("spark", Utils.localHostName(), 0, - conf = conf, securityManager = new SecurityManager(conf)) + val rpcEnv = RpcEnv.create("spark", Utils.localHostName(), 0, conf, new SecurityManager(conf)) val desc = new ApplicationDescription("TestClient", Some(1), 512, Command("spark.deploy.client.TestExecutor", Seq(), Map(), Seq(), Seq(), Seq()), "ignored") val listener = new TestListener - val client = new AppClient(actorSystem, Array(url), desc, listener, new SparkConf) + val client = new AppClient(rpcEnv, Array(url), desc, listener, new SparkConf) client.start() - actorSystem.awaitTermination() + rpcEnv.awaitTermination() } } diff --git a/core/src/main/scala/org/apache/spark/deploy/history/ApplicationHistoryProvider.scala b/core/src/main/scala/org/apache/spark/deploy/history/ApplicationHistoryProvider.scala index 298a8201960d..5f5e0fe1c34d 100644 --- a/core/src/main/scala/org/apache/spark/deploy/history/ApplicationHistoryProvider.scala +++ b/core/src/main/scala/org/apache/spark/deploy/history/ApplicationHistoryProvider.scala @@ -17,6 +17,9 @@ package org.apache.spark.deploy.history +import java.util.zip.ZipOutputStream + +import org.apache.spark.SparkException import org.apache.spark.ui.SparkUI private[spark] case class ApplicationAttemptInfo( @@ -62,4 +65,12 @@ private[history] abstract class ApplicationHistoryProvider { */ def getConfig(): Map[String, String] = Map() + /** + * Writes out the event logs to the output stream provided. The logs will be compressed into a + * single zip file and written out. + * @throws SparkException if the logs for the app id cannot be found. + */ + @throws(classOf[SparkException]) + def writeEventLogs(appId: String, attemptId: Option[String], zipStream: ZipOutputStream): Unit + } diff --git a/core/src/main/scala/org/apache/spark/deploy/history/FsHistoryProvider.scala b/core/src/main/scala/org/apache/spark/deploy/history/FsHistoryProvider.scala index 45c2be34c868..2cc465e55fce 100644 --- a/core/src/main/scala/org/apache/spark/deploy/history/FsHistoryProvider.scala +++ b/core/src/main/scala/org/apache/spark/deploy/history/FsHistoryProvider.scala @@ -17,16 +17,18 @@ package org.apache.spark.deploy.history -import java.io.{BufferedInputStream, FileNotFoundException, IOException, InputStream} +import java.io.{BufferedInputStream, FileNotFoundException, InputStream, IOException, OutputStream} import java.util.concurrent.{ExecutorService, Executors, TimeUnit} +import java.util.zip.{ZipEntry, ZipOutputStream} import scala.collection.mutable +import com.google.common.io.ByteStreams import com.google.common.util.concurrent.{MoreExecutors, ThreadFactoryBuilder} -import org.apache.hadoop.fs.{FileStatus, Path} +import org.apache.hadoop.fs.{FileStatus, FileSystem, Path} import org.apache.hadoop.fs.permission.AccessControlException -import org.apache.spark.{Logging, SecurityManager, SparkConf} +import org.apache.spark.{Logging, SecurityManager, SparkConf, SparkException} import org.apache.spark.deploy.SparkHadoopUtil import org.apache.spark.io.CompressionCodec import org.apache.spark.scheduler._ @@ -59,7 +61,8 @@ private[history] class FsHistoryProvider(conf: SparkConf, clock: Clock) .map { d => Utils.resolveURI(d).toString } .getOrElse(DEFAULT_LOG_DIR) - private val fs = Utils.getHadoopFileSystem(logDir, SparkHadoopUtil.get.newConfiguration(conf)) + private val hadoopConf = SparkHadoopUtil.get.newConfiguration(conf) + private val fs = Utils.getHadoopFileSystem(logDir, hadoopConf) // Used by check event thread and clean log thread. // Scheduled thread pool size must be one, otherwise it will have concurrent issues about fs @@ -80,12 +83,6 @@ private[history] class FsHistoryProvider(conf: SparkConf, clock: Clock) // List of application logs to be deleted by event log cleaner. private var attemptsToClean = new mutable.ListBuffer[FsApplicationAttemptInfo] - // Constants used to parse Spark 1.0.0 log directories. - private[history] val LOG_PREFIX = "EVENT_LOG_" - private[history] val SPARK_VERSION_PREFIX = EventLoggingListener.SPARK_VERSION_KEY + "_" - private[history] val COMPRESSION_CODEC_PREFIX = EventLoggingListener.COMPRESSION_CODEC_KEY + "_" - private[history] val APPLICATION_COMPLETE = "APPLICATION_COMPLETE" - /** * Return a runnable that performs the given operation on the event logs. * This operation is expected to be executed periodically. @@ -143,7 +140,7 @@ private[history] class FsHistoryProvider(conf: SparkConf, clock: Clock) override def getAppUI(appId: String, attemptId: Option[String]): Option[SparkUI] = { try { applications.get(appId).flatMap { appInfo => - appInfo.attempts.find(_.attemptId == attemptId).map { attempt => + appInfo.attempts.find(_.attemptId == attemptId).flatMap { attempt => val replayBus = new ReplayListenerBus() val ui = { val conf = this.conf.clone() @@ -152,20 +149,20 @@ private[history] class FsHistoryProvider(conf: SparkConf, clock: Clock) HistoryServer.getAttemptURI(appId, attempt.attemptId), attempt.startTime) // Do not call ui.bind() to avoid creating a new server for each application } - val appListener = new ApplicationEventListener() replayBus.addListener(appListener) val appInfo = replay(fs.getFileStatus(new Path(logDir, attempt.logPath)), replayBus) - - ui.setAppName(s"${appInfo.name} ($appId)") - - val uiAclsEnabled = conf.getBoolean("spark.history.ui.acls.enable", false) - ui.getSecurityManager.setAcls(uiAclsEnabled) - // make sure to set admin acls before view acls so they are properly picked up - ui.getSecurityManager.setAdminAcls(appListener.adminAcls.getOrElse("")) - ui.getSecurityManager.setViewAcls(attempt.sparkUser, - appListener.viewAcls.getOrElse("")) - ui + appInfo.map { info => + ui.setAppName(s"${info.name} ($appId)") + + val uiAclsEnabled = conf.getBoolean("spark.history.ui.acls.enable", false) + ui.getSecurityManager.setAcls(uiAclsEnabled) + // make sure to set admin acls before view acls so they are properly picked up + ui.getSecurityManager.setAdminAcls(appListener.adminAcls.getOrElse("")) + ui.getSecurityManager.setViewAcls(attempt.sparkUser, + appListener.viewAcls.getOrElse("")) + ui + } } } } catch { @@ -219,6 +216,58 @@ private[history] class FsHistoryProvider(conf: SparkConf, clock: Clock) } } + override def writeEventLogs( + appId: String, + attemptId: Option[String], + zipStream: ZipOutputStream): Unit = { + + /** + * This method compresses the files passed in, and writes the compressed data out into the + * [[OutputStream]] passed in. Each file is written as a new [[ZipEntry]] with its name being + * the name of the file being compressed. + */ + def zipFileToStream(file: Path, entryName: String, outputStream: ZipOutputStream): Unit = { + val fs = FileSystem.get(hadoopConf) + val inputStream = fs.open(file, 1 * 1024 * 1024) // 1MB Buffer + try { + outputStream.putNextEntry(new ZipEntry(entryName)) + ByteStreams.copy(inputStream, outputStream) + outputStream.closeEntry() + } finally { + inputStream.close() + } + } + + applications.get(appId) match { + case Some(appInfo) => + try { + // If no attempt is specified, or there is no attemptId for attempts, return all attempts + appInfo.attempts.filter { attempt => + attempt.attemptId.isEmpty || attemptId.isEmpty || attempt.attemptId.get == attemptId.get + }.foreach { attempt => + val logPath = new Path(logDir, attempt.logPath) + // If this is a legacy directory, then add the directory to the zipStream and add + // each file to that directory. + if (isLegacyLogDirectory(fs.getFileStatus(logPath))) { + val files = fs.listStatus(logPath) + zipStream.putNextEntry(new ZipEntry(attempt.logPath + "/")) + zipStream.closeEntry() + files.foreach { file => + val path = file.getPath + zipFileToStream(path, attempt.logPath + Path.SEPARATOR + path.getName, zipStream) + } + } else { + zipFileToStream(new Path(logDir, attempt.logPath), attempt.logPath, zipStream) + } + } + } finally { + zipStream.close() + } + case None => throw new SparkException(s"Logs for $appId not found.") + } + } + + /** * Replay the log files in the list and merge the list of old applications with new ones */ @@ -227,8 +276,12 @@ private[history] class FsHistoryProvider(conf: SparkConf, clock: Clock) val newAttempts = logs.flatMap { fileStatus => try { val res = replay(fileStatus, bus) - logInfo(s"Application log ${res.logPath} loaded successfully.") - Some(res) + res match { + case Some(r) => logDebug(s"Application log ${r.logPath} loaded successfully.") + case None => logWarning(s"Failed to load application log ${fileStatus.getPath}. " + + "The application may have not started.") + } + res } catch { case e: Exception => logError( @@ -374,9 +427,11 @@ private[history] class FsHistoryProvider(conf: SparkConf, clock: Clock) /** * Replays the events in the specified log file and returns information about the associated - * application. + * application. Return `None` if the application ID cannot be located. */ - private def replay(eventLog: FileStatus, bus: ReplayListenerBus): FsApplicationAttemptInfo = { + private def replay( + eventLog: FileStatus, + bus: ReplayListenerBus): Option[FsApplicationAttemptInfo] = { val logPath = eventLog.getPath() logInfo(s"Replaying log path: $logPath") val logInput = @@ -390,16 +445,24 @@ private[history] class FsHistoryProvider(conf: SparkConf, clock: Clock) val appCompleted = isApplicationCompleted(eventLog) bus.addListener(appListener) bus.replay(logInput, logPath.toString, !appCompleted) - new FsApplicationAttemptInfo( - logPath.getName(), - appListener.appName.getOrElse(NOT_STARTED), - appListener.appId.getOrElse(logPath.getName()), - appListener.appAttemptId, - appListener.startTime.getOrElse(-1L), - appListener.endTime.getOrElse(-1L), - getModificationTime(eventLog).get, - appListener.sparkUser.getOrElse(NOT_STARTED), - appCompleted) + + // Without an app ID, new logs will render incorrectly in the listing page, so do not list or + // try to show their UI. Some old versions of Spark generate logs without an app ID, so let + // logs generated by those versions go through. + if (appListener.appId.isDefined || !sparkVersionHasAppId(eventLog)) { + Some(new FsApplicationAttemptInfo( + logPath.getName(), + appListener.appName.getOrElse(NOT_STARTED), + appListener.appId.getOrElse(logPath.getName()), + appListener.appAttemptId, + appListener.startTime.getOrElse(-1L), + appListener.endTime.getOrElse(-1L), + getModificationTime(eventLog).get, + appListener.sparkUser.getOrElse(NOT_STARTED), + appCompleted)) + } else { + None + } } finally { logInput.close() } @@ -474,10 +537,34 @@ private[history] class FsHistoryProvider(conf: SparkConf, clock: Clock) } } + /** + * Returns whether the version of Spark that generated logs records app IDs. App IDs were added + * in Spark 1.1. + */ + private def sparkVersionHasAppId(entry: FileStatus): Boolean = { + if (isLegacyLogDirectory(entry)) { + fs.listStatus(entry.getPath()) + .find { status => status.getPath().getName().startsWith(SPARK_VERSION_PREFIX) } + .map { status => + val version = status.getPath().getName().substring(SPARK_VERSION_PREFIX.length()) + version != "1.0" && version != "1.1" + } + .getOrElse(true) + } else { + true + } + } + } -private object FsHistoryProvider { +private[history] object FsHistoryProvider { val DEFAULT_LOG_DIR = "file:/tmp/spark-events" + + // Constants used to parse Spark 1.0.0 log directories. + val LOG_PREFIX = "EVENT_LOG_" + val SPARK_VERSION_PREFIX = EventLoggingListener.SPARK_VERSION_KEY + "_" + val COMPRESSION_CODEC_PREFIX = EventLoggingListener.COMPRESSION_CODEC_KEY + "_" + val APPLICATION_COMPLETE = "APPLICATION_COMPLETE" } private class FsApplicationAttemptInfo( diff --git a/core/src/main/scala/org/apache/spark/deploy/history/HistoryServer.scala b/core/src/main/scala/org/apache/spark/deploy/history/HistoryServer.scala index 517cbe517624..10638afb7490 100644 --- a/core/src/main/scala/org/apache/spark/deploy/history/HistoryServer.scala +++ b/core/src/main/scala/org/apache/spark/deploy/history/HistoryServer.scala @@ -18,6 +18,7 @@ package org.apache.spark.deploy.history import java.util.NoSuchElementException +import java.util.zip.ZipOutputStream import javax.servlet.http.{HttpServlet, HttpServletRequest, HttpServletResponse} import com.google.common.cache._ @@ -25,7 +26,8 @@ import org.eclipse.jetty.servlet.{ServletContextHandler, ServletHolder} import org.apache.spark.{Logging, SecurityManager, SparkConf} import org.apache.spark.deploy.SparkHadoopUtil -import org.apache.spark.status.api.v1.{ApplicationInfo, ApplicationsListResource, JsonRootResource, UIRoot} +import org.apache.spark.status.api.v1.{ApiRootResource, ApplicationInfo, ApplicationsListResource, + UIRoot} import org.apache.spark.ui.{SparkUI, UIUtils, WebUI} import org.apache.spark.ui.JettyUtils._ import org.apache.spark.util.{SignalLogger, Utils} @@ -125,7 +127,7 @@ class HistoryServer( def initialize() { attachPage(new HistoryPage(this)) - attachHandler(JsonRootResource.getJsonServlet(this)) + attachHandler(ApiRootResource.getServletHandler(this)) attachHandler(createStaticHandler(SparkUI.STATIC_RESOURCE_DIR, "/static")) @@ -172,6 +174,13 @@ class HistoryServer( getApplicationList().iterator.map(ApplicationsListResource.appHistoryInfoToPublicAppInfo) } + override def writeEventLogs( + appId: String, + attemptId: Option[String], + zipStream: ZipOutputStream): Unit = { + provider.writeEventLogs(appId, attemptId, zipStream) + } + /** * Returns the provider configuration to show in the listing page. * diff --git a/core/src/main/scala/org/apache/spark/deploy/history/HistoryServerArguments.scala b/core/src/main/scala/org/apache/spark/deploy/history/HistoryServerArguments.scala index a2a97a7877ce..4692d22651c9 100644 --- a/core/src/main/scala/org/apache/spark/deploy/history/HistoryServerArguments.scala +++ b/core/src/main/scala/org/apache/spark/deploy/history/HistoryServerArguments.scala @@ -23,7 +23,7 @@ import org.apache.spark.util.Utils /** * Command-line parser for the master. */ -private[history] class HistoryServerArguments(conf: SparkConf, args: Array[String]) +private[history] class HistoryServerArguments(conf: SparkConf, args: Array[String]) extends Logging { private var propertiesFile: String = null diff --git a/core/src/main/scala/org/apache/spark/deploy/master/ApplicationInfo.scala b/core/src/main/scala/org/apache/spark/deploy/master/ApplicationInfo.scala index 1620e95bea21..aa54ed9360f3 100644 --- a/core/src/main/scala/org/apache/spark/deploy/master/ApplicationInfo.scala +++ b/core/src/main/scala/org/apache/spark/deploy/master/ApplicationInfo.scala @@ -22,10 +22,9 @@ import java.util.Date import scala.collection.mutable import scala.collection.mutable.ArrayBuffer -import akka.actor.ActorRef - import org.apache.spark.annotation.DeveloperApi import org.apache.spark.deploy.ApplicationDescription +import org.apache.spark.rpc.RpcEndpointRef import org.apache.spark.util.Utils private[spark] class ApplicationInfo( @@ -33,7 +32,7 @@ private[spark] class ApplicationInfo( val id: String, val desc: ApplicationDescription, val submitDate: Date, - val driver: ActorRef, + val driver: RpcEndpointRef, defaultCores: Int) extends Serializable { diff --git a/core/src/main/scala/org/apache/spark/deploy/master/Master.scala b/core/src/main/scala/org/apache/spark/deploy/master/Master.scala index fccceb3ea528..48070768f6ed 100644 --- a/core/src/main/scala/org/apache/spark/deploy/master/Master.scala +++ b/core/src/main/scala/org/apache/spark/deploy/master/Master.scala @@ -21,20 +21,18 @@ import java.io.FileNotFoundException import java.net.URLEncoder import java.text.SimpleDateFormat import java.util.Date +import java.util.concurrent.{ScheduledFuture, TimeUnit} import scala.collection.mutable.{ArrayBuffer, HashMap, HashSet} -import scala.concurrent.Await -import scala.concurrent.duration._ import scala.language.postfixOps import scala.util.Random -import akka.actor._ -import akka.pattern.ask -import akka.remote.{DisassociatedEvent, RemotingLifecycleEvent} import akka.serialization.Serialization import akka.serialization.SerializationExtension import org.apache.hadoop.fs.Path +import org.apache.spark.rpc.akka.AkkaRpcEnv +import org.apache.spark.rpc._ import org.apache.spark.{Logging, SecurityManager, SparkConf, SparkException} import org.apache.spark.deploy.{ApplicationDescription, DriverDescription, ExecutorState, SparkHadoopUtil} @@ -47,23 +45,27 @@ import org.apache.spark.deploy.rest.StandaloneRestServer import org.apache.spark.metrics.MetricsSystem import org.apache.spark.scheduler.{EventLoggingListener, ReplayListenerBus} import org.apache.spark.ui.SparkUI -import org.apache.spark.util.{ActorLogReceive, AkkaUtils, RpcUtils, SignalLogger, Utils} +import org.apache.spark.util.{ThreadUtils, SignalLogger, Utils} private[master] class Master( - host: String, - port: Int, + override val rpcEnv: RpcEnv, + address: RpcAddress, webUiPort: Int, val securityMgr: SecurityManager, val conf: SparkConf) - extends Actor with ActorLogReceive with Logging with LeaderElectable { + extends ThreadSafeRpcEndpoint with Logging with LeaderElectable { - import context.dispatcher // to use Akka's scheduler.schedule() + private val forwardMessageThread = + ThreadUtils.newDaemonSingleThreadScheduledExecutor("master-forward-message-thread") + + // TODO Remove it once we don't use akka.serialization.Serialization + private val actorSystem = rpcEnv.asInstanceOf[AkkaRpcEnv].actorSystem private val hadoopConf = SparkHadoopUtil.get.newConfiguration(conf) - private def createDateFormat = new SimpleDateFormat("yyyyMMddHHmmss") // For application IDs + private def createDateFormat = new SimpleDateFormat("yyyyMMddHHmmss") // For application IDs - private val WORKER_TIMEOUT = conf.getLong("spark.worker.timeout", 60) * 1000 + private val WORKER_TIMEOUT_MS = conf.getLong("spark.worker.timeout", 60) * 1000 private val RETAINED_APPLICATIONS = conf.getInt("spark.deploy.retainedApplications", 200) private val RETAINED_DRIVERS = conf.getInt("spark.deploy.retainedDrivers", 200) private val REAPER_ITERATIONS = conf.getInt("spark.dead.worker.persistence", 15) @@ -75,10 +77,10 @@ private[master] class Master( val apps = new HashSet[ApplicationInfo] private val idToWorker = new HashMap[String, WorkerInfo] - private val addressToWorker = new HashMap[Address, WorkerInfo] + private val addressToWorker = new HashMap[RpcAddress, WorkerInfo] - private val actorToApp = new HashMap[ActorRef, ApplicationInfo] - private val addressToApp = new HashMap[Address, ApplicationInfo] + private val endpointToApp = new HashMap[RpcEndpointRef, ApplicationInfo] + private val addressToApp = new HashMap[RpcAddress, ApplicationInfo] private val completedApps = new ArrayBuffer[ApplicationInfo] private var nextAppNumber = 0 private val appIdToUI = new HashMap[String, SparkUI] @@ -89,21 +91,22 @@ private[master] class Master( private val waitingDrivers = new ArrayBuffer[DriverInfo] private var nextDriverNumber = 0 - Utils.checkHost(host, "Expected hostname") + Utils.checkHost(address.host, "Expected hostname") private val masterMetricsSystem = MetricsSystem.createMetricsSystem("master", conf, securityMgr) private val applicationMetricsSystem = MetricsSystem.createMetricsSystem("applications", conf, securityMgr) private val masterSource = new MasterSource(this) - private val webUi = new MasterWebUI(this, webUiPort) + // After onStart, webUi will be set + private var webUi: MasterWebUI = null private val masterPublicAddress = { val envVar = conf.getenv("SPARK_PUBLIC_DNS") - if (envVar != null) envVar else host + if (envVar != null) envVar else address.host } - private val masterUrl = "spark://" + host + ":" + port + private val masterUrl = address.toSparkURL private var masterWebUiUrl: String = _ private var state = RecoveryState.STANDBY @@ -112,7 +115,9 @@ private[master] class Master( private var leaderElectionAgent: LeaderElectionAgent = _ - private var recoveryCompletionTask: Cancellable = _ + private var recoveryCompletionTask: ScheduledFuture[_] = _ + + private var checkForWorkerTimeOutTask: ScheduledFuture[_] = _ // As a temporary workaround before better ways of configuring memory, we allow users to set // a flag that will perform round-robin scheduling across the nodes (spreading out each app @@ -130,20 +135,23 @@ private[master] class Master( private val restServer = if (restServerEnabled) { val port = conf.getInt("spark.master.rest.port", 6066) - Some(new StandaloneRestServer(host, port, conf, self, masterUrl)) + Some(new StandaloneRestServer(address.host, port, conf, self, masterUrl)) } else { None } private val restServerBoundPort = restServer.map(_.start()) - override def preStart() { + override def onStart(): Unit = { logInfo("Starting Spark master at " + masterUrl) logInfo(s"Running Spark version ${org.apache.spark.SPARK_VERSION}") - // Listen for remote client disconnection events, since they don't go through Akka's watch() - context.system.eventStream.subscribe(self, classOf[RemotingLifecycleEvent]) + webUi = new MasterWebUI(this, webUiPort) webUi.bind() masterWebUiUrl = "http://" + masterPublicAddress + ":" + webUi.boundPort - context.system.scheduler.schedule(0 millis, WORKER_TIMEOUT millis, self, CheckForWorkerTimeOut) + checkForWorkerTimeOutTask = forwardMessageThread.scheduleAtFixedRate(new Runnable { + override def run(): Unit = Utils.tryLogNonFatalError { + self.send(CheckForWorkerTimeOut) + } + }, 0, WORKER_TIMEOUT_MS, TimeUnit.MILLISECONDS) masterMetricsSystem.registerSource(masterSource) masterMetricsSystem.start() @@ -157,16 +165,16 @@ private[master] class Master( case "ZOOKEEPER" => logInfo("Persisting recovery state to ZooKeeper") val zkFactory = - new ZooKeeperRecoveryModeFactory(conf, SerializationExtension(context.system)) + new ZooKeeperRecoveryModeFactory(conf, SerializationExtension(actorSystem)) (zkFactory.createPersistenceEngine(), zkFactory.createLeaderElectionAgent(this)) case "FILESYSTEM" => val fsFactory = - new FileSystemRecoveryModeFactory(conf, SerializationExtension(context.system)) + new FileSystemRecoveryModeFactory(conf, SerializationExtension(actorSystem)) (fsFactory.createPersistenceEngine(), fsFactory.createLeaderElectionAgent(this)) case "CUSTOM" => val clazz = Class.forName(conf.get("spark.deploy.recoveryMode.factory")) val factory = clazz.getConstructor(classOf[SparkConf], classOf[Serialization]) - .newInstance(conf, SerializationExtension(context.system)) + .newInstance(conf, SerializationExtension(actorSystem)) .asInstanceOf[StandaloneRecoveryModeFactory] (factory.createPersistenceEngine(), factory.createLeaderElectionAgent(this)) case _ => @@ -176,18 +184,17 @@ private[master] class Master( leaderElectionAgent = leaderElectionAgent_ } - override def preRestart(reason: Throwable, message: Option[Any]) { - super.preRestart(reason, message) // calls postStop()! - logError("Master actor restarted due to exception", reason) - } - - override def postStop() { + override def onStop() { masterMetricsSystem.report() applicationMetricsSystem.report() // prevent the CompleteRecovery message sending to restarted master if (recoveryCompletionTask != null) { - recoveryCompletionTask.cancel() + recoveryCompletionTask.cancel(true) } + if (checkForWorkerTimeOutTask != null) { + checkForWorkerTimeOutTask.cancel(true) + } + forwardMessageThread.shutdownNow() webUi.stop() restServer.foreach(_.stop()) masterMetricsSystem.stop() @@ -197,14 +204,14 @@ private[master] class Master( } override def electedLeader() { - self ! ElectedLeader + self.send(ElectedLeader) } override def revokedLeadership() { - self ! RevokedLeadership + self.send(RevokedLeadership) } - override def receiveWithLogging: PartialFunction[Any, Unit] = { + override def receive: PartialFunction[Any, Unit] = { case ElectedLeader => { val (storedApps, storedDrivers, storedWorkers) = persistenceEngine.readPersistedData() state = if (storedApps.isEmpty && storedDrivers.isEmpty && storedWorkers.isEmpty) { @@ -215,8 +222,11 @@ private[master] class Master( logInfo("I have been elected leader! New state: " + state) if (state == RecoveryState.RECOVERING) { beginRecovery(storedApps, storedDrivers, storedWorkers) - recoveryCompletionTask = context.system.scheduler.scheduleOnce(WORKER_TIMEOUT millis, self, - CompleteRecovery) + recoveryCompletionTask = forwardMessageThread.schedule(new Runnable { + override def run(): Unit = Utils.tryLogNonFatalError { + self.send(CompleteRecovery) + } + }, WORKER_TIMEOUT_MS, TimeUnit.MILLISECONDS) } } @@ -227,111 +237,42 @@ private[master] class Master( System.exit(0) } - case RegisterWorker(id, workerHost, workerPort, cores, memory, workerUiPort, publicAddress) => - { + case RegisterWorker( + id, workerHost, workerPort, workerRef, cores, memory, workerUiPort, publicAddress) => { logInfo("Registering worker %s:%d with %d cores, %s RAM".format( workerHost, workerPort, cores, Utils.megabytesToString(memory))) if (state == RecoveryState.STANDBY) { // ignore, don't send response } else if (idToWorker.contains(id)) { - sender ! RegisterWorkerFailed("Duplicate worker ID") + workerRef.send(RegisterWorkerFailed("Duplicate worker ID")) } else { val worker = new WorkerInfo(id, workerHost, workerPort, cores, memory, - sender, workerUiPort, publicAddress) + workerRef, workerUiPort, publicAddress) if (registerWorker(worker)) { persistenceEngine.addWorker(worker) - sender ! RegisteredWorker(masterUrl, masterWebUiUrl) + workerRef.send(RegisteredWorker(self, masterWebUiUrl)) schedule() } else { - val workerAddress = worker.actor.path.address + val workerAddress = worker.endpoint.address logWarning("Worker registration failed. Attempted to re-register worker at same " + "address: " + workerAddress) - sender ! RegisterWorkerFailed("Attempted to re-register worker at same address: " - + workerAddress) - } - } - } - - case RequestSubmitDriver(description) => { - if (state != RecoveryState.ALIVE) { - val msg = s"${Utils.BACKUP_STANDALONE_MASTER_PREFIX}: $state. " + - "Can only accept driver submissions in ALIVE state." - sender ! SubmitDriverResponse(false, None, msg) - } else { - logInfo("Driver submitted " + description.command.mainClass) - val driver = createDriver(description) - persistenceEngine.addDriver(driver) - waitingDrivers += driver - drivers.add(driver) - schedule() - - // TODO: It might be good to instead have the submission client poll the master to determine - // the current status of the driver. For now it's simply "fire and forget". - - sender ! SubmitDriverResponse(true, Some(driver.id), - s"Driver successfully submitted as ${driver.id}") - } - } - - case RequestKillDriver(driverId) => { - if (state != RecoveryState.ALIVE) { - val msg = s"${Utils.BACKUP_STANDALONE_MASTER_PREFIX}: $state. " + - s"Can only kill drivers in ALIVE state." - sender ! KillDriverResponse(driverId, success = false, msg) - } else { - logInfo("Asked to kill driver " + driverId) - val driver = drivers.find(_.id == driverId) - driver match { - case Some(d) => - if (waitingDrivers.contains(d)) { - waitingDrivers -= d - self ! DriverStateChanged(driverId, DriverState.KILLED, None) - } else { - // We just notify the worker to kill the driver here. The final bookkeeping occurs - // on the return path when the worker submits a state change back to the master - // to notify it that the driver was successfully killed. - d.worker.foreach { w => - w.actor ! KillDriver(driverId) - } - } - // TODO: It would be nice for this to be a synchronous response - val msg = s"Kill request for $driverId submitted" - logInfo(msg) - sender ! KillDriverResponse(driverId, success = true, msg) - case None => - val msg = s"Driver $driverId has already finished or does not exist" - logWarning(msg) - sender ! KillDriverResponse(driverId, success = false, msg) - } - } - } - - case RequestDriverStatus(driverId) => { - if (state != RecoveryState.ALIVE) { - val msg = s"${Utils.BACKUP_STANDALONE_MASTER_PREFIX}: $state. " + - "Can only request driver status in ALIVE state." - sender ! DriverStatusResponse(found = false, None, None, None, Some(new Exception(msg))) - } else { - (drivers ++ completedDrivers).find(_.id == driverId) match { - case Some(driver) => - sender ! DriverStatusResponse(found = true, Some(driver.state), - driver.worker.map(_.id), driver.worker.map(_.hostPort), driver.exception) - case None => - sender ! DriverStatusResponse(found = false, None, None, None, None) + workerRef.send(RegisterWorkerFailed("Attempted to re-register worker at same address: " + + workerAddress)) } } } - case RegisterApplication(description) => { + case RegisterApplication(description, driver) => { + // TODO Prevent repeated registrations from some driver if (state == RecoveryState.STANDBY) { // ignore, don't send response } else { logInfo("Registering app " + description.name) - val app = createApplication(description, sender) + val app = createApplication(description, driver) registerApplication(app) logInfo("Registered app " + description.name + " with ID " + app.id) persistenceEngine.addApplication(app) - sender ! RegisteredApplication(app.id, masterUrl) + driver.send(RegisteredApplication(app.id, self)) schedule() } } @@ -343,7 +284,7 @@ private[master] class Master( val appInfo = idToApp(appId) exec.state = state if (state == ExecutorState.RUNNING) { appInfo.resetRetryCount() } - exec.application.driver ! ExecutorUpdated(execId, state, message, exitStatus) + exec.application.driver.send(ExecutorUpdated(execId, state, message, exitStatus)) if (ExecutorState.isFinished(state)) { // Remove this executor from the worker and app logInfo(s"Removing executor ${exec.fullId} because it is $state") @@ -384,7 +325,7 @@ private[master] class Master( } } - case Heartbeat(workerId) => { + case Heartbeat(workerId, worker) => { idToWorker.get(workerId) match { case Some(workerInfo) => workerInfo.lastHeartbeat = System.currentTimeMillis() @@ -392,7 +333,7 @@ private[master] class Master( if (workers.map(_.id).contains(workerId)) { logWarning(s"Got heartbeat from unregistered worker $workerId." + " Asking it to re-register.") - sender ! ReconnectWorker(masterUrl) + worker.send(ReconnectWorker(masterUrl)) } else { logWarning(s"Got heartbeat from unregistered worker $workerId." + " This worker was never registered, so ignoring the heartbeat.") @@ -444,30 +385,103 @@ private[master] class Master( logInfo(s"Received unregister request from application $applicationId") idToApp.get(applicationId).foreach(finishApplication) - case DisassociatedEvent(_, address, _) => { - // The disconnected client could've been either a worker or an app; remove whichever it was - logInfo(s"$address got disassociated, removing it.") - addressToWorker.get(address).foreach(removeWorker) - addressToApp.get(address).foreach(finishApplication) - if (state == RecoveryState.RECOVERING && canCompleteRecovery) { completeRecovery() } + case CheckForWorkerTimeOut => { + timeOutDeadWorkers() } + } - case RequestMasterState => { - sender ! MasterStateResponse( - host, port, restServerBoundPort, - workers.toArray, apps.toArray, completedApps.toArray, - drivers.toArray, completedDrivers.toArray, state) + override def receiveAndReply(context: RpcCallContext): PartialFunction[Any, Unit] = { + case RequestSubmitDriver(description) => { + if (state != RecoveryState.ALIVE) { + val msg = s"${Utils.BACKUP_STANDALONE_MASTER_PREFIX}: $state. " + + "Can only accept driver submissions in ALIVE state." + context.reply(SubmitDriverResponse(self, false, None, msg)) + } else { + logInfo("Driver submitted " + description.command.mainClass) + val driver = createDriver(description) + persistenceEngine.addDriver(driver) + waitingDrivers += driver + drivers.add(driver) + schedule() + + // TODO: It might be good to instead have the submission client poll the master to determine + // the current status of the driver. For now it's simply "fire and forget". + + context.reply(SubmitDriverResponse(self, true, Some(driver.id), + s"Driver successfully submitted as ${driver.id}")) + } } - case CheckForWorkerTimeOut => { - timeOutDeadWorkers() + case RequestKillDriver(driverId) => { + if (state != RecoveryState.ALIVE) { + val msg = s"${Utils.BACKUP_STANDALONE_MASTER_PREFIX}: $state. " + + s"Can only kill drivers in ALIVE state." + context.reply(KillDriverResponse(self, driverId, success = false, msg)) + } else { + logInfo("Asked to kill driver " + driverId) + val driver = drivers.find(_.id == driverId) + driver match { + case Some(d) => + if (waitingDrivers.contains(d)) { + waitingDrivers -= d + self.send(DriverStateChanged(driverId, DriverState.KILLED, None)) + } else { + // We just notify the worker to kill the driver here. The final bookkeeping occurs + // on the return path when the worker submits a state change back to the master + // to notify it that the driver was successfully killed. + d.worker.foreach { w => + w.endpoint.send(KillDriver(driverId)) + } + } + // TODO: It would be nice for this to be a synchronous response + val msg = s"Kill request for $driverId submitted" + logInfo(msg) + context.reply(KillDriverResponse(self, driverId, success = true, msg)) + case None => + val msg = s"Driver $driverId has already finished or does not exist" + logWarning(msg) + context.reply(KillDriverResponse(self, driverId, success = false, msg)) + } + } + } + + case RequestDriverStatus(driverId) => { + if (state != RecoveryState.ALIVE) { + val msg = s"${Utils.BACKUP_STANDALONE_MASTER_PREFIX}: $state. " + + "Can only request driver status in ALIVE state." + context.reply( + DriverStatusResponse(found = false, None, None, None, Some(new Exception(msg)))) + } else { + (drivers ++ completedDrivers).find(_.id == driverId) match { + case Some(driver) => + context.reply(DriverStatusResponse(found = true, Some(driver.state), + driver.worker.map(_.id), driver.worker.map(_.hostPort), driver.exception)) + case None => + context.reply(DriverStatusResponse(found = false, None, None, None, None)) + } + } + } + + case RequestMasterState => { + context.reply(MasterStateResponse( + address.host, address.port, restServerBoundPort, + workers.toArray, apps.toArray, completedApps.toArray, + drivers.toArray, completedDrivers.toArray, state)) } case BoundPortsRequest => { - sender ! BoundPortsResponse(port, webUi.boundPort, restServerBoundPort) + context.reply(BoundPortsResponse(address.port, webUi.boundPort, restServerBoundPort)) } } + override def onDisconnected(address: RpcAddress): Unit = { + // The disconnected client could've been either a worker or an app; remove whichever it was + logInfo(s"$address got disassociated, removing it.") + addressToWorker.get(address).foreach(removeWorker) + addressToApp.get(address).foreach(finishApplication) + if (state == RecoveryState.RECOVERING && canCompleteRecovery) { completeRecovery() } + } + private def canCompleteRecovery = workers.count(_.state == WorkerState.UNKNOWN) == 0 && apps.count(_.state == ApplicationState.UNKNOWN) == 0 @@ -479,7 +493,7 @@ private[master] class Master( try { registerApplication(app) app.state = ApplicationState.UNKNOWN - app.driver ! MasterChanged(masterUrl, masterWebUiUrl) + app.driver.send(MasterChanged(self, masterWebUiUrl)) } catch { case e: Exception => logInfo("App " + app.id + " had exception on reconnect") } @@ -496,7 +510,7 @@ private[master] class Master( try { registerWorker(worker) worker.state = WorkerState.UNKNOWN - worker.actor ! MasterChanged(masterUrl, masterWebUiUrl) + worker.endpoint.send(MasterChanged(self, masterWebUiUrl)) } catch { case e: Exception => logInfo("Worker " + worker.id + " had exception on reconnect") } @@ -505,10 +519,8 @@ private[master] class Master( private def completeRecovery() { // Ensure "only-once" recovery semantics using a short synchronization period. - synchronized { - if (state != RecoveryState.RECOVERING) { return } - state = RecoveryState.COMPLETING_RECOVERY - } + if (state != RecoveryState.RECOVERING) { return } + state = RecoveryState.COMPLETING_RECOVERY // Kill off any workers and apps that didn't respond to us. workers.filter(_.state == WorkerState.UNKNOWN).foreach(removeWorker) @@ -623,10 +635,10 @@ private[master] class Master( private def launchExecutor(worker: WorkerInfo, exec: ExecutorDesc): Unit = { logInfo("Launching executor " + exec.fullId + " on worker " + worker.id) worker.addExecutor(exec) - worker.actor ! LaunchExecutor(masterUrl, - exec.application.id, exec.id, exec.application.desc, exec.cores, exec.memory) - exec.application.driver ! ExecutorAdded( - exec.id, worker.id, worker.hostPort, exec.cores, exec.memory) + worker.endpoint.send(LaunchExecutor(masterUrl, + exec.application.id, exec.id, exec.application.desc, exec.cores, exec.memory)) + exec.application.driver.send(ExecutorAdded( + exec.id, worker.id, worker.hostPort, exec.cores, exec.memory)) } private def registerWorker(worker: WorkerInfo): Boolean = { @@ -638,7 +650,7 @@ private[master] class Master( workers -= w } - val workerAddress = worker.actor.path.address + val workerAddress = worker.endpoint.address if (addressToWorker.contains(workerAddress)) { val oldWorker = addressToWorker(workerAddress) if (oldWorker.state == WorkerState.UNKNOWN) { @@ -661,11 +673,11 @@ private[master] class Master( logInfo("Removing worker " + worker.id + " on " + worker.host + ":" + worker.port) worker.setState(WorkerState.DEAD) idToWorker -= worker.id - addressToWorker -= worker.actor.path.address + addressToWorker -= worker.endpoint.address for (exec <- worker.executors.values) { logInfo("Telling app of lost executor: " + exec.id) - exec.application.driver ! ExecutorUpdated( - exec.id, ExecutorState.LOST, Some("worker lost"), None) + exec.application.driver.send(ExecutorUpdated( + exec.id, ExecutorState.LOST, Some("worker lost"), None)) exec.application.removeExecutor(exec) } for (driver <- worker.drivers.values) { @@ -687,14 +699,15 @@ private[master] class Master( schedule() } - private def createApplication(desc: ApplicationDescription, driver: ActorRef): ApplicationInfo = { + private def createApplication(desc: ApplicationDescription, driver: RpcEndpointRef): + ApplicationInfo = { val now = System.currentTimeMillis() val date = new Date(now) new ApplicationInfo(now, newApplicationId(date), desc, date, driver, defaultCores) } private def registerApplication(app: ApplicationInfo): Unit = { - val appAddress = app.driver.path.address + val appAddress = app.driver.address if (addressToApp.contains(appAddress)) { logInfo("Attempted to re-register application at same address: " + appAddress) return @@ -703,7 +716,7 @@ private[master] class Master( applicationMetricsSystem.registerSource(app.appSource) apps += app idToApp(app.id) = app - actorToApp(app.driver) = app + endpointToApp(app.driver) = app addressToApp(appAddress) = app waitingApps += app } @@ -717,8 +730,8 @@ private[master] class Master( logInfo("Removing app " + app.id) apps -= app idToApp -= app.id - actorToApp -= app.driver - addressToApp -= app.driver.path.address + endpointToApp -= app.driver + addressToApp -= app.driver.address if (completedApps.size >= RETAINED_APPLICATIONS) { val toRemove = math.max(RETAINED_APPLICATIONS / 10, 1) completedApps.take(toRemove).foreach( a => { @@ -735,19 +748,19 @@ private[master] class Master( for (exec <- app.executors.values) { exec.worker.removeExecutor(exec) - exec.worker.actor ! KillExecutor(masterUrl, exec.application.id, exec.id) + exec.worker.endpoint.send(KillExecutor(masterUrl, exec.application.id, exec.id)) exec.state = ExecutorState.KILLED } app.markFinished(state) if (state != ApplicationState.FINISHED) { - app.driver ! ApplicationRemoved(state.toString) + app.driver.send(ApplicationRemoved(state.toString)) } persistenceEngine.removeApplication(app) schedule() // Tell all workers that the application has finished, so they can clean up any app state. workers.foreach { w => - w.actor ! ApplicationFinished(app.id) + w.endpoint.send(ApplicationFinished(app.id)) } } } @@ -768,7 +781,7 @@ private[master] class Master( } val eventLogFilePrefix = EventLoggingListener.getLogPath( - eventLogDir, app.id, None, app.desc.eventLogCodec) + eventLogDir, app.id, app.desc.eventLogCodec) val fs = Utils.getHadoopFileSystem(eventLogDir, hadoopConf) val inProgressExists = fs.exists(new Path(eventLogFilePrefix + EventLoggingListener.IN_PROGRESS)) @@ -832,14 +845,14 @@ private[master] class Master( private def timeOutDeadWorkers() { // Copy the workers into an array so we don't modify the hashset while iterating through it val currentTime = System.currentTimeMillis() - val toRemove = workers.filter(_.lastHeartbeat < currentTime - WORKER_TIMEOUT).toArray + val toRemove = workers.filter(_.lastHeartbeat < currentTime - WORKER_TIMEOUT_MS).toArray for (worker <- toRemove) { if (worker.state != WorkerState.DEAD) { logWarning("Removing %s because we got no heartbeat in %d seconds".format( - worker.id, WORKER_TIMEOUT/1000)) + worker.id, WORKER_TIMEOUT_MS / 1000)) removeWorker(worker) } else { - if (worker.lastHeartbeat < currentTime - ((REAPER_ITERATIONS + 1) * WORKER_TIMEOUT)) { + if (worker.lastHeartbeat < currentTime - ((REAPER_ITERATIONS + 1) * WORKER_TIMEOUT_MS)) { workers -= worker // we've seen this DEAD worker in the UI, etc. for long enough; cull it } } @@ -862,7 +875,7 @@ private[master] class Master( logInfo("Launching driver " + driver.id + " on worker " + worker.id) worker.addDriver(driver) driver.worker = Some(worker) - worker.actor ! LaunchDriver(driver.id, driver.desc) + worker.endpoint.send(LaunchDriver(driver.id, driver.desc)) driver.state = DriverState.RUNNING } @@ -891,57 +904,33 @@ private[master] class Master( } private[deploy] object Master extends Logging { - val systemName = "sparkMaster" - private val actorName = "Master" + val SYSTEM_NAME = "sparkMaster" + val ENDPOINT_NAME = "Master" def main(argStrings: Array[String]) { SignalLogger.register(log) val conf = new SparkConf val args = new MasterArguments(argStrings, conf) - val (actorSystem, _, _, _) = startSystemAndActor(args.host, args.port, args.webUiPort, conf) - actorSystem.awaitTermination() - } - - /** - * Returns an `akka.tcp://...` URL for the Master actor given a sparkUrl `spark://host:port`. - * - * @throws SparkException if the url is invalid - */ - def toAkkaUrl(sparkUrl: String, protocol: String): String = { - val (host, port) = Utils.extractHostPortFromSparkUrl(sparkUrl) - AkkaUtils.address(protocol, systemName, host, port, actorName) - } - - /** - * Returns an akka `Address` for the Master actor given a sparkUrl `spark://host:port`. - * - * @throws SparkException if the url is invalid - */ - def toAkkaAddress(sparkUrl: String, protocol: String): Address = { - val (host, port) = Utils.extractHostPortFromSparkUrl(sparkUrl) - Address(protocol, systemName, host, port) + val (rpcEnv, _, _) = startRpcEnvAndEndpoint(args.host, args.port, args.webUiPort, conf) + rpcEnv.awaitTermination() } /** - * Start the Master and return a four tuple of: - * (1) The Master actor system - * (2) The bound port - * (3) The web UI bound port - * (4) The REST server bound port, if any + * Start the Master and return a three tuple of: + * (1) The Master RpcEnv + * (2) The web UI bound port + * (3) The REST server bound port, if any */ - def startSystemAndActor( + def startRpcEnvAndEndpoint( host: String, port: Int, webUiPort: Int, - conf: SparkConf): (ActorSystem, Int, Int, Option[Int]) = { + conf: SparkConf): (RpcEnv, Int, Option[Int]) = { val securityMgr = new SecurityManager(conf) - val (actorSystem, boundPort) = AkkaUtils.createActorSystem(systemName, host, port, conf = conf, - securityManager = securityMgr) - val actor = actorSystem.actorOf( - Props(classOf[Master], host, boundPort, webUiPort, securityMgr, conf), actorName) - val timeout = RpcUtils.askTimeout(conf) - val portsRequest = actor.ask(BoundPortsRequest)(timeout) - val portsResponse = Await.result(portsRequest, timeout).asInstanceOf[BoundPortsResponse] - (actorSystem, boundPort, portsResponse.webUIPort, portsResponse.restPort) + val rpcEnv = RpcEnv.create(SYSTEM_NAME, host, port, conf, securityMgr) + val masterEndpoint = rpcEnv.setupEndpoint(ENDPOINT_NAME, + new Master(rpcEnv, rpcEnv.address, webUiPort, securityMgr, conf)) + val portsResponse = masterEndpoint.askWithRetry[BoundPortsResponse](BoundPortsRequest) + (rpcEnv, portsResponse.webUIPort, portsResponse.restPort) } } diff --git a/core/src/main/scala/org/apache/spark/deploy/master/MasterMessages.scala b/core/src/main/scala/org/apache/spark/deploy/master/MasterMessages.scala index 15c6296888f7..68c937188b33 100644 --- a/core/src/main/scala/org/apache/spark/deploy/master/MasterMessages.scala +++ b/core/src/main/scala/org/apache/spark/deploy/master/MasterMessages.scala @@ -28,7 +28,7 @@ private[master] object MasterMessages { case object RevokedLeadership - // Actor System to Master + // Master to itself case object CheckForWorkerTimeOut diff --git a/core/src/main/scala/org/apache/spark/deploy/master/WorkerInfo.scala b/core/src/main/scala/org/apache/spark/deploy/master/WorkerInfo.scala index 9b3d48c6edc8..f75196660520 100644 --- a/core/src/main/scala/org/apache/spark/deploy/master/WorkerInfo.scala +++ b/core/src/main/scala/org/apache/spark/deploy/master/WorkerInfo.scala @@ -19,9 +19,7 @@ package org.apache.spark.deploy.master import scala.collection.mutable -import akka.actor.ActorRef - -import org.apache.spark.annotation.DeveloperApi +import org.apache.spark.rpc.RpcEndpointRef import org.apache.spark.util.Utils private[spark] class WorkerInfo( @@ -30,7 +28,7 @@ private[spark] class WorkerInfo( val port: Int, val cores: Int, val memory: Int, - val actor: ActorRef, + val endpoint: RpcEndpointRef, val webUiPort: Int, val publicAddress: String) extends Serializable { @@ -107,4 +105,6 @@ private[spark] class WorkerInfo( def setState(state: WorkerState.Value): Unit = { this.state = state } + + def isAlive(): Boolean = this.state == WorkerState.ALIVE } diff --git a/core/src/main/scala/org/apache/spark/deploy/master/ZooKeeperLeaderElectionAgent.scala b/core/src/main/scala/org/apache/spark/deploy/master/ZooKeeperLeaderElectionAgent.scala index 52758d6a7c4b..6fdff86f66e0 100644 --- a/core/src/main/scala/org/apache/spark/deploy/master/ZooKeeperLeaderElectionAgent.scala +++ b/core/src/main/scala/org/apache/spark/deploy/master/ZooKeeperLeaderElectionAgent.scala @@ -17,10 +17,7 @@ package org.apache.spark.deploy.master -import akka.actor.ActorRef - import org.apache.spark.{Logging, SparkConf} -import org.apache.spark.deploy.master.MasterMessages._ import org.apache.curator.framework.CuratorFramework import org.apache.curator.framework.recipes.leader.{LeaderLatchListener, LeaderLatch} import org.apache.spark.deploy.SparkCuratorUtil diff --git a/core/src/main/scala/org/apache/spark/deploy/master/ZooKeeperPersistenceEngine.scala b/core/src/main/scala/org/apache/spark/deploy/master/ZooKeeperPersistenceEngine.scala index 80db6d474b5c..328d95a7a0c6 100644 --- a/core/src/main/scala/org/apache/spark/deploy/master/ZooKeeperPersistenceEngine.scala +++ b/core/src/main/scala/org/apache/spark/deploy/master/ZooKeeperPersistenceEngine.scala @@ -32,7 +32,7 @@ import org.apache.spark.deploy.SparkCuratorUtil private[master] class ZooKeeperPersistenceEngine(conf: SparkConf, val serialization: Serialization) extends PersistenceEngine with Logging { - + private val WORKING_DIR = conf.get("spark.deploy.zookeeper.dir", "/spark") + "/master_status" private val zk: CuratorFramework = SparkCuratorUtil.newClient(conf) diff --git a/core/src/main/scala/org/apache/spark/deploy/master/ui/ApplicationPage.scala b/core/src/main/scala/org/apache/spark/deploy/master/ui/ApplicationPage.scala index 06e265f99e23..e28e7e379ac9 100644 --- a/core/src/main/scala/org/apache/spark/deploy/master/ui/ApplicationPage.scala +++ b/core/src/main/scala/org/apache/spark/deploy/master/ui/ApplicationPage.scala @@ -19,11 +19,8 @@ package org.apache.spark.deploy.master.ui import javax.servlet.http.HttpServletRequest -import scala.concurrent.Await import scala.xml.Node -import akka.pattern.ask - import org.apache.spark.deploy.ExecutorState import org.apache.spark.deploy.DeployMessages.{MasterStateResponse, RequestMasterState} import org.apache.spark.deploy.master.ExecutorDesc @@ -32,14 +29,12 @@ import org.apache.spark.util.Utils private[ui] class ApplicationPage(parent: MasterWebUI) extends WebUIPage("app") { - private val master = parent.masterActorRef - private val timeout = parent.timeout + private val master = parent.masterEndpointRef /** Executor details for a particular application */ def render(request: HttpServletRequest): Seq[Node] = { val appId = request.getParameter("appId") - val stateFuture = (master ? RequestMasterState)(timeout).mapTo[MasterStateResponse] - val state = Await.result(stateFuture, timeout) + val state = master.askWithRetry[MasterStateResponse](RequestMasterState) val app = state.activeApps.find(_.id == appId).getOrElse({ state.completedApps.find(_.id == appId).getOrElse(null) }) diff --git a/core/src/main/scala/org/apache/spark/deploy/master/ui/MasterPage.scala b/core/src/main/scala/org/apache/spark/deploy/master/ui/MasterPage.scala index 756927682cd2..c3e20ebf8d6e 100644 --- a/core/src/main/scala/org/apache/spark/deploy/master/ui/MasterPage.scala +++ b/core/src/main/scala/org/apache/spark/deploy/master/ui/MasterPage.scala @@ -19,25 +19,21 @@ package org.apache.spark.deploy.master.ui import javax.servlet.http.HttpServletRequest -import scala.concurrent.Await import scala.xml.Node -import akka.pattern.ask import org.json4s.JValue import org.apache.spark.deploy.JsonProtocol -import org.apache.spark.deploy.DeployMessages.{RequestKillDriver, MasterStateResponse, RequestMasterState} +import org.apache.spark.deploy.DeployMessages.{KillDriverResponse, RequestKillDriver, MasterStateResponse, RequestMasterState} import org.apache.spark.deploy.master._ import org.apache.spark.ui.{WebUIPage, UIUtils} import org.apache.spark.util.Utils private[ui] class MasterPage(parent: MasterWebUI) extends WebUIPage("") { - private val master = parent.masterActorRef - private val timeout = parent.timeout + private val master = parent.masterEndpointRef def getMasterState: MasterStateResponse = { - val stateFuture = (master ? RequestMasterState)(timeout).mapTo[MasterStateResponse] - Await.result(stateFuture, timeout) + master.askWithRetry[MasterStateResponse](RequestMasterState) } override def renderJson(request: HttpServletRequest): JValue = { @@ -53,7 +49,9 @@ private[ui] class MasterPage(parent: MasterWebUI) extends WebUIPage("") { } def handleDriverKillRequest(request: HttpServletRequest): Unit = { - handleKillRequest(request, id => { master ! RequestKillDriver(id) }) + handleKillRequest(request, id => { + master.ask[KillDriverResponse](RequestKillDriver(id)) + }) } private def handleKillRequest(request: HttpServletRequest, action: String => Unit): Unit = { @@ -75,6 +73,7 @@ private[ui] class MasterPage(parent: MasterWebUI) extends WebUIPage("") { val workerHeaders = Seq("Worker Id", "Address", "State", "Cores", "Memory") val workers = state.workers.sortBy(_.id) + val aliveWorkers = state.workers.filter(_.state == WorkerState.ALIVE) val workerTable = UIUtils.listingTable(workerHeaders, workerRow, workers) val appHeaders = Seq("Application ID", "Name", "Cores", "Memory per Node", "Submitted Time", @@ -108,12 +107,12 @@ private[ui] class MasterPage(parent: MasterWebUI) extends WebUIPage("") { }.getOrElse { Seq.empty } } -
  • Workers: {state.workers.size}
  • -
  • Cores: {state.workers.map(_.cores).sum} Total, - {state.workers.map(_.coresUsed).sum} Used
  • -
  • Memory: - {Utils.megabytesToString(state.workers.map(_.memory).sum)} Total, - {Utils.megabytesToString(state.workers.map(_.memoryUsed).sum)} Used
  • +
  • Alive Workers: {aliveWorkers.size}
  • +
  • Cores in use: {aliveWorkers.map(_.cores).sum} Total, + {aliveWorkers.map(_.coresUsed).sum} Used
  • +
  • Memory in use: + {Utils.megabytesToString(aliveWorkers.map(_.memory).sum)} Total, + {Utils.megabytesToString(aliveWorkers.map(_.memoryUsed).sum)} Used
  • Applications: {state.activeApps.size} Running, {state.completedApps.size} Completed
  • diff --git a/core/src/main/scala/org/apache/spark/deploy/master/ui/MasterWebUI.scala b/core/src/main/scala/org/apache/spark/deploy/master/ui/MasterWebUI.scala index eb26e9f99c70..6174fc11f83d 100644 --- a/core/src/main/scala/org/apache/spark/deploy/master/ui/MasterWebUI.scala +++ b/core/src/main/scala/org/apache/spark/deploy/master/ui/MasterWebUI.scala @@ -19,10 +19,10 @@ package org.apache.spark.deploy.master.ui import org.apache.spark.Logging import org.apache.spark.deploy.master.Master -import org.apache.spark.status.api.v1.{ApplicationsListResource, ApplicationInfo, JsonRootResource, UIRoot} +import org.apache.spark.status.api.v1.{ApiRootResource, ApplicationsListResource, ApplicationInfo, + UIRoot} import org.apache.spark.ui.{SparkUI, WebUI} import org.apache.spark.ui.JettyUtils._ -import org.apache.spark.util.RpcUtils /** * Web UI server for the standalone master. @@ -32,8 +32,7 @@ class MasterWebUI(val master: Master, requestedPort: Int) extends WebUI(master.securityMgr, requestedPort, master.conf, name = "MasterUI") with Logging with UIRoot { - val masterActorRef = master.self - val timeout = RpcUtils.askTimeout(master.conf) + val masterEndpointRef = master.self val killEnabled = master.conf.getBoolean("spark.ui.killEnabled", true) val masterPage = new MasterPage(this) @@ -47,7 +46,7 @@ class MasterWebUI(val master: Master, requestedPort: Int) attachPage(new HistoryNotFoundPage(this)) attachPage(masterPage) attachHandler(createStaticHandler(MasterWebUI.STATIC_RESOURCE_DIR, "/static")) - attachHandler(JsonRootResource.getJsonServlet(this)) + attachHandler(ApiRootResource.getServletHandler(this)) attachHandler(createRedirectHandler( "/app/kill", "/", masterPage.handleAppKillRequest, httpMethods = Set("POST"))) attachHandler(createRedirectHandler( diff --git a/core/src/main/scala/org/apache/spark/deploy/mesos/ui/DriverPage.scala b/core/src/main/scala/org/apache/spark/deploy/mesos/ui/DriverPage.scala index be8560d10fc6..e8ef60bd5428 100644 --- a/core/src/main/scala/org/apache/spark/deploy/mesos/ui/DriverPage.scala +++ b/core/src/main/scala/org/apache/spark/deploy/mesos/ui/DriverPage.scala @@ -68,7 +68,7 @@ private[ui] class DriverPage(parent: MesosClusterUI) extends WebUIPage("driver") retryHeaders, retryRow, Iterable.apply(driverState.description.retryState)) val content =

    Driver state information for driver id {driverId}

    - Back to Drivers + Back to Drivers

    Driver state: {driverState.state}

    diff --git a/core/src/main/scala/org/apache/spark/deploy/rest/RestSubmissionClient.scala b/core/src/main/scala/org/apache/spark/deploy/rest/RestSubmissionClient.scala index 6078f50518ba..1fe956320a1b 100644 --- a/core/src/main/scala/org/apache/spark/deploy/rest/RestSubmissionClient.scala +++ b/core/src/main/scala/org/apache/spark/deploy/rest/RestSubmissionClient.scala @@ -57,7 +57,11 @@ private[spark] class RestSubmissionClient(master: String) extends Logging { private val supportedMasterPrefixes = Seq("spark://", "mesos://") - private val masters: Array[String] = Utils.parseStandaloneMasterUrls(master) + private val masters: Array[String] = if (master.startsWith("spark://")) { + Utils.parseStandaloneMasterUrls(master) + } else { + Array(master) + } // Set of masters that lost contact with us, used to keep track of // whether there are masters still alive for us to communicate with diff --git a/core/src/main/scala/org/apache/spark/deploy/rest/StandaloneRestServer.scala b/core/src/main/scala/org/apache/spark/deploy/rest/StandaloneRestServer.scala index 502b9bb701cc..d5b9bcab1423 100644 --- a/core/src/main/scala/org/apache/spark/deploy/rest/StandaloneRestServer.scala +++ b/core/src/main/scala/org/apache/spark/deploy/rest/StandaloneRestServer.scala @@ -20,10 +20,10 @@ package org.apache.spark.deploy.rest import java.io.File import javax.servlet.http.HttpServletResponse -import akka.actor.ActorRef import org.apache.spark.deploy.ClientArguments._ import org.apache.spark.deploy.{Command, DeployMessages, DriverDescription} -import org.apache.spark.util.{AkkaUtils, RpcUtils, Utils} +import org.apache.spark.rpc.RpcEndpointRef +import org.apache.spark.util.Utils import org.apache.spark.{SPARK_VERSION => sparkVersion, SparkConf} /** @@ -45,35 +45,34 @@ import org.apache.spark.{SPARK_VERSION => sparkVersion, SparkConf} * @param host the address this server should bind to * @param requestedPort the port this server will attempt to bind to * @param masterConf the conf used by the Master - * @param masterActor reference to the Master actor to which requests can be sent + * @param masterEndpoint reference to the Master endpoint to which requests can be sent * @param masterUrl the URL of the Master new drivers will attempt to connect to */ private[deploy] class StandaloneRestServer( host: String, requestedPort: Int, masterConf: SparkConf, - masterActor: ActorRef, + masterEndpoint: RpcEndpointRef, masterUrl: String) extends RestSubmissionServer(host, requestedPort, masterConf) { protected override val submitRequestServlet = - new StandaloneSubmitRequestServlet(masterActor, masterUrl, masterConf) + new StandaloneSubmitRequestServlet(masterEndpoint, masterUrl, masterConf) protected override val killRequestServlet = - new StandaloneKillRequestServlet(masterActor, masterConf) + new StandaloneKillRequestServlet(masterEndpoint, masterConf) protected override val statusRequestServlet = - new StandaloneStatusRequestServlet(masterActor, masterConf) + new StandaloneStatusRequestServlet(masterEndpoint, masterConf) } /** * A servlet for handling kill requests passed to the [[StandaloneRestServer]]. */ -private[rest] class StandaloneKillRequestServlet(masterActor: ActorRef, conf: SparkConf) +private[rest] class StandaloneKillRequestServlet(masterEndpoint: RpcEndpointRef, conf: SparkConf) extends KillRequestServlet { protected def handleKill(submissionId: String): KillSubmissionResponse = { - val askTimeout = RpcUtils.askTimeout(conf) - val response = AkkaUtils.askWithReply[DeployMessages.KillDriverResponse]( - DeployMessages.RequestKillDriver(submissionId), masterActor, askTimeout) + val response = masterEndpoint.askWithRetry[DeployMessages.KillDriverResponse]( + DeployMessages.RequestKillDriver(submissionId)) val k = new KillSubmissionResponse k.serverSparkVersion = sparkVersion k.message = response.message @@ -86,13 +85,12 @@ private[rest] class StandaloneKillRequestServlet(masterActor: ActorRef, conf: Sp /** * A servlet for handling status requests passed to the [[StandaloneRestServer]]. */ -private[rest] class StandaloneStatusRequestServlet(masterActor: ActorRef, conf: SparkConf) +private[rest] class StandaloneStatusRequestServlet(masterEndpoint: RpcEndpointRef, conf: SparkConf) extends StatusRequestServlet { protected def handleStatus(submissionId: String): SubmissionStatusResponse = { - val askTimeout = RpcUtils.askTimeout(conf) - val response = AkkaUtils.askWithReply[DeployMessages.DriverStatusResponse]( - DeployMessages.RequestDriverStatus(submissionId), masterActor, askTimeout) + val response = masterEndpoint.askWithRetry[DeployMessages.DriverStatusResponse]( + DeployMessages.RequestDriverStatus(submissionId)) val message = response.exception.map { s"Exception from the cluster:\n" + formatException(_) } val d = new SubmissionStatusResponse d.serverSparkVersion = sparkVersion @@ -110,7 +108,7 @@ private[rest] class StandaloneStatusRequestServlet(masterActor: ActorRef, conf: * A servlet for handling submit requests passed to the [[StandaloneRestServer]]. */ private[rest] class StandaloneSubmitRequestServlet( - masterActor: ActorRef, + masterEndpoint: RpcEndpointRef, masterUrl: String, conf: SparkConf) extends SubmitRequestServlet { @@ -175,10 +173,9 @@ private[rest] class StandaloneSubmitRequestServlet( responseServlet: HttpServletResponse): SubmitRestProtocolResponse = { requestMessage match { case submitRequest: CreateSubmissionRequest => - val askTimeout = RpcUtils.askTimeout(conf) val driverDescription = buildDriverDescription(submitRequest) - val response = AkkaUtils.askWithReply[DeployMessages.SubmitDriverResponse]( - DeployMessages.RequestSubmitDriver(driverDescription), masterActor, askTimeout) + val response = masterEndpoint.askWithRetry[DeployMessages.SubmitDriverResponse]( + DeployMessages.RequestSubmitDriver(driverDescription)) val submitResponse = new CreateSubmissionResponse submitResponse.serverSparkVersion = sparkVersion submitResponse.message = response.message diff --git a/core/src/main/scala/org/apache/spark/deploy/rest/mesos/MesosRestServer.scala b/core/src/main/scala/org/apache/spark/deploy/rest/mesos/MesosRestServer.scala index 8198296eeb34..868cc35d06ef 100644 --- a/core/src/main/scala/org/apache/spark/deploy/rest/mesos/MesosRestServer.scala +++ b/core/src/main/scala/org/apache/spark/deploy/rest/mesos/MesosRestServer.scala @@ -59,7 +59,7 @@ private[mesos] class MesosSubmitRequestServlet( extends SubmitRequestServlet { private val DEFAULT_SUPERVISE = false - private val DEFAULT_MEMORY = 512 // mb + private val DEFAULT_MEMORY = Utils.DEFAULT_DRIVER_MEM_MB // mb private val DEFAULT_CORES = 1.0 private val nextDriverNumber = new AtomicLong(0) diff --git a/core/src/main/scala/org/apache/spark/deploy/worker/CommandUtils.scala b/core/src/main/scala/org/apache/spark/deploy/worker/CommandUtils.scala index 0a1d60f58bc5..45a3f4304543 100644 --- a/core/src/main/scala/org/apache/spark/deploy/worker/CommandUtils.scala +++ b/core/src/main/scala/org/apache/spark/deploy/worker/CommandUtils.scala @@ -24,6 +24,7 @@ import scala.collection.JavaConversions._ import scala.collection.Map import org.apache.spark.Logging +import org.apache.spark.SecurityManager import org.apache.spark.deploy.Command import org.apache.spark.launcher.WorkerCommandBuilder import org.apache.spark.util.Utils @@ -40,12 +41,14 @@ object CommandUtils extends Logging { */ def buildProcessBuilder( command: Command, + securityMgr: SecurityManager, memory: Int, sparkHome: String, substituteArguments: String => String, classPaths: Seq[String] = Seq[String](), env: Map[String, String] = sys.env): ProcessBuilder = { - val localCommand = buildLocalCommand(command, substituteArguments, classPaths, env) + val localCommand = buildLocalCommand( + command, securityMgr, substituteArguments, classPaths, env) val commandSeq = buildCommandSeq(localCommand, memory, sparkHome) val builder = new ProcessBuilder(commandSeq: _*) val environment = builder.environment() @@ -69,6 +72,7 @@ object CommandUtils extends Logging { */ private def buildLocalCommand( command: Command, + securityMgr: SecurityManager, substituteArguments: String => String, classPath: Seq[String] = Seq[String](), env: Map[String, String]): Command = { @@ -76,20 +80,26 @@ object CommandUtils extends Logging { val libraryPathEntries = command.libraryPathEntries val cmdLibraryPath = command.environment.get(libraryPathName) - val newEnvironment = if (libraryPathEntries.nonEmpty && libraryPathName.nonEmpty) { + var newEnvironment = if (libraryPathEntries.nonEmpty && libraryPathName.nonEmpty) { val libraryPaths = libraryPathEntries ++ cmdLibraryPath ++ env.get(libraryPathName) command.environment + ((libraryPathName, libraryPaths.mkString(File.pathSeparator))) } else { command.environment } + // set auth secret to env variable if needed + if (securityMgr.isAuthenticationEnabled) { + newEnvironment += (SecurityManager.ENV_AUTH_SECRET -> securityMgr.getSecretKey) + } + Command( command.mainClass, command.arguments.map(substituteArguments), newEnvironment, command.classPathEntries ++ classPath, Seq[String](), // library path already captured in environment variable - command.javaOpts) + // filter out auth secret from java options + command.javaOpts.filterNot(_.startsWith("-D" + SecurityManager.SPARK_AUTH_SECRET_CONF))) } /** Spawn a thread that will redirect a given stream to a file */ diff --git a/core/src/main/scala/org/apache/spark/deploy/worker/DriverRunner.scala b/core/src/main/scala/org/apache/spark/deploy/worker/DriverRunner.scala index ef7a703bffe6..ec51c3d935d8 100644 --- a/core/src/main/scala/org/apache/spark/deploy/worker/DriverRunner.scala +++ b/core/src/main/scala/org/apache/spark/deploy/worker/DriverRunner.scala @@ -21,7 +21,6 @@ import java.io._ import scala.collection.JavaConversions._ -import akka.actor.ActorRef import com.google.common.base.Charsets.UTF_8 import com.google.common.io.Files import org.apache.hadoop.fs.Path @@ -31,6 +30,7 @@ import org.apache.spark.deploy.{DriverDescription, SparkHadoopUtil} import org.apache.spark.deploy.DeployMessages.DriverStateChanged import org.apache.spark.deploy.master.DriverState import org.apache.spark.deploy.master.DriverState.DriverState +import org.apache.spark.rpc.RpcEndpointRef import org.apache.spark.util.{Utils, Clock, SystemClock} /** @@ -43,7 +43,7 @@ private[deploy] class DriverRunner( val workDir: File, val sparkHome: File, val driverDesc: DriverDescription, - val worker: ActorRef, + val worker: RpcEndpointRef, val workerUrl: String, val securityManager: SecurityManager) extends Logging { @@ -85,8 +85,8 @@ private[deploy] class DriverRunner( } // TODO: If we add ability to submit multiple jars they should also be added here - val builder = CommandUtils.buildProcessBuilder(driverDesc.command, driverDesc.mem, - sparkHome.getAbsolutePath, substituteVariables) + val builder = CommandUtils.buildProcessBuilder(driverDesc.command, securityManager, + driverDesc.mem, sparkHome.getAbsolutePath, substituteVariables) launchDriver(builder, driverDir, driverDesc.supervise) } catch { @@ -107,7 +107,7 @@ private[deploy] class DriverRunner( finalState = Some(state) - worker ! DriverStateChanged(driverId, state, finalException) + worker.send(DriverStateChanged(driverId, state, finalException)) } }.start() } diff --git a/core/src/main/scala/org/apache/spark/deploy/worker/ExecutorRunner.scala b/core/src/main/scala/org/apache/spark/deploy/worker/ExecutorRunner.scala index 7aa85b732fc8..29a504228557 100644 --- a/core/src/main/scala/org/apache/spark/deploy/worker/ExecutorRunner.scala +++ b/core/src/main/scala/org/apache/spark/deploy/worker/ExecutorRunner.scala @@ -21,11 +21,11 @@ import java.io._ import scala.collection.JavaConversions._ -import akka.actor.ActorRef import com.google.common.base.Charsets.UTF_8 import com.google.common.io.Files -import org.apache.spark.{SparkConf, Logging} +import org.apache.spark.rpc.RpcEndpointRef +import org.apache.spark.{SecurityManager, SparkConf, Logging} import org.apache.spark.deploy.{ApplicationDescription, ExecutorState} import org.apache.spark.deploy.DeployMessages.ExecutorStateChanged import org.apache.spark.util.Utils @@ -41,7 +41,7 @@ private[deploy] class ExecutorRunner( val appDesc: ApplicationDescription, val cores: Int, val memory: Int, - val worker: ActorRef, + val worker: RpcEndpointRef, val workerId: String, val host: String, val webUiPort: Int, @@ -91,7 +91,7 @@ private[deploy] class ExecutorRunner( process.destroy() exitCode = Some(process.waitFor()) } - worker ! ExecutorStateChanged(appId, execId, state, message, exitCode) + worker.send(ExecutorStateChanged(appId, execId, state, message, exitCode)) } /** Stop this executor runner, including killing the process it launched */ @@ -125,8 +125,8 @@ private[deploy] class ExecutorRunner( private def fetchAndRunExecutor() { try { // Launch the process - val builder = CommandUtils.buildProcessBuilder(appDesc.command, memory, - sparkHome.getAbsolutePath, substituteVariables) + val builder = CommandUtils.buildProcessBuilder(appDesc.command, new SecurityManager(conf), + memory, sparkHome.getAbsolutePath, substituteVariables) val command = builder.command() logInfo("Launch command: " + command.mkString("\"", "\" \"", "\"")) @@ -159,7 +159,7 @@ private[deploy] class ExecutorRunner( val exitCode = process.waitFor() state = ExecutorState.EXITED val message = "Command exited with code " + exitCode - worker ! ExecutorStateChanged(appId, execId, state, Some(message), Some(exitCode)) + worker.send(ExecutorStateChanged(appId, execId, state, Some(message), Some(exitCode))) } catch { case interrupted: InterruptedException => { logInfo("Runner thread for executor " + fullId + " interrupted") diff --git a/core/src/main/scala/org/apache/spark/deploy/worker/Worker.scala b/core/src/main/scala/org/apache/spark/deploy/worker/Worker.scala index c8df024dda35..82e9578bbcba 100755 --- a/core/src/main/scala/org/apache/spark/deploy/worker/Worker.scala +++ b/core/src/main/scala/org/apache/spark/deploy/worker/Worker.scala @@ -21,15 +21,14 @@ import java.io.File import java.io.IOException import java.text.SimpleDateFormat import java.util.{UUID, Date} +import java.util.concurrent._ +import java.util.concurrent.{Future => JFuture, ScheduledFuture => JScheduledFuture} import scala.collection.JavaConversions._ import scala.collection.mutable.{HashMap, HashSet} -import scala.concurrent.duration._ -import scala.language.postfixOps +import scala.concurrent.ExecutionContext import scala.util.Random - -import akka.actor._ -import akka.remote.{DisassociatedEvent, RemotingLifecycleEvent} +import scala.util.control.NonFatal import org.apache.spark.{Logging, SecurityManager, SparkConf} import org.apache.spark.deploy.{Command, ExecutorDescription, ExecutorState} @@ -38,32 +37,39 @@ import org.apache.spark.deploy.ExternalShuffleService import org.apache.spark.deploy.master.{DriverState, Master} import org.apache.spark.deploy.worker.ui.WorkerWebUI import org.apache.spark.metrics.MetricsSystem -import org.apache.spark.util.{ActorLogReceive, AkkaUtils, SignalLogger, Utils} +import org.apache.spark.rpc._ +import org.apache.spark.util.{ThreadUtils, SignalLogger, Utils} -/** - * @param masterAkkaUrls Each url should be a valid akka url. - */ private[worker] class Worker( - host: String, - port: Int, + override val rpcEnv: RpcEnv, webUiPort: Int, cores: Int, memory: Int, - masterAkkaUrls: Array[String], - actorSystemName: String, - actorName: String, + masterRpcAddresses: Array[RpcAddress], + systemName: String, + endpointName: String, workDirPath: String = null, val conf: SparkConf, val securityMgr: SecurityManager) - extends Actor with ActorLogReceive with Logging { - import context.dispatcher + extends ThreadSafeRpcEndpoint with Logging { + + private val host = rpcEnv.address.host + private val port = rpcEnv.address.port Utils.checkHost(host, "Expected hostname") assert (port > 0) + // A scheduled executor used to send messages at the specified time. + private val forwordMessageScheduler = + ThreadUtils.newDaemonSingleThreadScheduledExecutor("worker-forward-message-scheduler") + + // A separated thread to clean up the workDir. Used to provide the implicit parameter of `Future` + // methods. + private val cleanupThreadExecutor = ExecutionContext.fromExecutorService( + ThreadUtils.newDaemonSingleThreadExecutor("worker-cleanup-thread")) + // For worker and executor IDs private def createDateFormat = new SimpleDateFormat("yyyyMMddHHmmss") - // Send a heartbeat every (heartbeat timeout) / 4 milliseconds private val HEARTBEAT_MILLIS = conf.getLong("spark.worker.timeout", 60) * 1000 / 4 @@ -79,32 +85,26 @@ private[worker] class Worker( val randomNumberGenerator = new Random(UUID.randomUUID.getMostSignificantBits) randomNumberGenerator.nextDouble + FUZZ_MULTIPLIER_INTERVAL_LOWER_BOUND } - private val INITIAL_REGISTRATION_RETRY_INTERVAL = (math.round(10 * - REGISTRATION_RETRY_FUZZ_MULTIPLIER)).seconds - private val PROLONGED_REGISTRATION_RETRY_INTERVAL = (math.round(60 - * REGISTRATION_RETRY_FUZZ_MULTIPLIER)).seconds + private val INITIAL_REGISTRATION_RETRY_INTERVAL_SECONDS = (math.round(10 * + REGISTRATION_RETRY_FUZZ_MULTIPLIER)) + private val PROLONGED_REGISTRATION_RETRY_INTERVAL_SECONDS = (math.round(60 + * REGISTRATION_RETRY_FUZZ_MULTIPLIER)) private val CLEANUP_ENABLED = conf.getBoolean("spark.worker.cleanup.enabled", false) // How often worker will clean up old app folders private val CLEANUP_INTERVAL_MILLIS = conf.getLong("spark.worker.cleanup.interval", 60 * 30) * 1000 // TTL for app folders/data; after TTL expires it will be cleaned up - private val APP_DATA_RETENTION_SECS = + private val APP_DATA_RETENTION_SECONDS = conf.getLong("spark.worker.cleanup.appDataTtl", 7 * 24 * 3600) private val testing: Boolean = sys.props.contains("spark.testing") - private var master: ActorSelection = null - private var masterAddress: Address = null + private var master: Option[RpcEndpointRef] = None private var activeMasterUrl: String = "" private[worker] var activeMasterWebUiUrl : String = "" - private val akkaUrl = AkkaUtils.address( - AkkaUtils.protocol(context.system), - actorSystemName, - host, - port, - actorName) - @volatile private var registered = false - @volatile private var connected = false + private val workerUri = rpcEnv.uriOf(systemName, rpcEnv.address, endpointName) + private var registered = false + private var connected = false private val workerId = generateWorkerId() private val sparkHome = if (testing) { @@ -136,7 +136,18 @@ private[worker] class Worker( private val metricsSystem = MetricsSystem.createMetricsSystem("worker", conf, securityMgr) private val workerSource = new WorkerSource(this) - private var registrationRetryTimer: Option[Cancellable] = None + private var registerMasterFutures: Array[JFuture[_]] = null + private var registrationRetryTimer: Option[JScheduledFuture[_]] = None + + // A thread pool for registering with masters. Because registering with a master is a blocking + // action, this thread pool must be able to create "masterRpcAddresses.size" threads at the same + // time so that we can register with all masters. + private val registerMasterThreadPool = new ThreadPoolExecutor( + 0, + masterRpcAddresses.size, // Make sure we can register with all masters at the same time + 60L, TimeUnit.SECONDS, + new SynchronousQueue[Runnable](), + ThreadUtils.namedThreadFactory("worker-register-master-threadpool")) var coresUsed = 0 var memoryUsed = 0 @@ -162,14 +173,13 @@ private[worker] class Worker( } } - override def preStart() { + override def onStart() { assert(!registered) logInfo("Starting Spark worker %s:%d with %d cores, %s RAM".format( host, port, cores, Utils.megabytesToString(memory))) logInfo(s"Running Spark version ${org.apache.spark.SPARK_VERSION}") logInfo("Spark home: " + sparkHome) createWorkDir() - context.system.eventStream.subscribe(self, classOf[RemotingLifecycleEvent]) shuffleService.startIfEnabled() webUi = new WorkerWebUI(this, workDir, webUiPort) webUi.bind() @@ -181,24 +191,32 @@ private[worker] class Worker( metricsSystem.getServletHandlers.foreach(webUi.attachHandler) } - private def changeMaster(url: String, uiUrl: String) { + private def changeMaster(masterRef: RpcEndpointRef, uiUrl: String) { // activeMasterUrl it's a valid Spark url since we receive it from master. - activeMasterUrl = url + activeMasterUrl = masterRef.address.toSparkURL activeMasterWebUiUrl = uiUrl - master = context.actorSelection( - Master.toAkkaUrl(activeMasterUrl, AkkaUtils.protocol(context.system))) - masterAddress = Master.toAkkaAddress(activeMasterUrl, AkkaUtils.protocol(context.system)) + master = Some(masterRef) connected = true // Cancel any outstanding re-registration attempts because we found a new master - registrationRetryTimer.foreach(_.cancel()) - registrationRetryTimer = None + cancelLastRegistrationRetry() } - private def tryRegisterAllMasters() { - for (masterAkkaUrl <- masterAkkaUrls) { - logInfo("Connecting to master " + masterAkkaUrl + "...") - val actor = context.actorSelection(masterAkkaUrl) - actor ! RegisterWorker(workerId, host, port, cores, memory, webUi.boundPort, publicAddress) + private def tryRegisterAllMasters(): Array[JFuture[_]] = { + masterRpcAddresses.map { masterAddress => + registerMasterThreadPool.submit(new Runnable { + override def run(): Unit = { + try { + logInfo("Connecting to master " + masterAddress + "...") + val masterEndpoint = + rpcEnv.setupEndpointRef(Master.SYSTEM_NAME, masterAddress, Master.ENDPOINT_NAME) + masterEndpoint.send(RegisterWorker( + workerId, host, port, self, cores, memory, webUi.boundPort, publicAddress)) + } catch { + case ie: InterruptedException => // Cancelled + case NonFatal(e) => logWarning(s"Failed to connect to master $masterAddress", e) + } + } + }) } } @@ -211,8 +229,7 @@ private[worker] class Worker( Utils.tryOrExit { connectionAttemptCount += 1 if (registered) { - registrationRetryTimer.foreach(_.cancel()) - registrationRetryTimer = None + cancelLastRegistrationRetry() } else if (connectionAttemptCount <= TOTAL_REGISTRATION_RETRIES) { logInfo(s"Retrying connection to master (attempt # $connectionAttemptCount)") /** @@ -235,21 +252,48 @@ private[worker] class Worker( * still not safe if the old master recovers within this interval, but this is a much * less likely scenario. */ - if (master != null) { - master ! RegisterWorker( - workerId, host, port, cores, memory, webUi.boundPort, publicAddress) - } else { - // We are retrying the initial registration - tryRegisterAllMasters() + master match { + case Some(masterRef) => + // registered == false && master != None means we lost the connection to master, so + // masterRef cannot be used and we need to recreate it again. Note: we must not set + // master to None due to the above comments. + if (registerMasterFutures != null) { + registerMasterFutures.foreach(_.cancel(true)) + } + val masterAddress = masterRef.address + registerMasterFutures = Array(registerMasterThreadPool.submit(new Runnable { + override def run(): Unit = { + try { + logInfo("Connecting to master " + masterAddress + "...") + val masterEndpoint = + rpcEnv.setupEndpointRef(Master.SYSTEM_NAME, masterAddress, Master.ENDPOINT_NAME) + masterEndpoint.send(RegisterWorker( + workerId, host, port, self, cores, memory, webUi.boundPort, publicAddress)) + } catch { + case ie: InterruptedException => // Cancelled + case NonFatal(e) => logWarning(s"Failed to connect to master $masterAddress", e) + } + } + })) + case None => + if (registerMasterFutures != null) { + registerMasterFutures.foreach(_.cancel(true)) + } + // We are retrying the initial registration + registerMasterFutures = tryRegisterAllMasters() } // We have exceeded the initial registration retry threshold // All retries from now on should use a higher interval if (connectionAttemptCount == INITIAL_REGISTRATION_RETRIES) { - registrationRetryTimer.foreach(_.cancel()) - registrationRetryTimer = Some { - context.system.scheduler.schedule(PROLONGED_REGISTRATION_RETRY_INTERVAL, - PROLONGED_REGISTRATION_RETRY_INTERVAL, self, ReregisterWithMaster) - } + registrationRetryTimer.foreach(_.cancel(true)) + registrationRetryTimer = Some( + forwordMessageScheduler.scheduleAtFixedRate(new Runnable { + override def run(): Unit = Utils.tryLogNonFatalError { + self.send(ReregisterWithMaster) + } + }, PROLONGED_REGISTRATION_RETRY_INTERVAL_SECONDS, + PROLONGED_REGISTRATION_RETRY_INTERVAL_SECONDS, + TimeUnit.SECONDS)) } } else { logError("All masters are unresponsive! Giving up.") @@ -258,41 +302,67 @@ private[worker] class Worker( } } + /** + * Cancel last registeration retry, or do nothing if no retry + */ + private def cancelLastRegistrationRetry(): Unit = { + if (registerMasterFutures != null) { + registerMasterFutures.foreach(_.cancel(true)) + registerMasterFutures = null + } + registrationRetryTimer.foreach(_.cancel(true)) + registrationRetryTimer = None + } + private def registerWithMaster() { - // DisassociatedEvent may be triggered multiple times, so don't attempt registration + // onDisconnected may be triggered multiple times, so don't attempt registration // if there are outstanding registration attempts scheduled. registrationRetryTimer match { case None => registered = false - tryRegisterAllMasters() + registerMasterFutures = tryRegisterAllMasters() connectionAttemptCount = 0 - registrationRetryTimer = Some { - context.system.scheduler.schedule(INITIAL_REGISTRATION_RETRY_INTERVAL, - INITIAL_REGISTRATION_RETRY_INTERVAL, self, ReregisterWithMaster) - } + registrationRetryTimer = Some(forwordMessageScheduler.scheduleAtFixedRate( + new Runnable { + override def run(): Unit = Utils.tryLogNonFatalError { + self.send(ReregisterWithMaster) + } + }, + INITIAL_REGISTRATION_RETRY_INTERVAL_SECONDS, + INITIAL_REGISTRATION_RETRY_INTERVAL_SECONDS, + TimeUnit.SECONDS)) case Some(_) => logInfo("Not spawning another attempt to register with the master, since there is an" + " attempt scheduled already.") } } - override def receiveWithLogging: PartialFunction[Any, Unit] = { - case RegisteredWorker(masterUrl, masterWebUiUrl) => - logInfo("Successfully registered with master " + masterUrl) + override def receive: PartialFunction[Any, Unit] = { + case RegisteredWorker(masterRef, masterWebUiUrl) => + logInfo("Successfully registered with master " + masterRef.address.toSparkURL) registered = true - changeMaster(masterUrl, masterWebUiUrl) - context.system.scheduler.schedule(0 millis, HEARTBEAT_MILLIS millis, self, SendHeartbeat) + changeMaster(masterRef, masterWebUiUrl) + forwordMessageScheduler.scheduleAtFixedRate(new Runnable { + override def run(): Unit = Utils.tryLogNonFatalError { + self.send(SendHeartbeat) + } + }, 0, HEARTBEAT_MILLIS, TimeUnit.MILLISECONDS) if (CLEANUP_ENABLED) { logInfo(s"Worker cleanup enabled; old application directories will be deleted in: $workDir") - context.system.scheduler.schedule(CLEANUP_INTERVAL_MILLIS millis, - CLEANUP_INTERVAL_MILLIS millis, self, WorkDirCleanup) + forwordMessageScheduler.scheduleAtFixedRate(new Runnable { + override def run(): Unit = Utils.tryLogNonFatalError { + self.send(WorkDirCleanup) + } + }, CLEANUP_INTERVAL_MILLIS, CLEANUP_INTERVAL_MILLIS, TimeUnit.MILLISECONDS) } case SendHeartbeat => - if (connected) { master ! Heartbeat(workerId) } + if (connected) { sendToMaster(Heartbeat(workerId, self)) } case WorkDirCleanup => // Spin up a separate thread (in a future) to do the dir cleanup; don't tie up worker actor + // Copy ids so that it can be used in the cleanup thread. + val appIds = executors.values.map(_.appId).toSet val cleanupFuture = concurrent.future { val appDirs = workDir.listFiles() if (appDirs == null) { @@ -302,27 +372,27 @@ private[worker] class Worker( // the directory is used by an application - check that the application is not running // when cleaning up val appIdFromDir = dir.getName - val isAppStillRunning = executors.values.map(_.appId).contains(appIdFromDir) + val isAppStillRunning = appIds.contains(appIdFromDir) dir.isDirectory && !isAppStillRunning && - !Utils.doesDirectoryContainAnyNewFiles(dir, APP_DATA_RETENTION_SECS) + !Utils.doesDirectoryContainAnyNewFiles(dir, APP_DATA_RETENTION_SECONDS) }.foreach { dir => logInfo(s"Removing directory: ${dir.getPath}") Utils.deleteRecursively(dir) } - } + }(cleanupThreadExecutor) - cleanupFuture onFailure { + cleanupFuture.onFailure { case e: Throwable => logError("App dir cleanup failed: " + e.getMessage, e) - } + }(cleanupThreadExecutor) - case MasterChanged(masterUrl, masterWebUiUrl) => - logInfo("Master has changed, new master is at " + masterUrl) - changeMaster(masterUrl, masterWebUiUrl) + case MasterChanged(masterRef, masterWebUiUrl) => + logInfo("Master has changed, new master is at " + masterRef.address.toSparkURL) + changeMaster(masterRef, masterWebUiUrl) val execs = executors.values. map(e => new ExecutorDescription(e.appId, e.execId, e.cores, e.state)) - sender ! WorkerSchedulerStateResponse(workerId, execs.toList, drivers.keys.toSeq) + masterRef.send(WorkerSchedulerStateResponse(workerId, execs.toList, drivers.keys.toSeq)) case RegisterWorkerFailed(message) => if (!registered) { @@ -369,14 +439,14 @@ private[worker] class Worker( publicAddress, sparkHome, executorDir, - akkaUrl, + workerUri, conf, appLocalDirs, ExecutorState.LOADING) executors(appId + "/" + execId) = manager manager.start() coresUsed += cores_ memoryUsed += memory_ - master ! ExecutorStateChanged(appId, execId, manager.state, None, None) + sendToMaster(ExecutorStateChanged(appId, execId, manager.state, None, None)) } catch { case e: Exception => { logError(s"Failed to launch executor $appId/$execId for ${appDesc.name}.", e) @@ -384,14 +454,14 @@ private[worker] class Worker( executors(appId + "/" + execId).kill() executors -= appId + "/" + execId } - master ! ExecutorStateChanged(appId, execId, ExecutorState.FAILED, - Some(e.toString), None) + sendToMaster(ExecutorStateChanged(appId, execId, ExecutorState.FAILED, + Some(e.toString), None)) } } } - case ExecutorStateChanged(appId, execId, state, message, exitStatus) => - master ! ExecutorStateChanged(appId, execId, state, message, exitStatus) + case executorStateChanged @ ExecutorStateChanged(appId, execId, state, message, exitStatus) => + sendToMaster(executorStateChanged) val fullId = appId + "/" + execId if (ExecutorState.isFinished(state)) { executors.get(fullId) match { @@ -434,7 +504,7 @@ private[worker] class Worker( sparkHome, driverDesc.copy(command = Worker.maybeUpdateSSLSettings(driverDesc.command, conf)), self, - akkaUrl, + workerUri, securityMgr) drivers(driverId) = driver driver.start() @@ -453,7 +523,7 @@ private[worker] class Worker( } } - case DriverStateChanged(driverId, state, exception) => { + case driverStageChanged @ DriverStateChanged(driverId, state, exception) => { state match { case DriverState.ERROR => logWarning(s"Driver $driverId failed with unrecoverable exception: ${exception.get}") @@ -466,23 +536,13 @@ private[worker] class Worker( case _ => logDebug(s"Driver $driverId changed state to $state") } - master ! DriverStateChanged(driverId, state, exception) + sendToMaster(driverStageChanged) val driver = drivers.remove(driverId).get finishedDrivers(driverId) = driver memoryUsed -= driver.driverDesc.mem coresUsed -= driver.driverDesc.cores } - case x: DisassociatedEvent if x.remoteAddress == masterAddress => - logInfo(s"$x Disassociated !") - masterDisconnected() - - case RequestWorkerState => - sender ! WorkerStateResponse(host, port, workerId, executors.values.toList, - finishedExecutors.values.toList, drivers.values.toList, - finishedDrivers.values.toList, activeMasterUrl, cores, memory, - coresUsed, memoryUsed, activeMasterWebUiUrl) - case ReregisterWithMaster => reregisterWithMaster() @@ -491,6 +551,21 @@ private[worker] class Worker( maybeCleanupApplication(id) } + override def receiveAndReply(context: RpcCallContext): PartialFunction[Any, Unit] = { + case RequestWorkerState => + context.reply(WorkerStateResponse(host, port, workerId, executors.values.toList, + finishedExecutors.values.toList, drivers.values.toList, + finishedDrivers.values.toList, activeMasterUrl, cores, memory, + coresUsed, memoryUsed, activeMasterWebUiUrl)) + } + + override def onDisconnected(remoteAddress: RpcAddress): Unit = { + if (master.exists(_.address == remoteAddress)) { + logInfo(s"$remoteAddress Disassociated !") + masterDisconnected() + } + } + private def masterDisconnected() { logError("Connection to master failed! Waiting for master to reconnect...") connected = false @@ -510,13 +585,29 @@ private[worker] class Worker( } } + /** + * Send a message to the current master. If we have not yet registered successfully with any + * master, the message will be dropped. + */ + private def sendToMaster(message: Any): Unit = { + master match { + case Some(masterRef) => masterRef.send(message) + case None => + logWarning( + s"Dropping $message because the connection to master has not yet been established") + } + } + private def generateWorkerId(): String = { "worker-%s-%s-%d".format(createDateFormat.format(new Date), host, port) } - override def postStop() { + override def onStop() { + cleanupThreadExecutor.shutdownNow() metricsSystem.report() - registrationRetryTimer.foreach(_.cancel()) + cancelLastRegistrationRetry() + forwordMessageScheduler.shutdownNow() + registerMasterThreadPool.shutdownNow() executors.values.foreach(_.kill()) drivers.values.foreach(_.kill()) shuffleService.stop() @@ -530,12 +621,12 @@ private[deploy] object Worker extends Logging { SignalLogger.register(log) val conf = new SparkConf val args = new WorkerArguments(argStrings, conf) - val (actorSystem, _) = startSystemAndActor(args.host, args.port, args.webUiPort, args.cores, + val rpcEnv = startRpcEnvAndEndpoint(args.host, args.port, args.webUiPort, args.cores, args.memory, args.masters, args.workDir) - actorSystem.awaitTermination() + rpcEnv.awaitTermination() } - def startSystemAndActor( + def startRpcEnvAndEndpoint( host: String, port: Int, webUiPort: Int, @@ -544,18 +635,17 @@ private[deploy] object Worker extends Logging { masterUrls: Array[String], workDir: String, workerNumber: Option[Int] = None, - conf: SparkConf = new SparkConf): (ActorSystem, Int) = { + conf: SparkConf = new SparkConf): RpcEnv = { // The LocalSparkCluster runs multiple local sparkWorkerX actor systems val systemName = "sparkWorker" + workerNumber.map(_.toString).getOrElse("") val actorName = "Worker" val securityMgr = new SecurityManager(conf) - val (actorSystem, boundPort) = AkkaUtils.createActorSystem(systemName, host, port, - conf = conf, securityManager = securityMgr) - val masterAkkaUrls = masterUrls.map(Master.toAkkaUrl(_, AkkaUtils.protocol(actorSystem))) - actorSystem.actorOf(Props(classOf[Worker], host, boundPort, webUiPort, cores, memory, - masterAkkaUrls, systemName, actorName, workDir, conf, securityMgr), name = actorName) - (actorSystem, boundPort) + val rpcEnv = RpcEnv.create(systemName, host, port, conf, securityMgr) + val masterAddresses = masterUrls.map(RpcAddress.fromSparkURL(_)) + rpcEnv.setupEndpoint(actorName, new Worker(rpcEnv, webUiPort, cores, memory, masterAddresses, + systemName, actorName, workDir, conf, securityMgr)) + rpcEnv } def isUseLocalNodeSSLConfig(cmd: Command): Boolean = { diff --git a/core/src/main/scala/org/apache/spark/deploy/worker/WorkerArguments.scala b/core/src/main/scala/org/apache/spark/deploy/worker/WorkerArguments.scala index 9678631da9f6..1d2ecab51761 100644 --- a/core/src/main/scala/org/apache/spark/deploy/worker/WorkerArguments.scala +++ b/core/src/main/scala/org/apache/spark/deploy/worker/WorkerArguments.scala @@ -164,7 +164,7 @@ private[worker] class WorkerArguments(args: Array[String], conf: SparkConf) { } } // Leave out 1 GB for the operating system, but don't return a negative memory size - math.max(totalMb - 1024, 512) + math.max(totalMb - 1024, Utils.DEFAULT_DRIVER_MEM_MB) } def checkWorkerMemory(): Unit = { diff --git a/core/src/main/scala/org/apache/spark/deploy/worker/WorkerWatcher.scala b/core/src/main/scala/org/apache/spark/deploy/worker/WorkerWatcher.scala index 83fb991891a4..fae5640b9a21 100644 --- a/core/src/main/scala/org/apache/spark/deploy/worker/WorkerWatcher.scala +++ b/core/src/main/scala/org/apache/spark/deploy/worker/WorkerWatcher.scala @@ -18,7 +18,6 @@ package org.apache.spark.deploy.worker import org.apache.spark.Logging -import org.apache.spark.deploy.DeployMessages.SendHeartbeat import org.apache.spark.rpc._ /** diff --git a/core/src/main/scala/org/apache/spark/deploy/worker/ui/LogPage.scala b/core/src/main/scala/org/apache/spark/deploy/worker/ui/LogPage.scala index 88170d4df305..5a1d06eb87db 100644 --- a/core/src/main/scala/org/apache/spark/deploy/worker/ui/LogPage.scala +++ b/core/src/main/scala/org/apache/spark/deploy/worker/ui/LogPage.scala @@ -17,6 +17,8 @@ package org.apache.spark.deploy.worker.ui +import java.io.File +import java.net.URI import javax.servlet.http.HttpServletRequest import scala.xml.Node @@ -29,6 +31,7 @@ import org.apache.spark.util.logging.RollingFileAppender private[ui] class LogPage(parent: WorkerWebUI) extends WebUIPage("logPage") with Logging { private val worker = parent.worker private val workDir = parent.workDir + private val supportedLogTypes = Set("stderr", "stdout") def renderLog(request: HttpServletRequest): String = { val defaultBytes = 100 * 1024 @@ -129,6 +132,18 @@ private[ui] class LogPage(parent: WorkerWebUI) extends WebUIPage("logPage") with offsetOption: Option[Long], byteLength: Int ): (String, Long, Long, Long) = { + + if (!supportedLogTypes.contains(logType)) { + return ("Error: Log type must be one of " + supportedLogTypes.mkString(", "), 0, 0, 0) + } + + // Verify that the normalized path of the log directory is in the working directory + val normalizedUri = new URI(logDirectory).normalize() + val normalizedLogDir = new File(normalizedUri.getPath) + if (!Utils.isInDirectory(workDir, normalizedLogDir)) { + return ("Error: invalid log directory " + logDirectory, 0, 0, 0) + } + try { val files = RollingFileAppender.getSortedRolledOverFiles(logDirectory, logType) logDebug(s"Sorted log files of type $logType in $logDirectory:\n${files.mkString("\n")}") @@ -144,7 +159,7 @@ private[ui] class LogPage(parent: WorkerWebUI) extends WebUIPage("logPage") with offset } } - val endIndex = math.min(startIndex + totalLength, totalLength) + val endIndex = math.min(startIndex + byteLength, totalLength) logDebug(s"Getting log from $startIndex to $endIndex") val logText = Utils.offsetBytes(files, startIndex, endIndex) logDebug(s"Got log of length ${logText.length} bytes") diff --git a/core/src/main/scala/org/apache/spark/deploy/worker/ui/WorkerPage.scala b/core/src/main/scala/org/apache/spark/deploy/worker/ui/WorkerPage.scala index 9f9f27d71e1a..fd905feb97e9 100644 --- a/core/src/main/scala/org/apache/spark/deploy/worker/ui/WorkerPage.scala +++ b/core/src/main/scala/org/apache/spark/deploy/worker/ui/WorkerPage.scala @@ -17,10 +17,8 @@ package org.apache.spark.deploy.worker.ui -import scala.concurrent.Await import scala.xml.Node -import akka.pattern.ask import javax.servlet.http.HttpServletRequest import org.json4s.JValue @@ -32,18 +30,15 @@ import org.apache.spark.ui.{WebUIPage, UIUtils} import org.apache.spark.util.Utils private[ui] class WorkerPage(parent: WorkerWebUI) extends WebUIPage("") { - private val workerActor = parent.worker.self - private val timeout = parent.timeout + private val workerEndpoint = parent.worker.self override def renderJson(request: HttpServletRequest): JValue = { - val stateFuture = (workerActor ? RequestWorkerState)(timeout).mapTo[WorkerStateResponse] - val workerState = Await.result(stateFuture, timeout) + val workerState = workerEndpoint.askWithRetry[WorkerStateResponse](RequestWorkerState) JsonProtocol.writeWorkerState(workerState) } def render(request: HttpServletRequest): Seq[Node] = { - val stateFuture = (workerActor ? RequestWorkerState)(timeout).mapTo[WorkerStateResponse] - val workerState = Await.result(stateFuture, timeout) + val workerState = workerEndpoint.askWithRetry[WorkerStateResponse](RequestWorkerState) val executorHeaders = Seq("ExecutorID", "Cores", "State", "Memory", "Job Details", "Logs") val runningExecutors = workerState.executors diff --git a/core/src/main/scala/org/apache/spark/deploy/worker/ui/WorkerWebUI.scala b/core/src/main/scala/org/apache/spark/deploy/worker/ui/WorkerWebUI.scala index b3bb5f911dbd..334a5b10142a 100644 --- a/core/src/main/scala/org/apache/spark/deploy/worker/ui/WorkerWebUI.scala +++ b/core/src/main/scala/org/apache/spark/deploy/worker/ui/WorkerWebUI.scala @@ -38,7 +38,7 @@ class WorkerWebUI( extends WebUI(worker.securityMgr, requestedPort, worker.conf, name = "WorkerUI") with Logging { - private[ui] val timeout = RpcUtils.askTimeout(worker.conf) + private[ui] val timeout = RpcUtils.askRpcTimeout(worker.conf) initialize() diff --git a/core/src/main/scala/org/apache/spark/executor/CoarseGrainedExecutorBackend.scala b/core/src/main/scala/org/apache/spark/executor/CoarseGrainedExecutorBackend.scala index ed159dec4f99..34d4cfdca773 100644 --- a/core/src/main/scala/org/apache/spark/executor/CoarseGrainedExecutorBackend.scala +++ b/core/src/main/scala/org/apache/spark/executor/CoarseGrainedExecutorBackend.scala @@ -33,7 +33,7 @@ import org.apache.spark.deploy.worker.WorkerWatcher import org.apache.spark.scheduler.TaskDescription import org.apache.spark.scheduler.cluster.CoarseGrainedClusterMessages._ import org.apache.spark.serializer.SerializerInstance -import org.apache.spark.util.{SignalLogger, Utils} +import org.apache.spark.util.{ThreadUtils, SignalLogger, Utils} private[spark] class CoarseGrainedExecutorBackend( override val rpcEnv: RpcEnv, @@ -55,18 +55,22 @@ private[spark] class CoarseGrainedExecutorBackend( private[this] val ser: SerializerInstance = env.closureSerializer.newInstance() override def onStart() { - import scala.concurrent.ExecutionContext.Implicits.global logInfo("Connecting to driver: " + driverUrl) rpcEnv.asyncSetupEndpointRefByURI(driverUrl).flatMap { ref => + // This is a very fast action so we can use "ThreadUtils.sameThread" driver = Some(ref) ref.ask[RegisteredExecutor.type]( RegisterExecutor(executorId, self, hostPort, cores, extractLogUrls)) - } onComplete { + }(ThreadUtils.sameThread).onComplete { + // This is a very fast action so we can use "ThreadUtils.sameThread" case Success(msg) => Utils.tryLogNonFatalError { Option(self).foreach(_.send(msg)) // msg must be RegisteredExecutor } - case Failure(e) => logError(s"Cannot register with driver: $driverUrl", e) - } + case Failure(e) => { + logError(s"Cannot register with driver: $driverUrl", e) + System.exit(1) + } + }(ThreadUtils.sameThread) } def extractLogUrls: Map[String, String] = { diff --git a/core/src/main/scala/org/apache/spark/executor/TaskMetrics.scala b/core/src/main/scala/org/apache/spark/executor/TaskMetrics.scala index 06152f16ae61..a3b4561b07e7 100644 --- a/core/src/main/scala/org/apache/spark/executor/TaskMetrics.scala +++ b/core/src/main/scala/org/apache/spark/executor/TaskMetrics.scala @@ -43,22 +43,22 @@ class TaskMetrics extends Serializable { private var _hostname: String = _ def hostname: String = _hostname private[spark] def setHostname(value: String) = _hostname = value - + /** * Time taken on the executor to deserialize this task */ private var _executorDeserializeTime: Long = _ def executorDeserializeTime: Long = _executorDeserializeTime private[spark] def setExecutorDeserializeTime(value: Long) = _executorDeserializeTime = value - - + + /** * Time the executor spends actually running the task (including fetching shuffle data) */ private var _executorRunTime: Long = _ def executorRunTime: Long = _executorRunTime private[spark] def setExecutorRunTime(value: Long) = _executorRunTime = value - + /** * The number of bytes this task transmitted back to the driver as the TaskResult */ @@ -94,8 +94,8 @@ class TaskMetrics extends Serializable { */ private var _diskBytesSpilled: Long = _ def diskBytesSpilled: Long = _diskBytesSpilled - def incDiskBytesSpilled(value: Long): Unit = _diskBytesSpilled += value - def decDiskBytesSpilled(value: Long): Unit = _diskBytesSpilled -= value + private[spark] def incDiskBytesSpilled(value: Long): Unit = _diskBytesSpilled += value + private[spark] def decDiskBytesSpilled(value: Long): Unit = _diskBytesSpilled -= value /** * If this task reads from a HadoopRDD or from persisted data, metrics on how much data was read @@ -261,7 +261,7 @@ case class InputMetrics(readMethod: DataReadMethod.Value) { */ private var _recordsRead: Long = _ def recordsRead: Long = _recordsRead - def incRecordsRead(records: Long): Unit = _recordsRead += records + def incRecordsRead(records: Long): Unit = _recordsRead += records /** * Invoke the bytesReadCallback and mutate bytesRead. @@ -315,7 +315,7 @@ class ShuffleReadMetrics extends Serializable { def remoteBlocksFetched: Int = _remoteBlocksFetched private[spark] def incRemoteBlocksFetched(value: Int) = _remoteBlocksFetched += value private[spark] def decRemoteBlocksFetched(value: Int) = _remoteBlocksFetched -= value - + /** * Number of local blocks fetched in this shuffle by this task */ @@ -333,7 +333,7 @@ class ShuffleReadMetrics extends Serializable { def fetchWaitTime: Long = _fetchWaitTime private[spark] def incFetchWaitTime(value: Long) = _fetchWaitTime += value private[spark] def decFetchWaitTime(value: Long) = _fetchWaitTime -= value - + /** * Total number of remote bytes read from the shuffle by this task */ @@ -381,7 +381,7 @@ class ShuffleWriteMetrics extends Serializable { def shuffleBytesWritten: Long = _shuffleBytesWritten private[spark] def incShuffleBytesWritten(value: Long) = _shuffleBytesWritten += value private[spark] def decShuffleBytesWritten(value: Long) = _shuffleBytesWritten -= value - + /** * Time the task spent blocking on writes to disk or buffer cache, in nanoseconds */ @@ -389,7 +389,7 @@ class ShuffleWriteMetrics extends Serializable { def shuffleWriteTime: Long = _shuffleWriteTime private[spark] def incShuffleWriteTime(value: Long) = _shuffleWriteTime += value private[spark] def decShuffleWriteTime(value: Long) = _shuffleWriteTime -= value - + /** * Total number of records written to the shuffle by this task */ diff --git a/core/src/main/scala/org/apache/spark/io/CompressionCodec.scala b/core/src/main/scala/org/apache/spark/io/CompressionCodec.scala index 0756cdb2ed8e..0d8ac1f80a9f 100644 --- a/core/src/main/scala/org/apache/spark/io/CompressionCodec.scala +++ b/core/src/main/scala/org/apache/spark/io/CompressionCodec.scala @@ -17,7 +17,7 @@ package org.apache.spark.io -import java.io.{InputStream, OutputStream} +import java.io.{IOException, InputStream, OutputStream} import com.ning.compress.lzf.{LZFInputStream, LZFOutputStream} import net.jpountz.lz4.{LZ4BlockInputStream, LZ4BlockOutputStream} @@ -154,8 +154,53 @@ class SnappyCompressionCodec(conf: SparkConf) extends CompressionCodec { override def compressedOutputStream(s: OutputStream): OutputStream = { val blockSize = conf.getSizeAsBytes("spark.io.compression.snappy.blockSize", "32k").toInt - new SnappyOutputStream(s, blockSize) + new SnappyOutputStreamWrapper(new SnappyOutputStream(s, blockSize)) } override def compressedInputStream(s: InputStream): InputStream = new SnappyInputStream(s) } + +/** + * Wrapper over [[SnappyOutputStream]] which guards against write-after-close and double-close + * issues. See SPARK-7660 for more details. This wrapping can be removed if we upgrade to a version + * of snappy-java that contains the fix for https://github.com/xerial/snappy-java/issues/107. + */ +private final class SnappyOutputStreamWrapper(os: SnappyOutputStream) extends OutputStream { + + private[this] var closed: Boolean = false + + override def write(b: Int): Unit = { + if (closed) { + throw new IOException("Stream is closed") + } + os.write(b) + } + + override def write(b: Array[Byte]): Unit = { + if (closed) { + throw new IOException("Stream is closed") + } + os.write(b) + } + + override def write(b: Array[Byte], off: Int, len: Int): Unit = { + if (closed) { + throw new IOException("Stream is closed") + } + os.write(b, off, len) + } + + override def flush(): Unit = { + if (closed) { + throw new IOException("Stream is closed") + } + os.flush() + } + + override def close(): Unit = { + if (!closed) { + closed = true + os.close() + } + } +} diff --git a/core/src/main/scala/org/apache/spark/mapreduce/SparkHadoopMapReduceUtil.scala b/core/src/main/scala/org/apache/spark/mapreduce/SparkHadoopMapReduceUtil.scala index cfd20392d12f..390d148bc97f 100644 --- a/core/src/main/scala/org/apache/spark/mapreduce/SparkHadoopMapReduceUtil.scala +++ b/core/src/main/scala/org/apache/spark/mapreduce/SparkHadoopMapReduceUtil.scala @@ -60,7 +60,7 @@ trait SparkHadoopMapReduceUtil { val taskTypeClass = Class.forName("org.apache.hadoop.mapreduce.TaskType") .asInstanceOf[Class[Enum[_]]] val taskType = taskTypeClass.getMethod("valueOf", classOf[String]).invoke( - taskTypeClass, if(isMap) "MAP" else "REDUCE") + taskTypeClass, if (isMap) "MAP" else "REDUCE") val ctor = klass.getDeclaredConstructor(classOf[String], classOf[Int], taskTypeClass, classOf[Int], classOf[Int]) ctor.newInstance(jtIdentifier, new JInteger(jobId), taskType, new JInteger(taskId), diff --git a/core/src/main/scala/org/apache/spark/metrics/MetricsConfig.scala b/core/src/main/scala/org/apache/spark/metrics/MetricsConfig.scala index 8edf49378068..d7495551ad23 100644 --- a/core/src/main/scala/org/apache/spark/metrics/MetricsConfig.scala +++ b/core/src/main/scala/org/apache/spark/metrics/MetricsConfig.scala @@ -23,10 +23,10 @@ import java.util.Properties import scala.collection.mutable import scala.util.matching.Regex -import org.apache.spark.Logging import org.apache.spark.util.Utils +import org.apache.spark.{Logging, SparkConf} -private[spark] class MetricsConfig(val configFile: Option[String]) extends Logging { +private[spark] class MetricsConfig(conf: SparkConf) extends Logging { private val DEFAULT_PREFIX = "*" private val INSTANCE_REGEX = "^(\\*|[a-zA-Z]+)\\.(.+)".r @@ -46,23 +46,14 @@ private[spark] class MetricsConfig(val configFile: Option[String]) extends Loggi // Add default properties in case there's no properties file setDefaultProperties(properties) - // If spark.metrics.conf is not set, try to get file in class path - val isOpt: Option[InputStream] = configFile.map(new FileInputStream(_)).orElse { - try { - Option(Utils.getSparkClassLoader.getResourceAsStream(DEFAULT_METRICS_CONF_FILENAME)) - } catch { - case e: Exception => - logError("Error loading default configuration file", e) - None - } - } + loadPropertiesFromFile(conf.getOption("spark.metrics.conf")) - isOpt.foreach { is => - try { - properties.load(is) - } finally { - is.close() - } + // Also look for the properties in provided Spark configuration + val prefix = "spark.metrics.conf." + conf.getAll.foreach { + case (k, v) if k.startsWith(prefix) => + properties.setProperty(k.substring(prefix.length()), v) + case _ => } propertyCategories = subProperties(properties, INSTANCE_REGEX) @@ -97,5 +88,31 @@ private[spark] class MetricsConfig(val configFile: Option[String]) extends Loggi case None => propertyCategories.getOrElse(DEFAULT_PREFIX, new Properties) } } -} + /** + * Loads configuration from a config file. If no config file is provided, try to get file + * in class path. + */ + private[this] def loadPropertiesFromFile(path: Option[String]): Unit = { + var is: InputStream = null + try { + is = path match { + case Some(f) => new FileInputStream(f) + case None => Utils.getSparkClassLoader.getResourceAsStream(DEFAULT_METRICS_CONF_FILENAME) + } + + if (is != null) { + properties.load(is) + } + } catch { + case e: Exception => + val file = path.getOrElse(DEFAULT_METRICS_CONF_FILENAME) + logError(s"Error loading configuration file $file", e) + } finally { + if (is != null) { + is.close() + } + } + } + +} diff --git a/core/src/main/scala/org/apache/spark/metrics/MetricsSystem.scala b/core/src/main/scala/org/apache/spark/metrics/MetricsSystem.scala index 9150ad35712a..ed5131c79fdc 100644 --- a/core/src/main/scala/org/apache/spark/metrics/MetricsSystem.scala +++ b/core/src/main/scala/org/apache/spark/metrics/MetricsSystem.scala @@ -70,8 +70,7 @@ private[spark] class MetricsSystem private ( securityMgr: SecurityManager) extends Logging { - private[this] val confFile = conf.get("spark.metrics.conf", null) - private[this] val metricsConfig = new MetricsConfig(Option(confFile)) + private[this] val metricsConfig = new MetricsConfig(conf) private val sinks = new mutable.ArrayBuffer[Sink] private val sources = new mutable.ArrayBuffer[Source] diff --git a/core/src/main/scala/org/apache/spark/metrics/sink/Slf4jSink.scala b/core/src/main/scala/org/apache/spark/metrics/sink/Slf4jSink.scala index e8b3074e8f1a..11dfcfe2f04e 100644 --- a/core/src/main/scala/org/apache/spark/metrics/sink/Slf4jSink.scala +++ b/core/src/main/scala/org/apache/spark/metrics/sink/Slf4jSink.scala @@ -26,9 +26,9 @@ import org.apache.spark.SecurityManager import org.apache.spark.metrics.MetricsSystem private[spark] class Slf4jSink( - val property: Properties, + val property: Properties, val registry: MetricRegistry, - securityMgr: SecurityManager) + securityMgr: SecurityManager) extends Sink { val SLF4J_DEFAULT_PERIOD = 10 val SLF4J_DEFAULT_UNIT = "SECONDS" diff --git a/core/src/main/scala/org/apache/spark/metrics/sink/package.scala b/core/src/main/scala/org/apache/spark/metrics/sink/package.scala index 90e3aa70b99e..670e68366332 100644 --- a/core/src/main/scala/org/apache/spark/metrics/sink/package.scala +++ b/core/src/main/scala/org/apache/spark/metrics/sink/package.scala @@ -20,4 +20,4 @@ package org.apache.spark.metrics /** * Sinks used in Spark's metrics system. */ -package object sink +package object sink diff --git a/core/src/main/scala/org/apache/spark/network/nio/BlockMessage.scala b/core/src/main/scala/org/apache/spark/network/nio/BlockMessage.scala index b573f1a8a5fc..67a376102994 100644 --- a/core/src/main/scala/org/apache/spark/network/nio/BlockMessage.scala +++ b/core/src/main/scala/org/apache/spark/network/nio/BlockMessage.scala @@ -110,7 +110,7 @@ private[nio] class BlockMessage() { def getType: Int = typ def getId: BlockId = id def getData: ByteBuffer = data - def getLevel: StorageLevel = level + def getLevel: StorageLevel = level def toBufferMessage: BufferMessage = { val buffers = new ArrayBuffer[ByteBuffer]() @@ -155,7 +155,7 @@ private[nio] class BlockMessage() { override def toString: String = { "BlockMessage [type = " + typ + ", id = " + id + ", level = " + level + - ", data = " + (if (data != null) data.remaining.toString else "null") + "]" + ", data = " + (if (data != null) data.remaining.toString else "null") + "]" } } diff --git a/core/src/main/scala/org/apache/spark/network/nio/BlockMessageArray.scala b/core/src/main/scala/org/apache/spark/network/nio/BlockMessageArray.scala index 1ba25aa74aa0..7d0806f0c258 100644 --- a/core/src/main/scala/org/apache/spark/network/nio/BlockMessageArray.scala +++ b/core/src/main/scala/org/apache/spark/network/nio/BlockMessageArray.scala @@ -114,8 +114,8 @@ private[nio] object BlockMessageArray { val blockMessages = (0 until 10).map { i => if (i % 2 == 0) { - val buffer = ByteBuffer.allocate(100) - buffer.clear + val buffer = ByteBuffer.allocate(100) + buffer.clear() BlockMessage.fromPutBlock(PutBlock(TestBlockId(i.toString), buffer, StorageLevel.MEMORY_ONLY_SER)) } else { diff --git a/core/src/main/scala/org/apache/spark/network/nio/Connection.scala b/core/src/main/scala/org/apache/spark/network/nio/Connection.scala index 6b898bd4bfc1..1499da07bb83 100644 --- a/core/src/main/scala/org/apache/spark/network/nio/Connection.scala +++ b/core/src/main/scala/org/apache/spark/network/nio/Connection.scala @@ -326,15 +326,14 @@ class SendingConnection(val address: InetSocketAddress, selector_ : Selector, // MUST be called within the selector loop def connect() { - try{ + try { channel.register(selector, SelectionKey.OP_CONNECT) channel.connect(address) logInfo("Initiating connection to [" + address + "]") } catch { - case e: Exception => { + case e: Exception => logError("Error connecting to " + address, e) callOnExceptionCallbacks(e) - } } } diff --git a/core/src/main/scala/org/apache/spark/network/nio/ConnectionManager.scala b/core/src/main/scala/org/apache/spark/network/nio/ConnectionManager.scala index 497871ed6d5e..c0bca2c4bc99 100644 --- a/core/src/main/scala/org/apache/spark/network/nio/ConnectionManager.scala +++ b/core/src/main/scala/org/apache/spark/network/nio/ConnectionManager.scala @@ -635,12 +635,11 @@ private[nio] class ConnectionManager( val message = securityMsgResp.toBufferMessage if (message == null) throw new IOException("Error creating security message") sendSecurityMessage(waitingConn.getRemoteConnectionManagerId(), message) - } catch { - case e: Exception => { + } catch { + case e: Exception => logError("Error handling sasl client authentication", e) waitingConn.close() throw new IOException("Error evaluating sasl response: ", e) - } } } } diff --git a/core/src/main/scala/org/apache/spark/network/nio/SecurityMessage.scala b/core/src/main/scala/org/apache/spark/network/nio/SecurityMessage.scala index 747a2088a725..232c552f9865 100644 --- a/core/src/main/scala/org/apache/spark/network/nio/SecurityMessage.scala +++ b/core/src/main/scala/org/apache/spark/network/nio/SecurityMessage.scala @@ -75,7 +75,7 @@ private[nio] class SecurityMessage extends Logging { for (i <- 1 to idLength) { idBuilder += buffer.getChar() } - connectionId = idBuilder.toString() + connectionId = idBuilder.toString() val tokenLength = buffer.getInt() token = new Array[Byte](tokenLength) diff --git a/core/src/main/scala/org/apache/spark/package.scala b/core/src/main/scala/org/apache/spark/package.scala index 2ab41ba488ff..8ae76c5f72f2 100644 --- a/core/src/main/scala/org/apache/spark/package.scala +++ b/core/src/main/scala/org/apache/spark/package.scala @@ -43,5 +43,5 @@ package org.apache package object spark { // For package docs only - val SPARK_VERSION = "1.4.0-SNAPSHOT" + val SPARK_VERSION = "1.5.0-SNAPSHOT" } diff --git a/core/src/main/scala/org/apache/spark/partial/GroupedCountEvaluator.scala b/core/src/main/scala/org/apache/spark/partial/GroupedCountEvaluator.scala index 3ef3cc219dec..91b07ce3af1b 100644 --- a/core/src/main/scala/org/apache/spark/partial/GroupedCountEvaluator.scala +++ b/core/src/main/scala/org/apache/spark/partial/GroupedCountEvaluator.scala @@ -32,12 +32,12 @@ import org.apache.spark.util.collection.OpenHashMap * An ApproximateEvaluator for counts by key. Returns a map of key to confidence interval. */ private[spark] class GroupedCountEvaluator[T : ClassTag](totalOutputs: Int, confidence: Double) - extends ApproximateEvaluator[OpenHashMap[T,Long], Map[T, BoundedDouble]] { + extends ApproximateEvaluator[OpenHashMap[T, Long], Map[T, BoundedDouble]] { var outputsMerged = 0 - var sums = new OpenHashMap[T,Long]() // Sum of counts for each key + var sums = new OpenHashMap[T, Long]() // Sum of counts for each key - override def merge(outputId: Int, taskResult: OpenHashMap[T,Long]) { + override def merge(outputId: Int, taskResult: OpenHashMap[T, Long]) { outputsMerged += 1 taskResult.foreach { case (key, value) => sums.changeValue(key, value, _ + value) diff --git a/core/src/main/scala/org/apache/spark/rdd/AsyncRDDActions.scala b/core/src/main/scala/org/apache/spark/rdd/AsyncRDDActions.scala index ec185340c3a2..ca1eb1f4e4a9 100644 --- a/core/src/main/scala/org/apache/spark/rdd/AsyncRDDActions.scala +++ b/core/src/main/scala/org/apache/spark/rdd/AsyncRDDActions.scala @@ -19,8 +19,10 @@ package org.apache.spark.rdd import java.util.concurrent.atomic.AtomicLong +import org.apache.spark.util.ThreadUtils + import scala.collection.mutable.ArrayBuffer -import scala.concurrent.ExecutionContext.Implicits.global +import scala.concurrent.ExecutionContext import scala.reflect.ClassTag import org.apache.spark.{ComplexFutureAction, FutureAction, Logging} @@ -66,6 +68,8 @@ class AsyncRDDActions[T: ClassTag](self: RDD[T]) extends Serializable with Loggi val f = new ComplexFutureAction[Seq[T]] f.run { + // This is a blocking action so we should use "AsyncRDDActions.futureExecutionContext" which + // is a cached thread pool. val results = new ArrayBuffer[T](num) val totalParts = self.partitions.length var partsScanned = 0 @@ -81,9 +85,9 @@ class AsyncRDDActions[T: ClassTag](self: RDD[T]) extends Serializable with Loggi numPartsToTry = partsScanned * 4 } else { // the left side of max is >=1 whenever partsScanned >= 2 - numPartsToTry = Math.max(1, + numPartsToTry = Math.max(1, (1.5 * num * partsScanned / results.size).toInt - partsScanned) - numPartsToTry = Math.min(numPartsToTry, partsScanned * 4) + numPartsToTry = Math.min(numPartsToTry, partsScanned * 4) } } @@ -101,7 +105,7 @@ class AsyncRDDActions[T: ClassTag](self: RDD[T]) extends Serializable with Loggi partsScanned += numPartsToTry } results.toSeq - } + }(AsyncRDDActions.futureExecutionContext) f } @@ -123,3 +127,8 @@ class AsyncRDDActions[T: ClassTag](self: RDD[T]) extends Serializable with Loggi (index, data) => Unit, Unit) } } + +private object AsyncRDDActions { + val futureExecutionContext = ExecutionContext.fromExecutorService( + ThreadUtils.newDaemonCachedThreadPool("AsyncRDDActions-future", 128)) +} diff --git a/core/src/main/scala/org/apache/spark/rdd/CheckpointRDD.scala b/core/src/main/scala/org/apache/spark/rdd/CheckpointRDD.scala index 0d130dd4c7a6..e17bd47905d7 100644 --- a/core/src/main/scala/org/apache/spark/rdd/CheckpointRDD.scala +++ b/core/src/main/scala/org/apache/spark/rdd/CheckpointRDD.scala @@ -21,15 +21,14 @@ import java.io.IOException import scala.reflect.ClassTag -import org.apache.hadoop.conf.Configuration import org.apache.hadoop.fs.Path import org.apache.spark._ import org.apache.spark.broadcast.Broadcast import org.apache.spark.deploy.SparkHadoopUtil -import org.apache.spark.util.Utils +import org.apache.spark.util.{SerializableConfiguration, Utils} -private[spark] class CheckpointRDDPartition(val index: Int) extends Partition {} +private[spark] class CheckpointRDDPartition(val index: Int) extends Partition /** * This RDD represents a RDD checkpoint file (similar to HadoopRDD). @@ -38,9 +37,11 @@ private[spark] class CheckpointRDD[T: ClassTag](sc: SparkContext, val checkpointPath: String) extends RDD[T](sc, Nil) { - val broadcastedConf = sc.broadcast(new SerializableWritable(sc.hadoopConfiguration)) + private val broadcastedConf = sc.broadcast(new SerializableConfiguration(sc.hadoopConfiguration)) - @transient val fs = new Path(checkpointPath).getFileSystem(sc.hadoopConfiguration) + @transient private val fs = new Path(checkpointPath).getFileSystem(sc.hadoopConfiguration) + + override def getCheckpointFile: Option[String] = Some(checkpointPath) override def getPartitions: Array[Partition] = { val cpath = new Path(checkpointPath) @@ -49,7 +50,7 @@ class CheckpointRDD[T: ClassTag](sc: SparkContext, val checkpointPath: String) if (fs.exists(cpath)) { val dirContents = fs.listStatus(cpath).map(_.getPath) val partitionFiles = dirContents.filter(_.getName.startsWith("part-")).map(_.toString).sorted - val numPart = partitionFiles.length + val numPart = partitionFiles.length if (numPart > 0 && (! partitionFiles(0).endsWith(CheckpointRDD.splitIdToFile(0)) || ! partitionFiles(numPart-1).endsWith(CheckpointRDD.splitIdToFile(numPart-1)))) { throw new SparkException("Invalid checkpoint directory: " + checkpointPath) @@ -60,9 +61,6 @@ class CheckpointRDD[T: ClassTag](sc: SparkContext, val checkpointPath: String) Array.tabulate(numPartitions)(i => new CheckpointRDDPartition(i)) } - checkpointData = Some(new RDDCheckpointData[T](this)) - checkpointData.get.cpFile = Some(checkpointPath) - override def getPreferredLocations(split: Partition): Seq[String] = { val status = fs.getFileStatus(new Path(checkpointPath, CheckpointRDD.splitIdToFile(split.index))) @@ -75,9 +73,9 @@ class CheckpointRDD[T: ClassTag](sc: SparkContext, val checkpointPath: String) CheckpointRDD.readFromFile(file, broadcastedConf, context) } - override def checkpoint() { - // Do nothing. CheckpointRDD should not be checkpointed. - } + // CheckpointRDD should not be checkpointed again + override def checkpoint(): Unit = { } + override def doCheckpoint(): Unit = { } } private[spark] object CheckpointRDD extends Logging { @@ -87,7 +85,7 @@ private[spark] object CheckpointRDD extends Logging { def writeToFile[T: ClassTag]( path: String, - broadcastedConf: Broadcast[SerializableWritable[Configuration]], + broadcastedConf: Broadcast[SerializableConfiguration], blockSize: Int = -1 )(ctx: TaskContext, iterator: Iterator[T]) { val env = SparkEnv.get @@ -135,7 +133,7 @@ private[spark] object CheckpointRDD extends Logging { def readFromFile[T]( path: Path, - broadcastedConf: Broadcast[SerializableWritable[Configuration]], + broadcastedConf: Broadcast[SerializableConfiguration], context: TaskContext ): Iterator[T] = { val env = SparkEnv.get @@ -164,7 +162,7 @@ private[spark] object CheckpointRDD extends Logging { val path = new Path(hdfsPath, "temp") val conf = SparkHadoopUtil.get.newConfiguration(new SparkConf()) val fs = path.getFileSystem(conf) - val broadcastedConf = sc.broadcast(new SerializableWritable(conf)) + val broadcastedConf = sc.broadcast(new SerializableConfiguration(conf)) sc.runJob(rdd, CheckpointRDD.writeToFile[Int](path.toString, broadcastedConf, 1024) _) val cpRDD = new CheckpointRDD[Int](sc, path.toString) assert(cpRDD.partitions.length == rdd.partitions.length, "Number of partitions is not the same") diff --git a/core/src/main/scala/org/apache/spark/rdd/CoalescedRDD.scala b/core/src/main/scala/org/apache/spark/rdd/CoalescedRDD.scala index 0c1b02c07d09..663eebb8e419 100644 --- a/core/src/main/scala/org/apache/spark/rdd/CoalescedRDD.scala +++ b/core/src/main/scala/org/apache/spark/rdd/CoalescedRDD.scala @@ -310,11 +310,11 @@ private class PartitionCoalescer(maxPartitions: Int, prev: RDD[_], balanceSlack: def throwBalls() { if (noLocality) { // no preferredLocations in parent RDD, no randomization needed if (maxPartitions > groupArr.size) { // just return prev.partitions - for ((p,i) <- prev.partitions.zipWithIndex) { + for ((p, i) <- prev.partitions.zipWithIndex) { groupArr(i).arr += p } } else { // no locality available, then simply split partitions based on positions in array - for(i <- 0 until maxPartitions) { + for (i <- 0 until maxPartitions) { val rangeStart = ((i.toLong * prev.partitions.length) / maxPartitions).toInt val rangeEnd = (((i.toLong + 1) * prev.partitions.length) / maxPartitions).toInt (rangeStart until rangeEnd).foreach{ j => groupArr(i).arr += prev.partitions(j) } diff --git a/core/src/main/scala/org/apache/spark/rdd/HadoopRDD.scala b/core/src/main/scala/org/apache/spark/rdd/HadoopRDD.scala index 2cefe63d44b2..bee59a437f12 100644 --- a/core/src/main/scala/org/apache/spark/rdd/HadoopRDD.scala +++ b/core/src/main/scala/org/apache/spark/rdd/HadoopRDD.scala @@ -44,7 +44,7 @@ import org.apache.spark.broadcast.Broadcast import org.apache.spark.deploy.SparkHadoopUtil import org.apache.spark.executor.DataReadMethod import org.apache.spark.rdd.HadoopRDD.HadoopMapPartitionsWithSplitRDD -import org.apache.spark.util.{NextIterator, Utils} +import org.apache.spark.util.{SerializableConfiguration, NextIterator, Utils} import org.apache.spark.scheduler.{HostTaskLocation, HDFSCacheTaskLocation} import org.apache.spark.storage.StorageLevel @@ -100,7 +100,7 @@ private[spark] class HadoopPartition(rddId: Int, idx: Int, @transient s: InputSp @DeveloperApi class HadoopRDD[K, V]( @transient sc: SparkContext, - broadcastedConf: Broadcast[SerializableWritable[Configuration]], + broadcastedConf: Broadcast[SerializableConfiguration], initLocalJobConfFuncOpt: Option[JobConf => Unit], inputFormatClass: Class[_ <: InputFormat[K, V]], keyClass: Class[K], @@ -121,8 +121,8 @@ class HadoopRDD[K, V]( minPartitions: Int) = { this( sc, - sc.broadcast(new SerializableWritable(conf)) - .asInstanceOf[Broadcast[SerializableWritable[Configuration]]], + sc.broadcast(new SerializableConfiguration(conf)) + .asInstanceOf[Broadcast[SerializableConfiguration]], None /* initLocalJobConfFuncOpt */, inputFormatClass, keyClass, diff --git a/core/src/main/scala/org/apache/spark/rdd/NewHadoopRDD.scala b/core/src/main/scala/org/apache/spark/rdd/NewHadoopRDD.scala index 2ab967f4bb31..f827270ee6a4 100644 --- a/core/src/main/scala/org/apache/spark/rdd/NewHadoopRDD.scala +++ b/core/src/main/scala/org/apache/spark/rdd/NewHadoopRDD.scala @@ -33,7 +33,7 @@ import org.apache.spark._ import org.apache.spark.executor.DataReadMethod import org.apache.spark.mapreduce.SparkHadoopMapReduceUtil import org.apache.spark.rdd.NewHadoopRDD.NewHadoopMapPartitionsWithSplitRDD -import org.apache.spark.util.Utils +import org.apache.spark.util.{SerializableConfiguration, Utils} import org.apache.spark.deploy.SparkHadoopUtil import org.apache.spark.storage.StorageLevel @@ -74,7 +74,7 @@ class NewHadoopRDD[K, V]( with Logging { // A Hadoop Configuration can be about 10 KB, which is pretty big, so broadcast it - private val confBroadcast = sc.broadcast(new SerializableWritable(conf)) + private val confBroadcast = sc.broadcast(new SerializableConfiguration(conf)) // private val serializableConf = new SerializableWritable(conf) private val jobTrackerId: String = { @@ -196,7 +196,7 @@ class NewHadoopRDD[K, V]( override def getPreferredLocations(hsplit: Partition): Seq[String] = { val split = hsplit.asInstanceOf[NewHadoopPartition].serializableHadoopSplit.value val locs = HadoopRDD.SPLIT_INFO_REFLECTIONS match { - case Some(c) => + case Some(c) => try { val infos = c.newGetLocationInfo.invoke(split).asInstanceOf[Array[AnyRef]] Some(HadoopRDD.convertSplitLocationInfo(infos)) diff --git a/core/src/main/scala/org/apache/spark/rdd/PairRDDFunctions.scala b/core/src/main/scala/org/apache/spark/rdd/PairRDDFunctions.scala index a6d5d2c94e17..91a6a2d03985 100644 --- a/core/src/main/scala/org/apache/spark/rdd/PairRDDFunctions.scala +++ b/core/src/main/scala/org/apache/spark/rdd/PairRDDFunctions.scala @@ -44,7 +44,7 @@ import org.apache.spark.executor.{DataWriteMethod, OutputMetrics} import org.apache.spark.mapreduce.SparkHadoopMapReduceUtil import org.apache.spark.partial.{BoundedDouble, PartialResult} import org.apache.spark.serializer.Serializer -import org.apache.spark.util.Utils +import org.apache.spark.util.{SerializableConfiguration, Utils} import org.apache.spark.util.collection.CompactBuffer import org.apache.spark.util.random.StratifiedSamplingUtils @@ -296,6 +296,7 @@ class PairRDDFunctions[K, V](self: RDD[(K, V)]) * before sending results to a reducer, similarly to a "combiner" in MapReduce. */ def reduceByKeyLocally(func: (V, V) => V): Map[K, V] = self.withScope { + val cleanedF = self.sparkContext.clean(func) if (keyClass.isArray) { throw new SparkException("reduceByKeyLocally() does not support array keys") @@ -305,7 +306,7 @@ class PairRDDFunctions[K, V](self: RDD[(K, V)]) val map = new JHashMap[K, V] iter.foreach { pair => val old = map.get(pair._1) - map.put(pair._1, if (old == null) pair._2 else func(old, pair._2)) + map.put(pair._1, if (old == null) pair._2 else cleanedF(old, pair._2)) } Iterator(map) } : Iterator[JHashMap[K, V]] @@ -313,7 +314,7 @@ class PairRDDFunctions[K, V](self: RDD[(K, V)]) val mergeMaps = (m1: JHashMap[K, V], m2: JHashMap[K, V]) => { m2.foreach { pair => val old = m1.get(pair._1) - m1.put(pair._1, if (old == null) pair._2 else func(old, pair._2)) + m1.put(pair._1, if (old == null) pair._2 else cleanedF(old, pair._2)) } m1 } : JHashMap[K, V] @@ -327,7 +328,7 @@ class PairRDDFunctions[K, V](self: RDD[(K, V)]) reduceByKeyLocally(func) } - /** + /** * Count the number of elements for each key, collecting the results to a local Map. * * Note that this method should only be used if the resulting map is expected to be small, as @@ -466,7 +467,7 @@ class PairRDDFunctions[K, V](self: RDD[(K, V)]) val mergeValue = (buf: CompactBuffer[V], v: V) => buf += v val mergeCombiners = (c1: CompactBuffer[V], c2: CompactBuffer[V]) => c1 ++= c2 val bufs = combineByKey[CompactBuffer[V]]( - createCombiner, mergeValue, mergeCombiners, partitioner, mapSideCombine=false) + createCombiner, mergeValue, mergeCombiners, partitioner, mapSideCombine = false) bufs.asInstanceOf[RDD[(K, Iterable[V])]] } @@ -1001,7 +1002,7 @@ class PairRDDFunctions[K, V](self: RDD[(K, V)]) val formatter = new SimpleDateFormat("yyyyMMddHHmm") val jobtrackerID = formatter.format(new Date()) val stageId = self.id - val wrappedConf = new SerializableWritable(job.getConfiguration) + val wrappedConf = new SerializableConfiguration(job.getConfiguration) val outfmt = job.getOutputFormatClass val jobFormat = outfmt.newInstance @@ -1010,7 +1011,7 @@ class PairRDDFunctions[K, V](self: RDD[(K, V)]) jobFormat.checkOutputSpecs(job) } - val writeShard = (context: TaskContext, iter: Iterator[(K,V)]) => { + val writeShard = (context: TaskContext, iter: Iterator[(K, V)]) => { val config = wrappedConf.value /* "reduce task" */ val attemptId = newTaskAttemptID(jobtrackerID, stageId, isMap = false, context.partitionId, @@ -1026,7 +1027,7 @@ class PairRDDFunctions[K, V](self: RDD[(K, V)]) val (outputMetrics, bytesWrittenCallback) = initHadoopOutputMetrics(context) - val writer = format.getRecordWriter(hadoopContext).asInstanceOf[NewRecordWriter[K,V]] + val writer = format.getRecordWriter(hadoopContext).asInstanceOf[NewRecordWriter[K, V]] require(writer != null, "Unable to obtain RecordWriter") var recordsWritten = 0L Utils.tryWithSafeFinally { @@ -1064,7 +1065,7 @@ class PairRDDFunctions[K, V](self: RDD[(K, V)]) def saveAsHadoopDataset(conf: JobConf): Unit = self.withScope { // Rename this as hadoopConf internally to avoid shadowing (see SPARK-2038). val hadoopConf = conf - val wrappedConf = new SerializableWritable(hadoopConf) + val wrappedConf = new SerializableConfiguration(hadoopConf) val outputFormatInstance = hadoopConf.getOutputFormat val keyClass = hadoopConf.getOutputKeyClass val valueClass = hadoopConf.getOutputValueClass diff --git a/core/src/main/scala/org/apache/spark/rdd/PartitionerAwareUnionRDD.scala b/core/src/main/scala/org/apache/spark/rdd/PartitionerAwareUnionRDD.scala index 7598ff617b39..9e3880714a79 100644 --- a/core/src/main/scala/org/apache/spark/rdd/PartitionerAwareUnionRDD.scala +++ b/core/src/main/scala/org/apache/spark/rdd/PartitionerAwareUnionRDD.scala @@ -86,7 +86,7 @@ class PartitionerAwareUnionRDD[T: ClassTag]( } val location = if (locations.isEmpty) { None - } else { + } else { // Find the location that maximum number of parent partitions prefer Some(locations.groupBy(x => x).maxBy(_._2.length)._1) } diff --git a/core/src/main/scala/org/apache/spark/rdd/RDD.scala b/core/src/main/scala/org/apache/spark/rdd/RDD.scala index f7fa37e4cdcd..9f7ebae3e9af 100644 --- a/core/src/main/scala/org/apache/spark/rdd/RDD.scala +++ b/core/src/main/scala/org/apache/spark/rdd/RDD.scala @@ -194,7 +194,7 @@ abstract class RDD[T: ClassTag]( @transient private var partitions_ : Array[Partition] = null /** An Option holding our checkpoint RDD, if we are checkpointed */ - private def checkpointRDD: Option[RDD[T]] = checkpointData.flatMap(_.checkpointRDD) + private def checkpointRDD: Option[CheckpointRDD[T]] = checkpointData.flatMap(_.checkpointRDD) /** * Get the list of dependencies of this RDD, taking into account whether the @@ -434,11 +434,11 @@ abstract class RDD[T: ClassTag]( * @return A random sub-sample of the RDD without replacement. */ private[spark] def randomSampleWithRange(lb: Double, ub: Double, seed: Long): RDD[T] = { - this.mapPartitionsWithIndex { case (index, partition) => + this.mapPartitionsWithIndex( { (index, partition) => val sampler = new BernoulliCellSampler[T](lb, ub) sampler.setSeed(seed + index) sampler.sample(partition) - } + }, preservesPartitioning = true) } /** @@ -454,7 +454,7 @@ abstract class RDD[T: ClassTag]( withReplacement: Boolean, num: Int, seed: Long = Utils.random.nextLong): Array[T] = { - val numStDev = 10.0 + val numStDev = 10.0 if (num < 0) { throw new IllegalArgumentException("Negative number of elements requested") @@ -890,6 +890,10 @@ abstract class RDD[T: ClassTag]( * Return an iterator that contains all of the elements in this RDD. * * The iterator will consume as much memory as the largest partition in this RDD. + * + * Note: this results in multiple Spark jobs, and if the input RDD is the result + * of a wide transformation (e.g. join with different partitioners), to avoid + * recomputing the input RDD should be cached first. */ def toLocalIterator: Iterator[T] = withScope { def collectPartition(p: Int): Array[T] = { @@ -1015,9 +1019,16 @@ abstract class RDD[T: ClassTag]( /** * Aggregate the elements of each partition, and then the results for all the partitions, using a - * given associative function and a neutral "zero value". The function op(t1, t2) is allowed to - * modify t1 and return it as its result value to avoid object allocation; however, it should not - * modify t2. + * given associative and commutative function and a neutral "zero value". The function + * op(t1, t2) is allowed to modify t1 and return it as its result value to avoid object + * allocation; however, it should not modify t2. + * + * This behaves somewhat differently from fold operations implemented for non-distributed + * collections in functional languages like Scala. This fold operation may be applied to + * partitions individually, and then fold those results into the final result, rather than + * apply the fold to each element sequentially in some defined ordering. For functions + * that are not commutative, the result may differ from that of a fold applied to a + * non-distributed collection. */ def fold(zeroValue: T)(op: (T, T) => T): T = withScope { // Clone the zero value since we will also be serializing it as part of tasks @@ -1131,8 +1142,8 @@ abstract class RDD[T: ClassTag]( if (elementClassTag.runtimeClass.isArray) { throw new SparkException("countByValueApprox() does not support arrays") } - val countPartition: (TaskContext, Iterator[T]) => OpenHashMap[T,Long] = { (ctx, iter) => - val map = new OpenHashMap[T,Long] + val countPartition: (TaskContext, Iterator[T]) => OpenHashMap[T, Long] = { (ctx, iter) => + val map = new OpenHashMap[T, Long] iter.foreach { t => map.changeValue(t, 1L, _ + 1L) } @@ -1440,12 +1451,16 @@ abstract class RDD[T: ClassTag]( * executed on this RDD. It is strongly recommended that this RDD is persisted in * memory, otherwise saving it on a file will require recomputation. */ - def checkpoint() { + def checkpoint(): Unit = { if (context.checkpointDir.isEmpty) { throw new SparkException("Checkpoint directory has not been set in the SparkContext") } else if (checkpointData.isEmpty) { - checkpointData = Some(new RDDCheckpointData(this)) - checkpointData.get.markForCheckpoint() + // NOTE: we use a global lock here due to complexities downstream with ensuring + // children RDD partitions point to the correct parent partitions. In the future + // we should revisit this consideration. + RDDCheckpointData.synchronized { + checkpointData = Some(new RDDCheckpointData(this)) + } } } @@ -1486,7 +1501,7 @@ abstract class RDD[T: ClassTag]( private[spark] var checkpointData: Option[RDDCheckpointData[T]] = None /** Returns the first parent RDD */ - protected[spark] def firstParent[U: ClassTag] = { + protected[spark] def firstParent[U: ClassTag]: RDD[U] = { dependencies.head.rdd.asInstanceOf[RDD[U]] } @@ -1578,15 +1593,15 @@ abstract class RDD[T: ClassTag]( case 0 => Seq.empty case 1 => val d = rdd.dependencies.head - debugString(d.rdd, prefix, d.isInstanceOf[ShuffleDependency[_,_,_]], true) + debugString(d.rdd, prefix, d.isInstanceOf[ShuffleDependency[_, _, _]], true) case _ => val frontDeps = rdd.dependencies.take(len - 1) val frontDepStrings = frontDeps.flatMap( - d => debugString(d.rdd, prefix, d.isInstanceOf[ShuffleDependency[_,_,_]])) + d => debugString(d.rdd, prefix, d.isInstanceOf[ShuffleDependency[_, _, _]])) val lastDep = rdd.dependencies.last val lastDepStrings = - debugString(lastDep.rdd, prefix, lastDep.isInstanceOf[ShuffleDependency[_,_,_]], true) + debugString(lastDep.rdd, prefix, lastDep.isInstanceOf[ShuffleDependency[_, _, _]], true) (frontDepStrings ++ lastDepStrings) } diff --git a/core/src/main/scala/org/apache/spark/rdd/RDDCheckpointData.scala b/core/src/main/scala/org/apache/spark/rdd/RDDCheckpointData.scala index 1722c27e5500..4f954363bed8 100644 --- a/core/src/main/scala/org/apache/spark/rdd/RDDCheckpointData.scala +++ b/core/src/main/scala/org/apache/spark/rdd/RDDCheckpointData.scala @@ -22,15 +22,15 @@ import scala.reflect.ClassTag import org.apache.hadoop.fs.Path import org.apache.spark._ -import org.apache.spark.scheduler.{ResultTask, ShuffleMapTask} +import org.apache.spark.util.SerializableConfiguration /** * Enumeration to manage state transitions of an RDD through checkpointing - * [ Initialized --> marked for checkpointing --> checkpointing in progress --> checkpointed ] + * [ Initialized --> checkpointing in progress --> checkpointed ]. */ private[spark] object CheckpointState extends Enumeration { type CheckpointState = Value - val Initialized, MarkedForCheckpoint, CheckpointingInProgress, Checkpointed = Value + val Initialized, CheckpointingInProgress, Checkpointed = Value } /** @@ -45,37 +45,37 @@ private[spark] class RDDCheckpointData[T: ClassTag](@transient rdd: RDD[T]) import CheckpointState._ // The checkpoint state of the associated RDD. - var cpState = Initialized + private var cpState = Initialized // The file to which the associated RDD has been checkpointed to - @transient var cpFile: Option[String] = None + private var cpFile: Option[String] = None // The CheckpointRDD created from the checkpoint file, that is, the new parent the associated RDD. - var cpRDD: Option[RDD[T]] = None + // This is defined if and only if `cpState` is `Checkpointed`. + private var cpRDD: Option[CheckpointRDD[T]] = None - // Mark the RDD for checkpointing - def markForCheckpoint() { - RDDCheckpointData.synchronized { - if (cpState == Initialized) cpState = MarkedForCheckpoint - } - } + // TODO: are we sure we need to use a global lock in the following methods? // Is the RDD already checkpointed - def isCheckpointed: Boolean = { - RDDCheckpointData.synchronized { cpState == Checkpointed } + def isCheckpointed: Boolean = RDDCheckpointData.synchronized { + cpState == Checkpointed } // Get the file to which this RDD was checkpointed to as an Option - def getCheckpointFile: Option[String] = { - RDDCheckpointData.synchronized { cpFile } + def getCheckpointFile: Option[String] = RDDCheckpointData.synchronized { + cpFile } - // Do the checkpointing of the RDD. Called after the first job using that RDD is over. - def doCheckpoint() { - // If it is marked for checkpointing AND checkpointing is not already in progress, - // then set it to be in progress, else return + /** + * Materialize this RDD and write its content to a reliable DFS. + * This is called immediately after the first action invoked on this RDD has completed. + */ + def doCheckpoint(): Unit = { + + // Guard against multiple threads checkpointing the same RDD by + // atomically flipping the state of this RDDCheckpointData RDDCheckpointData.synchronized { - if (cpState == MarkedForCheckpoint) { + if (cpState == Initialized) { cpState = CheckpointingInProgress } else { return @@ -86,18 +86,20 @@ private[spark] class RDDCheckpointData[T: ClassTag](@transient rdd: RDD[T]) val path = RDDCheckpointData.rddCheckpointDataPath(rdd.context, rdd.id).get val fs = path.getFileSystem(rdd.context.hadoopConfiguration) if (!fs.mkdirs(path)) { - throw new SparkException("Failed to create checkpoint path " + path) + throw new SparkException(s"Failed to create checkpoint path $path") } // Save to file, and reload it as an RDD val broadcastedConf = rdd.context.broadcast( - new SerializableWritable(rdd.context.hadoopConfiguration)) + new SerializableConfiguration(rdd.context.hadoopConfiguration)) val newRDD = new CheckpointRDD[T](rdd.context, path.toString) if (rdd.conf.getBoolean("spark.cleaner.referenceTracking.cleanCheckpoints", false)) { rdd.context.cleaner.foreach { cleaner => cleaner.registerRDDCheckpointDataForCleanup(newRDD, rdd.id) } } + + // TODO: This is expensive because it computes the RDD again unnecessarily (SPARK-8582) rdd.context.runJob(rdd, CheckpointRDD.writeToFile[T](path.toString, broadcastedConf) _) if (newRDD.partitions.length != rdd.partitions.length) { throw new SparkException( @@ -112,34 +114,26 @@ private[spark] class RDDCheckpointData[T: ClassTag](@transient rdd: RDD[T]) rdd.markCheckpointed(newRDD) // Update the RDD's dependencies and partitions cpState = Checkpointed } - logInfo("Done checkpointing RDD " + rdd.id + " to " + path + ", new parent is RDD " + newRDD.id) - } - - // Get preferred location of a split after checkpointing - def getPreferredLocations(split: Partition): Seq[String] = { - RDDCheckpointData.synchronized { - cpRDD.get.preferredLocations(split) - } + logInfo(s"Done checkpointing RDD ${rdd.id} to $path, new parent is RDD ${newRDD.id}") } - def getPartitions: Array[Partition] = { - RDDCheckpointData.synchronized { - cpRDD.get.partitions - } + def getPartitions: Array[Partition] = RDDCheckpointData.synchronized { + cpRDD.get.partitions } - def checkpointRDD: Option[RDD[T]] = { - RDDCheckpointData.synchronized { - cpRDD - } + def checkpointRDD: Option[CheckpointRDD[T]] = RDDCheckpointData.synchronized { + cpRDD } } private[spark] object RDDCheckpointData { + + /** Return the path of the directory to which this RDD's checkpoint data is written. */ def rddCheckpointDataPath(sc: SparkContext, rddId: Int): Option[Path] = { - sc.checkpointDir.map { dir => new Path(dir, "rdd-" + rddId) } + sc.checkpointDir.map { dir => new Path(dir, s"rdd-$rddId") } } + /** Clean up the files associated with the checkpoint data for this RDD. */ def clearRDDCheckpointData(sc: SparkContext, rddId: Int): Unit = { rddCheckpointDataPath(sc, rddId).foreach { path => val fs = path.getFileSystem(sc.hadoopConfiguration) diff --git a/core/src/main/scala/org/apache/spark/rdd/RDDOperationScope.scala b/core/src/main/scala/org/apache/spark/rdd/RDDOperationScope.scala index 2725826f421f..44667281c106 100644 --- a/core/src/main/scala/org/apache/spark/rdd/RDDOperationScope.scala +++ b/core/src/main/scala/org/apache/spark/rdd/RDDOperationScope.scala @@ -24,7 +24,7 @@ import com.fasterxml.jackson.annotation.JsonInclude.Include import com.fasterxml.jackson.databind.ObjectMapper import com.fasterxml.jackson.module.scala.DefaultScalaModule -import org.apache.spark.SparkContext +import org.apache.spark.{Logging, SparkContext} /** * A general, named code block representing an operation that instantiates RDDs. @@ -43,9 +43,8 @@ import org.apache.spark.SparkContext @JsonPropertyOrder(Array("id", "name", "parent")) private[spark] class RDDOperationScope( val name: String, - val parent: Option[RDDOperationScope] = None) { - - val id: Int = RDDOperationScope.nextScopeId() + val parent: Option[RDDOperationScope] = None, + val id: String = RDDOperationScope.nextScopeId().toString) { def toJson: String = { RDDOperationScope.jsonMapper.writeValueAsString(this) @@ -75,7 +74,7 @@ private[spark] class RDDOperationScope( * A collection of utility methods to construct a hierarchical representation of RDD scopes. * An RDD scope tracks the series of operations that created a given RDD. */ -private[spark] object RDDOperationScope { +private[spark] object RDDOperationScope extends Logging { private val jsonMapper = new ObjectMapper().registerModule(DefaultScalaModule) private val scopeCounter = new AtomicInteger(0) @@ -88,14 +87,24 @@ private[spark] object RDDOperationScope { /** * Execute the given body such that all RDDs created in this body will have the same scope. - * The name of the scope will be the name of the method that immediately encloses this one. + * The name of the scope will be the first method name in the stack trace that is not the + * same as this method's. * * Note: Return statements are NOT allowed in body. */ private[spark] def withScope[T]( sc: SparkContext, allowNesting: Boolean = false)(body: => T): T = { - val callerMethodName = Thread.currentThread.getStackTrace()(3).getMethodName + val ourMethodName = "withScope" + val callerMethodName = Thread.currentThread.getStackTrace() + .dropWhile(_.getMethodName != ourMethodName) + .find(_.getMethodName != ourMethodName) + .map(_.getMethodName) + .getOrElse { + // Log a warning just in case, but this should almost certainly never happen + logWarning("No valid method name for this RDD operation scope!") + "N/A" + } withScope[T](sc, callerMethodName, allowNesting, ignoreParent = false)(body) } diff --git a/core/src/main/scala/org/apache/spark/rdd/SequenceFileRDDFunctions.scala b/core/src/main/scala/org/apache/spark/rdd/SequenceFileRDDFunctions.scala index 3dfcf67f0eb6..4b5f15dd06b8 100644 --- a/core/src/main/scala/org/apache/spark/rdd/SequenceFileRDDFunctions.scala +++ b/core/src/main/scala/org/apache/spark/rdd/SequenceFileRDDFunctions.scala @@ -104,13 +104,13 @@ class SequenceFileRDDFunctions[K <% Writable: ClassTag, V <% Writable : ClassTag if (!convertKey && !convertValue) { self.saveAsHadoopFile(path, keyWritableClass, valueWritableClass, format, jobConf, codec) } else if (!convertKey && convertValue) { - self.map(x => (x._1,anyToWritable(x._2))).saveAsHadoopFile( + self.map(x => (x._1, anyToWritable(x._2))).saveAsHadoopFile( path, keyWritableClass, valueWritableClass, format, jobConf, codec) } else if (convertKey && !convertValue) { - self.map(x => (anyToWritable(x._1),x._2)).saveAsHadoopFile( + self.map(x => (anyToWritable(x._1), x._2)).saveAsHadoopFile( path, keyWritableClass, valueWritableClass, format, jobConf, codec) } else if (convertKey && convertValue) { - self.map(x => (anyToWritable(x._1),anyToWritable(x._2))).saveAsHadoopFile( + self.map(x => (anyToWritable(x._1), anyToWritable(x._2))).saveAsHadoopFile( path, keyWritableClass, valueWritableClass, format, jobConf, codec) } } diff --git a/core/src/main/scala/org/apache/spark/rdd/SubtractedRDD.scala b/core/src/main/scala/org/apache/spark/rdd/SubtractedRDD.scala index 633aeba3bbae..f7cb1791d4ac 100644 --- a/core/src/main/scala/org/apache/spark/rdd/SubtractedRDD.scala +++ b/core/src/main/scala/org/apache/spark/rdd/SubtractedRDD.scala @@ -125,7 +125,7 @@ private[spark] class SubtractedRDD[K: ClassTag, V: ClassTag, W: ClassTag]( integrate(0, t => getSeq(t._1) += t._2) // the second dep is rdd2; remove all of its keys integrate(1, t => map.remove(t._1)) - map.iterator.map { t => t._2.iterator.map { (t._1, _) } }.flatten + map.iterator.map { t => t._2.iterator.map { (t._1, _) } }.flatten } override def clearDependencies() { diff --git a/core/src/main/scala/org/apache/spark/rdd/ZippedPartitionsRDD.scala b/core/src/main/scala/org/apache/spark/rdd/ZippedPartitionsRDD.scala index a96b6c3d2345..81f40ad33aa5 100644 --- a/core/src/main/scala/org/apache/spark/rdd/ZippedPartitionsRDD.scala +++ b/core/src/main/scala/org/apache/spark/rdd/ZippedPartitionsRDD.scala @@ -123,7 +123,7 @@ private[spark] class ZippedPartitionsRDD3 } private[spark] class ZippedPartitionsRDD4 - [A: ClassTag, B: ClassTag, C: ClassTag, D:ClassTag, V: ClassTag]( + [A: ClassTag, B: ClassTag, C: ClassTag, D: ClassTag, V: ClassTag]( sc: SparkContext, var f: (Iterator[A], Iterator[B], Iterator[C], Iterator[D]) => Iterator[V], var rdd1: RDD[A], diff --git a/core/src/main/scala/org/apache/spark/rpc/RpcEndpointRef.scala b/core/src/main/scala/org/apache/spark/rpc/RpcEndpointRef.scala index 69181edb9ad4..6ae47894598b 100644 --- a/core/src/main/scala/org/apache/spark/rpc/RpcEndpointRef.scala +++ b/core/src/main/scala/org/apache/spark/rpc/RpcEndpointRef.scala @@ -17,8 +17,7 @@ package org.apache.spark.rpc -import scala.concurrent.{Await, Future} -import scala.concurrent.duration.FiniteDuration +import scala.concurrent.Future import scala.reflect.ClassTag import org.apache.spark.util.RpcUtils @@ -32,7 +31,7 @@ private[spark] abstract class RpcEndpointRef(@transient conf: SparkConf) private[this] val maxRetries = RpcUtils.numRetries(conf) private[this] val retryWaitMs = RpcUtils.retryWaitMs(conf) - private[this] val defaultAskTimeout = RpcUtils.askTimeout(conf) + private[this] val defaultAskTimeout = RpcUtils.askRpcTimeout(conf) /** * return the address for the [[RpcEndpointRef]] @@ -52,7 +51,7 @@ private[spark] abstract class RpcEndpointRef(@transient conf: SparkConf) * * This method only sends the message once and never retries. */ - def ask[T: ClassTag](message: Any, timeout: FiniteDuration): Future[T] + def ask[T: ClassTag](message: Any, timeout: RpcTimeout): Future[T] /** * Send a message to the corresponding [[RpcEndpoint.receiveAndReply)]] and return a [[Future]] to @@ -91,7 +90,7 @@ private[spark] abstract class RpcEndpointRef(@transient conf: SparkConf) * @tparam T type of the reply message * @return the reply message from the corresponding [[RpcEndpoint]] */ - def askWithRetry[T: ClassTag](message: Any, timeout: FiniteDuration): T = { + def askWithRetry[T: ClassTag](message: Any, timeout: RpcTimeout): T = { // TODO: Consider removing multiple attempts var attempts = 0 var lastException: Exception = null @@ -99,7 +98,7 @@ private[spark] abstract class RpcEndpointRef(@transient conf: SparkConf) attempts += 1 try { val future = ask[T](message, timeout) - val result = Await.result(future, timeout) + val result = timeout.awaitResult(future) if (result == null) { throw new SparkException("Actor returned null") } @@ -110,10 +109,14 @@ private[spark] abstract class RpcEndpointRef(@transient conf: SparkConf) lastException = e logWarning(s"Error sending message [message = $message] in $attempts attempts", e) } - Thread.sleep(retryWaitMs) + + if (attempts < maxRetries) { + Thread.sleep(retryWaitMs) + } } throw new SparkException( s"Error sending message [message = $message]", lastException) } + } diff --git a/core/src/main/scala/org/apache/spark/rpc/RpcEnv.scala b/core/src/main/scala/org/apache/spark/rpc/RpcEnv.scala index 12b6b28d4d7e..1709bdf560b6 100644 --- a/core/src/main/scala/org/apache/spark/rpc/RpcEnv.scala +++ b/core/src/main/scala/org/apache/spark/rpc/RpcEnv.scala @@ -18,8 +18,10 @@ package org.apache.spark.rpc import java.net.URI +import java.util.concurrent.TimeoutException -import scala.concurrent.{Await, Future} +import scala.concurrent.{Awaitable, Await, Future} +import scala.concurrent.duration._ import scala.language.postfixOps import org.apache.spark.{SecurityManager, SparkConf} @@ -66,7 +68,7 @@ private[spark] object RpcEnv { */ private[spark] abstract class RpcEnv(conf: SparkConf) { - private[spark] val defaultLookupTimeout = RpcUtils.lookupTimeout(conf) + private[spark] val defaultLookupTimeout = RpcUtils.lookupRpcTimeout(conf) /** * Return RpcEndpointRef of the registered [[RpcEndpoint]]. Will be used to implement @@ -94,7 +96,7 @@ private[spark] abstract class RpcEnv(conf: SparkConf) { * Retrieve the [[RpcEndpointRef]] represented by `uri`. This is a blocking action. */ def setupEndpointRefByURI(uri: String): RpcEndpointRef = { - Await.result(asyncSetupEndpointRefByURI(uri), defaultLookupTimeout) + defaultLookupTimeout.awaitResult(asyncSetupEndpointRefByURI(uri)) } /** @@ -158,6 +160,8 @@ private[spark] case class RpcAddress(host: String, port: Int) { val hostPort: String = host + ":" + port override val toString: String = hostPort + + def toSparkURL: String = "spark://" + hostPort } @@ -182,3 +186,107 @@ private[spark] object RpcAddress { RpcAddress(host, port) } } + + +/** + * An exception thrown if RpcTimeout modifies a [[TimeoutException]]. + */ +private[rpc] class RpcTimeoutException(message: String, cause: TimeoutException) + extends TimeoutException(message) { initCause(cause) } + + +/** + * Associates a timeout with a description so that a when a TimeoutException occurs, additional + * context about the timeout can be amended to the exception message. + * @param duration timeout duration in seconds + * @param timeoutProp the configuration property that controls this timeout + */ +private[spark] class RpcTimeout(val duration: FiniteDuration, val timeoutProp: String) + extends Serializable { + + /** Amends the standard message of TimeoutException to include the description */ + private def createRpcTimeoutException(te: TimeoutException): RpcTimeoutException = { + new RpcTimeoutException(te.getMessage() + ". This timeout is controlled by " + timeoutProp, te) + } + + /** + * PartialFunction to match a TimeoutException and add the timeout description to the message + * + * @note This can be used in the recover callback of a Future to add to a TimeoutException + * Example: + * val timeout = new RpcTimeout(5 millis, "short timeout") + * Future(throw new TimeoutException).recover(timeout.addMessageIfTimeout) + */ + def addMessageIfTimeout[T]: PartialFunction[Throwable, T] = { + // The exception has already been converted to a RpcTimeoutException so just raise it + case rte: RpcTimeoutException => throw rte + // Any other TimeoutException get converted to a RpcTimeoutException with modified message + case te: TimeoutException => throw createRpcTimeoutException(te) + } + + /** + * Wait for the completed result and return it. If the result is not available within this + * timeout, throw a [[RpcTimeoutException]] to indicate which configuration controls the timeout. + * @param awaitable the `Awaitable` to be awaited + * @throws RpcTimeoutException if after waiting for the specified time `awaitable` + * is still not ready + */ + def awaitResult[T](awaitable: Awaitable[T]): T = { + try { + Await.result(awaitable, duration) + } catch addMessageIfTimeout + } +} + + +private[spark] object RpcTimeout { + + /** + * Lookup the timeout property in the configuration and create + * a RpcTimeout with the property key in the description. + * @param conf configuration properties containing the timeout + * @param timeoutProp property key for the timeout in seconds + * @throws NoSuchElementException if property is not set + */ + def apply(conf: SparkConf, timeoutProp: String): RpcTimeout = { + val timeout = { conf.getTimeAsSeconds(timeoutProp) seconds } + new RpcTimeout(timeout, timeoutProp) + } + + /** + * Lookup the timeout property in the configuration and create + * a RpcTimeout with the property key in the description. + * Uses the given default value if property is not set + * @param conf configuration properties containing the timeout + * @param timeoutProp property key for the timeout in seconds + * @param defaultValue default timeout value in seconds if property not found + */ + def apply(conf: SparkConf, timeoutProp: String, defaultValue: String): RpcTimeout = { + val timeout = { conf.getTimeAsSeconds(timeoutProp, defaultValue) seconds } + new RpcTimeout(timeout, timeoutProp) + } + + /** + * Lookup prioritized list of timeout properties in the configuration + * and create a RpcTimeout with the first set property key in the + * description. + * Uses the given default value if property is not set + * @param conf configuration properties containing the timeout + * @param timeoutPropList prioritized list of property keys for the timeout in seconds + * @param defaultValue default timeout value in seconds if no properties found + */ + def apply(conf: SparkConf, timeoutPropList: Seq[String], defaultValue: String): RpcTimeout = { + require(timeoutPropList.nonEmpty) + + // Find the first set property or use the default value with the first property + val itr = timeoutPropList.iterator + var foundProp: Option[(String, String)] = None + while (itr.hasNext && foundProp.isEmpty){ + val propKey = itr.next() + conf.getOption(propKey).foreach { prop => foundProp = Some(propKey, prop) } + } + val finalProp = foundProp.getOrElse(timeoutPropList.head, defaultValue) + val timeout = { Utils.timeStringAsSeconds(finalProp._2) seconds } + new RpcTimeout(timeout, finalProp._1) + } +} diff --git a/core/src/main/scala/org/apache/spark/rpc/akka/AkkaRpcEnv.scala b/core/src/main/scala/org/apache/spark/rpc/akka/AkkaRpcEnv.scala index ba0d468f111e..f2d87f68341a 100644 --- a/core/src/main/scala/org/apache/spark/rpc/akka/AkkaRpcEnv.scala +++ b/core/src/main/scala/org/apache/spark/rpc/akka/AkkaRpcEnv.scala @@ -20,7 +20,6 @@ package org.apache.spark.rpc.akka import java.util.concurrent.ConcurrentHashMap import scala.concurrent.Future -import scala.concurrent.duration._ import scala.language.postfixOps import scala.reflect.ClassTag import scala.util.control.NonFatal @@ -29,9 +28,11 @@ import akka.actor.{ActorSystem, ExtendedActorSystem, Actor, ActorRef, Props, Add import akka.event.Logging.Error import akka.pattern.{ask => akkaAsk} import akka.remote.{AssociationEvent, AssociatedEvent, DisassociatedEvent, AssociationErrorEvent} +import com.google.common.util.concurrent.MoreExecutors + import org.apache.spark.{SparkException, Logging, SparkConf} import org.apache.spark.rpc._ -import org.apache.spark.util.{ActorLogReceive, AkkaUtils} +import org.apache.spark.util.{ActorLogReceive, AkkaUtils, ThreadUtils} /** * A RpcEnv implementation based on Akka. @@ -178,10 +179,10 @@ private[spark] class AkkaRpcEnv private[akka] ( }) } catch { case NonFatal(e) => - if (needReply) { - // If the sender asks a reply, we should send the error back to the sender - _sender ! AkkaFailure(e) - } else { + _sender ! AkkaFailure(e) + if (!needReply) { + // If the sender does not require a reply, it may not handle the exception. So we rethrow + // "e" to make sure it will be processed. throw e } } @@ -212,8 +213,11 @@ private[spark] class AkkaRpcEnv private[akka] ( override def asyncSetupEndpointRefByURI(uri: String): Future[RpcEndpointRef] = { import actorSystem.dispatcher - actorSystem.actorSelection(uri).resolveOne(defaultLookupTimeout). - map(new AkkaRpcEndpointRef(defaultAddress, _, conf)) + actorSystem.actorSelection(uri).resolveOne(defaultLookupTimeout.duration). + map(new AkkaRpcEndpointRef(defaultAddress, _, conf)). + // this is just in case there is a timeout from creating the future in resolveOne, we want the + // exception to indicate the conf that determines the timeout + recover(defaultLookupTimeout.addMessageIfTimeout) } override def uriOf(systemName: String, address: RpcAddress, endpointName: String): String = { @@ -293,9 +297,9 @@ private[akka] class AkkaRpcEndpointRef( actorRef ! AkkaMessage(message, false) } - override def ask[T: ClassTag](message: Any, timeout: FiniteDuration): Future[T] = { - import scala.concurrent.ExecutionContext.Implicits.global - actorRef.ask(AkkaMessage(message, true))(timeout).flatMap { + override def ask[T: ClassTag](message: Any, timeout: RpcTimeout): Future[T] = { + actorRef.ask(AkkaMessage(message, true))(timeout.duration).flatMap { + // The function will run in the calling thread, so it should be short and never block. case msg @ AkkaMessage(message, reply) => if (reply) { logError(s"Receive $msg but the sender cannot reply") @@ -305,7 +309,8 @@ private[akka] class AkkaRpcEndpointRef( } case AkkaFailure(e) => Future.failed(e) - }.mapTo[T] + }(ThreadUtils.sameThread).mapTo[T]. + recover(timeout.addMessageIfTimeout)(ThreadUtils.sameThread) } override def toString: String = s"${getClass.getSimpleName}($actorRef)" diff --git a/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala b/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala index 5d812918a13d..6841fa835747 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala @@ -35,6 +35,7 @@ import org.apache.spark.broadcast.Broadcast import org.apache.spark.executor.TaskMetrics import org.apache.spark.partial.{ApproximateActionListener, ApproximateEvaluator, PartialResult} import org.apache.spark.rdd.RDD +import org.apache.spark.rpc.RpcTimeout import org.apache.spark.storage._ import org.apache.spark.unsafe.memory.TaskMemoryManager import org.apache.spark.util._ @@ -81,6 +82,8 @@ class DAGScheduler( def this(sc: SparkContext) = this(sc, sc.taskScheduler) + private[scheduler] val metricsSource: DAGSchedulerSource = new DAGSchedulerSource(this) + private[scheduler] val nextJobId = new AtomicInteger(0) private[scheduler] def numTotalJobs: Int = nextJobId.get() private val nextStageId = new AtomicInteger(0) @@ -137,6 +140,22 @@ class DAGScheduler( private[scheduler] val eventProcessLoop = new DAGSchedulerEventProcessLoop(this) taskScheduler.setDAGScheduler(this) + // Flag to control if reduce tasks are assigned preferred locations + private val shuffleLocalityEnabled = + sc.getConf.getBoolean("spark.shuffle.reduceLocality.enabled", true) + // Number of map, reduce tasks above which we do not assign preferred locations + // based on map output sizes. We limit the size of jobs for which assign preferred locations + // as computing the top locations by size becomes expensive. + private[this] val SHUFFLE_PREF_MAP_THRESHOLD = 1000 + // NOTE: This should be less than 2000 as we use HighlyCompressedMapStatus beyond that + private[this] val SHUFFLE_PREF_REDUCE_THRESHOLD = 1000 + + // Fraction of total map output that must be at a location for it to considered as a preferred + // location for a reduce task. + // Making this larger will focus on fewer locations where most data can be read locally, but + // may lead to more delay in scheduling if those locations are busy. + private[scheduler] val REDUCER_PREF_LOCS_FRACTION = 0.2 + // Called by TaskScheduler to report task's starting. def taskStarted(task: Task[_], taskInfo: TaskInfo) { eventProcessLoop.post(BeginEvent(task, taskInfo)) @@ -170,7 +189,7 @@ class DAGScheduler( blockManagerId: BlockManagerId): Boolean = { listenerBus.post(SparkListenerExecutorMetricsUpdate(execId, taskMetrics)) blockManagerMaster.driverEndpoint.askWithRetry[Boolean]( - BlockManagerHeartbeat(blockManagerId), 600 seconds) + BlockManagerHeartbeat(blockManagerId), new RpcTimeout(600 seconds, "BlockManagerHeartbeat")) } // Called by TaskScheduler when an executor fails. @@ -193,9 +212,15 @@ class DAGScheduler( def getCacheLocs(rdd: RDD[_]): Seq[Seq[TaskLocation]] = cacheLocs.synchronized { // Note: this doesn't use `getOrElse()` because this method is called O(num tasks) times if (!cacheLocs.contains(rdd.id)) { - val blockIds = rdd.partitions.indices.map(index => RDDBlockId(rdd.id, index)).toArray[BlockId] - val locs: Seq[Seq[TaskLocation]] = blockManagerMaster.getLocations(blockIds).map { bms => - bms.map(bm => TaskLocation(bm.host, bm.executorId)) + // Note: if the storage level is NONE, we don't need to get locations from block manager. + val locs: Seq[Seq[TaskLocation]] = if (rdd.getStorageLevel == StorageLevel.NONE) { + Seq.fill(rdd.partitions.size)(Nil) + } else { + val blockIds = + rdd.partitions.indices.map(index => RDDBlockId(rdd.id, index)).toArray[BlockId] + blockManagerMaster.getLocations(blockIds).map { bms => + bms.map(bm => TaskLocation(bm.host, bm.executorId)) + } } cacheLocs(rdd.id) = locs } @@ -208,19 +233,17 @@ class DAGScheduler( /** * Get or create a shuffle map stage for the given shuffle dependency's map side. - * The jobId value passed in will be used if the stage doesn't already exist with - * a lower jobId (jobId always increases across jobs.) */ private def getShuffleMapStage( shuffleDep: ShuffleDependency[_, _, _], - jobId: Int): ShuffleMapStage = { + firstJobId: Int): ShuffleMapStage = { shuffleToMapStage.get(shuffleDep.shuffleId) match { case Some(stage) => stage case None => // We are going to register ancestor shuffle dependencies - registerShuffleDependencies(shuffleDep, jobId) + registerShuffleDependencies(shuffleDep, firstJobId) // Then register current shuffleDep - val stage = newOrUsedShuffleStage(shuffleDep, jobId) + val stage = newOrUsedShuffleStage(shuffleDep, firstJobId) shuffleToMapStage(shuffleDep.shuffleId) = stage stage @@ -230,15 +253,15 @@ class DAGScheduler( /** * Helper function to eliminate some code re-use when creating new stages. */ - private def getParentStagesAndId(rdd: RDD[_], jobId: Int): (List[Stage], Int) = { - val parentStages = getParentStages(rdd, jobId) + private def getParentStagesAndId(rdd: RDD[_], firstJobId: Int): (List[Stage], Int) = { + val parentStages = getParentStages(rdd, firstJobId) val id = nextStageId.getAndIncrement() (parentStages, id) } /** * Create a ShuffleMapStage as part of the (re)-creation of a shuffle map stage in - * newOrUsedShuffleStage. The stage will be associated with the provided jobId. + * newOrUsedShuffleStage. The stage will be associated with the provided firstJobId. * Production of shuffle map stages should always use newOrUsedShuffleStage, not * newShuffleMapStage directly. */ @@ -246,21 +269,19 @@ class DAGScheduler( rdd: RDD[_], numTasks: Int, shuffleDep: ShuffleDependency[_, _, _], - jobId: Int, + firstJobId: Int, callSite: CallSite): ShuffleMapStage = { - val (parentStages: List[Stage], id: Int) = getParentStagesAndId(rdd, jobId) + val (parentStages: List[Stage], id: Int) = getParentStagesAndId(rdd, firstJobId) val stage: ShuffleMapStage = new ShuffleMapStage(id, rdd, numTasks, parentStages, - jobId, callSite, shuffleDep) + firstJobId, callSite, shuffleDep) stageIdToStage(id) = stage - updateJobIdStageIdMaps(jobId, stage) + updateJobIdStageIdMaps(firstJobId, stage) stage } /** - * Create a ResultStage -- either directly for use as a result stage, or as part of the - * (re)-creation of a shuffle map stage in newOrUsedShuffleStage. The stage will be associated - * with the provided jobId. + * Create a ResultStage associated with the provided jobId. */ private def newResultStage( rdd: RDD[_], @@ -277,16 +298,16 @@ class DAGScheduler( /** * Create a shuffle map Stage for the given RDD. The stage will also be associated with the - * provided jobId. If a stage for the shuffleId existed previously so that the shuffleId is + * provided firstJobId. If a stage for the shuffleId existed previously so that the shuffleId is * present in the MapOutputTracker, then the number and location of available outputs are * recovered from the MapOutputTracker */ private def newOrUsedShuffleStage( shuffleDep: ShuffleDependency[_, _, _], - jobId: Int): ShuffleMapStage = { + firstJobId: Int): ShuffleMapStage = { val rdd = shuffleDep.rdd val numTasks = rdd.partitions.size - val stage = newShuffleMapStage(rdd, numTasks, shuffleDep, jobId, rdd.creationSite) + val stage = newShuffleMapStage(rdd, numTasks, shuffleDep, firstJobId, rdd.creationSite) if (mapOutputTracker.containsShuffle(shuffleDep.shuffleId)) { val serLocs = mapOutputTracker.getSerializedMapOutputStatuses(shuffleDep.shuffleId) val locs = MapOutputTracker.deserializeMapStatuses(serLocs) @@ -304,10 +325,10 @@ class DAGScheduler( } /** - * Get or create the list of parent stages for a given RDD. The stages will be assigned the - * provided jobId if they haven't already been created with a lower jobId. + * Get or create the list of parent stages for a given RDD. The new Stages will be created with + * the provided firstJobId. */ - private def getParentStages(rdd: RDD[_], jobId: Int): List[Stage] = { + private def getParentStages(rdd: RDD[_], firstJobId: Int): List[Stage] = { val parents = new HashSet[Stage] val visited = new HashSet[RDD[_]] // We are manually maintaining a stack here to prevent StackOverflowError @@ -321,7 +342,7 @@ class DAGScheduler( for (dep <- r.dependencies) { dep match { case shufDep: ShuffleDependency[_, _, _] => - parents += getShuffleMapStage(shufDep, jobId) + parents += getShuffleMapStage(shufDep, firstJobId) case _ => waitingForVisit.push(dep.rdd) } @@ -336,11 +357,11 @@ class DAGScheduler( } /** Find ancestor missing shuffle dependencies and register into shuffleToMapStage */ - private def registerShuffleDependencies(shuffleDep: ShuffleDependency[_, _, _], jobId: Int) { + private def registerShuffleDependencies(shuffleDep: ShuffleDependency[_, _, _], firstJobId: Int) { val parentsWithNoMapStage = getAncestorShuffleDependencies(shuffleDep.rdd) while (parentsWithNoMapStage.nonEmpty) { val currentShufDep = parentsWithNoMapStage.pop() - val stage = newOrUsedShuffleStage(currentShufDep, jobId) + val stage = newOrUsedShuffleStage(currentShufDep, firstJobId) shuffleToMapStage(currentShufDep.shuffleId) = stage } } @@ -386,11 +407,12 @@ class DAGScheduler( def visit(rdd: RDD[_]) { if (!visited(rdd)) { visited += rdd - if (getCacheLocs(rdd).contains(Nil)) { + val rddHasUncachedPartitions = getCacheLocs(rdd).contains(Nil) + if (rddHasUncachedPartitions) { for (dep <- rdd.dependencies) { dep match { case shufDep: ShuffleDependency[_, _, _] => - val mapStage = getShuffleMapStage(shufDep, stage.jobId) + val mapStage = getShuffleMapStage(shufDep, stage.firstJobId) if (!mapStage.isAvailable) { missing += mapStage } @@ -577,7 +599,7 @@ class DAGScheduler( private[scheduler] def doCancelAllJobs() { // Cancel all running jobs. - runningStages.map(_.jobId).foreach(handleJobCancellation(_, + runningStages.map(_.firstJobId).foreach(handleJobCancellation(_, reason = "as part of cancellation of all jobs")) activeJobs.clear() // These should already be empty by this point, jobIdToActiveJob.clear() // but just in case we lost track of some jobs... @@ -603,7 +625,7 @@ class DAGScheduler( clearCacheLocs() val failedStagesCopy = failedStages.toArray failedStages.clear() - for (stage <- failedStagesCopy.sortBy(_.jobId)) { + for (stage <- failedStagesCopy.sortBy(_.firstJobId)) { submitStage(stage) } } @@ -623,7 +645,7 @@ class DAGScheduler( logTrace("failed: " + failedStages) val waitingStagesCopy = waitingStages.toArray waitingStages.clear() - for (stage <- waitingStagesCopy.sortBy(_.jobId)) { + for (stage <- waitingStagesCopy.sortBy(_.firstJobId)) { submitStage(stage) } } @@ -843,7 +865,7 @@ class DAGScheduler( } } - val properties = jobIdToActiveJob.get(stage.jobId).map(_.properties).orNull + val properties = jobIdToActiveJob.get(stage.firstJobId).map(_.properties).orNull runningStages += stage // SparkListenerStageSubmitted should be posted before testing whether tasks are @@ -886,22 +908,29 @@ class DAGScheduler( return } - val tasks: Seq[Task[_]] = stage match { - case stage: ShuffleMapStage => - partitionsToCompute.map { id => - val locs = getPreferredLocs(stage.rdd, id) - val part = stage.rdd.partitions(id) - new ShuffleMapTask(stage.id, taskBinary, part, locs) - } + val tasks: Seq[Task[_]] = try { + stage match { + case stage: ShuffleMapStage => + partitionsToCompute.map { id => + val locs = getPreferredLocs(stage.rdd, id) + val part = stage.rdd.partitions(id) + new ShuffleMapTask(stage.id, taskBinary, part, locs) + } - case stage: ResultStage => - val job = stage.resultOfJob.get - partitionsToCompute.map { id => - val p: Int = job.partitions(id) - val part = stage.rdd.partitions(p) - val locs = getPreferredLocs(stage.rdd, p) - new ResultTask(stage.id, taskBinary, part, locs, id) - } + case stage: ResultStage => + val job = stage.resultOfJob.get + partitionsToCompute.map { id => + val p: Int = job.partitions(id) + val part = stage.rdd.partitions(p) + val locs = getPreferredLocs(stage.rdd, p) + new ResultTask(stage.id, taskBinary, part, locs, id) + } + } + } catch { + case NonFatal(e) => + abortStage(stage, s"Task creation failed: $e\n${e.getStackTraceString}") + runningStages -= stage + return } if (tasks.size > 0) { @@ -909,7 +938,7 @@ class DAGScheduler( stage.pendingTasks ++= tasks logDebug("New pending tasks: " + stage.pendingTasks) taskScheduler.submitTasks( - new TaskSet(tasks.toArray, stage.id, stage.newAttemptId(), stage.jobId, properties)) + new TaskSet(tasks.toArray, stage.id, stage.newAttemptId(), stage.firstJobId, properties)) stage.latestInfo.submissionTime = Some(clock.getTimeMillis()) } else { // Because we posted SparkListenerStageSubmitted earlier, we should mark @@ -1323,7 +1352,7 @@ class DAGScheduler( for (dep <- rdd.dependencies) { dep match { case shufDep: ShuffleDependency[_, _, _] => - val mapStage = getShuffleMapStage(shufDep, stage.jobId) + val mapStage = getShuffleMapStage(shufDep, stage.firstJobId) if (!mapStage.isAvailable) { waitingForVisit.push(mapStage.rdd) } // Otherwise there's no need to follow the dependency back @@ -1364,10 +1393,10 @@ class DAGScheduler( private def getPreferredLocsInternal( rdd: RDD[_], partition: Int, - visited: HashSet[(RDD[_],Int)]): Seq[TaskLocation] = { + visited: HashSet[(RDD[_], Int)]): Seq[TaskLocation] = { // If the partition has already been visited, no need to re-visit. // This avoids exponential path exploration. SPARK-695 - if (!visited.add((rdd,partition))) { + if (!visited.add((rdd, partition))) { // Nil has already been returned for previously visited partitions. return Nil } @@ -1381,17 +1410,32 @@ class DAGScheduler( if (rddPrefs.nonEmpty) { return rddPrefs.map(TaskLocation(_)) } - // If the RDD has narrow dependencies, pick the first partition of the first narrow dep - // that has any placement preferences. Ideally we would choose based on transfer sizes, - // but this will do for now. + rdd.dependencies.foreach { case n: NarrowDependency[_] => + // If the RDD has narrow dependencies, pick the first partition of the first narrow dep + // that has any placement preferences. Ideally we would choose based on transfer sizes, + // but this will do for now. for (inPart <- n.getParents(partition)) { val locs = getPreferredLocsInternal(n.rdd, inPart, visited) if (locs != Nil) { return locs } } + case s: ShuffleDependency[_, _, _] => + // For shuffle dependencies, pick locations which have at least REDUCER_PREF_LOCS_FRACTION + // of data as preferred locations + if (shuffleLocalityEnabled && + rdd.partitions.size < SHUFFLE_PREF_REDUCE_THRESHOLD && + s.rdd.partitions.size < SHUFFLE_PREF_MAP_THRESHOLD) { + // Get the preferred map output locations for this reducer + val topLocsForReducer = mapOutputTracker.getLocationsWithLargestOutputs(s.shuffleId, + partition, rdd.partitions.size, REDUCER_PREF_LOCS_FRACTION) + if (topLocsForReducer.nonEmpty) { + return topLocsForReducer.get.map(loc => TaskLocation(loc.host, loc.executorId)) + } + } + case _ => } Nil @@ -1404,17 +1448,29 @@ class DAGScheduler( taskScheduler.stop() } - // Start the event thread at the end of the constructor + // Start the event thread and register the metrics source at the end of the constructor + env.metricsSystem.registerSource(metricsSource) eventProcessLoop.start() } private[scheduler] class DAGSchedulerEventProcessLoop(dagScheduler: DAGScheduler) extends EventLoop[DAGSchedulerEvent]("dag-scheduler-event-loop") with Logging { + private[this] val timer = dagScheduler.metricsSource.messageProcessingTimer + /** * The main event loop of the DAG scheduler. */ - override def onReceive(event: DAGSchedulerEvent): Unit = event match { + override def onReceive(event: DAGSchedulerEvent): Unit = { + val timerContext = timer.time() + try { + doOnReceive(event) + } finally { + timerContext.stop() + } + } + + private def doOnReceive(event: DAGSchedulerEvent): Unit = event match { case JobSubmitted(jobId, rdd, func, partitions, allowLocal, callSite, listener, properties) => dagScheduler.handleJobSubmitted(jobId, rdd, func, partitions, allowLocal, callSite, listener, properties) diff --git a/core/src/main/scala/org/apache/spark/scheduler/DAGSchedulerSource.scala b/core/src/main/scala/org/apache/spark/scheduler/DAGSchedulerSource.scala index 12668b6c0988..6b667d5d7645 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/DAGSchedulerSource.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/DAGSchedulerSource.scala @@ -17,12 +17,11 @@ package org.apache.spark.scheduler -import com.codahale.metrics.{Gauge,MetricRegistry} +import com.codahale.metrics.{Gauge, MetricRegistry, Timer} -import org.apache.spark.SparkContext import org.apache.spark.metrics.source.Source -private[spark] class DAGSchedulerSource(val dagScheduler: DAGScheduler) +private[scheduler] class DAGSchedulerSource(val dagScheduler: DAGScheduler) extends Source { override val metricRegistry = new MetricRegistry() override val sourceName = "DAGScheduler" @@ -46,4 +45,8 @@ private[spark] class DAGSchedulerSource(val dagScheduler: DAGScheduler) metricRegistry.register(MetricRegistry.name("job", "activeJobs"), new Gauge[Int] { override def getValue: Int = dagScheduler.activeJobs.size }) + + /** Timer that tracks the time to process messages in the DAGScheduler's event loop */ + val messageProcessingTimer: Timer = + metricRegistry.timer(MetricRegistry.name("messageProcessingTime")) } diff --git a/core/src/main/scala/org/apache/spark/scheduler/ReplayListenerBus.scala b/core/src/main/scala/org/apache/spark/scheduler/ReplayListenerBus.scala index 86f357abb872..c6d957b65f3f 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/ReplayListenerBus.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/ReplayListenerBus.scala @@ -41,7 +41,7 @@ private[spark] class ReplayListenerBus extends SparkListenerBus with Logging { * * @param logData Stream containing event log data. * @param sourceName Filename (or other source identifier) from whence @logData is being read - * @param maybeTruncated Indicate whether log file might be truncated (some abnormal situations + * @param maybeTruncated Indicate whether log file might be truncated (some abnormal situations * encountered, log file might not finished writing) or not */ def replay( @@ -62,7 +62,7 @@ private[spark] class ReplayListenerBus extends SparkListenerBus with Logging { if (!maybeTruncated || lines.hasNext) { throw jpe } else { - logWarning(s"Got JsonParseException from log file $sourceName" + + logWarning(s"Got JsonParseException from log file $sourceName" + s" at line $lineNumber, the file might not have finished writing cleanly.") } } diff --git a/core/src/main/scala/org/apache/spark/scheduler/ResultStage.scala b/core/src/main/scala/org/apache/spark/scheduler/ResultStage.scala index c0f3d5a13d62..bf81b9aca481 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/ResultStage.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/ResultStage.scala @@ -28,9 +28,9 @@ private[spark] class ResultStage( rdd: RDD[_], numTasks: Int, parents: List[Stage], - jobId: Int, + firstJobId: Int, callSite: CallSite) - extends Stage(id, rdd, numTasks, parents, jobId, callSite) { + extends Stage(id, rdd, numTasks, parents, firstJobId, callSite) { // The active job for this result stage. Will be empty if the job has already finished // (e.g., because the job was cancelled). diff --git a/core/src/main/scala/org/apache/spark/scheduler/SchedulerBackend.scala b/core/src/main/scala/org/apache/spark/scheduler/SchedulerBackend.scala index 646820520ea1..8801a761afae 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/SchedulerBackend.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/SchedulerBackend.scala @@ -49,4 +49,11 @@ private[spark] trait SchedulerBackend { */ def applicationAttemptId(): Option[String] = None + /** + * Get the URLs for the driver logs. These URLs are used to display the links in the UI + * Executors tab for the driver. + * @return Map containing the log names and their respective URLs + */ + def getDriverLogUrls: Option[Map[String, String]] = None + } diff --git a/core/src/main/scala/org/apache/spark/scheduler/SchedulingAlgorithm.scala b/core/src/main/scala/org/apache/spark/scheduler/SchedulingAlgorithm.scala index 5e62c8468f00..864941d468af 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/SchedulingAlgorithm.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/SchedulingAlgorithm.scala @@ -56,7 +56,7 @@ private[spark] class FairSchedulingAlgorithm extends SchedulingAlgorithm { val minShareRatio2 = runningTasks2.toDouble / math.max(minShare2, 1.0).toDouble val taskToWeightRatio1 = runningTasks1.toDouble / s1.weight.toDouble val taskToWeightRatio2 = runningTasks2.toDouble / s2.weight.toDouble - var compare:Int = 0 + var compare: Int = 0 if (s1Needy && !s2Needy) { return true diff --git a/core/src/main/scala/org/apache/spark/scheduler/ShuffleMapStage.scala b/core/src/main/scala/org/apache/spark/scheduler/ShuffleMapStage.scala index d02210743484..66c75f325fcd 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/ShuffleMapStage.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/ShuffleMapStage.scala @@ -30,10 +30,10 @@ private[spark] class ShuffleMapStage( rdd: RDD[_], numTasks: Int, parents: List[Stage], - jobId: Int, + firstJobId: Int, callSite: CallSite, val shuffleDep: ShuffleDependency[_, _, _]) - extends Stage(id, rdd, numTasks, parents, jobId, callSite) { + extends Stage(id, rdd, numTasks, parents, firstJobId, callSite) { override def toString: String = "ShuffleMapStage " + id diff --git a/core/src/main/scala/org/apache/spark/scheduler/SparkListener.scala b/core/src/main/scala/org/apache/spark/scheduler/SparkListener.scala index 169d4fd3a94f..9620915f495a 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/SparkListener.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/SparkListener.scala @@ -110,8 +110,13 @@ case class SparkListenerExecutorMetricsUpdate( extends SparkListenerEvent @DeveloperApi -case class SparkListenerApplicationStart(appName: String, appId: Option[String], - time: Long, sparkUser: String, appAttemptId: Option[String]) extends SparkListenerEvent +case class SparkListenerApplicationStart( + appName: String, + appId: Option[String], + time: Long, + sparkUser: String, + appAttemptId: Option[String], + driverLogs: Option[Map[String, String]] = None) extends SparkListenerEvent @DeveloperApi case class SparkListenerApplicationEnd(time: Long) extends SparkListenerEvent @@ -265,7 +270,7 @@ class StatsReportListener extends SparkListener with Logging { private[spark] object StatsReportListener extends Logging { // For profiling, the extremes are more interesting - val percentiles = Array[Int](0,5,10,25,50,75,90,95,100) + val percentiles = Array[Int](0, 5, 10, 25, 50, 75, 90, 95, 100) val probabilities = percentiles.map(_ / 100.0) val percentilesHeader = "\t" + percentiles.mkString("%\t") + "%" @@ -299,7 +304,7 @@ private[spark] object StatsReportListener extends Logging { dOpt.foreach { d => showDistribution(heading, d, formatNumber)} } - def showDistribution(heading: String, dOpt: Option[Distribution], format:String) { + def showDistribution(heading: String, dOpt: Option[Distribution], format: String) { def f(d: Double): String = format.format(d) showDistribution(heading, dOpt, f _) } @@ -313,7 +318,7 @@ private[spark] object StatsReportListener extends Logging { } def showBytesDistribution( - heading:String, + heading: String, getMetric: (TaskInfo, TaskMetrics) => Option[Long], taskInfoMetrics: Seq[(TaskInfo, TaskMetrics)]) { showBytesDistribution(heading, extractLongDistribution(taskInfoMetrics, getMetric)) diff --git a/core/src/main/scala/org/apache/spark/scheduler/Stage.scala b/core/src/main/scala/org/apache/spark/scheduler/Stage.scala index 5d0ddb8377c3..c59d6e4f5bc0 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/Stage.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/Stage.scala @@ -34,7 +34,7 @@ import org.apache.spark.util.CallSite * initiated a job (e.g. count(), save(), etc). For shuffle map stages, we also track the nodes * that each output partition is on. * - * Each Stage also has a jobId, identifying the job that first submitted the stage. When FIFO + * Each Stage also has a firstJobId, identifying the job that first submitted the stage. When FIFO * scheduling is used, this allows Stages from earlier jobs to be computed first or recovered * faster on failure. * @@ -51,7 +51,7 @@ private[spark] abstract class Stage( val rdd: RDD[_], val numTasks: Int, val parents: List[Stage], - val jobId: Int, + val firstJobId: Int, val callSite: CallSite) extends Logging { diff --git a/core/src/main/scala/org/apache/spark/scheduler/Task.scala b/core/src/main/scala/org/apache/spark/scheduler/Task.scala index 586d1e06204c..15101c64f050 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/Task.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/Task.scala @@ -125,7 +125,7 @@ private[spark] abstract class Task[T](val stageId: Int, var partitionId: Int) ex if (interruptThread && taskThread != null) { taskThread.interrupt() } - } + } } /** diff --git a/core/src/main/scala/org/apache/spark/scheduler/TaskResult.scala b/core/src/main/scala/org/apache/spark/scheduler/TaskResult.scala index 1f114a0207f7..8b2a742b9698 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/TaskResult.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/TaskResult.scala @@ -40,6 +40,9 @@ class DirectTaskResult[T](var valueBytes: ByteBuffer, var accumUpdates: Map[Long var metrics: TaskMetrics) extends TaskResult[T] with Externalizable { + private var valueObjectDeserialized = false + private var valueObject: T = _ + def this() = this(null.asInstanceOf[ByteBuffer], null, null) override def writeExternal(out: ObjectOutput): Unit = Utils.tryOrIOException { @@ -72,10 +75,26 @@ class DirectTaskResult[T](var valueBytes: ByteBuffer, var accumUpdates: Map[Long } } metrics = in.readObject().asInstanceOf[TaskMetrics] + valueObjectDeserialized = false } + /** + * When `value()` is called at the first time, it needs to deserialize `valueObject` from + * `valueBytes`. It may cost dozens of seconds for a large instance. So when calling `value` at + * the first time, the caller should avoid to block other threads. + * + * After the first time, `value()` is trivial and just returns the deserialized `valueObject`. + */ def value(): T = { - val resultSer = SparkEnv.get.serializer.newInstance() - resultSer.deserialize(valueBytes) + if (valueObjectDeserialized) { + valueObject + } else { + // This should not run when holding a lock because it may cost dozens of seconds for a large + // value. + val resultSer = SparkEnv.get.serializer.newInstance() + valueObject = resultSer.deserialize(valueBytes) + valueObjectDeserialized = true + valueObject + } } } diff --git a/core/src/main/scala/org/apache/spark/scheduler/TaskResultGetter.scala b/core/src/main/scala/org/apache/spark/scheduler/TaskResultGetter.scala index 391827c1d215..46a6f6537e2e 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/TaskResultGetter.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/TaskResultGetter.scala @@ -54,6 +54,10 @@ private[spark] class TaskResultGetter(sparkEnv: SparkEnv, scheduler: TaskSchedul if (!taskSetManager.canFetchMoreResults(serializedData.limit())) { return } + // deserialize "value" without holding any lock so that it won't block other threads. + // We should call it here, so that when it's called again in + // "TaskSetManager.handleSuccessfulTask", it does not need to deserialize the value. + directResult.value() (directResult, serializedData.limit()) case IndirectTaskResult(blockId, size) => if (!taskSetManager.canFetchMoreResults(size)) { diff --git a/core/src/main/scala/org/apache/spark/scheduler/TaskSchedulerImpl.scala b/core/src/main/scala/org/apache/spark/scheduler/TaskSchedulerImpl.scala index b4b8a630694b..ed3dde0fc305 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/TaskSchedulerImpl.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/TaskSchedulerImpl.scala @@ -19,9 +19,9 @@ package org.apache.spark.scheduler import java.nio.ByteBuffer import java.util.{TimerTask, Timer} +import java.util.concurrent.TimeUnit import java.util.concurrent.atomic.AtomicLong -import scala.concurrent.duration._ import scala.collection.mutable.ArrayBuffer import scala.collection.mutable.HashMap import scala.collection.mutable.HashSet @@ -32,7 +32,7 @@ import org.apache.spark._ import org.apache.spark.TaskState.TaskState import org.apache.spark.scheduler.SchedulingMode.SchedulingMode import org.apache.spark.scheduler.TaskLocality.TaskLocality -import org.apache.spark.util.Utils +import org.apache.spark.util.{ThreadUtils, Utils} import org.apache.spark.executor.TaskMetrics import org.apache.spark.storage.BlockManagerId @@ -64,6 +64,9 @@ private[spark] class TaskSchedulerImpl( // How often to check for speculative tasks val SPECULATION_INTERVAL_MS = conf.getTimeAsMs("spark.speculation.interval", "100ms") + private val speculationScheduler = + ThreadUtils.newDaemonSingleThreadScheduledExecutor("task-scheduler-speculation") + // Threshold above which we warn user initial TaskSet may be starved val STARVATION_TIMEOUT_MS = conf.getTimeAsMs("spark.starvation.timeout", "15s") @@ -142,10 +145,11 @@ private[spark] class TaskSchedulerImpl( if (!isLocal && conf.getBoolean("spark.speculation", false)) { logInfo("Starting speculative execution thread") - sc.env.actorSystem.scheduler.schedule(SPECULATION_INTERVAL_MS milliseconds, - SPECULATION_INTERVAL_MS milliseconds) { - Utils.tryOrStopSparkContext(sc) { checkSpeculatableTasks() } - }(sc.env.actorSystem.dispatcher) + speculationScheduler.scheduleAtFixedRate(new Runnable { + override def run(): Unit = Utils.tryOrStopSparkContext(sc) { + checkSpeculatableTasks() + } + }, SPECULATION_INTERVAL_MS, SPECULATION_INTERVAL_MS, TimeUnit.MILLISECONDS) } } @@ -412,6 +416,7 @@ private[spark] class TaskSchedulerImpl( } override def stop() { + speculationScheduler.shutdown() if (backend != null) { backend.stop() } diff --git a/core/src/main/scala/org/apache/spark/scheduler/TaskSetManager.scala b/core/src/main/scala/org/apache/spark/scheduler/TaskSetManager.scala index 7dc325283d96..82455b0426a5 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/TaskSetManager.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/TaskSetManager.scala @@ -46,7 +46,7 @@ import org.apache.spark.util.{Clock, SystemClock, Utils} * * @param sched the TaskSchedulerImpl associated with the TaskSetManager * @param taskSet the TaskSet to manage scheduling for - * @param maxTaskFailures if any particular task fails more than this number of times, the entire + * @param maxTaskFailures if any particular task fails this number of times, the entire * task set will be aborted */ private[spark] class TaskSetManager( @@ -620,6 +620,12 @@ private[spark] class TaskSetManager( val index = info.index info.markSuccessful() removeRunningTask(tid) + // This method is called by "TaskSchedulerImpl.handleSuccessfulTask" which holds the + // "TaskSchedulerImpl" lock until exiting. To avoid the SPARK-7655 issue, we should not + // "deserialize" the value when holding a lock to avoid blocking other threads. So we call + // "result.value()" in "TaskResultGetter.enqueueSuccessfulTask" before reaching here. + // Note: "result.value()" only deserializes the value when it's called at the first time, so + // here "result.value()" just returns the value and won't block other threads. sched.dagScheduler.taskEnded( tasks(index), Success, result.value(), result.accumUpdates, info, result.metrics) if (!successful(index)) { @@ -775,10 +781,10 @@ private[spark] class TaskSetManager( // that it's okay if we add a task to the same queue twice (if it had multiple preferred // locations), because dequeueTaskFromList will skip already-running tasks. for (index <- getPendingTasksForExecutor(execId)) { - addPendingTask(index, readding=true) + addPendingTask(index, readding = true) } for (index <- getPendingTasksForHost(host)) { - addPendingTask(index, readding=true) + addPendingTask(index, readding = true) } // Re-enqueue any tasks that ran on the failed executor if this is a shuffle map stage, @@ -855,9 +861,9 @@ private[spark] class TaskSetManager( case TaskLocality.RACK_LOCAL => "spark.locality.wait.rack" case _ => null } - + if (localityWaitKey != null) { - conf.getTimeAsMs(localityWaitKey, defaultWait) + conf.getTimeAsMs(localityWaitKey, defaultWait) } else { 0L } diff --git a/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedClusterMessage.scala b/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedClusterMessage.scala index 70364cea62a8..4be1eda2e929 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedClusterMessage.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedClusterMessage.scala @@ -75,7 +75,8 @@ private[spark] object CoarseGrainedClusterMessages { case class SetupDriver(driver: RpcEndpointRef) extends CoarseGrainedClusterMessage // Exchanged between the driver and the AM in Yarn client mode - case class AddWebUIFilter(filterName:String, filterParams: Map[String, String], proxyBase: String) + case class AddWebUIFilter( + filterName: String, filterParams: Map[String, String], proxyBase: String) extends CoarseGrainedClusterMessage // Messages exchanged between the driver and the cluster manager for executor allocation diff --git a/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedSchedulerBackend.scala b/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedSchedulerBackend.scala index f107148f3b8c..7c7f70d8a193 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedSchedulerBackend.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedSchedulerBackend.scala @@ -69,6 +69,11 @@ class CoarseGrainedSchedulerBackend(scheduler: TaskSchedulerImpl, val rpcEnv: Rp class DriverEndpoint(override val rpcEnv: RpcEnv, sparkProperties: Seq[(String, String)]) extends ThreadSafeRpcEndpoint with Logging { + // If this DriverEndpoint is changed to support multiple threads, + // then this may need to be changed so that we don't share the serializer + // instance across threads + private val ser = SparkEnv.get.closureSerializer.newInstance() + override protected def log = CoarseGrainedSchedulerBackend.this.log private val addressToExecutorId = new HashMap[RpcAddress, String] @@ -79,7 +84,7 @@ class CoarseGrainedSchedulerBackend(scheduler: TaskSchedulerImpl, val rpcEnv: Rp override def onStart() { // Periodically revive offers to allow delay scheduling to work val reviveIntervalMs = conf.getTimeAsMs("spark.scheduler.revive.interval", "1s") - + reviveThread.scheduleAtFixedRate(new Runnable { override def run(): Unit = Utils.tryLogNonFatalError { Option(self).foreach(_.send(ReviveOffers)) @@ -98,7 +103,7 @@ class CoarseGrainedSchedulerBackend(scheduler: TaskSchedulerImpl, val rpcEnv: Rp case None => // Ignoring the update since we don't know about the executor. logWarning(s"Ignored task status update ($taskId state $state) " + - "from unknown executor $sender with ID $executorId") + s"from unknown executor with ID $executorId") } } @@ -163,7 +168,7 @@ class CoarseGrainedSchedulerBackend(scheduler: TaskSchedulerImpl, val rpcEnv: Rp } // Make fake resource offers on all executors - def makeOffers() { + private def makeOffers() { launchTasks(scheduler.resourceOffers(executorDataMap.map { case (id, executorData) => new WorkerOffer(id, executorData.executorHost, executorData.freeCores) }.toSeq)) @@ -175,16 +180,15 @@ class CoarseGrainedSchedulerBackend(scheduler: TaskSchedulerImpl, val rpcEnv: Rp } // Make fake resource offers on just one executor - def makeOffers(executorId: String) { + private def makeOffers(executorId: String) { val executorData = executorDataMap(executorId) launchTasks(scheduler.resourceOffers( Seq(new WorkerOffer(executorId, executorData.executorHost, executorData.freeCores)))) } // Launch tasks returned by a set of resource offers - def launchTasks(tasks: Seq[Seq[TaskDescription]]) { + private def launchTasks(tasks: Seq[Seq[TaskDescription]]) { for (task <- tasks.flatten) { - val ser = SparkEnv.get.closureSerializer.newInstance() val serializedTask = ser.serialize(task) if (serializedTask.limit >= akkaFrameSize - AkkaUtils.reservedSizeBytes) { val taskSetId = scheduler.taskIdToTaskSetId(task.taskId) diff --git a/core/src/main/scala/org/apache/spark/scheduler/cluster/SparkDeploySchedulerBackend.scala b/core/src/main/scala/org/apache/spark/scheduler/cluster/SparkDeploySchedulerBackend.scala index ccf1dc5af612..687ae9620460 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/cluster/SparkDeploySchedulerBackend.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/cluster/SparkDeploySchedulerBackend.scala @@ -85,7 +85,7 @@ private[spark] class SparkDeploySchedulerBackend( val coresPerExecutor = conf.getOption("spark.executor.cores").map(_.toInt) val appDesc = new ApplicationDescription(sc.appName, maxCores, sc.executorMemory, command, appUIAddress, sc.eventLogDir, sc.eventLogCodec, coresPerExecutor) - client = new AppClient(sc.env.actorSystem, masters, appDesc, this, conf) + client = new AppClient(sc.env.rpcEnv, masters, appDesc, this, conf) client.start() waitForRegistration() } diff --git a/core/src/main/scala/org/apache/spark/scheduler/cluster/YarnSchedulerBackend.scala b/core/src/main/scala/org/apache/spark/scheduler/cluster/YarnSchedulerBackend.scala index 2a3a5d925d06..bc67abb5df44 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/cluster/YarnSchedulerBackend.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/cluster/YarnSchedulerBackend.scala @@ -46,7 +46,7 @@ private[spark] abstract class YarnSchedulerBackend( private val yarnSchedulerEndpoint = rpcEnv.setupEndpoint( YarnSchedulerBackend.ENDPOINT_NAME, new YarnSchedulerEndpoint(rpcEnv)) - private implicit val askTimeout = RpcUtils.askTimeout(sc.conf) + private implicit val askTimeout = RpcUtils.askRpcTimeout(sc.conf) /** * Request executors from the ApplicationMaster by specifying the total number desired. @@ -149,7 +149,7 @@ private[spark] abstract class YarnSchedulerBackend( } } - override def onStop(): Unit ={ + override def onStop(): Unit = { askAmThreadPool.shutdownNow() } } diff --git a/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/CoarseMesosSchedulerBackend.scala b/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/CoarseMesosSchedulerBackend.scala index dc59545b4331..b68f8c7685eb 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/CoarseMesosSchedulerBackend.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/CoarseMesosSchedulerBackend.scala @@ -18,17 +18,18 @@ package org.apache.spark.scheduler.cluster.mesos import java.io.File -import java.util.{Collections, List => JList} +import java.util.{List => JList} import scala.collection.JavaConversions._ import scala.collection.mutable.{HashMap, HashSet} -import org.apache.mesos.Protos.{TaskInfo => MesosTaskInfo, _} import org.apache.mesos.{Scheduler => MScheduler, _} +import org.apache.mesos.Protos.{TaskInfo => MesosTaskInfo, _} +import org.apache.spark.{SparkContext, SparkEnv, SparkException, TaskState} +import org.apache.spark.rpc.RpcAddress import org.apache.spark.scheduler.TaskSchedulerImpl import org.apache.spark.scheduler.cluster.CoarseGrainedSchedulerBackend -import org.apache.spark.util.{AkkaUtils, Utils} -import org.apache.spark.{SparkContext, SparkEnv, SparkException, TaskState} +import org.apache.spark.util.Utils /** * A SchedulerBackend that runs tasks on Mesos, but uses "coarse-grained" tasks, where it holds @@ -51,7 +52,7 @@ private[spark] class CoarseMesosSchedulerBackend( val MAX_SLAVE_FAILURES = 2 // Blacklist a slave after this many failures // Maximum number of cores to acquire (TODO: we'll need more flexible controls here) - val maxCores = conf.get("spark.cores.max", Int.MaxValue.toString).toInt + val maxCores = conf.get("spark.cores.max", Int.MaxValue.toString).toInt // Cores we have acquired with each Mesos task ID val coresByTaskId = new HashMap[Int, Int] @@ -65,6 +66,10 @@ private[spark] class CoarseMesosSchedulerBackend( val extraCoresPerSlave = conf.getInt("spark.mesos.extra.cores", 0) + // Offer constraints + private val slaveOfferConstraints = + parseConstraintString(sc.conf.get("spark.mesos.constraints", "")) + var nextMesosTaskId = 0 @volatile var appId: String = _ @@ -115,11 +120,9 @@ private[spark] class CoarseMesosSchedulerBackend( } val command = CommandInfo.newBuilder() .setEnvironment(environment) - val driverUrl = AkkaUtils.address( - AkkaUtils.protocol(sc.env.actorSystem), + val driverUrl = sc.env.rpcEnv.uriOf( SparkEnv.driverActorSystemName, - conf.get("spark.driver.host"), - conf.get("spark.driver.port"), + RpcAddress(conf.get("spark.driver.host"), conf.get("spark.driver.port").toInt), CoarseGrainedSchedulerBackend.ENDPOINT_NAME) val uri = conf.getOption("spark.executor.uri") @@ -171,13 +174,16 @@ private[spark] class CoarseMesosSchedulerBackend( override def resourceOffers(d: SchedulerDriver, offers: JList[Offer]) { synchronized { val filters = Filters.newBuilder().setRefuseSeconds(5).build() - for (offer <- offers) { + val offerAttributes = toAttributeMap(offer.getAttributesList) + val meetsConstraints = matchesAttributeRequirements(slaveOfferConstraints, offerAttributes) val slaveId = offer.getSlaveId.toString val mem = getResource(offer.getResourcesList, "mem") val cpus = getResource(offer.getResourcesList, "cpus").toInt - if (totalCoresAcquired < maxCores && - mem >= MemoryUtils.calculateTotalMemory(sc) && + val id = offer.getId.getValue + if (meetsConstraints && + totalCoresAcquired < maxCores && + mem >= calculateTotalMemory(sc) && cpus >= 1 && failuresBySlaveId.getOrElse(slaveId, 0) < MAX_SLAVE_FAILURES && !slaveIdsWithExecutors.contains(slaveId)) { @@ -194,33 +200,25 @@ private[spark] class CoarseMesosSchedulerBackend( .setCommand(createCommand(offer, cpusToUse + extraCoresPerSlave)) .setName("Task " + taskId) .addResources(createResource("cpus", cpusToUse)) - .addResources(createResource("mem", - MemoryUtils.calculateTotalMemory(sc))) + .addResources(createResource("mem", calculateTotalMemory(sc))) sc.conf.getOption("spark.mesos.executor.docker.image").foreach { image => MesosSchedulerBackendUtil - .setupContainerBuilderDockerInfo(image, sc.conf, task.getContainerBuilder()) + .setupContainerBuilderDockerInfo(image, sc.conf, task.getContainerBuilder) } - d.launchTasks( - Collections.singleton(offer.getId), Collections.singletonList(task.build()), filters) + // accept the offer and launch the task + logDebug(s"Accepting offer: $id with attributes: $offerAttributes mem: $mem cpu: $cpus") + d.launchTasks(List(offer.getId), List(task.build()), filters) } else { - // Filter it out - d.launchTasks( - Collections.singleton(offer.getId), Collections.emptyList[MesosTaskInfo](), filters) + // Decline the offer + logDebug(s"Declining offer: $id with attributes: $offerAttributes mem: $mem cpu: $cpus") + d.declineOffer(offer.getId) } } } } - /** Build a Mesos resource protobuf object */ - private def createResource(resourceName: String, quantity: Double): Protos.Resource = { - Resource.newBuilder() - .setName(resourceName) - .setType(Value.Type.SCALAR) - .setScalar(Value.Scalar.newBuilder().setValue(quantity).build()) - .build() - } override def statusUpdate(d: SchedulerDriver, status: TaskStatus) { val taskId = status.getTaskId.getValue.toInt diff --git a/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosClusterScheduler.scala b/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosClusterScheduler.scala index 1067a7f1caf4..d3a20f822176 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosClusterScheduler.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosClusterScheduler.scala @@ -29,6 +29,7 @@ import org.apache.mesos.Protos.Environment.Variable import org.apache.mesos.Protos.TaskStatus.Reason import org.apache.mesos.Protos.{TaskState => MesosTaskState, _} import org.apache.mesos.{Scheduler, SchedulerDriver} + import org.apache.spark.deploy.mesos.MesosDriverDescription import org.apache.spark.deploy.rest.{CreateSubmissionResponse, KillSubmissionResponse, SubmissionStatusResponse} import org.apache.spark.metrics.MetricsSystem diff --git a/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosSchedulerBackend.scala b/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosSchedulerBackend.scala index db0a080b3b0c..d72e2af456e1 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosSchedulerBackend.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosSchedulerBackend.scala @@ -23,14 +23,14 @@ import java.util.{ArrayList => JArrayList, Collections, List => JList} import scala.collection.JavaConversions._ import scala.collection.mutable.{HashMap, HashSet} +import org.apache.mesos.{Scheduler => MScheduler, _} import org.apache.mesos.Protos.{ExecutorInfo => MesosExecutorInfo, TaskInfo => MesosTaskInfo, _} import org.apache.mesos.protobuf.ByteString -import org.apache.mesos.{Scheduler => MScheduler, _} +import org.apache.spark.{SparkContext, SparkException, TaskState} import org.apache.spark.executor.MesosExecutorBackend import org.apache.spark.scheduler._ import org.apache.spark.scheduler.cluster.ExecutorInfo import org.apache.spark.util.Utils -import org.apache.spark.{SparkContext, SparkException, TaskState} /** * A SchedulerBackend for running fine-grained tasks on Mesos. Each Spark task is mapped to a @@ -59,6 +59,10 @@ private[spark] class MesosSchedulerBackend( private[mesos] val mesosExecutorCores = sc.conf.getDouble("spark.mesos.mesosExecutor.cores", 1) + // Offer constraints + private[this] val slaveOfferConstraints = + parseConstraintString(sc.conf.get("spark.mesos.constraints", "")) + @volatile var appId: String = _ override def start() { @@ -71,8 +75,8 @@ private[spark] class MesosSchedulerBackend( val executorSparkHome = sc.conf.getOption("spark.mesos.executor.home") .orElse(sc.getSparkHome()) // Fall back to driver Spark home for backward compatibility .getOrElse { - throw new SparkException("Executor Spark home `spark.mesos.executor.home` is not set!") - } + throw new SparkException("Executor Spark home `spark.mesos.executor.home` is not set!") + } val environment = Environment.newBuilder() sc.conf.getOption("spark.executor.extraClassPath").foreach { cp => environment.addVariables( @@ -115,14 +119,14 @@ private[spark] class MesosSchedulerBackend( .setName("cpus") .setType(Value.Type.SCALAR) .setScalar(Value.Scalar.newBuilder() - .setValue(mesosExecutorCores).build()) + .setValue(mesosExecutorCores).build()) .build() val memory = Resource.newBuilder() .setName("mem") .setType(Value.Type.SCALAR) .setScalar( Value.Scalar.newBuilder() - .setValue(MemoryUtils.calculateTotalMemory(sc)).build()) + .setValue(calculateTotalMemory(sc)).build()) .build() val executorInfo = MesosExecutorInfo.newBuilder() .setExecutorId(ExecutorID.newBuilder().setValue(execId).build()) @@ -146,7 +150,7 @@ private[spark] class MesosSchedulerBackend( private def createExecArg(): Array[Byte] = { if (execArgs == null) { val props = new HashMap[String, String] - for ((key,value) <- sc.conf.getAll) { + for ((key, value) <- sc.conf.getAll) { props(key) = value } // Serialize the map as an array of (String, String) pairs @@ -191,13 +195,31 @@ private[spark] class MesosSchedulerBackend( val mem = getResource(o.getResourcesList, "mem") val cpus = getResource(o.getResourcesList, "cpus") val slaveId = o.getSlaveId.getValue - (mem >= MemoryUtils.calculateTotalMemory(sc) && - // need at least 1 for executor, 1 for task - cpus >= (mesosExecutorCores + scheduler.CPUS_PER_TASK)) || - (slaveIdsWithExecutors.contains(slaveId) && - cpus >= scheduler.CPUS_PER_TASK) + val offerAttributes = toAttributeMap(o.getAttributesList) + + // check if all constraints are satisfield + // 1. Attribute constraints + // 2. Memory requirements + // 3. CPU requirements - need at least 1 for executor, 1 for task + val meetsConstraints = matchesAttributeRequirements(slaveOfferConstraints, offerAttributes) + val meetsMemoryRequirements = mem >= calculateTotalMemory(sc) + val meetsCPURequirements = cpus >= (mesosExecutorCores + scheduler.CPUS_PER_TASK) + + val meetsRequirements = + (meetsConstraints && meetsMemoryRequirements && meetsCPURequirements) || + (slaveIdsWithExecutors.contains(slaveId) && cpus >= scheduler.CPUS_PER_TASK) + + // add some debug messaging + val debugstr = if (meetsRequirements) "Accepting" else "Declining" + val id = o.getId.getValue + logDebug(s"$debugstr offer: $id with attributes: $offerAttributes mem: $mem cpu: $cpus") + + meetsRequirements } + // Decline offers we ruled out immediately + unUsableOffers.foreach(o => d.declineOffer(o.getId)) + val workerOffers = usableOffers.map { o => val cpus = if (slaveIdsWithExecutors.contains(o.getSlaveId.getValue)) { getResource(o.getResourcesList, "cpus").toInt @@ -223,15 +245,15 @@ private[spark] class MesosSchedulerBackend( val acceptedOffers = scheduler.resourceOffers(workerOffers).filter(!_.isEmpty) acceptedOffers .foreach { offer => - offer.foreach { taskDesc => - val slaveId = taskDesc.executorId - slaveIdsWithExecutors += slaveId - slavesIdsOfAcceptedOffers += slaveId - taskIdToSlaveId(taskDesc.taskId) = slaveId - mesosTasks.getOrElseUpdate(slaveId, new JArrayList[MesosTaskInfo]) - .add(createMesosTask(taskDesc, slaveId)) - } + offer.foreach { taskDesc => + val slaveId = taskDesc.executorId + slaveIdsWithExecutors += slaveId + slavesIdsOfAcceptedOffers += slaveId + taskIdToSlaveId(taskDesc.taskId) = slaveId + mesosTasks.getOrElseUpdate(slaveId, new JArrayList[MesosTaskInfo]) + .add(createMesosTask(taskDesc, slaveId)) } + } // Reply to the offers val filters = Filters.newBuilder().setRefuseSeconds(1).build() // TODO: lower timeout? @@ -251,8 +273,6 @@ private[spark] class MesosSchedulerBackend( d.declineOffer(o.getId) } - // Decline offers we ruled out immediately - unUsableOffers.foreach(o => d.declineOffer(o.getId)) } } diff --git a/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosSchedulerBackendUtil.scala b/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosSchedulerBackendUtil.scala index 928c5cfed417..e79c543a9de2 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosSchedulerBackendUtil.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosSchedulerBackendUtil.scala @@ -37,14 +37,14 @@ private[mesos] object MesosSchedulerBackendUtil extends Logging { .newBuilder() .setMode(Volume.Mode.RW) spec match { - case Array(container_path) => + case Array(container_path) => Some(vol.setContainerPath(container_path)) case Array(container_path, "rw") => Some(vol.setContainerPath(container_path)) case Array(container_path, "ro") => Some(vol.setContainerPath(container_path) .setMode(Volume.Mode.RO)) - case Array(host_path, container_path) => + case Array(host_path, container_path) => Some(vol.setContainerPath(container_path) .setHostPath(host_path)) case Array(host_path, container_path, "rw") => @@ -108,7 +108,7 @@ private[mesos] object MesosSchedulerBackendUtil extends Logging { image: String, volumes: Option[List[Volume]] = None, network: Option[ContainerInfo.DockerInfo.Network] = None, - portmaps: Option[List[ContainerInfo.DockerInfo.PortMapping]] = None):Unit = { + portmaps: Option[List[ContainerInfo.DockerInfo.PortMapping]] = None): Unit = { val docker = ContainerInfo.DockerInfo.newBuilder().setImage(image) diff --git a/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosSchedulerUtils.scala b/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosSchedulerUtils.scala index d11228f3d016..d8a8c848bb4d 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosSchedulerUtils.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosSchedulerUtils.scala @@ -17,14 +17,17 @@ package org.apache.spark.scheduler.cluster.mesos -import java.util.List +import java.util.{List => JList} import java.util.concurrent.CountDownLatch import scala.collection.JavaConversions._ +import scala.util.control.NonFatal -import org.apache.mesos.Protos.{FrameworkInfo, Resource, Status} -import org.apache.mesos.{MesosSchedulerDriver, Scheduler} -import org.apache.spark.Logging +import com.google.common.base.Splitter +import org.apache.mesos.{MesosSchedulerDriver, Protos, Scheduler} +import org.apache.mesos.Protos._ +import org.apache.mesos.protobuf.GeneratedMessage +import org.apache.spark.{Logging, SparkContext} import org.apache.spark.util.Utils /** @@ -86,10 +89,150 @@ private[mesos] trait MesosSchedulerUtils extends Logging { /** * Get the amount of resources for the specified type from the resource list */ - protected def getResource(res: List[Resource], name: String): Double = { + protected def getResource(res: JList[Resource], name: String): Double = { for (r <- res if r.getName == name) { return r.getScalar.getValue } 0.0 } + + /** Helper method to get the key,value-set pair for a Mesos Attribute protobuf */ + protected def getAttribute(attr: Attribute): (String, Set[String]) = { + (attr.getName, attr.getText.getValue.split(',').toSet) + } + + + /** Build a Mesos resource protobuf object */ + protected def createResource(resourceName: String, quantity: Double): Protos.Resource = { + Resource.newBuilder() + .setName(resourceName) + .setType(Value.Type.SCALAR) + .setScalar(Value.Scalar.newBuilder().setValue(quantity).build()) + .build() + } + + /** + * Converts the attributes from the resource offer into a Map of name -> Attribute Value + * The attribute values are the mesos attribute types and they are + * @param offerAttributes + * @return + */ + protected def toAttributeMap(offerAttributes: JList[Attribute]): Map[String, GeneratedMessage] = { + offerAttributes.map(attr => { + val attrValue = attr.getType match { + case Value.Type.SCALAR => attr.getScalar + case Value.Type.RANGES => attr.getRanges + case Value.Type.SET => attr.getSet + case Value.Type.TEXT => attr.getText + } + (attr.getName, attrValue) + }).toMap + } + + + /** + * Match the requirements (if any) to the offer attributes. + * if attribute requirements are not specified - return true + * else if attribute is defined and no values are given, simple attribute presence is performed + * else if attribute name and value is specified, subset match is performed on slave attributes + */ + def matchesAttributeRequirements( + slaveOfferConstraints: Map[String, Set[String]], + offerAttributes: Map[String, GeneratedMessage]): Boolean = { + slaveOfferConstraints.forall { + // offer has the required attribute and subsumes the required values for that attribute + case (name, requiredValues) => + offerAttributes.get(name) match { + case None => false + case Some(_) if requiredValues.isEmpty => true // empty value matches presence + case Some(scalarValue: Value.Scalar) => + // check if provided values is less than equal to the offered values + requiredValues.map(_.toDouble).exists(_ <= scalarValue.getValue) + case Some(rangeValue: Value.Range) => + val offerRange = rangeValue.getBegin to rangeValue.getEnd + // Check if there is some required value that is between the ranges specified + // Note: We only support the ability to specify discrete values, in the future + // we may expand it to subsume ranges specified with a XX..YY value or something + // similar to that. + requiredValues.map(_.toLong).exists(offerRange.contains(_)) + case Some(offeredValue: Value.Set) => + // check if the specified required values is a subset of offered set + requiredValues.subsetOf(offeredValue.getItemList.toSet) + case Some(textValue: Value.Text) => + // check if the specified value is equal, if multiple values are specified + // we succeed if any of them match. + requiredValues.contains(textValue.getValue) + } + } + } + + /** + * Parses the attributes constraints provided to spark and build a matching data struct: + * Map[, Set[values-to-match]] + * The constraints are specified as ';' separated key-value pairs where keys and values + * are separated by ':'. The ':' implies equality (for singular values) and "is one of" for + * multiple values (comma separated). For example: + * {{{ + * parseConstraintString("tachyon:true;zone:us-east-1a,us-east-1b") + * // would result in + * + * Map( + * "tachyon" -> Set("true"), + * "zone": -> Set("us-east-1a", "us-east-1b") + * ) + * }}} + * + * Mesos documentation: http://mesos.apache.org/documentation/attributes-resources/ + * https://github.com/apache/mesos/blob/master/src/common/values.cpp + * https://github.com/apache/mesos/blob/master/src/common/attributes.cpp + * + * @param constraintsVal constaints string consisting of ';' separated key-value pairs (separated + * by ':') + * @return Map of constraints to match resources offers. + */ + def parseConstraintString(constraintsVal: String): Map[String, Set[String]] = { + /* + Based on mesos docs: + attributes : attribute ( ";" attribute )* + attribute : labelString ":" ( labelString | "," )+ + labelString : [a-zA-Z0-9_/.-] + */ + val splitter = Splitter.on(';').trimResults().withKeyValueSeparator(':') + // kv splitter + if (constraintsVal.isEmpty) { + Map() + } else { + try { + Map() ++ mapAsScalaMap(splitter.split(constraintsVal)).map { + case (k, v) => + if (v == null || v.isEmpty) { + (k, Set[String]()) + } else { + (k, v.split(',').toSet) + } + } + } catch { + case NonFatal(e) => + throw new IllegalArgumentException(s"Bad constraint string: $constraintsVal", e) + } + } + } + + // These defaults copied from YARN + private val MEMORY_OVERHEAD_FRACTION = 0.10 + private val MEMORY_OVERHEAD_MINIMUM = 384 + + /** + * Return the amount of memory to allocate to each executor, taking into account + * container overheads. + * @param sc SparkContext to use to get `spark.mesos.executor.memoryOverhead` value + * @return memory requirement as (0.1 * ) or MEMORY_OVERHEAD_MINIMUM + * (whichever is larger) + */ + def calculateTotalMemory(sc: SparkContext): Int = { + sc.conf.getInt("spark.mesos.executor.memoryOverhead", + math.max(MEMORY_OVERHEAD_FRACTION * sc.executorMemory, MEMORY_OVERHEAD_MINIMUM).toInt) + + sc.executorMemory + } + } diff --git a/core/src/main/scala/org/apache/spark/scheduler/local/LocalBackend.scala b/core/src/main/scala/org/apache/spark/scheduler/local/LocalBackend.scala index e64d06c4d3cf..3078a1b10be8 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/local/LocalBackend.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/local/LocalBackend.scala @@ -18,14 +18,12 @@ package org.apache.spark.scheduler.local import java.nio.ByteBuffer -import java.util.concurrent.TimeUnit import org.apache.spark.{Logging, SparkConf, SparkContext, SparkEnv, TaskState} import org.apache.spark.TaskState.TaskState import org.apache.spark.executor.{Executor, ExecutorBackend} -import org.apache.spark.rpc.{ThreadSafeRpcEndpoint, RpcCallContext, RpcEndpointRef, RpcEnv} +import org.apache.spark.rpc.{RpcCallContext, RpcEndpointRef, RpcEnv, ThreadSafeRpcEndpoint} import org.apache.spark.scheduler.{SchedulerBackend, TaskSchedulerImpl, WorkerOffer} -import org.apache.spark.util.{ThreadUtils, Utils} private case class ReviveOffers() @@ -47,9 +45,6 @@ private[spark] class LocalEndpoint( private val totalCores: Int) extends ThreadSafeRpcEndpoint with Logging { - private val reviveThread = - ThreadUtils.newDaemonSingleThreadScheduledExecutor("local-revive-thread") - private var freeCores = totalCores private val localExecutorId = SparkContext.DRIVER_IDENTIFIER @@ -79,27 +74,13 @@ private[spark] class LocalEndpoint( context.reply(true) } - def reviveOffers() { val offers = Seq(new WorkerOffer(localExecutorId, localExecutorHostname, freeCores)) - val tasks = scheduler.resourceOffers(offers).flatten - for (task <- tasks) { + for (task <- scheduler.resourceOffers(offers).flatten) { freeCores -= scheduler.CPUS_PER_TASK executor.launchTask(executorBackend, taskId = task.taskId, attemptNumber = task.attemptNumber, task.name, task.serializedTask) } - if (tasks.isEmpty && scheduler.activeTaskSets.nonEmpty) { - // Try to reviveOffer after 1 second, because scheduler may wait for locality timeout - reviveThread.schedule(new Runnable { - override def run(): Unit = Utils.tryLogNonFatalError { - Option(self).foreach(_.send(ReviveOffers)) - } - }, 1000, TimeUnit.MILLISECONDS) - } - } - - override def onStop(): Unit = { - reviveThread.shutdownNow() } } diff --git a/core/src/main/scala/org/apache/spark/serializer/KryoSerializer.scala b/core/src/main/scala/org/apache/spark/serializer/KryoSerializer.scala index 64ba27f34d2f..ed35cffe968f 100644 --- a/core/src/main/scala/org/apache/spark/serializer/KryoSerializer.scala +++ b/core/src/main/scala/org/apache/spark/serializer/KryoSerializer.scala @@ -17,8 +17,9 @@ package org.apache.spark.serializer -import java.io.{EOFException, InputStream, OutputStream} +import java.io.{EOFException, IOException, InputStream, OutputStream} import java.nio.ByteBuffer +import javax.annotation.Nullable import scala.reflect.ClassTag @@ -35,7 +36,7 @@ import org.apache.spark.network.nio.{GetBlock, GotBlock, PutBlock} import org.apache.spark.network.util.ByteUnit import org.apache.spark.scheduler.{CompressedMapStatus, HighlyCompressedMapStatus} import org.apache.spark.storage._ -import org.apache.spark.util.BoundedPriorityQueue +import org.apache.spark.util.{BoundedPriorityQueue, SerializableConfiguration, SerializableJobConf} import org.apache.spark.util.collection.CompactBuffer /** @@ -51,7 +52,7 @@ class KryoSerializer(conf: SparkConf) with Serializable { private val bufferSizeKb = conf.getSizeAsKb("spark.kryoserializer.buffer", "64k") - + if (bufferSizeKb >= ByteUnit.GiB.toKiB(2)) { throw new IllegalArgumentException("spark.kryoserializer.buffer must be less than " + s"2048 mb, got: + ${ByteUnit.KiB.toMiB(bufferSizeKb)} mb.") @@ -93,8 +94,10 @@ class KryoSerializer(conf: SparkConf) // For results returned by asJavaIterable. See JavaIterableWrapperSerializer. kryo.register(JavaIterableWrapperSerializer.wrapperClass, new JavaIterableWrapperSerializer) - // Allow sending SerializableWritable + // Allow sending classes with custom Java serializers kryo.register(classOf[SerializableWritable[_]], new KryoJavaSerializer()) + kryo.register(classOf[SerializableConfiguration], new KryoJavaSerializer()) + kryo.register(classOf[SerializableJobConf], new KryoJavaSerializer()) kryo.register(classOf[HttpBroadcast[_]], new KryoJavaSerializer()) kryo.register(classOf[PythonBroadcast], new KryoJavaSerializer()) @@ -136,21 +139,45 @@ class KryoSerializer(conf: SparkConf) } private[spark] -class KryoSerializationStream(kryo: Kryo, outStream: OutputStream) extends SerializationStream { - val output = new KryoOutput(outStream) +class KryoSerializationStream( + serInstance: KryoSerializerInstance, + outStream: OutputStream) extends SerializationStream { + + private[this] var output: KryoOutput = new KryoOutput(outStream) + private[this] var kryo: Kryo = serInstance.borrowKryo() override def writeObject[T: ClassTag](t: T): SerializationStream = { kryo.writeClassAndObject(output, t) this } - override def flush() { output.flush() } - override def close() { output.close() } + override def flush() { + if (output == null) { + throw new IOException("Stream is closed") + } + output.flush() + } + + override def close() { + if (output != null) { + try { + output.close() + } finally { + serInstance.releaseKryo(kryo) + kryo = null + output = null + } + } + } } private[spark] -class KryoDeserializationStream(kryo: Kryo, inStream: InputStream) extends DeserializationStream { - private val input = new KryoInput(inStream) +class KryoDeserializationStream( + serInstance: KryoSerializerInstance, + inStream: InputStream) extends DeserializationStream { + + private[this] var input: KryoInput = new KryoInput(inStream) + private[this] var kryo: Kryo = serInstance.borrowKryo() override def readObject[T: ClassTag](): T = { try { @@ -163,50 +190,105 @@ class KryoDeserializationStream(kryo: Kryo, inStream: InputStream) extends Deser } override def close() { - // Kryo's Input automatically closes the input stream it is using. - input.close() + if (input != null) { + try { + // Kryo's Input automatically closes the input stream it is using. + input.close() + } finally { + serInstance.releaseKryo(kryo) + kryo = null + input = null + } + } } } private[spark] class KryoSerializerInstance(ks: KryoSerializer) extends SerializerInstance { - private val kryo = ks.newKryo() - // Make these lazy vals to avoid creating a buffer unless we use them + /** + * A re-used [[Kryo]] instance. Methods will borrow this instance by calling `borrowKryo()`, do + * their work, then release the instance by calling `releaseKryo()`. Logically, this is a caching + * pool of size one. SerializerInstances are not thread-safe, hence accesses to this field are + * not synchronized. + */ + @Nullable private[this] var cachedKryo: Kryo = borrowKryo() + + /** + * Borrows a [[Kryo]] instance. If possible, this tries to re-use a cached Kryo instance; + * otherwise, it allocates a new instance. + */ + private[serializer] def borrowKryo(): Kryo = { + if (cachedKryo != null) { + val kryo = cachedKryo + // As a defensive measure, call reset() to clear any Kryo state that might have been modified + // by the last operation to borrow this instance (see SPARK-7766 for discussion of this issue) + kryo.reset() + cachedKryo = null + kryo + } else { + ks.newKryo() + } + } + + /** + * Release a borrowed [[Kryo]] instance. If this serializer instance already has a cached Kryo + * instance, then the given Kryo instance is discarded; otherwise, the Kryo is stored for later + * re-use. + */ + private[serializer] def releaseKryo(kryo: Kryo): Unit = { + if (cachedKryo == null) { + cachedKryo = kryo + } + } + + // Make these lazy vals to avoid creating a buffer unless we use them. private lazy val output = ks.newKryoOutput() private lazy val input = new KryoInput() override def serialize[T: ClassTag](t: T): ByteBuffer = { output.clear() + val kryo = borrowKryo() try { kryo.writeClassAndObject(output, t) } catch { case e: KryoException if e.getMessage.startsWith("Buffer overflow") => throw new SparkException(s"Kryo serialization failed: ${e.getMessage}. To avoid this, " + "increase spark.kryoserializer.buffer.max value.") + } finally { + releaseKryo(kryo) } ByteBuffer.wrap(output.toBytes) } override def deserialize[T: ClassTag](bytes: ByteBuffer): T = { - input.setBuffer(bytes.array) - kryo.readClassAndObject(input).asInstanceOf[T] + val kryo = borrowKryo() + try { + input.setBuffer(bytes.array) + kryo.readClassAndObject(input).asInstanceOf[T] + } finally { + releaseKryo(kryo) + } } override def deserialize[T: ClassTag](bytes: ByteBuffer, loader: ClassLoader): T = { + val kryo = borrowKryo() val oldClassLoader = kryo.getClassLoader - kryo.setClassLoader(loader) - input.setBuffer(bytes.array) - val obj = kryo.readClassAndObject(input).asInstanceOf[T] - kryo.setClassLoader(oldClassLoader) - obj + try { + kryo.setClassLoader(loader) + input.setBuffer(bytes.array) + kryo.readClassAndObject(input).asInstanceOf[T] + } finally { + kryo.setClassLoader(oldClassLoader) + releaseKryo(kryo) + } } override def serializeStream(s: OutputStream): SerializationStream = { - new KryoSerializationStream(kryo, s) + new KryoSerializationStream(this, s) } override def deserializeStream(s: InputStream): DeserializationStream = { - new KryoDeserializationStream(kryo, s) + new KryoDeserializationStream(this, s) } /** @@ -216,7 +298,12 @@ private[spark] class KryoSerializerInstance(ks: KryoSerializer) extends Serializ def getAutoReset(): Boolean = { val field = classOf[Kryo].getDeclaredField("autoReset") field.setAccessible(true) - field.get(kryo).asInstanceOf[Boolean] + val kryo = borrowKryo() + try { + field.get(kryo).asInstanceOf[Boolean] + } finally { + releaseKryo(kryo) + } } } diff --git a/core/src/main/scala/org/apache/spark/serializer/SerializationDebugger.scala b/core/src/main/scala/org/apache/spark/serializer/SerializationDebugger.scala index 5abfa467c0ec..cc2f0506817d 100644 --- a/core/src/main/scala/org/apache/spark/serializer/SerializationDebugger.scala +++ b/core/src/main/scala/org/apache/spark/serializer/SerializationDebugger.scala @@ -17,7 +17,7 @@ package org.apache.spark.serializer -import java.io.{NotSerializableException, ObjectOutput, ObjectStreamClass, ObjectStreamField} +import java.io._ import java.lang.reflect.{Field, Method} import java.security.AccessController @@ -27,7 +27,7 @@ import scala.util.control.NonFatal import org.apache.spark.Logging -private[serializer] object SerializationDebugger extends Logging { +private[spark] object SerializationDebugger extends Logging { /** * Improve the given NotSerializableException with the serialization path leading from the given @@ -62,7 +62,7 @@ private[serializer] object SerializationDebugger extends Logging { * * It does not yet handle writeObject override, but that shouldn't be too hard to do either. */ - def find(obj: Any): List[String] = { + private[serializer] def find(obj: Any): List[String] = { new SerializationDebugger().visit(obj, List.empty) } @@ -125,6 +125,12 @@ private[serializer] object SerializationDebugger extends Logging { return List.empty } + /** + * Visit an externalizable object. + * Since writeExternal() can choose to add arbitrary objects at the time of serialization, + * the only way to capture all the objects it will serialize is by using a + * dummy ObjectOutput that collects all the relevant objects for further testing. + */ private def visitExternalizable(o: java.io.Externalizable, stack: List[String]): List[String] = { val fieldList = new ListObjectOutput @@ -145,17 +151,50 @@ private[serializer] object SerializationDebugger extends Logging { // An object contains multiple slots in serialization. // Get the slots and visit fields in all of them. val (finalObj, desc) = findObjectAndDescriptor(o) + + // If the object has been replaced using writeReplace(), + // then call visit() on it again to test its type again. + if (!finalObj.eq(o)) { + return visit(finalObj, s"writeReplace data (class: ${finalObj.getClass.getName})" :: stack) + } + + // Every class is associated with one or more "slots", each slot refers to the parent + // classes of this class. These slots are used by the ObjectOutputStream + // serialization code to recursively serialize the fields of an object and + // its parent classes. For example, if there are the following classes. + // + // class ParentClass(parentField: Int) + // class ChildClass(childField: Int) extends ParentClass(1) + // + // Then serializing the an object Obj of type ChildClass requires first serializing the fields + // of ParentClass (that is, parentField), and then serializing the fields of ChildClass + // (that is, childField). Correspondingly, there will be two slots related to this object: + // + // 1. ParentClass slot, which will be used to serialize parentField of Obj + // 2. ChildClass slot, which will be used to serialize childField fields of Obj + // + // The following code uses the description of each slot to find the fields in the + // corresponding object to visit. + // val slotDescs = desc.getSlotDescs var i = 0 while (i < slotDescs.length) { val slotDesc = slotDescs(i) if (slotDesc.hasWriteObjectMethod) { - // TODO: Handle classes that specify writeObject method. + // If the class type corresponding to current slot has writeObject() defined, + // then its not obvious which fields of the class will be serialized as the writeObject() + // can choose arbitrary fields for serialization. This case is handled separately. + val elem = s"writeObject data (class: ${slotDesc.getName})" + val childStack = visitSerializableWithWriteObjectMethod(finalObj, elem :: stack) + if (childStack.nonEmpty) { + return childStack + } } else { + // Visit all the fields objects of the class corresponding to the current slot. val fields: Array[ObjectStreamField] = slotDesc.getFields val objFieldValues: Array[Object] = new Array[Object](slotDesc.getNumObjFields) val numPrims = fields.length - objFieldValues.length - desc.getObjFieldValues(finalObj, objFieldValues) + slotDesc.getObjFieldValues(finalObj, objFieldValues) var j = 0 while (j < objFieldValues.length) { @@ -169,18 +208,54 @@ private[serializer] object SerializationDebugger extends Logging { } j += 1 } - } i += 1 } return List.empty } + + /** + * Visit a serializable object which has the writeObject() defined. + * Since writeObject() can choose to add arbitrary objects at the time of serialization, + * the only way to capture all the objects it will serialize is by using a + * dummy ObjectOutputStream that collects all the relevant fields for further testing. + * This is similar to how externalizable objects are visited. + */ + private def visitSerializableWithWriteObjectMethod( + o: Object, stack: List[String]): List[String] = { + val innerObjectsCatcher = new ListObjectOutputStream + var notSerializableFound = false + try { + innerObjectsCatcher.writeObject(o) + } catch { + case io: IOException => + notSerializableFound = true + } + + // If something was not serializable, then visit the captured objects. + // Otherwise, all the captured objects are safely serializable, so no need to visit them. + // As an optimization, just added them to the visited list. + if (notSerializableFound) { + val innerObjects = innerObjectsCatcher.outputArray + var k = 0 + while (k < innerObjects.length) { + val childStack = visit(innerObjects(k), stack) + if (childStack.nonEmpty) { + return childStack + } + k += 1 + } + } else { + visited ++= innerObjectsCatcher.outputArray + } + return List.empty + } } /** * Find the object to serialize and the associated [[ObjectStreamClass]]. This method handles * writeReplace in Serializable. It starts with the object itself, and keeps calling the - * writeReplace method until there is no more + * writeReplace method until there is no more. */ @tailrec private def findObjectAndDescriptor(o: Object): (Object, ObjectStreamClass) = { @@ -220,6 +295,31 @@ private[serializer] object SerializationDebugger extends Logging { override def writeByte(i: Int): Unit = {} } + /** An output stream that emulates /dev/null */ + private class NullOutputStream extends OutputStream { + override def write(b: Int) { } + } + + /** + * A dummy [[ObjectOutputStream]] that saves the list of objects written to it and returns + * them through `outputArray`. This works by using the [[ObjectOutputStream]]'s `replaceObject()` + * method which gets called on every object, only if replacing is enabled. So this subclass + * of [[ObjectOutputStream]] enabled replacing, and uses replaceObject to get the objects that + * are being serializabled. The serialized bytes are ignored by sending them to a + * [[NullOutputStream]], which acts like a /dev/null. + */ + private class ListObjectOutputStream extends ObjectOutputStream(new NullOutputStream) { + private val output = new mutable.ArrayBuffer[Any] + this.enableReplaceObject(true) + + def outputArray: Array[Any] = output.toArray + + override def replaceObject(obj: Object): Object = { + output += obj + obj + } + } + /** An implicit class that allows us to call private methods of ObjectStreamClass. */ implicit class ObjectStreamClassMethods(val desc: ObjectStreamClass) extends AnyVal { def getSlotDescs: Array[ObjectStreamClass] = { diff --git a/core/src/main/scala/org/apache/spark/serializer/Serializer.scala b/core/src/main/scala/org/apache/spark/serializer/Serializer.scala index 6078c9d433eb..bd2704dc8187 100644 --- a/core/src/main/scala/org/apache/spark/serializer/Serializer.scala +++ b/core/src/main/scala/org/apache/spark/serializer/Serializer.scala @@ -19,6 +19,7 @@ package org.apache.spark.serializer import java.io._ import java.nio.ByteBuffer +import javax.annotation.concurrent.NotThreadSafe import scala.reflect.ClassTag @@ -114,8 +115,12 @@ object Serializer { /** * :: DeveloperApi :: * An instance of a serializer, for use by one thread at a time. + * + * It is legal to create multiple serialization / deserialization streams from the same + * SerializerInstance as long as those streams are all used within the same thread. */ @DeveloperApi +@NotThreadSafe abstract class SerializerInstance { def serialize[T: ClassTag](t: T): ByteBuffer @@ -177,6 +182,7 @@ abstract class DeserializationStream { } catch { case eof: EOFException => finished = true + null } } diff --git a/core/src/main/scala/org/apache/spark/shuffle/hash/BlockStoreShuffleFetcher.scala b/core/src/main/scala/org/apache/spark/shuffle/hash/BlockStoreShuffleFetcher.scala index 80374adc4429..9d8e7e9f03ae 100644 --- a/core/src/main/scala/org/apache/spark/shuffle/hash/BlockStoreShuffleFetcher.scala +++ b/core/src/main/scala/org/apache/spark/shuffle/hash/BlockStoreShuffleFetcher.scala @@ -17,29 +17,29 @@ package org.apache.spark.shuffle.hash -import scala.collection.mutable.ArrayBuffer -import scala.collection.mutable.HashMap -import scala.util.{Failure, Success, Try} +import java.io.InputStream + +import scala.collection.mutable.{ArrayBuffer, HashMap} +import scala.util.{Failure, Success} import org.apache.spark._ -import org.apache.spark.serializer.Serializer import org.apache.spark.shuffle.FetchFailedException -import org.apache.spark.storage.{BlockId, BlockManagerId, ShuffleBlockFetcherIterator, ShuffleBlockId} -import org.apache.spark.util.CompletionIterator +import org.apache.spark.storage.{BlockId, BlockManager, BlockManagerId, ShuffleBlockFetcherIterator, + ShuffleBlockId} private[hash] object BlockStoreShuffleFetcher extends Logging { - def fetch[T]( + def fetchBlockStreams( shuffleId: Int, reduceId: Int, context: TaskContext, - serializer: Serializer) - : Iterator[T] = + blockManager: BlockManager, + mapOutputTracker: MapOutputTracker) + : Iterator[(BlockId, InputStream)] = { logDebug("Fetching outputs for shuffle %d, reduce %d".format(shuffleId, reduceId)) - val blockManager = SparkEnv.get.blockManager val startTime = System.currentTimeMillis - val statuses = SparkEnv.get.mapOutputTracker.getServerStatuses(shuffleId, reduceId) + val statuses = mapOutputTracker.getServerStatuses(shuffleId, reduceId) logDebug("Fetching map output location for shuffle %d, reduce %d took %d ms".format( shuffleId, reduceId, System.currentTimeMillis - startTime)) @@ -53,12 +53,21 @@ private[hash] object BlockStoreShuffleFetcher extends Logging { (address, splits.map(s => (ShuffleBlockId(shuffleId, s._1, reduceId), s._2))) } - def unpackBlock(blockPair: (BlockId, Try[Iterator[Any]])) : Iterator[T] = { + val blockFetcherItr = new ShuffleBlockFetcherIterator( + context, + blockManager.shuffleClient, + blockManager, + blocksByAddress, + // Note: we use getSizeAsMb when no suffix is provided for backwards compatibility + SparkEnv.get.conf.getSizeAsMb("spark.reducer.maxSizeInFlight", "48m") * 1024 * 1024) + + // Make sure that fetch failures are wrapped inside a FetchFailedException for the scheduler + blockFetcherItr.map { blockPair => val blockId = blockPair._1 val blockOption = blockPair._2 blockOption match { - case Success(block) => { - block.asInstanceOf[Iterator[T]] + case Success(inputStream) => { + (blockId, inputStream) } case Failure(e) => { blockId match { @@ -72,27 +81,5 @@ private[hash] object BlockStoreShuffleFetcher extends Logging { } } } - - val blockFetcherItr = new ShuffleBlockFetcherIterator( - context, - SparkEnv.get.blockManager.shuffleClient, - blockManager, - blocksByAddress, - serializer, - // Note: we use getSizeAsMb when no suffix is provided for backwards compatibility - SparkEnv.get.conf.getSizeAsMb("spark.reducer.maxSizeInFlight", "48m") * 1024 * 1024) - val itr = blockFetcherItr.flatMap(unpackBlock) - - val completionIter = CompletionIterator[T, Iterator[T]](itr, { - context.taskMetrics.updateShuffleReadMetrics() - }) - - new InterruptibleIterator[T](context, completionIter) { - val readMetrics = context.taskMetrics.createShuffleReadMetricsForDependency() - override def next(): T = { - readMetrics.incRecordsRead(1) - delegate.next() - } - } } } diff --git a/core/src/main/scala/org/apache/spark/shuffle/hash/HashShuffleReader.scala b/core/src/main/scala/org/apache/spark/shuffle/hash/HashShuffleReader.scala index 41bafabde05b..d5c9880659dd 100644 --- a/core/src/main/scala/org/apache/spark/shuffle/hash/HashShuffleReader.scala +++ b/core/src/main/scala/org/apache/spark/shuffle/hash/HashShuffleReader.scala @@ -17,16 +17,20 @@ package org.apache.spark.shuffle.hash -import org.apache.spark.{InterruptibleIterator, TaskContext} +import org.apache.spark.{InterruptibleIterator, MapOutputTracker, SparkEnv, TaskContext} import org.apache.spark.serializer.Serializer import org.apache.spark.shuffle.{BaseShuffleHandle, ShuffleReader} +import org.apache.spark.storage.BlockManager +import org.apache.spark.util.CompletionIterator import org.apache.spark.util.collection.ExternalSorter private[spark] class HashShuffleReader[K, C]( handle: BaseShuffleHandle[K, _, C], startPartition: Int, endPartition: Int, - context: TaskContext) + context: TaskContext, + blockManager: BlockManager = SparkEnv.get.blockManager, + mapOutputTracker: MapOutputTracker = SparkEnv.get.mapOutputTracker) extends ShuffleReader[K, C] { require(endPartition == startPartition + 1, @@ -36,20 +40,52 @@ private[spark] class HashShuffleReader[K, C]( /** Read the combined key-values for this reduce task */ override def read(): Iterator[Product2[K, C]] = { + val blockStreams = BlockStoreShuffleFetcher.fetchBlockStreams( + handle.shuffleId, startPartition, context, blockManager, mapOutputTracker) + + // Wrap the streams for compression based on configuration + val wrappedStreams = blockStreams.map { case (blockId, inputStream) => + blockManager.wrapForCompression(blockId, inputStream) + } + val ser = Serializer.getSerializer(dep.serializer) - val iter = BlockStoreShuffleFetcher.fetch(handle.shuffleId, startPartition, context, ser) + val serializerInstance = ser.newInstance() + + // Create a key/value iterator for each stream + val recordIter = wrappedStreams.flatMap { wrappedStream => + // Note: the asKeyValueIterator below wraps a key/value iterator inside of a + // NextIterator. The NextIterator makes sure that close() is called on the + // underlying InputStream when all records have been read. + serializerInstance.deserializeStream(wrappedStream).asKeyValueIterator + } + + // Update the context task metrics for each record read. + val readMetrics = context.taskMetrics.createShuffleReadMetricsForDependency() + val metricIter = CompletionIterator[(Any, Any), Iterator[(Any, Any)]]( + recordIter.map(record => { + readMetrics.incRecordsRead(1) + record + }), + context.taskMetrics().updateShuffleReadMetrics()) + + // An interruptible iterator must be used here in order to support task cancellation + val interruptibleIter = new InterruptibleIterator[(Any, Any)](context, metricIter) val aggregatedIter: Iterator[Product2[K, C]] = if (dep.aggregator.isDefined) { if (dep.mapSideCombine) { - new InterruptibleIterator(context, dep.aggregator.get.combineCombinersByKey(iter, context)) + // We are reading values that are already combined + val combinedKeyValuesIterator = interruptibleIter.asInstanceOf[Iterator[(K, C)]] + dep.aggregator.get.combineCombinersByKey(combinedKeyValuesIterator, context) } else { - new InterruptibleIterator(context, dep.aggregator.get.combineValuesByKey(iter, context)) + // We don't know the value type, but also don't care -- the dependency *should* + // have made sure its compatible w/ this aggregator, which will convert the value + // type to the combined type C + val keyValuesIterator = interruptibleIter.asInstanceOf[Iterator[(K, Nothing)]] + dep.aggregator.get.combineValuesByKey(keyValuesIterator, context) } } else { require(!dep.mapSideCombine, "Map-side combine without Aggregator specified!") - - // Convert the Product2s to pairs since this is what downstream RDDs currently expect - iter.asInstanceOf[Iterator[Product2[K, C]]].map(pair => (pair._1, pair._2)) + interruptibleIter.asInstanceOf[Iterator[Product2[K, C]]] } // Sort the output if there is a sort ordering defined. diff --git a/core/src/main/scala/org/apache/spark/shuffle/sort/SortShuffleWriter.scala b/core/src/main/scala/org/apache/spark/shuffle/sort/SortShuffleWriter.scala index c9dd6bfc4c21..5865e7640c1c 100644 --- a/core/src/main/scala/org/apache/spark/shuffle/sort/SortShuffleWriter.scala +++ b/core/src/main/scala/org/apache/spark/shuffle/sort/SortShuffleWriter.scala @@ -17,9 +17,10 @@ package org.apache.spark.shuffle.sort -import org.apache.spark.{MapOutputTracker, SparkEnv, Logging, TaskContext} +import org.apache.spark._ import org.apache.spark.executor.ShuffleWriteMetrics import org.apache.spark.scheduler.MapStatus +import org.apache.spark.serializer.Serializer import org.apache.spark.shuffle.{IndexShuffleBlockResolver, ShuffleWriter, BaseShuffleHandle} import org.apache.spark.storage.ShuffleBlockId import org.apache.spark.util.collection.ExternalSorter @@ -35,7 +36,7 @@ private[spark] class SortShuffleWriter[K, V, C]( private val blockManager = SparkEnv.get.blockManager - private var sorter: ExternalSorter[K, V, _] = null + private var sorter: SortShuffleFileWriter[K, V] = null // Are we in the process of stopping? Because map tasks can call stop() with success = true // and then call stop() with success = false if they get an exception, we want to make sure @@ -49,18 +50,27 @@ private[spark] class SortShuffleWriter[K, V, C]( /** Write a bunch of records to this task's output */ override def write(records: Iterator[Product2[K, V]]): Unit = { - if (dep.mapSideCombine) { + sorter = if (dep.mapSideCombine) { require(dep.aggregator.isDefined, "Map-side combine without Aggregator specified!") - sorter = new ExternalSorter[K, V, C]( + new ExternalSorter[K, V, C]( dep.aggregator, Some(dep.partitioner), dep.keyOrdering, dep.serializer) - sorter.insertAll(records) + } else if (SortShuffleWriter.shouldBypassMergeSort( + SparkEnv.get.conf, dep.partitioner.numPartitions, aggregator = None, keyOrdering = None)) { + // If there are fewer than spark.shuffle.sort.bypassMergeThreshold partitions and we don't + // need local aggregation and sorting, write numPartitions files directly and just concatenate + // them at the end. This avoids doing serialization and deserialization twice to merge + // together the spilled files, which would happen with the normal code path. The downside is + // having multiple files open at a time and thus more memory allocated to buffers. + new BypassMergeSortShuffleWriter[K, V](SparkEnv.get.conf, blockManager, dep.partitioner, + writeMetrics, Serializer.getSerializer(dep.serializer)) } else { // In this case we pass neither an aggregator nor an ordering to the sorter, because we don't // care whether the keys get sorted in each partition; that will be done on the reduce side // if the operation being run is sortByKey. - sorter = new ExternalSorter[K, V, V](None, Some(dep.partitioner), None, dep.serializer) - sorter.insertAll(records) + new ExternalSorter[K, V, V]( + aggregator = None, Some(dep.partitioner), ordering = None, dep.serializer) } + sorter.insertAll(records) // Don't bother including the time to open the merged output file in the shuffle write time, // because it just opens a single file, so is typically too fast to measure accurately @@ -100,3 +110,13 @@ private[spark] class SortShuffleWriter[K, V, C]( } } +private[spark] object SortShuffleWriter { + def shouldBypassMergeSort( + conf: SparkConf, + numPartitions: Int, + aggregator: Option[Aggregator[_, _, _]], + keyOrdering: Option[Ordering[_]]): Boolean = { + val bypassMergeThreshold: Int = conf.getInt("spark.shuffle.sort.bypassMergeThreshold", 200) + numPartitions <= bypassMergeThreshold && aggregator.isEmpty && keyOrdering.isEmpty + } +} diff --git a/core/src/main/scala/org/apache/spark/shuffle/unsafe/UnsafeShuffleManager.scala b/core/src/main/scala/org/apache/spark/shuffle/unsafe/UnsafeShuffleManager.scala index f2bfef376d3c..df7bbd64247d 100644 --- a/core/src/main/scala/org/apache/spark/shuffle/unsafe/UnsafeShuffleManager.scala +++ b/core/src/main/scala/org/apache/spark/shuffle/unsafe/UnsafeShuffleManager.scala @@ -56,9 +56,6 @@ private[spark] object UnsafeShuffleManager extends Logging { } else if (dependency.aggregator.isDefined) { log.debug(s"Can't use UnsafeShuffle for shuffle $shufId because an aggregator is defined") false - } else if (dependency.keyOrdering.isDefined) { - log.debug(s"Can't use UnsafeShuffle for shuffle $shufId because a key ordering is defined") - false } else if (dependency.partitioner.numPartitions > MAX_SHUFFLE_OUTPUT_PARTITIONS) { log.debug(s"Can't use UnsafeShuffle for shuffle $shufId because it has more than " + s"$MAX_SHUFFLE_OUTPUT_PARTITIONS partitions") diff --git a/core/src/main/scala/org/apache/spark/status/api/v1/AllStagesResource.scala b/core/src/main/scala/org/apache/spark/status/api/v1/AllStagesResource.scala index 50608588f09a..390c136df79b 100644 --- a/core/src/main/scala/org/apache/spark/status/api/v1/AllStagesResource.scala +++ b/core/src/main/scala/org/apache/spark/status/api/v1/AllStagesResource.scala @@ -169,7 +169,7 @@ private[v1] object AllStagesResource { val outputMetrics: Option[OutputMetricDistributions] = new MetricHelper[InternalOutputMetrics, OutputMetricDistributions](rawMetrics, quantiles) { - def getSubmetrics(raw:InternalTaskMetrics): Option[InternalOutputMetrics] = { + def getSubmetrics(raw: InternalTaskMetrics): Option[InternalOutputMetrics] = { raw.outputMetrics } def build: OutputMetricDistributions = new OutputMetricDistributions( @@ -284,7 +284,7 @@ private[v1] object AllStagesResource { * the options (returning None if the metrics are all empty), and extract the quantiles for each * metric. After creating an instance, call metricOption to get the result type. */ -private[v1] abstract class MetricHelper[I,O]( +private[v1] abstract class MetricHelper[I, O]( rawMetrics: Seq[InternalTaskMetrics], quantiles: Array[Double]) { diff --git a/core/src/main/scala/org/apache/spark/status/api/v1/JsonRootResource.scala b/core/src/main/scala/org/apache/spark/status/api/v1/ApiRootResource.scala similarity index 86% rename from core/src/main/scala/org/apache/spark/status/api/v1/JsonRootResource.scala rename to core/src/main/scala/org/apache/spark/status/api/v1/ApiRootResource.scala index c3ec45f54681..50b6ba67e993 100644 --- a/core/src/main/scala/org/apache/spark/status/api/v1/JsonRootResource.scala +++ b/core/src/main/scala/org/apache/spark/status/api/v1/ApiRootResource.scala @@ -16,6 +16,7 @@ */ package org.apache.spark.status.api.v1 +import java.util.zip.ZipOutputStream import javax.servlet.ServletContext import javax.ws.rs._ import javax.ws.rs.core.{Context, Response} @@ -39,7 +40,7 @@ import org.apache.spark.ui.SparkUI * HistoryServerSuite. */ @Path("/v1") -private[v1] class JsonRootResource extends UIRootFromServletContext { +private[v1] class ApiRootResource extends UIRootFromServletContext { @Path("applications") def getApplicationList(): ApplicationListResource = { @@ -101,7 +102,7 @@ private[v1] class JsonRootResource extends UIRootFromServletContext { @Path("applications/{appId}/stages") - def getStages(@PathParam("appId") appId: String): AllStagesResource= { + def getStages(@PathParam("appId") appId: String): AllStagesResource = { uiRoot.withSparkUI(appId, None) { ui => new AllStagesResource(ui) } @@ -110,14 +111,14 @@ private[v1] class JsonRootResource extends UIRootFromServletContext { @Path("applications/{appId}/{attemptId}/stages") def getStages( @PathParam("appId") appId: String, - @PathParam("attemptId") attemptId: String): AllStagesResource= { + @PathParam("attemptId") attemptId: String): AllStagesResource = { uiRoot.withSparkUI(appId, Some(attemptId)) { ui => new AllStagesResource(ui) } } @Path("applications/{appId}/stages/{stageId: \\d+}") - def getStage(@PathParam("appId") appId: String): OneStageResource= { + def getStage(@PathParam("appId") appId: String): OneStageResource = { uiRoot.withSparkUI(appId, None) { ui => new OneStageResource(ui) } @@ -164,14 +165,26 @@ private[v1] class JsonRootResource extends UIRootFromServletContext { } } + @Path("applications/{appId}/logs") + def getEventLogs( + @PathParam("appId") appId: String): EventLogDownloadResource = { + new EventLogDownloadResource(uiRoot, appId, None) + } + + @Path("applications/{appId}/{attemptId}/logs") + def getEventLogs( + @PathParam("appId") appId: String, + @PathParam("attemptId") attemptId: String): EventLogDownloadResource = { + new EventLogDownloadResource(uiRoot, appId, Some(attemptId)) + } } -private[spark] object JsonRootResource { +private[spark] object ApiRootResource { - def getJsonServlet(uiRoot: UIRoot): ServletContextHandler = { + def getServletHandler(uiRoot: UIRoot): ServletContextHandler = { val jerseyContext = new ServletContextHandler(ServletContextHandler.NO_SESSIONS) - jerseyContext.setContextPath("/json") - val holder:ServletHolder = new ServletHolder(classOf[ServletContainer]) + jerseyContext.setContextPath("/api") + val holder: ServletHolder = new ServletHolder(classOf[ServletContainer]) holder.setInitParameter("com.sun.jersey.config.property.resourceConfigClass", "com.sun.jersey.api.core.PackagesResourceConfig") holder.setInitParameter("com.sun.jersey.config.property.packages", @@ -193,6 +206,17 @@ private[spark] trait UIRoot { def getSparkUI(appKey: String): Option[SparkUI] def getApplicationInfoList: Iterator[ApplicationInfo] + /** + * Write the event logs for the given app to the [[ZipOutputStream]] instance. If attemptId is + * [[None]], event logs for all attempts of this application will be written out. + */ + def writeEventLogs(appId: String, attemptId: Option[String], zipStream: ZipOutputStream): Unit = { + Response.serverError() + .entity("Event logs are only available through the history server.") + .status(Response.Status.SERVICE_UNAVAILABLE) + .build() + } + /** * Get the spark UI with the given appID, and apply a function * to it. If there is no such app, throw an appropriate exception diff --git a/core/src/main/scala/org/apache/spark/status/api/v1/EventLogDownloadResource.scala b/core/src/main/scala/org/apache/spark/status/api/v1/EventLogDownloadResource.scala new file mode 100644 index 000000000000..22e21f0c62a2 --- /dev/null +++ b/core/src/main/scala/org/apache/spark/status/api/v1/EventLogDownloadResource.scala @@ -0,0 +1,70 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.status.api.v1 + +import java.io.OutputStream +import java.util.zip.ZipOutputStream +import javax.ws.rs.{GET, Produces} +import javax.ws.rs.core.{MediaType, Response, StreamingOutput} + +import scala.util.control.NonFatal + +import org.apache.spark.{Logging, SparkConf} +import org.apache.spark.deploy.SparkHadoopUtil + +@Produces(Array(MediaType.APPLICATION_OCTET_STREAM)) +private[v1] class EventLogDownloadResource( + val uIRoot: UIRoot, + val appId: String, + val attemptId: Option[String]) extends Logging { + val conf = SparkHadoopUtil.get.newConfiguration(new SparkConf) + + @GET + def getEventLogs(): Response = { + try { + val fileName = { + attemptId match { + case Some(id) => s"eventLogs-$appId-$id.zip" + case None => s"eventLogs-$appId.zip" + } + } + + val stream = new StreamingOutput { + override def write(output: OutputStream): Unit = { + val zipStream = new ZipOutputStream(output) + try { + uIRoot.writeEventLogs(appId, attemptId, zipStream) + } finally { + zipStream.close() + } + + } + } + + Response.ok(stream) + .header("Content-Disposition", s"attachment; filename=$fileName") + .header("Content-Type", MediaType.APPLICATION_OCTET_STREAM) + .build() + } catch { + case NonFatal(e) => + Response.serverError() + .entity(s"Event logs are not available for app: $appId.") + .status(Response.Status.SERVICE_UNAVAILABLE) + .build() + } + } +} diff --git a/core/src/main/scala/org/apache/spark/status/api/v1/OneRDDResource.scala b/core/src/main/scala/org/apache/spark/status/api/v1/OneRDDResource.scala index 07b224fac478..dfdc09c6caf3 100644 --- a/core/src/main/scala/org/apache/spark/status/api/v1/OneRDDResource.scala +++ b/core/src/main/scala/org/apache/spark/status/api/v1/OneRDDResource.scala @@ -25,7 +25,7 @@ import org.apache.spark.ui.SparkUI private[v1] class OneRDDResource(ui: SparkUI) { @GET - def rddData(@PathParam("rddId") rddId: Int): RDDStorageInfo = { + def rddData(@PathParam("rddId") rddId: Int): RDDStorageInfo = { AllRDDResource.getRDDStorageInfo(rddId, ui.storageListener, true).getOrElse( throw new NotFoundException(s"no rdd found w/ id $rddId") ) diff --git a/core/src/main/scala/org/apache/spark/status/api/v1/OneStageResource.scala b/core/src/main/scala/org/apache/spark/status/api/v1/OneStageResource.scala index fd24aea63a8a..f9812f06cf52 100644 --- a/core/src/main/scala/org/apache/spark/status/api/v1/OneStageResource.scala +++ b/core/src/main/scala/org/apache/spark/status/api/v1/OneStageResource.scala @@ -83,7 +83,7 @@ private[v1] class OneStageResource(ui: SparkUI) { withStageAttempt(stageId, stageAttemptId) { stage => val tasks = stage.ui.taskData.values.map{AllStagesResource.convertTaskData}.toIndexedSeq .sorted(OneStageResource.ordering(sortBy)) - tasks.slice(offset, offset + length) + tasks.slice(offset, offset + length) } } diff --git a/core/src/main/scala/org/apache/spark/status/api/v1/SimpleDateParam.scala b/core/src/main/scala/org/apache/spark/status/api/v1/SimpleDateParam.scala index cee29786c301..0c71cd238222 100644 --- a/core/src/main/scala/org/apache/spark/status/api/v1/SimpleDateParam.scala +++ b/core/src/main/scala/org/apache/spark/status/api/v1/SimpleDateParam.scala @@ -16,40 +16,33 @@ */ package org.apache.spark.status.api.v1 -import java.text.SimpleDateFormat +import java.text.{ParseException, SimpleDateFormat} import java.util.TimeZone import javax.ws.rs.WebApplicationException import javax.ws.rs.core.Response import javax.ws.rs.core.Response.Status -import scala.util.Try - private[v1] class SimpleDateParam(val originalValue: String) { - val timestamp: Long = { - SimpleDateParam.formats.collectFirst { - case fmt if Try(fmt.parse(originalValue)).isSuccess => - fmt.parse(originalValue).getTime() - }.getOrElse( - throw new WebApplicationException( - Response - .status(Status.BAD_REQUEST) - .entity("Couldn't parse date: " + originalValue) - .build() - ) - ) - } -} -private[v1] object SimpleDateParam { - - val formats: Seq[SimpleDateFormat] = { - - val gmtDay = new SimpleDateFormat("yyyy-MM-dd") - gmtDay.setTimeZone(TimeZone.getTimeZone("GMT")) - - Seq( - new SimpleDateFormat("yyyy-MM-dd'T'HH:mm:ss.SSSz"), - gmtDay - ) + val timestamp: Long = { + val format = new SimpleDateFormat("yyyy-MM-dd'T'HH:mm:ss.SSSz") + try { + format.parse(originalValue).getTime() + } catch { + case _: ParseException => + val gmtDay = new SimpleDateFormat("yyyy-MM-dd") + gmtDay.setTimeZone(TimeZone.getTimeZone("GMT")) + try { + gmtDay.parse(originalValue).getTime() + } catch { + case _: ParseException => + throw new WebApplicationException( + Response + .status(Status.BAD_REQUEST) + .entity("Couldn't parse date: " + originalValue) + .build() + ) + } + } } } diff --git a/core/src/main/scala/org/apache/spark/status/api/v1/api.scala b/core/src/main/scala/org/apache/spark/status/api/v1/api.scala index ef3c8570d818..2bec64f2ef02 100644 --- a/core/src/main/scala/org/apache/spark/status/api/v1/api.scala +++ b/core/src/main/scala/org/apache/spark/status/api/v1/api.scala @@ -134,7 +134,7 @@ class StageData private[spark]( val accumulatorUpdates: Seq[AccumulableInfo], val tasks: Option[Map[Long, TaskData]], - val executorSummary:Option[Map[String,ExecutorStageSummary]]) + val executorSummary: Option[Map[String, ExecutorStageSummary]]) class TaskData private[spark]( val taskId: Long, diff --git a/core/src/main/scala/org/apache/spark/storage/BlockManager.scala b/core/src/main/scala/org/apache/spark/storage/BlockManager.scala index cc794e5c90ff..1beafa177144 100644 --- a/core/src/main/scala/org/apache/spark/storage/BlockManager.scala +++ b/core/src/main/scala/org/apache/spark/storage/BlockManager.scala @@ -17,12 +17,11 @@ package org.apache.spark.storage -import java.io.{BufferedOutputStream, ByteArrayOutputStream, File, InputStream, OutputStream} +import java.io._ import java.nio.{ByteBuffer, MappedByteBuffer} import scala.collection.mutable.{ArrayBuffer, HashMap} -import scala.concurrent.{Await, Future} -import scala.concurrent.ExecutionContext.Implicits.global +import scala.concurrent.{ExecutionContext, Await, Future} import scala.concurrent.duration._ import scala.util.Random @@ -77,12 +76,17 @@ private[spark] class BlockManager( private val blockInfo = new TimeStampedHashMap[BlockId, BlockInfo] + private val futureExecutionContext = ExecutionContext.fromExecutorService( + ThreadUtils.newDaemonCachedThreadPool("block-manager-future", 128)) + // Actual storage of where blocks are kept private var externalBlockStoreInitialized = false private[spark] val memoryStore = new MemoryStore(this, maxMemory) private[spark] val diskStore = new DiskStore(this, diskBlockManager) - private[spark] lazy val externalBlockStore: ExternalBlockStore = + private[spark] lazy val externalBlockStore: ExternalBlockStore = { + externalBlockStoreInitialized = true new ExternalBlockStore(this, executorId) + } private[spark] val externalShuffleServiceEnabled = conf.getBoolean("spark.shuffle.service.enabled", false) @@ -266,11 +270,13 @@ private[spark] class BlockManager( asyncReregisterLock.synchronized { if (asyncReregisterTask == null) { asyncReregisterTask = Future[Unit] { + // This is a blocking action and should run in futureExecutionContext which is a cached + // thread pool reregister() asyncReregisterLock.synchronized { asyncReregisterTask = null } - } + }(futureExecutionContext) } } } @@ -485,16 +491,17 @@ private[spark] class BlockManager( if (level.useOffHeap) { logDebug(s"Getting block $blockId from ExternalBlockStore") if (externalBlockStore.contains(blockId)) { - externalBlockStore.getBytes(blockId) match { - case Some(bytes) => - if (!asBlockResult) { - return Some(bytes) - } else { - return Some(new BlockResult( - dataDeserialize(blockId, bytes), DataReadMethod.Memory, info.size)) - } + val result = if (asBlockResult) { + externalBlockStore.getValues(blockId) + .map(new BlockResult(_, DataReadMethod.Memory, info.size)) + } else { + externalBlockStore.getBytes(blockId) + } + result match { + case Some(values) => + return result case None => - logDebug(s"Block $blockId not found in externalBlockStore") + logDebug(s"Block $blockId not found in ExternalBlockStore") } } } @@ -744,7 +751,11 @@ private[spark] class BlockManager( case b: ByteBufferValues if putLevel.replication > 1 => // Duplicate doesn't copy the bytes, but just creates a wrapper val bufferView = b.buffer.duplicate() - Future { replicate(blockId, bufferView, putLevel) } + Future { + // This is a blocking action and should run in futureExecutionContext which is a cached + // thread pool + replicate(blockId, bufferView, putLevel) + }(futureExecutionContext) case _ => null } @@ -1198,8 +1209,19 @@ private[spark] class BlockManager( bytes: ByteBuffer, serializer: Serializer = defaultSerializer): Iterator[Any] = { bytes.rewind() - val stream = wrapForCompression(blockId, new ByteBufferInputStream(bytes, true)) - serializer.newInstance().deserializeStream(stream).asIterator + dataDeserializeStream(blockId, new ByteBufferInputStream(bytes, true), serializer) + } + + /** + * Deserializes a InputStream into an iterator of values and disposes of it when the end of + * the iterator is reached. + */ + def dataDeserializeStream( + blockId: BlockId, + inputStream: InputStream, + serializer: Serializer = defaultSerializer): Iterator[Any] = { + val stream = new BufferedInputStream(inputStream) + serializer.newInstance().deserializeStream(wrapForCompression(blockId, stream)).asIterator } def stop(): Unit = { @@ -1218,6 +1240,7 @@ private[spark] class BlockManager( } metadataCleaner.cancel() broadcastCleaner.cancel() + futureExecutionContext.shutdownNow() logInfo("BlockManager stopped") } } diff --git a/core/src/main/scala/org/apache/spark/storage/BlockManagerMaster.scala b/core/src/main/scala/org/apache/spark/storage/BlockManagerMaster.scala index a85e1c763297..f70f701494db 100644 --- a/core/src/main/scala/org/apache/spark/storage/BlockManagerMaster.scala +++ b/core/src/main/scala/org/apache/spark/storage/BlockManagerMaster.scala @@ -17,13 +17,14 @@ package org.apache.spark.storage +import scala.collection.Iterable +import scala.collection.generic.CanBuildFrom import scala.concurrent.{Await, Future} -import scala.concurrent.ExecutionContext.Implicits.global import org.apache.spark.rpc.RpcEndpointRef import org.apache.spark.{Logging, SparkConf, SparkException} import org.apache.spark.storage.BlockManagerMessages._ -import org.apache.spark.util.RpcUtils +import org.apache.spark.util.{ThreadUtils, RpcUtils} private[spark] class BlockManagerMaster( @@ -32,7 +33,7 @@ class BlockManagerMaster( isDriver: Boolean) extends Logging { - val timeout = RpcUtils.askTimeout(conf) + val timeout = RpcUtils.askRpcTimeout(conf) /** Remove a dead executor from the driver endpoint. This is only called on the driver side. */ def removeExecutor(execId: String) { @@ -102,10 +103,10 @@ class BlockManagerMaster( val future = driverEndpoint.askWithRetry[Future[Seq[Int]]](RemoveRdd(rddId)) future.onFailure { case e: Exception => - logWarning(s"Failed to remove RDD $rddId - ${e.getMessage}}") - } + logWarning(s"Failed to remove RDD $rddId - ${e.getMessage}}", e) + }(ThreadUtils.sameThread) if (blocking) { - Await.result(future, timeout) + timeout.awaitResult(future) } } @@ -114,10 +115,10 @@ class BlockManagerMaster( val future = driverEndpoint.askWithRetry[Future[Seq[Boolean]]](RemoveShuffle(shuffleId)) future.onFailure { case e: Exception => - logWarning(s"Failed to remove shuffle $shuffleId - ${e.getMessage}}") - } + logWarning(s"Failed to remove shuffle $shuffleId - ${e.getMessage}}", e) + }(ThreadUtils.sameThread) if (blocking) { - Await.result(future, timeout) + timeout.awaitResult(future) } } @@ -128,10 +129,10 @@ class BlockManagerMaster( future.onFailure { case e: Exception => logWarning(s"Failed to remove broadcast $broadcastId" + - s" with removeFromMaster = $removeFromMaster - ${e.getMessage}}") - } + s" with removeFromMaster = $removeFromMaster - ${e.getMessage}}", e) + }(ThreadUtils.sameThread) if (blocking) { - Await.result(future, timeout) + timeout.awaitResult(future) } } @@ -169,11 +170,17 @@ class BlockManagerMaster( val response = driverEndpoint. askWithRetry[Map[BlockManagerId, Future[Option[BlockStatus]]]](msg) val (blockManagerIds, futures) = response.unzip - val result = Await.result(Future.sequence(futures), timeout) - if (result == null) { + implicit val sameThread = ThreadUtils.sameThread + val cbf = + implicitly[ + CanBuildFrom[Iterable[Future[Option[BlockStatus]]], + Option[BlockStatus], + Iterable[Option[BlockStatus]]]] + val blockStatus = timeout.awaitResult( + Future.sequence[Option[BlockStatus], Iterable](futures)(cbf, ThreadUtils.sameThread)) + if (blockStatus == null) { throw new SparkException("BlockManager returned null for BlockStatus query: " + blockId) } - val blockStatus = result.asInstanceOf[Iterable[Option[BlockStatus]]] blockManagerIds.zip(blockStatus).flatMap { case (blockManagerId, status) => status.map { s => (blockManagerId, s) } }.toMap @@ -192,7 +199,15 @@ class BlockManagerMaster( askSlaves: Boolean): Seq[BlockId] = { val msg = GetMatchingBlockIds(filter, askSlaves) val future = driverEndpoint.askWithRetry[Future[Seq[BlockId]]](msg) - Await.result(future, timeout) + timeout.awaitResult(future) + } + + /** + * Find out if the executor has cached blocks. This method does not consider broadcast blocks, + * since they are not reported the master. + */ + def hasCachedBlocks(executorId: String): Boolean = { + driverEndpoint.askWithRetry[Boolean](HasCachedBlocks(executorId)) } /** Stop the driver endpoint, called only on the Spark driver node */ diff --git a/core/src/main/scala/org/apache/spark/storage/BlockManagerMasterEndpoint.scala b/core/src/main/scala/org/apache/spark/storage/BlockManagerMasterEndpoint.scala index 3afb4c3c02e2..68ed9096731c 100644 --- a/core/src/main/scala/org/apache/spark/storage/BlockManagerMasterEndpoint.scala +++ b/core/src/main/scala/org/apache/spark/storage/BlockManagerMasterEndpoint.scala @@ -19,6 +19,7 @@ package org.apache.spark.storage import java.util.{HashMap => JHashMap} +import scala.collection.immutable.HashSet import scala.collection.mutable import scala.collection.JavaConversions._ import scala.concurrent.{ExecutionContext, Future} @@ -112,6 +113,17 @@ class BlockManagerMasterEndpoint( case BlockManagerHeartbeat(blockManagerId) => context.reply(heartbeatReceived(blockManagerId)) + case HasCachedBlocks(executorId) => + blockManagerIdByExecutor.get(executorId) match { + case Some(bm) => + if (blockManagerInfo.contains(bm)) { + val bmInfo = blockManagerInfo(bm) + context.reply(bmInfo.cachedBlocks.nonEmpty) + } else { + context.reply(false) + } + case None => context.reply(false) + } } private def removeRdd(rddId: Int): Future[Seq[Int]] = { @@ -292,16 +304,16 @@ class BlockManagerMasterEndpoint( blockManagerIdByExecutor.get(id.executorId) match { case Some(oldId) => // A block manager of the same executor already exists, so remove it (assumed dead) - logError("Got two different block manager registrations on same executor - " + logError("Got two different block manager registrations on same executor - " + s" will replace old one $oldId with new one $id") - removeExecutor(id.executorId) + removeExecutor(id.executorId) case None => } logInfo("Registering block manager %s with %s RAM, %s".format( id.hostPort, Utils.bytesToString(maxMemSize), id)) - + blockManagerIdByExecutor(id.executorId) = id - + blockManagerInfo(id) = new BlockManagerInfo( id, System.currentTimeMillis(), maxMemSize, slaveEndpoint) } @@ -418,6 +430,9 @@ private[spark] class BlockManagerInfo( // Mapping from block id to its status. private val _blocks = new JHashMap[BlockId, BlockStatus] + // Cached blocks held by this BlockManager. This does not include broadcast blocks. + private val _cachedBlocks = new mutable.HashSet[BlockId] + def getStatus(blockId: BlockId): Option[BlockStatus] = Option(_blocks.get(blockId)) def updateLastSeenMs() { @@ -451,27 +466,35 @@ private[spark] class BlockManagerInfo( * and the diskSize here indicates the data size in or dropped to disk. * They can be both larger than 0, when a block is dropped from memory to disk. * Therefore, a safe way to set BlockStatus is to set its info in accurate modes. */ + var blockStatus: BlockStatus = null if (storageLevel.useMemory) { - _blocks.put(blockId, BlockStatus(storageLevel, memSize, 0, 0)) + blockStatus = BlockStatus(storageLevel, memSize, 0, 0) + _blocks.put(blockId, blockStatus) _remainingMem -= memSize logInfo("Added %s in memory on %s (size: %s, free: %s)".format( blockId, blockManagerId.hostPort, Utils.bytesToString(memSize), Utils.bytesToString(_remainingMem))) } if (storageLevel.useDisk) { - _blocks.put(blockId, BlockStatus(storageLevel, 0, diskSize, 0)) + blockStatus = BlockStatus(storageLevel, 0, diskSize, 0) + _blocks.put(blockId, blockStatus) logInfo("Added %s on disk on %s (size: %s)".format( blockId, blockManagerId.hostPort, Utils.bytesToString(diskSize))) } if (storageLevel.useOffHeap) { - _blocks.put(blockId, BlockStatus(storageLevel, 0, 0, externalBlockStoreSize)) + blockStatus = BlockStatus(storageLevel, 0, 0, externalBlockStoreSize) + _blocks.put(blockId, blockStatus) logInfo("Added %s on ExternalBlockStore on %s (size: %s)".format( blockId, blockManagerId.hostPort, Utils.bytesToString(externalBlockStoreSize))) } + if (!blockId.isBroadcast && blockStatus.isCached) { + _cachedBlocks += blockId + } } else if (_blocks.containsKey(blockId)) { // If isValid is not true, drop the block. val blockStatus: BlockStatus = _blocks.get(blockId) _blocks.remove(blockId) + _cachedBlocks -= blockId if (blockStatus.storageLevel.useMemory) { logInfo("Removed %s on %s in memory (size: %s, free: %s)".format( blockId, blockManagerId.hostPort, Utils.bytesToString(blockStatus.memSize), @@ -494,6 +517,7 @@ private[spark] class BlockManagerInfo( _remainingMem += _blocks.get(blockId).memSize _blocks.remove(blockId) } + _cachedBlocks -= blockId } def remainingMem: Long = _remainingMem @@ -502,6 +526,9 @@ private[spark] class BlockManagerInfo( def blocks: JHashMap[BlockId, BlockStatus] = _blocks + // This does not include broadcast blocks. + def cachedBlocks: collection.Set[BlockId] = _cachedBlocks + override def toString: String = "BlockManagerInfo " + timeMs + " " + _remainingMem def clear() { diff --git a/core/src/main/scala/org/apache/spark/storage/BlockManagerMessages.scala b/core/src/main/scala/org/apache/spark/storage/BlockManagerMessages.scala index 1683576067fe..376e9eb48843 100644 --- a/core/src/main/scala/org/apache/spark/storage/BlockManagerMessages.scala +++ b/core/src/main/scala/org/apache/spark/storage/BlockManagerMessages.scala @@ -42,7 +42,6 @@ private[spark] object BlockManagerMessages { case class RemoveBroadcast(broadcastId: Long, removeFromDriver: Boolean = true) extends ToBlockManagerSlave - ////////////////////////////////////////////////////////////////////////////////// // Messages from slaves to the master. ////////////////////////////////////////////////////////////////////////////////// @@ -108,4 +107,6 @@ private[spark] object BlockManagerMessages { extends ToBlockManagerMaster case class BlockManagerHeartbeat(blockManagerId: BlockManagerId) extends ToBlockManagerMaster + + case class HasCachedBlocks(executorId: String) extends ToBlockManagerMaster } diff --git a/core/src/main/scala/org/apache/spark/storage/BlockManagerSlaveEndpoint.scala b/core/src/main/scala/org/apache/spark/storage/BlockManagerSlaveEndpoint.scala index 543df4e1350d..7478ab0fc2f7 100644 --- a/core/src/main/scala/org/apache/spark/storage/BlockManagerSlaveEndpoint.scala +++ b/core/src/main/scala/org/apache/spark/storage/BlockManagerSlaveEndpoint.scala @@ -40,7 +40,7 @@ class BlockManagerSlaveEndpoint( private implicit val asyncExecutionContext = ExecutionContext.fromExecutorService(asyncThreadPool) // Operations that involve removing blocks may be slow and should be done asynchronously - override def receiveAndReply(context: RpcCallContext): PartialFunction[Any, Unit] = { + override def receiveAndReply(context: RpcCallContext): PartialFunction[Any, Unit] = { case RemoveBlock(blockId) => doAsync[Boolean]("removing block " + blockId, context) { blockManager.removeBlock(blockId) diff --git a/core/src/main/scala/org/apache/spark/storage/BlockManagerSource.scala b/core/src/main/scala/org/apache/spark/storage/BlockManagerSource.scala index 8569c6f3cbbc..c5ba9af3e265 100644 --- a/core/src/main/scala/org/apache/spark/storage/BlockManagerSource.scala +++ b/core/src/main/scala/org/apache/spark/storage/BlockManagerSource.scala @@ -17,9 +17,8 @@ package org.apache.spark.storage -import com.codahale.metrics.{Gauge,MetricRegistry} +import com.codahale.metrics.{Gauge, MetricRegistry} -import org.apache.spark.SparkContext import org.apache.spark.metrics.source.Source private[spark] class BlockManagerSource(val blockManager: BlockManager) diff --git a/core/src/main/scala/org/apache/spark/storage/BlockObjectWriter.scala b/core/src/main/scala/org/apache/spark/storage/BlockObjectWriter.scala index a33f22ef5268..7eeabd1e0489 100644 --- a/core/src/main/scala/org/apache/spark/storage/BlockObjectWriter.scala +++ b/core/src/main/scala/org/apache/spark/storage/BlockObjectWriter.scala @@ -95,6 +95,7 @@ private[spark] class DiskBlockObjectWriter( private var objOut: SerializationStream = null private var initialized = false private var hasBeenClosed = false + private var commitAndCloseHasBeenCalled = false /** * Cursors used to represent positions in the file. @@ -167,20 +168,22 @@ private[spark] class DiskBlockObjectWriter( objOut.flush() bs.flush() close() + finalPosition = file.length() + // In certain compression codecs, more bytes are written after close() is called + writeMetrics.incShuffleBytesWritten(finalPosition - reportedPosition) + } else { + finalPosition = file.length() } - finalPosition = file.length() - // In certain compression codecs, more bytes are written after close() is called - writeMetrics.incShuffleBytesWritten(finalPosition - reportedPosition) + commitAndCloseHasBeenCalled = true } // Discard current writes. We do this by flushing the outstanding writes and then // truncating the file to its initial position. override def revertPartialWritesAndClose() { try { - writeMetrics.decShuffleBytesWritten(reportedPosition - initialPosition) - writeMetrics.decShuffleRecordsWritten(numRecordsWritten) - if (initialized) { + writeMetrics.decShuffleBytesWritten(reportedPosition - initialPosition) + writeMetrics.decShuffleRecordsWritten(numRecordsWritten) objOut.flush() bs.flush() close() @@ -228,6 +231,10 @@ private[spark] class DiskBlockObjectWriter( } override def fileSegment(): FileSegment = { + if (!commitAndCloseHasBeenCalled) { + throw new IllegalStateException( + "fileSegment() is only valid after commitAndClose() has been called") + } new FileSegment(file, initialPosition, finalPosition - initialPosition) } diff --git a/core/src/main/scala/org/apache/spark/storage/DiskBlockManager.scala b/core/src/main/scala/org/apache/spark/storage/DiskBlockManager.scala index 2a4447705fa6..91ef86389a0c 100644 --- a/core/src/main/scala/org/apache/spark/storage/DiskBlockManager.scala +++ b/core/src/main/scala/org/apache/spark/storage/DiskBlockManager.scala @@ -139,8 +139,8 @@ private[spark] class DiskBlockManager(blockManager: BlockManager, conf: SparkCon } private def addShutdownHook(): AnyRef = { - Utils.addShutdownHook { () => - logDebug("Shutdown hook called") + Utils.addShutdownHook(Utils.TEMP_DIR_SHUTDOWN_PRIORITY + 1) { () => + logInfo("Shutdown hook called") DiskBlockManager.this.doStop() } } @@ -151,7 +151,7 @@ private[spark] class DiskBlockManager(blockManager: BlockManager, conf: SparkCon try { Utils.removeShutdownHook(shutdownHook) } catch { - case e: Exception => + case e: Exception => logError(s"Exception while removing shutdown hook.", e) } doStop() diff --git a/core/src/main/scala/org/apache/spark/storage/ExternalBlockManager.scala b/core/src/main/scala/org/apache/spark/storage/ExternalBlockManager.scala index 8964762df6af..f39325a12d24 100644 --- a/core/src/main/scala/org/apache/spark/storage/ExternalBlockManager.scala +++ b/core/src/main/scala/org/apache/spark/storage/ExternalBlockManager.scala @@ -32,6 +32,8 @@ import java.nio.ByteBuffer */ private[spark] abstract class ExternalBlockManager { + protected var blockManager: BlockManager = _ + override def toString: String = {"External Block Store"} /** @@ -41,7 +43,9 @@ private[spark] abstract class ExternalBlockManager { * * @throws java.io.IOException if there is any file system failure during the initialization. */ - def init(blockManager: BlockManager, executorId: String): Unit + def init(blockManager: BlockManager, executorId: String): Unit = { + this.blockManager = blockManager + } /** * Drop the block from underlying external block store, if it exists.. @@ -73,6 +77,11 @@ private[spark] abstract class ExternalBlockManager { */ def putBytes(blockId: BlockId, bytes: ByteBuffer): Unit + def putValues(blockId: BlockId, values: Iterator[_]): Unit = { + val bytes = blockManager.dataSerialize(blockId, values) + putBytes(blockId, bytes) + } + /** * Retrieve the block bytes. * @return Some(ByteBuffer) if the block bytes is successfully retrieved @@ -82,6 +91,17 @@ private[spark] abstract class ExternalBlockManager { */ def getBytes(blockId: BlockId): Option[ByteBuffer] + /** + * Retrieve the block data. + * @return Some(Iterator[Any]) if the block data is successfully retrieved + * None if the block does not exist in the external block store. + * + * @throws java.io.IOException if there is any file system failure in getting the block. + */ + def getValues(blockId: BlockId): Option[Iterator[_]] = { + getBytes(blockId).map(buffer => blockManager.dataDeserialize(blockId, buffer)) + } + /** * Get the size of the block saved in the underlying external block store, * which is saved before by putBytes. diff --git a/core/src/main/scala/org/apache/spark/storage/ExternalBlockStore.scala b/core/src/main/scala/org/apache/spark/storage/ExternalBlockStore.scala index 0bf770306ae9..291394ed3481 100644 --- a/core/src/main/scala/org/apache/spark/storage/ExternalBlockStore.scala +++ b/core/src/main/scala/org/apache/spark/storage/ExternalBlockStore.scala @@ -18,9 +18,11 @@ package org.apache.spark.storage import java.nio.ByteBuffer + +import scala.util.control.NonFatal + import org.apache.spark.Logging import org.apache.spark.util.Utils -import scala.util.control.NonFatal /** @@ -40,7 +42,7 @@ private[spark] class ExternalBlockStore(blockManager: BlockManager, executorId: externalBlockManager.map(_.getSize(blockId)).getOrElse(0) } catch { case NonFatal(t) => - logError(s"error in getSize from $blockId", t) + logError(s"Error in getSize($blockId)", t) 0L } } @@ -54,7 +56,7 @@ private[spark] class ExternalBlockStore(blockManager: BlockManager, executorId: values: Array[Any], level: StorageLevel, returnValues: Boolean): PutResult = { - putIterator(blockId, values.toIterator, level, returnValues) + putIntoExternalBlockStore(blockId, values.toIterator, returnValues) } override def putIterator( @@ -62,42 +64,70 @@ private[spark] class ExternalBlockStore(blockManager: BlockManager, executorId: values: Iterator[Any], level: StorageLevel, returnValues: Boolean): PutResult = { - logDebug(s"Attempting to write values for block $blockId") - val bytes = blockManager.dataSerialize(blockId, values) - putIntoExternalBlockStore(blockId, bytes, returnValues) + putIntoExternalBlockStore(blockId, values, returnValues) } private def putIntoExternalBlockStore( blockId: BlockId, - bytes: ByteBuffer, + values: Iterator[_], returnValues: Boolean): PutResult = { - // So that we do not modify the input offsets ! - // duplicate does not copy buffer, so inexpensive - val byteBuffer = bytes.duplicate() - byteBuffer.rewind() - logDebug(s"Attempting to put block $blockId into ExtBlk store") + logTrace(s"Attempting to put block $blockId into ExternalBlockStore") // we should never hit here if externalBlockManager is None. Handle it anyway for safety. try { val startTime = System.currentTimeMillis if (externalBlockManager.isDefined) { - externalBlockManager.get.putBytes(blockId, bytes) + externalBlockManager.get.putValues(blockId, values) + val size = getSize(blockId) + val data = if (returnValues) { + Left(getValues(blockId).get) + } else { + null + } val finishTime = System.currentTimeMillis logDebug("Block %s stored as %s file in ExternalBlockStore in %d ms".format( - blockId, Utils.bytesToString(byteBuffer.limit), finishTime - startTime)) + blockId, Utils.bytesToString(size), finishTime - startTime)) + PutResult(size, data) + } else { + logError(s"Error in putValues($blockId): no ExternalBlockManager has been configured") + PutResult(-1, null, Seq((blockId, BlockStatus.empty))) + } + } catch { + case NonFatal(t) => + logError(s"Error in putValues($blockId)", t) + PutResult(-1, null, Seq((blockId, BlockStatus.empty))) + } + } - if (returnValues) { - PutResult(bytes.limit(), Right(bytes.duplicate())) + private def putIntoExternalBlockStore( + blockId: BlockId, + bytes: ByteBuffer, + returnValues: Boolean): PutResult = { + logTrace(s"Attempting to put block $blockId into ExternalBlockStore") + // we should never hit here if externalBlockManager is None. Handle it anyway for safety. + try { + val startTime = System.currentTimeMillis + if (externalBlockManager.isDefined) { + val byteBuffer = bytes.duplicate() + byteBuffer.rewind() + externalBlockManager.get.putBytes(blockId, byteBuffer) + val size = bytes.limit() + val data = if (returnValues) { + Right(bytes) } else { - PutResult(bytes.limit(), null) + null } + val finishTime = System.currentTimeMillis + logDebug("Block %s stored as %s file in ExternalBlockStore in %d ms".format( + blockId, Utils.bytesToString(size), finishTime - startTime)) + PutResult(size, data) } else { - logError(s"error in putBytes $blockId") - PutResult(bytes.limit(), null, Seq((blockId, BlockStatus.empty))) + logError(s"Error in putBytes($blockId): no ExternalBlockManager has been configured") + PutResult(-1, null, Seq((blockId, BlockStatus.empty))) } } catch { case NonFatal(t) => - logError(s"error in putBytes $blockId", t) - PutResult(bytes.limit(), null, Seq((blockId, BlockStatus.empty))) + logError(s"Error in putBytes($blockId)", t) + PutResult(-1, null, Seq((blockId, BlockStatus.empty))) } } @@ -107,13 +137,19 @@ private[spark] class ExternalBlockStore(blockManager: BlockManager, executorId: externalBlockManager.map(_.removeBlock(blockId)).getOrElse(true) } catch { case NonFatal(t) => - logError(s"error in removing $blockId", t) + logError(s"Error in removeBlock($blockId)", t) true } } override def getValues(blockId: BlockId): Option[Iterator[Any]] = { - getBytes(blockId).map(buffer => blockManager.dataDeserialize(blockId, buffer)) + try { + externalBlockManager.flatMap(_.getValues(blockId)) + } catch { + case NonFatal(t) => + logError(s"Error in getValues($blockId)", t) + None + } } override def getBytes(blockId: BlockId): Option[ByteBuffer] = { @@ -121,7 +157,7 @@ private[spark] class ExternalBlockStore(blockManager: BlockManager, executorId: externalBlockManager.flatMap(_.getBytes(blockId)) } catch { case NonFatal(t) => - logError(s"error in getBytes from $blockId", t) + logError(s"Error in getBytes($blockId)", t) None } } @@ -130,13 +166,13 @@ private[spark] class ExternalBlockStore(blockManager: BlockManager, executorId: try { val ret = externalBlockManager.map(_.blockExists(blockId)).getOrElse(false) if (!ret) { - logInfo(s"remove block $blockId") + logInfo(s"Remove block $blockId") blockManager.removeBlock(blockId, true) } ret } catch { case NonFatal(t) => - logError(s"error in getBytes from $blockId", t) + logError(s"Error in getBytes($blockId)", t) false } } diff --git a/core/src/main/scala/org/apache/spark/storage/FileSegment.scala b/core/src/main/scala/org/apache/spark/storage/FileSegment.scala index 95e2d688d9b1..021a9facfb0b 100644 --- a/core/src/main/scala/org/apache/spark/storage/FileSegment.scala +++ b/core/src/main/scala/org/apache/spark/storage/FileSegment.scala @@ -24,6 +24,8 @@ import java.io.File * based off an offset and a length. */ private[spark] class FileSegment(val file: File, val offset: Long, val length: Long) { + require(offset >= 0, s"File segment offset cannot be negative (got $offset)") + require(length >= 0, s"File segment length cannot be negative (got $length)") override def toString: String = { "(name=%s, offset=%d, length=%d)".format(file.getName, offset, length) } diff --git a/core/src/main/scala/org/apache/spark/storage/ShuffleBlockFetcherIterator.scala b/core/src/main/scala/org/apache/spark/storage/ShuffleBlockFetcherIterator.scala index d0faab62c9e9..e49e39679e94 100644 --- a/core/src/main/scala/org/apache/spark/storage/ShuffleBlockFetcherIterator.scala +++ b/core/src/main/scala/org/apache/spark/storage/ShuffleBlockFetcherIterator.scala @@ -17,23 +17,23 @@ package org.apache.spark.storage +import java.io.InputStream import java.util.concurrent.LinkedBlockingQueue import scala.collection.mutable.{ArrayBuffer, HashSet, Queue} import scala.util.{Failure, Try} import org.apache.spark.{Logging, TaskContext} -import org.apache.spark.network.shuffle.{BlockFetchingListener, ShuffleClient} import org.apache.spark.network.buffer.ManagedBuffer -import org.apache.spark.serializer.{SerializerInstance, Serializer} -import org.apache.spark.util.{CompletionIterator, Utils} +import org.apache.spark.network.shuffle.{BlockFetchingListener, ShuffleClient} +import org.apache.spark.util.Utils /** * An iterator that fetches multiple blocks. For local blocks, it fetches from the local block * manager. For remote blocks, it fetches them using the provided BlockTransferService. * - * This creates an iterator of (BlockID, values) tuples so the caller can handle blocks in a - * pipelined fashion as they are received. + * This creates an iterator of (BlockID, Try[InputStream]) tuples so the caller can handle blocks + * in a pipelined fashion as they are received. * * The implementation throttles the remote fetches to they don't exceed maxBytesInFlight to avoid * using too much memory. @@ -44,7 +44,6 @@ import org.apache.spark.util.{CompletionIterator, Utils} * @param blocksByAddress list of blocks to fetch grouped by the [[BlockManagerId]]. * For each block we also require the size (in bytes as a long field) in * order to throttle the memory usage. - * @param serializer serializer used to deserialize the data. * @param maxBytesInFlight max size (in bytes) of remote blocks to fetch at any given point. */ private[spark] @@ -53,9 +52,8 @@ final class ShuffleBlockFetcherIterator( shuffleClient: ShuffleClient, blockManager: BlockManager, blocksByAddress: Seq[(BlockManagerId, Seq[(BlockId, Long)])], - serializer: Serializer, maxBytesInFlight: Long) - extends Iterator[(BlockId, Try[Iterator[Any]])] with Logging { + extends Iterator[(BlockId, Try[InputStream])] with Logging { import ShuffleBlockFetcherIterator._ @@ -83,7 +81,7 @@ final class ShuffleBlockFetcherIterator( /** * A queue to hold our results. This turns the asynchronous model provided by - * [[BlockTransferService]] into a synchronous model (iterator). + * [[org.apache.spark.network.BlockTransferService]] into a synchronous model (iterator). */ private[this] val results = new LinkedBlockingQueue[FetchResult] @@ -102,9 +100,7 @@ final class ShuffleBlockFetcherIterator( /** Current bytes in flight from our requests */ private[this] var bytesInFlight = 0L - private[this] val shuffleMetrics = context.taskMetrics.createShuffleReadMetricsForDependency() - - private[this] val serializerInstance: SerializerInstance = serializer.newInstance() + private[this] val shuffleMetrics = context.taskMetrics().createShuffleReadMetricsForDependency() /** * Whether the iterator is still active. If isZombie is true, the callback interface will no @@ -114,17 +110,23 @@ final class ShuffleBlockFetcherIterator( initialize() - /** - * Mark the iterator as zombie, and release all buffers that haven't been deserialized yet. - */ - private[this] def cleanup() { - isZombie = true + // Decrements the buffer reference count. + // The currentResult is set to null to prevent releasing the buffer again on cleanup() + private[storage] def releaseCurrentResultBuffer(): Unit = { // Release the current buffer if necessary currentResult match { case SuccessFetchResult(_, _, buf) => buf.release() case _ => } + currentResult = null + } + /** + * Mark the iterator as zombie, and release all buffers that haven't been deserialized yet. + */ + private[this] def cleanup() { + isZombie = true + releaseCurrentResultBuffer() // Release buffers in the results queue val iter = results.iterator() while (iter.hasNext) { @@ -272,7 +274,13 @@ final class ShuffleBlockFetcherIterator( override def hasNext: Boolean = numBlocksProcessed < numBlocksToFetch - override def next(): (BlockId, Try[Iterator[Any]]) = { + /** + * Fetches the next (BlockId, Try[InputStream]). If a task fails, the ManagedBuffers + * underlying each InputStream will be freed by the cleanup() method registered with the + * TaskCompletionListener. However, callers should close() these InputStreams + * as soon as they are no longer needed, in order to release memory as early as possible. + */ + override def next(): (BlockId, Try[InputStream]) = { numBlocksProcessed += 1 val startFetchWait = System.currentTimeMillis() currentResult = results.take() @@ -290,22 +298,15 @@ final class ShuffleBlockFetcherIterator( sendRequest(fetchRequests.dequeue()) } - val iteratorTry: Try[Iterator[Any]] = result match { + val iteratorTry: Try[InputStream] = result match { case FailureFetchResult(_, e) => Failure(e) case SuccessFetchResult(blockId, _, buf) => // There is a chance that createInputStream can fail (e.g. fetching a local file that does // not exist, SPARK-4085). In that case, we should propagate the right exception so // the scheduler gets a FetchFailedException. - Try(buf.createInputStream()).map { is0 => - val is = blockManager.wrapForCompression(blockId, is0) - val iter = serializerInstance.deserializeStream(is).asKeyValueIterator - CompletionIterator[Any, Iterator[Any]](iter, { - // Once the iterator is exhausted, release the buffer and set currentResult to null - // so we don't release it again in cleanup. - currentResult = null - buf.release() - }) + Try(buf.createInputStream()).map { inputStream => + new BufferReleasingInputStream(inputStream, this) } } @@ -313,6 +314,39 @@ final class ShuffleBlockFetcherIterator( } } +/** + * Helper class that ensures a ManagedBuffer is release upon InputStream.close() + */ +private class BufferReleasingInputStream( + private val delegate: InputStream, + private val iterator: ShuffleBlockFetcherIterator) + extends InputStream { + private[this] var closed = false + + override def read(): Int = delegate.read() + + override def close(): Unit = { + if (!closed) { + delegate.close() + iterator.releaseCurrentResultBuffer() + closed = true + } + } + + override def available(): Int = delegate.available() + + override def mark(readlimit: Int): Unit = delegate.mark(readlimit) + + override def skip(n: Long): Long = delegate.skip(n) + + override def markSupported(): Boolean = delegate.markSupported() + + override def read(b: Array[Byte]): Int = delegate.read(b) + + override def read(b: Array[Byte], off: Int, len: Int): Int = delegate.read(b, off, len) + + override def reset(): Unit = delegate.reset() +} private[storage] object ShuffleBlockFetcherIterator { diff --git a/core/src/main/scala/org/apache/spark/storage/TachyonBlockManager.scala b/core/src/main/scala/org/apache/spark/storage/TachyonBlockManager.scala index bdc6276e4191..b53c86e89a27 100644 --- a/core/src/main/scala/org/apache/spark/storage/TachyonBlockManager.scala +++ b/core/src/main/scala/org/apache/spark/storage/TachyonBlockManager.scala @@ -22,7 +22,10 @@ import java.nio.ByteBuffer import java.text.SimpleDateFormat import java.util.{Date, Random} +import scala.util.control.NonFatal + import com.google.common.io.ByteStreams + import tachyon.client.{ReadType, WriteType, TachyonFS, TachyonFile} import tachyon.TachyonURI @@ -38,7 +41,6 @@ import org.apache.spark.util.Utils */ private[spark] class TachyonBlockManager() extends ExternalBlockManager with Logging { - var blockManager: BlockManager =_ var rootDirs: String = _ var master: String = _ var client: tachyon.client.TachyonFS = _ @@ -52,7 +54,7 @@ private[spark] class TachyonBlockManager() extends ExternalBlockManager with Log override def init(blockManager: BlockManager, executorId: String): Unit = { - this.blockManager = blockManager + super.init(blockManager, executorId) val storeDir = blockManager.conf.get(ExternalBlockStore.BASE_DIR, "/tmp_spark_tachyon") val appFolderName = blockManager.conf.get(ExternalBlockStore.FOLD_NAME) @@ -95,8 +97,29 @@ private[spark] class TachyonBlockManager() extends ExternalBlockManager with Log override def putBytes(blockId: BlockId, bytes: ByteBuffer): Unit = { val file = getFile(blockId) val os = file.getOutStream(WriteType.TRY_CACHE) - os.write(bytes.array()) - os.close() + try { + os.write(bytes.array()) + } catch { + case NonFatal(e) => + logWarning(s"Failed to put bytes of block $blockId into Tachyon", e) + os.cancel() + } finally { + os.close() + } + } + + override def putValues(blockId: BlockId, values: Iterator[_]): Unit = { + val file = getFile(blockId) + val os = file.getOutStream(WriteType.TRY_CACHE) + try { + blockManager.dataSerializeStream(blockId, os, values) + } catch { + case NonFatal(e) => + logWarning(s"Failed to put values of block $blockId into Tachyon", e) + os.cancel() + } finally { + os.close() + } } override def getBytes(blockId: BlockId): Option[ByteBuffer] = { @@ -105,21 +128,31 @@ private[spark] class TachyonBlockManager() extends ExternalBlockManager with Log return None } val is = file.getInStream(ReadType.CACHE) - assert (is != null) try { val size = file.length val bs = new Array[Byte](size.asInstanceOf[Int]) ByteStreams.readFully(is, bs) Some(ByteBuffer.wrap(bs)) } catch { - case ioe: IOException => - logWarning(s"Failed to fetch the block $blockId from Tachyon", ioe) + case NonFatal(e) => + logWarning(s"Failed to get bytes of block $blockId from Tachyon", e) None } finally { is.close() } } + override def getValues(blockId: BlockId): Option[Iterator[_]] = { + val file = getFile(blockId) + if (file == null || file.getLocationHosts().size() == 0) { + return None + } + val is = file.getInStream(ReadType.CACHE) + Option(is).map { is => + blockManager.dataDeserializeStream(blockId, is) + } + } + override def getSize(blockId: BlockId): Long = { getFile(blockId.name).length } @@ -184,7 +217,7 @@ private[spark] class TachyonBlockManager() extends ExternalBlockManager with Log tachyonDir = client.getFile(path) } } catch { - case e: Exception => + case NonFatal(e) => logWarning("Attempt " + tries + " to create tachyon dir " + tachyonDir + " failed", e) } } @@ -206,7 +239,7 @@ private[spark] class TachyonBlockManager() extends ExternalBlockManager with Log Utils.deleteRecursively(tachyonDir, client) } } catch { - case e: Exception => + case NonFatal(e) => logError("Exception while deleting tachyon spark dir: " + tachyonDir, e) } } diff --git a/core/src/main/scala/org/apache/spark/ui/SparkUI.scala b/core/src/main/scala/org/apache/spark/ui/SparkUI.scala index bfe4a180e8a6..3788916cf39b 100644 --- a/core/src/main/scala/org/apache/spark/ui/SparkUI.scala +++ b/core/src/main/scala/org/apache/spark/ui/SparkUI.scala @@ -19,7 +19,8 @@ package org.apache.spark.ui import java.util.Date -import org.apache.spark.status.api.v1.{ApplicationAttemptInfo, ApplicationInfo, JsonRootResource, UIRoot} +import org.apache.spark.status.api.v1.{ApiRootResource, ApplicationAttemptInfo, ApplicationInfo, + UIRoot} import org.apache.spark.{Logging, SecurityManager, SparkConf, SparkContext} import org.apache.spark.scheduler._ import org.apache.spark.storage.StorageStatusListener @@ -64,7 +65,7 @@ private[spark] class SparkUI private ( attachTab(new ExecutorsTab(this)) attachHandler(createStaticHandler(SparkUI.STATIC_RESOURCE_DIR, "/static")) attachHandler(createRedirectHandler("/", "/jobs", basePath = basePath)) - attachHandler(JsonRootResource.getJsonServlet(this)) + attachHandler(ApiRootResource.getServletHandler(this)) // This should be POST only, but, the YARN AM proxy won't proxy POSTs attachHandler(createRedirectHandler( "/stages/stage/kill", "/stages", stagesTab.handleKillRequest, @@ -136,7 +137,7 @@ private[spark] object SparkUI { jobProgressListener: JobProgressListener, securityManager: SecurityManager, appName: String, - startTime: Long): SparkUI = { + startTime: Long): SparkUI = { create(Some(sc), conf, listenerBus, securityManager, appName, jobProgressListener = Some(jobProgressListener), startTime = startTime) } diff --git a/core/src/main/scala/org/apache/spark/ui/ToolTips.scala b/core/src/main/scala/org/apache/spark/ui/ToolTips.scala index 063e2a1f8b18..e2d25e36365f 100644 --- a/core/src/main/scala/org/apache/spark/ui/ToolTips.scala +++ b/core/src/main/scala/org/apache/spark/ui/ToolTips.scala @@ -35,6 +35,10 @@ private[spark] object ToolTips { val OUTPUT = "Bytes and records written to Hadoop." + val STORAGE_MEMORY = + "Memory used / total available memory for storage of data " + + "like RDD partitions cached in memory. " + val SHUFFLE_WRITE = "Bytes and records written to disk in order to be read by a shuffle in a future stage." diff --git a/core/src/main/scala/org/apache/spark/ui/UIUtils.scala b/core/src/main/scala/org/apache/spark/ui/UIUtils.scala index ad16becde85d..789803951920 100644 --- a/core/src/main/scala/org/apache/spark/ui/UIUtils.scala +++ b/core/src/main/scala/org/apache/spark/ui/UIUtils.scala @@ -309,7 +309,7 @@ private[spark] object UIUtils extends Logging { started: Int, completed: Int, failed: Int, - skipped:Int, + skipped: Int, total: Int): Seq[Node] = { val completeWidth = "width: %s%%".format((completed.toDouble/total)*100) val startWidth = "width: %s%%".format((started.toDouble/total)*100) @@ -352,15 +352,17 @@ private[spark] object UIUtils extends Logging {
    -
    + @@ -110,6 +114,10 @@ public class KMeansExample { // Evaluate clustering by computing Within Set Sum of Squared Errors double WSSSE = clusters.computeCost(parsedData.rdd()); System.out.println("Within Set Sum of Squared Errors = " + WSSSE); + + // Save and load model + clusters.save(sc.sc(), "myModelPath"); + KMeansModel sameModel = KMeansModel.load(sc.sc(), "myModelPath"); } } {% endhighlight %} @@ -124,7 +132,7 @@ Within Set Sum of Squared Error (WSSSE). You can reduce this error measure by in fact the optimal *k* is usually one where there is an "elbow" in the WSSSE graph. {% highlight python %} -from pyspark.mllib.clustering import KMeans +from pyspark.mllib.clustering import KMeans, KMeansModel from numpy import array from math import sqrt @@ -143,6 +151,10 @@ def error(point): WSSSE = parsedData.map(lambda point: error(point)).reduce(lambda x, y: x + y) print("Within Set Sum of Squared Error = " + str(WSSSE)) + +# Save and load model +clusters.save(sc, "myModelPath") +sameModel = KMeansModel.load(sc, "myModelPath") {% endhighlight %}
    @@ -237,11 +249,11 @@ public class GaussianMixtureExample { GaussianMixtureModel gmm = new GaussianMixture().setK(2).run(parsedData.rdd()); // Save and load GaussianMixtureModel - gmm.save(sc, "myGMMModel") - GaussianMixtureModel sameModel = GaussianMixtureModel.load(sc, "myGMMModel") + gmm.save(sc.sc(), "myGMMModel"); + GaussianMixtureModel sameModel = GaussianMixtureModel.load(sc.sc(), "myGMMModel"); // Output the parameters of the mixture model for(int j=0; j + val parts = line.split(' ') + (parts(0).toLong, parts(1).toLong, parts(2).toDouble) +} -val pic = new PowerIteartionClustering() - .setK(3) - .setMaxIterations(20) +// Cluster the data into two classes using PowerIterationClustering +val pic = new PowerIterationClustering() + .setK(2) + .setMaxIterations(10) val model = pic.run(similarities) model.assignments.foreach { a => println(s"${a.id} -> ${a.cluster}") } + +// Save and load model +model.save(sc, "myModelPath") +val sameModel = PowerIterationClusteringModel.load(sc, "myModelPath") {% endhighlight %} A full example that produces the experiment described in the PIC paper can be found under @@ -347,11 +369,22 @@ import scala.Tuple2; import scala.Tuple3; import org.apache.spark.api.java.JavaRDD; +import org.apache.spark.api.java.function.Function; import org.apache.spark.mllib.clustering.PowerIterationClustering; import org.apache.spark.mllib.clustering.PowerIterationClusteringModel; -JavaRDD> similarities = ... +// Load and parse the data +JavaRDD data = sc.textFile("data/mllib/pic_data.txt"); +JavaRDD> similarities = data.map( + new Function>() { + public Tuple3 call(String line) { + String[] parts = line.split(" "); + return new Tuple3<>(new Long(parts[0]), new Long(parts[1]), new Double(parts[2])); + } + } +); +// Cluster the data into two classes using PowerIterationClustering PowerIterationClustering pic = new PowerIterationClustering() .setK(2) .setMaxIterations(10); @@ -360,6 +393,39 @@ PowerIterationClusteringModel model = pic.run(similarities); for (PowerIterationClustering.Assignment a: model.assignments().toJavaRDD().collect()) { System.out.println(a.id() + " -> " + a.cluster()); } + +// Save and load model +model.save(sc.sc(), "myModelPath"); +PowerIterationClusteringModel sameModel = PowerIterationClusteringModel.load(sc.sc(), "myModelPath"); +{% endhighlight %} +
    + +
    + +[`PowerIterationClustering`](api/python/pyspark.mllib.html#pyspark.mllib.clustering.PowerIterationClustering) +implements the PIC algorithm. +It takes an `RDD` of `(srcId: Long, dstId: Long, similarity: Double)` tuples representing the +affinity matrix. +Calling `PowerIterationClustering.run` returns a +[`PowerIterationClusteringModel`](api/python/pyspark.mllib.html#pyspark.mllib.clustering.PowerIterationClustering), +which contains the computed clustering assignments. + +{% highlight python %} +from __future__ import print_function +from pyspark.mllib.clustering import PowerIterationClustering, PowerIterationClusteringModel + +# Load and parse the data +data = sc.textFile("data/mllib/pic_data.txt") +similarities = data.map(lambda line: tuple([float(x) for x in line.split(' ')])) + +# Cluster the data into two classes using PowerIterationClustering +model = PowerIterationClustering.train(similarities, 2, 10) + +model.assignments().foreach(lambda x: print(str(x.id) + " -> " + str(x.cluster))) + +# Save and load model +model.save(sc, "myModelPath") +sameModel = PowerIterationClusteringModel.load(sc, "myModelPath") {% endhighlight %}
    @@ -573,6 +639,50 @@ ssc.start() ssc.awaitTermination() {% endhighlight %} +
    + +
    +First we import the neccessary classes. + +{% highlight python %} +from pyspark.mllib.linalg import Vectors +from pyspark.mllib.regression import LabeledPoint +from pyspark.mllib.clustering import StreamingKMeans +{% endhighlight %} + +Then we make an input stream of vectors for training, as well as a stream of labeled data +points for testing. We assume a StreamingContext `ssc` has been created, see +[Spark Streaming Programming Guide](streaming-programming-guide.html#initializing) for more info. + +{% highlight python %} +def parse(lp): + label = float(lp[lp.find('(') + 1: lp.find(',')]) + vec = Vectors.dense(lp[lp.find('[') + 1: lp.find(']')].split(',')) + return LabeledPoint(label, vec) + +trainingData = ssc.textFileStream("/training/data/dir").map(Vectors.parse) +testData = ssc.textFileStream("/testing/data/dir").map(parse) +{% endhighlight %} + +We create a model with random clusters and specify the number of clusters to find + +{% highlight python %} +model = StreamingKMeans(k=2, decayFactor=1.0).setRandomCenters(3, 1.0, 0) +{% endhighlight %} + +Now register the streams for training and testing and start the job, printing +the predicted cluster assignments on new data points as they arrive. + +{% highlight python %} +model.trainOn(trainingData) +print(model.predictOnValues(testData.map(lambda lp: (lp.label, lp.features)))) + +ssc.start() +ssc.awaitTermination() +{% endhighlight %} +
    + + As you add new text files with data the cluster centers will update. Each training point should be formatted as `[x1, x2, x3]`, and each test data point @@ -580,7 +690,3 @@ should be formatted as `(y, [x1, x2, x3])`, where `y` is some useful label or id (e.g. a true category assignment). Anytime a text file is placed in `/training/data/dir` the model will update. Anytime a text file is placed in `/testing/data/dir` you will see predictions. With new data, the cluster centers will change! - - - - diff --git a/docs/mllib-collaborative-filtering.md b/docs/mllib-collaborative-filtering.md index 7b397e30b2d9..eedc23424ad5 100644 --- a/docs/mllib-collaborative-filtering.md +++ b/docs/mllib-collaborative-filtering.md @@ -77,7 +77,7 @@ val ratings = data.map(_.split(',') match { case Array(user, item, rate) => // Build the recommendation model using ALS val rank = 10 -val numIterations = 20 +val numIterations = 10 val model = ALS.train(ratings, rank, numIterations, 0.01) // Evaluate the model on rating data @@ -107,7 +107,8 @@ other signals), you can use the `trainImplicit` method to get better results. {% highlight scala %} val alpha = 0.01 -val model = ALS.trainImplicit(ratings, rank, numIterations, alpha) +val lambda = 0.01 +val model = ALS.trainImplicit(ratings, rank, numIterations, lambda, alpha) {% endhighlight %} @@ -148,7 +149,7 @@ public class CollaborativeFiltering { // Build the recommendation model using ALS int rank = 10; - int numIterations = 20; + int numIterations = 10; MatrixFactorizationModel model = ALS.train(JavaRDD.toRDD(ratings), rank, numIterations, 0.01); // Evaluate the model on rating data @@ -209,7 +210,7 @@ ratings = data.map(lambda l: l.split(',')).map(lambda l: Rating(int(l[0]), int(l # Build the recommendation model using Alternating Least Squares rank = 10 -numIterations = 20 +numIterations = 10 model = ALS.train(ratings, rank, numIterations) # Evaluate the model on training data diff --git a/docs/mllib-data-types.md b/docs/mllib-data-types.md index 4f2a2f71048f..d824dab1d7f7 100644 --- a/docs/mllib-data-types.md +++ b/docs/mllib-data-types.md @@ -31,7 +31,7 @@ The base class of local vectors is implementations: [`DenseVector`](api/scala/index.html#org.apache.spark.mllib.linalg.DenseVector) and [`SparseVector`](api/scala/index.html#org.apache.spark.mllib.linalg.SparseVector). We recommend using the factory methods implemented in -[`Vectors`](api/scala/index.html#org.apache.spark.mllib.linalg.Vector) to create local vectors. +[`Vectors`](api/scala/index.html#org.apache.spark.mllib.linalg.Vectors$) to create local vectors. {% highlight scala %} import org.apache.spark.mllib.linalg.{Vector, Vectors} @@ -57,7 +57,7 @@ The base class of local vectors is implementations: [`DenseVector`](api/java/org/apache/spark/mllib/linalg/DenseVector.html) and [`SparseVector`](api/java/org/apache/spark/mllib/linalg/SparseVector.html). We recommend using the factory methods implemented in -[`Vectors`](api/java/org/apache/spark/mllib/linalg/Vector.html) to create local vectors. +[`Vectors`](api/java/org/apache/spark/mllib/linalg/Vectors.html) to create local vectors. {% highlight java %} import org.apache.spark.mllib.linalg.Vector; @@ -84,7 +84,7 @@ and the following as sparse vectors: with a single column We recommend using NumPy arrays over lists for efficiency, and using the factory methods implemented -in [`Vectors`](api/python/pyspark.mllib.html#pyspark.mllib.linalg.Vector) to create sparse vectors. +in [`Vectors`](api/python/pyspark.mllib.html#pyspark.mllib.linalg.Vectors) to create sparse vectors. {% highlight python %} import numpy as np @@ -241,7 +241,7 @@ The base class of local matrices is [`Matrix`](api/scala/index.html#org.apache.spark.mllib.linalg.Matrix), and we provide one implementation: [`DenseMatrix`](api/scala/index.html#org.apache.spark.mllib.linalg.DenseMatrix). We recommend using the factory methods implemented -in [`Matrices`](api/scala/index.html#org.apache.spark.mllib.linalg.Matrices) to create local +in [`Matrices`](api/scala/index.html#org.apache.spark.mllib.linalg.Matrices$) to create local matrices. {% highlight scala %} @@ -296,70 +296,6 @@ backed by an RDD of its entries. The underlying RDDs of a distributed matrix must be deterministic, because we cache the matrix size. In general the use of non-deterministic RDDs can lead to errors. -### BlockMatrix - -A `BlockMatrix` is a distributed matrix backed by an RDD of `MatrixBlock`s, where a `MatrixBlock` is -a tuple of `((Int, Int), Matrix)`, where the `(Int, Int)` is the index of the block, and `Matrix` is -the sub-matrix at the given index with size `rowsPerBlock` x `colsPerBlock`. -`BlockMatrix` supports methods such as `add` and `multiply` with another `BlockMatrix`. -`BlockMatrix` also has a helper function `validate` which can be used to check whether the -`BlockMatrix` is set up properly. - -
    -
    - -A [`BlockMatrix`](api/scala/index.html#org.apache.spark.mllib.linalg.distributed.BlockMatrix) can be -most easily created from an `IndexedRowMatrix` or `CoordinateMatrix` by calling `toBlockMatrix`. -`toBlockMatrix` creates blocks of size 1024 x 1024 by default. -Users may change the block size by supplying the values through `toBlockMatrix(rowsPerBlock, colsPerBlock)`. - -{% highlight scala %} -import org.apache.spark.mllib.linalg.distributed.{BlockMatrix, CoordinateMatrix, MatrixEntry} - -val entries: RDD[MatrixEntry] = ... // an RDD of (i, j, v) matrix entries -// Create a CoordinateMatrix from an RDD[MatrixEntry]. -val coordMat: CoordinateMatrix = new CoordinateMatrix(entries) -// Transform the CoordinateMatrix to a BlockMatrix -val matA: BlockMatrix = coordMat.toBlockMatrix().cache() - -// Validate whether the BlockMatrix is set up properly. Throws an Exception when it is not valid. -// Nothing happens if it is valid. -matA.validate() - -// Calculate A^T A. -val ata = matA.transpose.multiply(matA) -{% endhighlight %} -
    - -
    - -A [`BlockMatrix`](api/java/org/apache/spark/mllib/linalg/distributed/BlockMatrix.html) can be -most easily created from an `IndexedRowMatrix` or `CoordinateMatrix` by calling `toBlockMatrix`. -`toBlockMatrix` creates blocks of size 1024 x 1024 by default. -Users may change the block size by supplying the values through `toBlockMatrix(rowsPerBlock, colsPerBlock)`. - -{% highlight java %} -import org.apache.spark.api.java.JavaRDD; -import org.apache.spark.mllib.linalg.distributed.BlockMatrix; -import org.apache.spark.mllib.linalg.distributed.CoordinateMatrix; -import org.apache.spark.mllib.linalg.distributed.IndexedRowMatrix; - -JavaRDD entries = ... // a JavaRDD of (i, j, v) Matrix Entries -// Create a CoordinateMatrix from a JavaRDD. -CoordinateMatrix coordMat = new CoordinateMatrix(entries.rdd()); -// Transform the CoordinateMatrix to a BlockMatrix -BlockMatrix matA = coordMat.toBlockMatrix().cache(); - -// Validate whether the BlockMatrix is set up properly. Throws an Exception when it is not valid. -// Nothing happens if it is valid. -matA.validate(); - -// Calculate A^T A. -BlockMatrix ata = matA.transpose().multiply(matA); -{% endhighlight %} -
    -
    - ### RowMatrix A `RowMatrix` is a row-oriented distributed matrix without meaningful row indices, backed by an RDD @@ -530,3 +466,67 @@ IndexedRowMatrix indexedRowMatrix = mat.toIndexedRowMatrix(); {% endhighlight %} + +### BlockMatrix + +A `BlockMatrix` is a distributed matrix backed by an RDD of `MatrixBlock`s, where a `MatrixBlock` is +a tuple of `((Int, Int), Matrix)`, where the `(Int, Int)` is the index of the block, and `Matrix` is +the sub-matrix at the given index with size `rowsPerBlock` x `colsPerBlock`. +`BlockMatrix` supports methods such as `add` and `multiply` with another `BlockMatrix`. +`BlockMatrix` also has a helper function `validate` which can be used to check whether the +`BlockMatrix` is set up properly. + +
    +
    + +A [`BlockMatrix`](api/scala/index.html#org.apache.spark.mllib.linalg.distributed.BlockMatrix) can be +most easily created from an `IndexedRowMatrix` or `CoordinateMatrix` by calling `toBlockMatrix`. +`toBlockMatrix` creates blocks of size 1024 x 1024 by default. +Users may change the block size by supplying the values through `toBlockMatrix(rowsPerBlock, colsPerBlock)`. + +{% highlight scala %} +import org.apache.spark.mllib.linalg.distributed.{BlockMatrix, CoordinateMatrix, MatrixEntry} + +val entries: RDD[MatrixEntry] = ... // an RDD of (i, j, v) matrix entries +// Create a CoordinateMatrix from an RDD[MatrixEntry]. +val coordMat: CoordinateMatrix = new CoordinateMatrix(entries) +// Transform the CoordinateMatrix to a BlockMatrix +val matA: BlockMatrix = coordMat.toBlockMatrix().cache() + +// Validate whether the BlockMatrix is set up properly. Throws an Exception when it is not valid. +// Nothing happens if it is valid. +matA.validate() + +// Calculate A^T A. +val ata = matA.transpose.multiply(matA) +{% endhighlight %} +
    + +
    + +A [`BlockMatrix`](api/java/org/apache/spark/mllib/linalg/distributed/BlockMatrix.html) can be +most easily created from an `IndexedRowMatrix` or `CoordinateMatrix` by calling `toBlockMatrix`. +`toBlockMatrix` creates blocks of size 1024 x 1024 by default. +Users may change the block size by supplying the values through `toBlockMatrix(rowsPerBlock, colsPerBlock)`. + +{% highlight java %} +import org.apache.spark.api.java.JavaRDD; +import org.apache.spark.mllib.linalg.distributed.BlockMatrix; +import org.apache.spark.mllib.linalg.distributed.CoordinateMatrix; +import org.apache.spark.mllib.linalg.distributed.IndexedRowMatrix; + +JavaRDD entries = ... // a JavaRDD of (i, j, v) Matrix Entries +// Create a CoordinateMatrix from a JavaRDD. +CoordinateMatrix coordMat = new CoordinateMatrix(entries.rdd()); +// Transform the CoordinateMatrix to a BlockMatrix +BlockMatrix matA = coordMat.toBlockMatrix().cache(); + +// Validate whether the BlockMatrix is set up properly. Throws an Exception when it is not valid. +// Nothing happens if it is valid. +matA.validate(); + +// Calculate A^T A. +BlockMatrix ata = matA.transpose().multiply(matA); +{% endhighlight %} +
    +
    diff --git a/docs/mllib-feature-extraction.md b/docs/mllib-feature-extraction.md index f723cd6b9dfa..a69e41e2a193 100644 --- a/docs/mllib-feature-extraction.md +++ b/docs/mllib-feature-extraction.md @@ -188,7 +188,7 @@ Here we assume the extracted file is `text8` and in same directory as you run th import org.apache.spark._ import org.apache.spark.rdd._ import org.apache.spark.SparkContext._ -import org.apache.spark.mllib.feature.Word2Vec +import org.apache.spark.mllib.feature.{Word2Vec, Word2VecModel} val input = sc.textFile("text8").map(line => line.split(" ").toSeq) @@ -201,6 +201,10 @@ val synonyms = model.findSynonyms("china", 40) for((synonym, cosineSimilarity) <- synonyms) { println(s"$synonym $cosineSimilarity") } + +// Save and load model +model.save(sc, "myModelPath") +val sameModel = Word2VecModel.load(sc, "myModelPath") {% endhighlight %}
    @@ -380,7 +384,7 @@ data2 = labels.zip(normalizer2.transform(features)) [Feature selection](http://en.wikipedia.org/wiki/Feature_selection) allows selecting the most relevant features for use in model construction. Feature selection reduces the size of the vector space and, in turn, the complexity of any subsequent operation with vectors. The number of features to select can be tuned using a held-out validation set. ### ChiSqSelector -[`ChiSqSelector`](api/scala/index.html#org.apache.spark.mllib.feature.ChiSqSelector) stands for Chi-Squared feature selection. It operates on labeled data with categorical features. `ChiSqSelector` orders features based on a Chi-Squared test of independence from the class, and then filters (selects) the top features which are most closely related to the label. +[`ChiSqSelector`](api/scala/index.html#org.apache.spark.mllib.feature.ChiSqSelector) stands for Chi-Squared feature selection. It operates on labeled data with categorical features. `ChiSqSelector` orders features based on a Chi-Squared test of independence from the class, and then filters (selects) the top features which the class label depends on the most. This is akin to yielding the features with the most predictive power. #### Model Fitting @@ -401,7 +405,7 @@ Note that the user can also construct a `ChiSqSelectorModel` by hand by providin #### Example -The following example shows the basic use of ChiSqSelector. +The following example shows the basic use of ChiSqSelector. The data set used has a feature matrix consisting of greyscale values that vary from 0 to 255 for each feature.
    @@ -410,14 +414,16 @@ import org.apache.spark.SparkContext._ import org.apache.spark.mllib.linalg.Vectors import org.apache.spark.mllib.regression.LabeledPoint import org.apache.spark.mllib.util.MLUtils +import org.apache.spark.mllib.feature.ChiSqSelector // Load some data in libsvm format val data = MLUtils.loadLibSVMFile(sc, "data/mllib/sample_libsvm_data.txt") // Discretize data in 16 equal bins since ChiSqSelector requires categorical features +// Even though features are doubles, the ChiSqSelector treats each unique value as a category val discretizedData = data.map { lp => - LabeledPoint(lp.label, Vectors.dense(lp.features.toArray.map { x => x / 16 } ) ) + LabeledPoint(lp.label, Vectors.dense(lp.features.toArray.map { x => (x / 16).floor } ) ) } -// Create ChiSqSelector that will select 50 features +// Create ChiSqSelector that will select top 50 of 692 features val selector = new ChiSqSelector(50) // Create ChiSqSelector model (selecting features) val transformer = selector.fit(discretizedData) @@ -446,19 +452,20 @@ JavaRDD points = MLUtils.loadLibSVMFile(sc.sc(), "data/mllib/sample_libsvm_data.txt").toJavaRDD().cache(); // Discretize data in 16 equal bins since ChiSqSelector requires categorical features +// Even though features are doubles, the ChiSqSelector treats each unique value as a category JavaRDD discretizedData = points.map( new Function() { @Override public LabeledPoint call(LabeledPoint lp) { final double[] discretizedFeatures = new double[lp.features().size()]; for (int i = 0; i < lp.features().size(); ++i) { - discretizedFeatures[i] = lp.features().apply(i) / 16; + discretizedFeatures[i] = Math.floor(lp.features().apply(i) / 16); } return new LabeledPoint(lp.label(), Vectors.dense(discretizedFeatures)); } }); -// Create ChiSqSelector that will select 50 features +// Create ChiSqSelector that will select top 50 of 692 features ChiSqSelector selector = new ChiSqSelector(50); // Create ChiSqSelector model (selecting features) final ChiSqSelectorModel transformer = selector.fit(discretizedData.rdd()); @@ -505,7 +512,7 @@ v_N ### Example -This example below demonstrates how to load a simple vectors file, extract a set of vectors, then transform those vectors using a transforming vector value. +This example below demonstrates how to transform vectors using a transforming vector value.
    @@ -514,16 +521,67 @@ import org.apache.spark.SparkContext._ import org.apache.spark.mllib.feature.ElementwiseProduct import org.apache.spark.mllib.linalg.Vectors -// Load and parse the data: -val data = sc.textFile("data/mllib/kmeans_data.txt") -val parsedData = data.map(s => Vectors.dense(s.split(' ').map(_.toDouble))) +// Create some vector data; also works for sparse vectors +val data = sc.parallelize(Array(Vectors.dense(1.0, 2.0, 3.0), Vectors.dense(4.0, 5.0, 6.0))) val transformingVector = Vectors.dense(0.0, 1.0, 2.0) val transformer = new ElementwiseProduct(transformingVector) // Batch transform and per-row transform give the same results: -val transformedData = transformer.transform(parsedData) -val transformedData2 = parsedData.map(x => transformer.transform(x)) +val transformedData = transformer.transform(data) +val transformedData2 = data.map(x => transformer.transform(x)) + +{% endhighlight %} +
    + +
    +{% highlight java %} +import java.util.Arrays; +import org.apache.spark.api.java.JavaRDD; +import org.apache.spark.api.java.JavaSparkContext; +import org.apache.spark.mllib.feature.ElementwiseProduct; +import org.apache.spark.mllib.linalg.Vector; +import org.apache.spark.mllib.linalg.Vectors; + +// Create some vector data; also works for sparse vectors +JavaRDD data = sc.parallelize(Arrays.asList( + Vectors.dense(1.0, 2.0, 3.0), Vectors.dense(4.0, 5.0, 6.0))); +Vector transformingVector = Vectors.dense(0.0, 1.0, 2.0); +ElementwiseProduct transformer = new ElementwiseProduct(transformingVector); + +// Batch transform and per-row transform give the same results: +JavaRDD transformedData = transformer.transform(data); +JavaRDD transformedData2 = data.map( + new Function() { + @Override + public Vector call(Vector v) { + return transformer.transform(v); + } + } +); + +{% endhighlight %} +
    + +
    +{% highlight python %} +from pyspark import SparkContext +from pyspark.mllib.linalg import Vectors +from pyspark.mllib.feature import ElementwiseProduct + +# Load and parse the data +sc = SparkContext() +data = sc.textFile("data/mllib/kmeans_data.txt") +parsedData = data.map(lambda x: [float(t) for t in x.split(" ")]) + +# Create weight vector. +transformingVector = Vectors.dense([0.0, 1.0, 2.0]) +transformer = ElementwiseProduct(transformingVector) + +# Batch transform +transformedData = transformer.transform(parsedData) +# Single-row transform +transformedData2 = transformer.transform(parsedData.first()) {% endhighlight %}
    diff --git a/docs/mllib-frequent-pattern-mining.md b/docs/mllib-frequent-pattern-mining.md index 9fd9be0dd01b..bcc066a18552 100644 --- a/docs/mllib-frequent-pattern-mining.md +++ b/docs/mllib-frequent-pattern-mining.md @@ -39,11 +39,11 @@ MLlib's FP-growth implementation takes the following (hyper-)parameters:
    -[`FPGrowth`](api/java/org/apache/spark/mllib/fpm/FPGrowth.html) implements the +[`FPGrowth`](api/scala/index.html#org.apache.spark.mllib.fpm.FPGrowth) implements the FP-growth algorithm. It take a `JavaRDD` of transactions, where each transaction is an `Iterable` of items of a generic type. Calling `FPGrowth.run` with transactions returns an -[`FPGrowthModel`](api/java/org/apache/spark/mllib/fpm/FPGrowthModel.html) +[`FPGrowthModel`](api/scala/index.html#org.apache.spark.mllib.fpm.FPGrowthModel) that stores the frequent itemsets with their frequencies. {% highlight scala %} diff --git a/docs/mllib-guide.md b/docs/mllib-guide.md index f8e879496c13..d2d1cc93fe00 100644 --- a/docs/mllib-guide.md +++ b/docs/mllib-guide.md @@ -7,7 +7,19 @@ description: MLlib machine learning library overview for Spark SPARK_VERSION_SHO MLlib is Spark's scalable machine learning library consisting of common learning algorithms and utilities, including classification, regression, clustering, collaborative -filtering, dimensionality reduction, as well as underlying optimization primitives, as outlined below: +filtering, dimensionality reduction, as well as underlying optimization primitives. +Guides for individual algorithms are listed below. + +The API is divided into 2 parts: + +* [The original `spark.mllib` API](mllib-guide.html#mllib-types-algorithms-and-utilities) is the primary API. +* [The "Pipelines" `spark.ml` API](mllib-guide.html#sparkml-high-level-apis-for-ml-pipelines) is a higher-level API for constructing ML workflows. + +We list major functionality from both below, with links to detailed guides. + +# MLlib types, algorithms and utilities + +This lists functionality included in `spark.mllib`, the main MLlib API. * [Data types](mllib-data-types.html) * [Basic statistics](mllib-statistics.html) @@ -39,6 +51,7 @@ filtering, dimensionality reduction, as well as underlying optimization primitiv * [Optimization (developer)](mllib-optimization.html) * stochastic gradient descent * limited-memory BFGS (L-BFGS) +* [PMML model export](mllib-pmml-model-export.html) MLlib is under active development. The APIs marked `Experimental`/`DeveloperApi` may change in future releases, @@ -48,8 +61,8 @@ and the migration guide below will explain all changes between releases. Spark 1.2 introduced a new package called `spark.ml`, which aims to provide a uniform set of high-level APIs that help users create and tune practical machine learning pipelines. -It is currently an alpha component, and we would like to hear back from the community about -how it fits real-world use cases and how it could be improved. + +*Graduated from Alpha!* The Pipelines API is no longer an alpha component, although many elements of it are still `Experimental` or `DeveloperApi`. Note that we will keep supporting and adding features to `spark.mllib` along with the development of `spark.ml`. @@ -57,7 +70,11 @@ Users should be comfortable using `spark.mllib` features and expect more feature Developers should contribute new algorithms to `spark.mllib` and can optionally contribute to `spark.ml`. -See the **[spark.ml programming guide](ml-guide.html)** for more information on this package. +More detailed guides for `spark.ml` include: + +* **[spark.ml programming guide](ml-guide.html)**: overview of the Pipelines API and major concepts +* [Feature transformers](ml-features.html): Details on transformers supported in the Pipelines API, including a few not in the lower-level `spark.mllib` API +* [Ensembles](ml-ensembles.html): Details on ensemble learning methods in the Pipelines API # Dependencies @@ -89,21 +106,14 @@ version 1.4 or newer. For the `spark.ml` package, please see the [spark.ml Migration Guide](ml-guide.html#migration-guide). -## From 1.2 to 1.3 - -In the `spark.mllib` package, there were several breaking changes. The first change (in `ALS`) is the only one in a component not marked as Alpha or Experimental. - -* *(Breaking change)* In [`ALS`](api/scala/index.html#org.apache.spark.mllib.recommendation.ALS), the extraneous method `solveLeastSquares` has been removed. The `DeveloperApi` method `analyzeBlocks` was also removed. -* *(Breaking change)* [`StandardScalerModel`](api/scala/index.html#org.apache.spark.mllib.feature.StandardScalerModel) remains an Alpha component. In it, the `variance` method has been replaced with the `std` method. To compute the column variance values returned by the original `variance` method, simply square the standard deviation values returned by `std`. -* *(Breaking change)* [`StreamingLinearRegressionWithSGD`](api/scala/index.html#org.apache.spark.mllib.regression.StreamingLinearRegressionWithSGD) remains an Experimental component. In it, there were two changes: - * The constructor taking arguments was removed in favor of a builder patten using the default constructor plus parameter setter methods. - * Variable `model` is no longer public. -* *(Breaking change)* [`DecisionTree`](api/scala/index.html#org.apache.spark.mllib.tree.DecisionTree) remains an Experimental component. In it and its associated classes, there were several changes: - * In `DecisionTree`, the deprecated class method `train` has been removed. (The object/static `train` methods remain.) - * In `Strategy`, the `checkpointDir` parameter has been removed. Checkpointing is still supported, but the checkpoint directory must be set before calling tree and tree ensemble training. -* `PythonMLlibAPI` (the interface between Scala/Java and Python for MLlib) was a public API but is now private, declared `private[python]`. This was never meant for external use. -* In linear regression (including Lasso and ridge regression), the squared loss is now divided by 2. - So in order to produce the same result as in 1.2, the regularization parameter needs to be divided by 2 and the step size needs to be multiplied by 2. +## From 1.3 to 1.4 + +In the `spark.mllib` package, there were several breaking changes, but all in `DeveloperApi` or `Experimental` APIs: + +* Gradient-Boosted Trees + * *(Breaking change)* The signature of the [`Loss.gradient`](api/scala/index.html#org.apache.spark.mllib.tree.loss.Loss) method was changed. This is only an issues for users who wrote their own losses for GBTs. + * *(Breaking change)* The `apply` and `copy` methods for the case class [`BoostingStrategy`](api/scala/index.html#org.apache.spark.mllib.tree.configuration.BoostingStrategy) have been changed because of a modification to the case class fields. This could be an issue for users who use `BoostingStrategy` to set GBT parameters. +* *(Breaking change)* The return value of [`LDA.run`](api/scala/index.html#org.apache.spark.mllib.clustering.LDA) has changed. It now returns an abstract class `LDAModel` instead of the concrete class `DistributedLDAModel`. The object of type `LDAModel` can still be cast to the appropriate concrete type, which depends on the optimization algorithm. ## Previous Spark Versions diff --git a/docs/mllib-isotonic-regression.md b/docs/mllib-isotonic-regression.md index b521c2f27cd6..5732bc4c7e79 100644 --- a/docs/mllib-isotonic-regression.md +++ b/docs/mllib-isotonic-regression.md @@ -60,7 +60,7 @@ Model is created using the training set and a mean squared error is calculated f labels and real labels in the test set. {% highlight scala %} -import org.apache.spark.mllib.regression.IsotonicRegression +import org.apache.spark.mllib.regression.{IsotonicRegression, IsotonicRegressionModel} val data = sc.textFile("data/mllib/sample_isotonic_regression_data.txt") @@ -88,6 +88,10 @@ val predictionAndLabel = test.map { point => // Calculate mean squared error between predicted and real labels. val meanSquaredError = predictionAndLabel.map{case(p, l) => math.pow((p - l), 2)}.mean() println("Mean Squared Error = " + meanSquaredError) + +// Save and load model +model.save(sc, "myModelPath") +val sameModel = IsotonicRegressionModel.load(sc, "myModelPath") {% endhighlight %}
    @@ -150,6 +154,10 @@ Double meanSquaredError = new JavaDoubleRDD(predictionAndLabel.map( ).rdd()).mean(); System.out.println("Mean Squared Error = " + meanSquaredError); + +// Save and load model +model.save(sc.sc(), "myModelPath"); +IsotonicRegressionModel sameModel = IsotonicRegressionModel.load(sc.sc(), "myModelPath"); {% endhighlight %}
    diff --git a/docs/mllib-linear-methods.md b/docs/mllib-linear-methods.md index 2b2be4d9d027..3927d65fbf8f 100644 --- a/docs/mllib-linear-methods.md +++ b/docs/mllib-linear-methods.md @@ -163,11 +163,8 @@ object, and make predictions with the resulting model to compute the training error. {% highlight scala %} -import org.apache.spark.SparkContext import org.apache.spark.mllib.classification.{SVMModel, SVMWithSGD} import org.apache.spark.mllib.evaluation.BinaryClassificationMetrics -import org.apache.spark.mllib.regression.LabeledPoint -import org.apache.spark.mllib.linalg.Vectors import org.apache.spark.mllib.util.MLUtils // Load training data in LIBSVM format. @@ -231,15 +228,13 @@ calling `.rdd()` on your `JavaRDD` object. A self-contained application example that is equivalent to the provided example in Scala is given bellow: {% highlight java %} -import java.util.Random; - import scala.Tuple2; import org.apache.spark.api.java.*; import org.apache.spark.api.java.function.Function; import org.apache.spark.mllib.classification.*; import org.apache.spark.mllib.evaluation.BinaryClassificationMetrics; -import org.apache.spark.mllib.linalg.Vector; + import org.apache.spark.mllib.regression.LabeledPoint; import org.apache.spark.mllib.util.MLUtils; import org.apache.spark.SparkConf; @@ -282,8 +277,8 @@ public class SVMClassifier { System.out.println("Area under ROC = " + auROC); // Save and load model - model.save(sc.sc(), "myModelPath"); - SVMModel sameModel = SVMModel.load(sc.sc(), "myModelPath"); + model.save(sc, "myModelPath"); + SVMModel sameModel = SVMModel.load(sc, "myModelPath"); } } {% endhighlight %} @@ -315,15 +310,12 @@ a dependency.
    -The following example shows how to load a sample dataset, build Logistic Regression model, +The following example shows how to load a sample dataset, build SVM model, and make predictions with the resulting model to compute the training error. -Note that the Python API does not yet support model save/load but will in the future. - {% highlight python %} -from pyspark.mllib.classification import LogisticRegressionWithSGD +from pyspark.mllib.classification import SVMWithSGD, SVMModel from pyspark.mllib.regression import LabeledPoint -from numpy import array # Load and parse the data def parsePoint(line): @@ -334,12 +326,16 @@ data = sc.textFile("data/mllib/sample_svm_data.txt") parsedData = data.map(parsePoint) # Build the model -model = LogisticRegressionWithSGD.train(parsedData) +model = SVMWithSGD.train(parsedData, iterations=100) # Evaluating the model on training data labelsAndPreds = parsedData.map(lambda p: (p.label, model.predict(p.features))) trainErr = labelsAndPreds.filter(lambda (v, p): v != p).count() / float(parsedData.count()) print("Training Error = " + str(trainErr)) + +# Save and load model +model.save(sc, "myModelPath") +sameModel = SVMModel.load(sc, "myModelPath") {% endhighlight %}
    @@ -503,7 +499,7 @@ Note that the Python API does not yet support multiclass classification and mode will in the future. {% highlight python %} -from pyspark.mllib.classification import LogisticRegressionWithLBFGS +from pyspark.mllib.classification import LogisticRegressionWithLBFGS, LogisticRegressionModel from pyspark.mllib.regression import LabeledPoint from numpy import array @@ -522,6 +518,10 @@ model = LogisticRegressionWithLBFGS.train(parsedData) labelsAndPreds = parsedData.map(lambda p: (p.label, model.predict(p.features))) trainErr = labelsAndPreds.filter(lambda (v, p): v != p).count() / float(parsedData.count()) print("Training Error = " + str(trainErr)) + +# Save and load model +model.save(sc, "myModelPath") +sameModel = LogisticRegressionModel.load(sc, "myModelPath") {% endhighlight %}
    @@ -672,7 +672,7 @@ values. We compute the mean squared error at the end to evaluate Note that the Python API does not yet support model save/load but will in the future. {% highlight python %} -from pyspark.mllib.regression import LabeledPoint, LinearRegressionWithSGD +from pyspark.mllib.regression import LabeledPoint, LinearRegressionWithSGD, LinearRegressionModel from numpy import array # Load and parse the data @@ -690,6 +690,10 @@ model = LinearRegressionWithSGD.train(parsedData) valuesAndPreds = parsedData.map(lambda p: (p.label, model.predict(p.features))) MSE = valuesAndPreds.map(lambda (v, p): (v - p)**2).reduce(lambda x, y: x + y) / valuesAndPreds.count() print("Mean Squared Error = " + str(MSE)) + +# Save and load model +model.save(sc, "myModelPath") +sameModel = LinearRegressionModel.load(sc, "myModelPath") {% endhighlight %} @@ -772,6 +776,58 @@ will get better! +
    + +First, we import the necessary classes for parsing our input data and creating the model. + +{% highlight python %} +from pyspark.mllib.linalg import Vectors +from pyspark.mllib.regression import LabeledPoint +from pyspark.mllib.regression import StreamingLinearRegressionWithSGD +{% endhighlight %} + +Then we make input streams for training and testing data. We assume a StreamingContext `ssc` +has already been created, see [Spark Streaming Programming Guide](streaming-programming-guide.html#initializing) +for more info. For this example, we use labeled points in training and testing streams, +but in practice you will likely want to use unlabeled vectors for test data. + +{% highlight python %} +def parse(lp): + label = float(lp[lp.find('(') + 1: lp.find(',')]) + vec = Vectors.dense(lp[lp.find('[') + 1: lp.find(']')].split(',')) + return LabeledPoint(label, vec) + +trainingData = ssc.textFileStream("/training/data/dir").map(parse).cache() +testData = ssc.textFileStream("/testing/data/dir").map(parse) +{% endhighlight %} + +We create our model by initializing the weights to 0 + +{% highlight python %} +numFeatures = 3 +model = StreamingLinearRegressionWithSGD() +model.setInitialWeights([0.0, 0.0, 0.0]) +{% endhighlight %} + +Now we register the streams for training and testing and start the job. + +{% highlight python %} +model.trainOn(trainingData) +print(model.predictOnValues(testData.map(lambda lp: (lp.label, lp.features)))) + +ssc.start() +ssc.awaitTermination() +{% endhighlight %} + +We can now save text files with data to the training or testing folders. +Each line should be a data point formatted as `(y,[x1,x2,x3])` where `y` is the label +and `x1,x2,x3` are the features. Anytime a text file is placed in `/training/data/dir` +the model will update. Anytime a text file is placed in `/testing/data/dir` you will see predictions. +As you feed more data to the training directory, the predictions +will get better! + +
    + @@ -785,8 +841,7 @@ gradient descent (`stepSize`, `numIterations`, `miniBatchFraction`). For each o all three possible regularizations (none, L1 or L2). For Logistic Regression, [L-BFGS](api/scala/index.html#org.apache.spark.mllib.optimization.LBFGS) -version is implemented under [LogisticRegressionWithLBFGS] -(api/scala/index.html#org.apache.spark.mllib.classification.LogisticRegressionWithLBFGS), and this +version is implemented under [LogisticRegressionWithLBFGS](api/scala/index.html#org.apache.spark.mllib.classification.LogisticRegressionWithLBFGS), and this version supports both binary and multinomial Logistic Regression while SGD version only supports binary Logistic Regression. However, L-BFGS version doesn't support L1 regularization but SGD one supports L1 regularization. When L1 regularization is not required, L-BFGS version is strongly diff --git a/docs/mllib-migration-guides.md b/docs/mllib-migration-guides.md index 4de2d9491ac2..8df68d81f3c7 100644 --- a/docs/mllib-migration-guides.md +++ b/docs/mllib-migration-guides.md @@ -7,6 +7,22 @@ description: MLlib migration guides from before Spark SPARK_VERSION_SHORT The migration guide for the current Spark version is kept on the [MLlib Programming Guide main page](mllib-guide.html#migration-guide). +## From 1.2 to 1.3 + +In the `spark.mllib` package, there were several breaking changes. The first change (in `ALS`) is the only one in a component not marked as Alpha or Experimental. + +* *(Breaking change)* In [`ALS`](api/scala/index.html#org.apache.spark.mllib.recommendation.ALS), the extraneous method `solveLeastSquares` has been removed. The `DeveloperApi` method `analyzeBlocks` was also removed. +* *(Breaking change)* [`StandardScalerModel`](api/scala/index.html#org.apache.spark.mllib.feature.StandardScalerModel) remains an Alpha component. In it, the `variance` method has been replaced with the `std` method. To compute the column variance values returned by the original `variance` method, simply square the standard deviation values returned by `std`. +* *(Breaking change)* [`StreamingLinearRegressionWithSGD`](api/scala/index.html#org.apache.spark.mllib.regression.StreamingLinearRegressionWithSGD) remains an Experimental component. In it, there were two changes: + * The constructor taking arguments was removed in favor of a builder pattern using the default constructor plus parameter setter methods. + * Variable `model` is no longer public. +* *(Breaking change)* [`DecisionTree`](api/scala/index.html#org.apache.spark.mllib.tree.DecisionTree) remains an Experimental component. In it and its associated classes, there were several changes: + * In `DecisionTree`, the deprecated class method `train` has been removed. (The object/static `train` methods remain.) + * In `Strategy`, the `checkpointDir` parameter has been removed. Checkpointing is still supported, but the checkpoint directory must be set before calling tree and tree ensemble training. +* `PythonMLlibAPI` (the interface between Scala/Java and Python for MLlib) was a public API but is now private, declared `private[python]`. This was never meant for external use. +* In linear regression (including Lasso and ridge regression), the squared loss is now divided by 2. + So in order to produce the same result as in 1.2, the regularization parameter needs to be divided by 2 and the step size needs to be multiplied by 2. + ## From 1.1 to 1.2 The only API changes in MLlib v1.2 are in diff --git a/docs/mllib-naive-bayes.md b/docs/mllib-naive-bayes.md index 9780ea52c499..e73bd30f3a90 100644 --- a/docs/mllib-naive-bayes.md +++ b/docs/mllib-naive-bayes.md @@ -14,14 +14,13 @@ and use it for prediction. MLlib supports [multinomial naive Bayes](http://en.wikipedia.org/wiki/Naive_Bayes_classifier#Multinomial_naive_Bayes) -and [Bernoulli naive Bayes] (http://nlp.stanford.edu/IR-book/html/htmledition/the-bernoulli-model-1.html). -These models are typically used for [document classification] -(http://nlp.stanford.edu/IR-book/html/htmledition/naive-bayes-text-classification-1.html). +and [Bernoulli naive Bayes](http://nlp.stanford.edu/IR-book/html/htmledition/the-bernoulli-model-1.html). +These models are typically used for [document classification](http://nlp.stanford.edu/IR-book/html/htmledition/naive-bayes-text-classification-1.html). Within that context, each observation is a document and each feature represents a term whose value is the frequency of the term (in multinomial naive Bayes) or a zero or one indicating whether the term was found in the document (in Bernoulli naive Bayes). Feature values must be nonnegative. The model type is selected with an optional parameter -"Multinomial" or "Bernoulli" with "Multinomial" as the default. +"multinomial" or "bernoulli" with "multinomial" as the default. [Additive smoothing](http://en.wikipedia.org/wiki/Lidstone_smoothing) can be used by setting the parameter $\lambda$ (default to $1.0$). For document classification, the input feature vectors are usually sparse, and sparse vectors should be supplied as input to take advantage of @@ -35,7 +34,7 @@ sparsity. Since the training data is only used once, it is not necessary to cach [NaiveBayes](api/scala/index.html#org.apache.spark.mllib.classification.NaiveBayes$) implements multinomial naive Bayes. It takes an RDD of [LabeledPoint](api/scala/index.html#org.apache.spark.mllib.regression.LabeledPoint) and an optional -smoothing parameter `lambda` as input, an optional model type parameter (default is Multinomial), and outputs a +smoothing parameter `lambda` as input, an optional model type parameter (default is "multinomial"), and outputs a [NaiveBayesModel](api/scala/index.html#org.apache.spark.mllib.classification.NaiveBayesModel), which can be used for evaluation and prediction. @@ -54,7 +53,7 @@ val splits = parsedData.randomSplit(Array(0.6, 0.4), seed = 11L) val training = splits(0) val test = splits(1) -val model = NaiveBayes.train(training, lambda = 1.0, model = "Multinomial") +val model = NaiveBayes.train(training, lambda = 1.0, modelType = "multinomial") val predictionAndLabel = test.map(p => (model.predict(p.features), p.label)) val accuracy = 1.0 * predictionAndLabel.filter(x => x._1 == x._2).count() / test.count() @@ -75,6 +74,8 @@ optionally smoothing parameter `lambda` as input, and output a can be used for evaluation and prediction. {% highlight java %} +import scala.Tuple2; + import org.apache.spark.api.java.JavaPairRDD; import org.apache.spark.api.java.JavaRDD; import org.apache.spark.api.java.function.Function; @@ -82,7 +83,6 @@ import org.apache.spark.api.java.function.PairFunction; import org.apache.spark.mllib.classification.NaiveBayes; import org.apache.spark.mllib.classification.NaiveBayesModel; import org.apache.spark.mllib.regression.LabeledPoint; -import scala.Tuple2; JavaRDD training = ... // training set JavaRDD test = ... // test set @@ -119,7 +119,7 @@ used for evaluation and prediction. Note that the Python API does not yet support model save/load but will in the future. {% highlight python %} -from pyspark.mllib.classification import NaiveBayes +from pyspark.mllib.classification import NaiveBayes, NaiveBayesModel from pyspark.mllib.linalg import Vectors from pyspark.mllib.regression import LabeledPoint @@ -140,6 +140,10 @@ model = NaiveBayes.train(training, 1.0) # Make prediction and test accuracy. predictionAndLabel = test.map(lambda p : (model.predict(p.features), p.label)) accuracy = 1.0 * predictionAndLabel.filter(lambda (x, v): x == v).count() / test.count() + +# Save and load model +model.save(sc, "myModelPath") +sameModel = NaiveBayesModel.load(sc, "myModelPath") {% endhighlight %} diff --git a/docs/mllib-pmml-model-export.md b/docs/mllib-pmml-model-export.md new file mode 100644 index 000000000000..42ea2ca81f80 --- /dev/null +++ b/docs/mllib-pmml-model-export.md @@ -0,0 +1,86 @@ +--- +layout: global +title: PMML model export - MLlib +displayTitle: MLlib - PMML model export +--- + +* Table of contents +{:toc} + +## MLlib supported models + +MLlib supports model export to Predictive Model Markup Language ([PMML](http://en.wikipedia.org/wiki/Predictive_Model_Markup_Language)). + +The table below outlines the MLlib models that can be exported to PMML and their equivalent PMML model. + + + + + + + + + + + + + + + + + + + + + + + + + +
    MLlib modelPMML model
    KMeansModelClusteringModel
    LinearRegressionModelRegressionModel (functionName="regression")
    RidgeRegressionModelRegressionModel (functionName="regression")
    LassoModelRegressionModel (functionName="regression")
    SVMModelRegressionModel (functionName="classification" normalizationMethod="none")
    Binary LogisticRegressionModelRegressionModel (functionName="classification" normalizationMethod="logit")
    + +## Examples +
    + +
    +To export a supported `model` (see table above) to PMML, simply call `model.toPMML`. + +Here a complete example of building a KMeansModel and print it out in PMML format: +{% highlight scala %} +import org.apache.spark.mllib.clustering.KMeans +import org.apache.spark.mllib.linalg.Vectors + +// Load and parse the data +val data = sc.textFile("data/mllib/kmeans_data.txt") +val parsedData = data.map(s => Vectors.dense(s.split(' ').map(_.toDouble))).cache() + +// Cluster the data into two classes using KMeans +val numClusters = 2 +val numIterations = 20 +val clusters = KMeans.train(parsedData, numClusters, numIterations) + +// Export to PMML +println("PMML Model:\n" + clusters.toPMML) +{% endhighlight %} + +As well as exporting the PMML model to a String (`model.toPMML` as in the example above), you can export the PMML model to other formats: + +{% highlight scala %} +// Export the model to a String in PMML format +clusters.toPMML + +// Export the model to a local file in PMML format +clusters.toPMML("/tmp/kmeans.xml") + +// Export the model to a directory on a distributed file system in PMML format +clusters.toPMML(sc,"/tmp/kmeans") + +// Export the model to the OutputStream in PMML format +clusters.toPMML(System.out) +{% endhighlight %} + +For unsupported models, either you will not find a `.toPMML` method or an `IllegalArgumentException` will be thrown. + +
    + +
    diff --git a/docs/monitoring.md b/docs/monitoring.md index 1e0fc150862f..bcf885fe4e68 100644 --- a/docs/monitoring.md +++ b/docs/monitoring.md @@ -178,9 +178,9 @@ Note that the history server only displays completed Spark jobs. One way to sign In addition to viewing the metrics in the UI, they are also available as JSON. This gives developers an easy way to create new visualizations and monitoring tools for Spark. The JSON is available for -both running applications, and in the history server. The endpoints are mounted at `/json/v1`. Eg., -for the history server, they would typically be accessible at `http://:18080/json/v1`, and -for a running application, at `http://localhost:4040/json/v1`. +both running applications, and in the history server. The endpoints are mounted at `/api/v1`. Eg., +for the history server, they would typically be accessible at `http://:18080/api/v1`, and +for a running application, at `http://localhost:4040/api/v1`. @@ -228,6 +228,14 @@ for a running application, at `http://localhost:4040/json/v1`. + + + + + + + +
    EndpointMeaning
    /applications/[app-id]/storage/rdd/[rdd-id] Details for the storage status of a given RDD
    /applications/[app-id]/logsDownload the event logs for all attempts of the given application as a zip file
    /applications/[app-id]/[attempt-id]/logsDownload the event logs for the specified attempt of the given application as a zip file
    When running on Yarn, each application has multiple attempts, so `[app-id]` is actually @@ -240,12 +248,12 @@ These endpoints have been strongly versioned to make it easier to develop applic * Individual fields will never be removed for any given endpoint * New endpoints may be added * New fields may be added to existing endpoints -* New versions of the api may be added in the future at a separate endpoint (eg., `json/v2`). New versions are *not* required to be backwards compatible. +* New versions of the api may be added in the future at a separate endpoint (eg., `api/v2`). New versions are *not* required to be backwards compatible. * Api versions may be dropped, but only after at least one minor release of co-existing with a new api version Note that even when examining the UI of a running applications, the `applications/[app-id]` portion is still required, though there is only one application available. Eg. to see the list of jobs for the -running app, you would go to `http://localhost:4040/json/v1/applications/[app-id]/jobs`. This is to +running app, you would go to `http://localhost:4040/api/v1/applications/[app-id]/jobs`. This is to keep the paths consistent in both modes. # Metrics diff --git a/docs/programming-guide.md b/docs/programming-guide.md index 27816515c5de..ae712d62746f 100644 --- a/docs/programming-guide.md +++ b/docs/programming-guide.md @@ -41,19 +41,20 @@ In addition, if you wish to access an HDFS cluster, you need to add a dependency artifactId = hadoop-client version = -Finally, you need to import some Spark classes and implicit conversions into your program. Add the following lines: +Finally, you need to import some Spark classes into your program. Add the following lines: {% highlight scala %} import org.apache.spark.SparkContext -import org.apache.spark.SparkContext._ import org.apache.spark.SparkConf {% endhighlight %} +(Before Spark 1.3.0, you need to explicitly `import org.apache.spark.SparkContext._` to enable essential implicit conversions.) +
    -Spark {{site.SPARK_VERSION}} works with Java 6 and higher. If you are using Java 8, Spark supports +Spark {{site.SPARK_VERSION}} works with Java 7 and higher. If you are using Java 8, Spark supports [lambda expressions](http://docs.oracle.com/javase/tutorial/java/javaOO/lambdaexpressions.html) for concisely writing functions, otherwise you can use the classes in the [org.apache.spark.api.java.function](api/java/index.html?org/apache/spark/api/java/function/package-summary.html) package. @@ -97,9 +98,9 @@ to your version of HDFS. Some common HDFS version tags are listed on the [Prebuilt packages](http://spark.apache.org/downloads.html) are also available on the Spark homepage for common HDFS versions. -Finally, you need to import some Spark classes into your program. Add the following lines: +Finally, you need to import some Spark classes into your program. Add the following line: -{% highlight scala %} +{% highlight python %} from pyspark import SparkContext, SparkConf {% endhighlight %} @@ -477,7 +478,6 @@ the [Converter examples]({{site.SPARK_GITHUB_URL}}/tree/master/examples/src/main for examples of using Cassandra / HBase ```InputFormat``` and ```OutputFormat``` with custom converters.
    - ## RDD Operations @@ -821,11 +821,9 @@ by a key. In Scala, these operations are automatically available on RDDs containing [Tuple2](http://www.scala-lang.org/api/{{site.SCALA_VERSION}}/index.html#scala.Tuple2) objects -(the built-in tuples in the language, created by simply writing `(a, b)`), as long as you -import `org.apache.spark.SparkContext._` in your program to enable Spark's implicit -conversions. The key-value pair operations are available in the +(the built-in tuples in the language, created by simply writing `(a, b)`). The key-value pair operations are available in the [PairRDDFunctions](api/scala/index.html#org.apache.spark.rdd.PairRDDFunctions) class, -which automatically wraps around an RDD of tuples if you import the conversions. +which automatically wraps around an RDD of tuples. For example, the following code uses the `reduceByKey` operation on key-value pairs to count how many times each line of text occurs in a file: @@ -916,7 +914,8 @@ The following table lists some of the common transformations supported by Spark. RDD API doc ([Scala](api/scala/index.html#org.apache.spark.rdd.RDD), [Java](api/java/index.html?org/apache/spark/api/java/JavaRDD.html), - [Python](api/python/pyspark.html#pyspark.RDD)) + [Python](api/python/pyspark.html#pyspark.RDD), + [R](api/R/index.html)) and pair RDD functions doc ([Scala](api/scala/index.html#org.apache.spark.rdd.PairRDDFunctions), [Java](api/java/index.html?org/apache/spark/api/java/JavaPairRDD.html)) @@ -1029,7 +1028,9 @@ The following table lists some of the common actions supported by Spark. Refer t RDD API doc ([Scala](api/scala/index.html#org.apache.spark.rdd.RDD), [Java](api/java/index.html?org/apache/spark/api/java/JavaRDD.html), - [Python](api/python/pyspark.html#pyspark.RDD)) + [Python](api/python/pyspark.html#pyspark.RDD), + [R](api/R/index.html)) + and pair RDD functions doc ([Scala](api/scala/index.html#org.apache.spark.rdd.PairRDDFunctions), [Java](api/java/index.html?org/apache/spark/api/java/JavaPairRDD.html)) @@ -1071,7 +1072,7 @@ for details. saveAsSequenceFile(path)
    (Java and Scala) - Write the elements of the dataset as a Hadoop SequenceFile in a given path in the local filesystem, HDFS or any other Hadoop-supported file system. This is available on RDDs of key-value pairs that either implement Hadoop's Writable interface. In Scala, it is also + Write the elements of the dataset as a Hadoop SequenceFile in a given path in the local filesystem, HDFS or any other Hadoop-supported file system. This is available on RDDs of key-value pairs that implement Hadoop's Writable interface. In Scala, it is also available on types that are implicitly convertible to Writable (Spark includes conversions for basic types like Int, Double, String, etc). @@ -1122,7 +1123,7 @@ ordered data following shuffle then it's possible to use: * `sortBy` to make a globally ordered RDD Operations which can cause a shuffle include **repartition** operations like -[`repartition`](#RepartitionLink), and [`coalesce`](#CoalesceLink), **'ByKey** operations +[`repartition`](#RepartitionLink) and [`coalesce`](#CoalesceLink), **'ByKey** operations (except for counting) like [`groupByKey`](#GroupByLink) and [`reduceByKey`](#ReduceByLink), and **join** operations like [`cogroup`](#CogroupLink) and [`join`](#JoinLink). @@ -1138,14 +1139,16 @@ read the relevant sorted blocks. Certain shuffle operations can consume significant amounts of heap memory since they employ in-memory data structures to organize records before or after transferring them. Specifically, -`reduceByKey` and `aggregateByKey` create these structures on the map side and `'ByKey` operations +`reduceByKey` and `aggregateByKey` create these structures on the map side, and `'ByKey` operations generate these on the reduce side. When data does not fit in memory Spark will spill these tables to disk, incurring the additional overhead of disk I/O and increased garbage collection. Shuffle also generates a large number of intermediate files on disk. As of Spark 1.3, these files -are not cleaned up from Spark's temporary storage until Spark is stopped, which means that -long-running Spark jobs may consume available disk space. This is done so the shuffle doesn't need -to be re-computed if the lineage is re-computed. The temporary storage directory is specified by the +are preserved until the corresponding RDDs are no longer used and are garbage collected. +This is done so the shuffle files don't need to be re-created if the lineage is re-computed. +Garbage collection may happen only after a long period time, if the application retains references +to these RDDs or if GC does not kick in frequently. This means that long-running Spark jobs may +consume a large amount of disk space. The temporary storage directory is specified by the `spark.local.dir` configuration parameter when configuring the Spark context. Shuffle behavior can be tuned by adjusting a variety of configuration parameters. See the @@ -1213,9 +1216,11 @@ storage levels is: Compared to MEMORY_ONLY_SER, OFF_HEAP reduces garbage collection overhead and allows executors to be smaller and to share a pool of memory, making it attractive in environments with large heaps or multiple concurrent applications. Furthermore, as the RDDs reside in Tachyon, - the crash of an executor does not lead to losing the in-memory cache. In this mode, the memory + the crash of an executor does not lead to losing the in-memory cache. In this mode, the memory in Tachyon is discardable. Thus, Tachyon does not attempt to reconstruct a block that it evicts - from memory. + from memory. If you plan to use Tachyon as the off heap store, Spark is compatible with Tachyon + out-of-the-box. Please refer to this page + for the suggested version pairings. @@ -1566,7 +1571,8 @@ You can see some [example Spark programs](http://spark.apache.org/examples.html) In addition, Spark includes several samples in the `examples` directory ([Scala]({{site.SPARK_GITHUB_URL}}/tree/master/examples/src/main/scala/org/apache/spark/examples), [Java]({{site.SPARK_GITHUB_URL}}/tree/master/examples/src/main/java/org/apache/spark/examples), - [Python]({{site.SPARK_GITHUB_URL}}/tree/master/examples/src/main/python)). + [Python]({{site.SPARK_GITHUB_URL}}/tree/master/examples/src/main/python), + [R]({{site.SPARK_GITHUB_URL}}/tree/master/examples/src/main/r)). You can run Java and Scala examples by passing the class name to Spark's `bin/run-example` script; for instance: ./bin/run-example SparkPi @@ -1575,6 +1581,10 @@ For Python examples, use `spark-submit` instead: ./bin/spark-submit examples/src/main/python/pi.py +For R examples, use `spark-submit` instead: + + ./bin/spark-submit examples/src/main/r/dataframe.R + For help on optimizing your programs, the [configuration](configuration.html) and [tuning](tuning.html) guides provide information on best practices. They are especially important for making sure that your data is stored in memory in an efficient format. @@ -1582,4 +1592,4 @@ For help on deploying, the [cluster mode overview](cluster-overview.html) descri in distributed operation and supported cluster managers. Finally, full API documentation is available in -[Scala](api/scala/#org.apache.spark.package), [Java](api/java/) and [Python](api/python/). +[Scala](api/scala/#org.apache.spark.package), [Java](api/java/), [Python](api/python/) and [R](api/R/). diff --git a/docs/quick-start.md b/docs/quick-start.md index 81143da865cf..bb39e4111f24 100644 --- a/docs/quick-start.md +++ b/docs/quick-start.md @@ -184,10 +184,10 @@ scala> linesWithSpark.cache() res7: spark.RDD[String] = spark.FilteredRDD@17e51082 scala> linesWithSpark.count() -res8: Long = 15 +res8: Long = 19 scala> linesWithSpark.count() -res9: Long = 15 +res9: Long = 19 {% endhighlight %} It may seem silly to use Spark to explore and cache a 100-line text file. The interesting part is @@ -202,10 +202,10 @@ a cluster, as described in the [programming guide](programming-guide.html#initia >>> linesWithSpark.cache() >>> linesWithSpark.count() -15 +19 >>> linesWithSpark.count() -15 +19 {% endhighlight %} It may seem silly to use Spark to explore and cache a 100-line text file. The interesting part is @@ -423,14 +423,14 @@ dependencies to `spark-submit` through its `--py-files` argument by packaging th We can run this application using the `bin/spark-submit` script: -{% highlight python %} +{% highlight bash %} # Use spark-submit to run your application $ YOUR_SPARK_HOME/bin/spark-submit \ --master local[4] \ SimpleApp.py ... Lines with a: 46, Lines with b: 23 -{% endhighlight python %} +{% endhighlight %} @@ -444,7 +444,8 @@ Congratulations on running your first Spark application! * Finally, Spark includes several samples in the `examples` directory ([Scala]({{site.SPARK_GITHUB_URL}}/tree/master/examples/src/main/scala/org/apache/spark/examples), [Java]({{site.SPARK_GITHUB_URL}}/tree/master/examples/src/main/java/org/apache/spark/examples), - [Python]({{site.SPARK_GITHUB_URL}}/tree/master/examples/src/main/python)). + [Python]({{site.SPARK_GITHUB_URL}}/tree/master/examples/src/main/python), + [R]({{site.SPARK_GITHUB_URL}}/tree/master/examples/src/main/r)). You can run them as follows: {% highlight bash %} @@ -453,4 +454,7 @@ You can run them as follows: # For Python examples, use spark-submit directly: ./bin/spark-submit examples/src/main/python/pi.py + +# For R examples, use spark-submit directly: +./bin/spark-submit examples/src/main/r/dataframe.R {% endhighlight %} diff --git a/docs/running-on-mesos.md b/docs/running-on-mesos.md index 5f1d6daeb27f..1f915d8ea1d7 100644 --- a/docs/running-on-mesos.md +++ b/docs/running-on-mesos.md @@ -184,6 +184,14 @@ acquire. By default, it will acquire *all* cores in the cluster (that get offere only makes sense if you run just one application at a time. You can cap the maximum number of cores using `conf.set("spark.cores.max", "10")` (for example). +You may also make use of `spark.mesos.constraints` to set attribute based constraints on mesos resource offers. By default, all resource offers will be accepted. + +{% highlight scala %} +conf.set("spark.mesos.constraints", "tachyon=true;us-east-1=false") +{% endhighlight %} + +For example, Let's say `spark.mesos.constraints` is set to `tachyon=true;us-east-1=false`, then the resource offers will be checked to see if they meet both these constraints and only then will be accepted to start new executors. + # Mesos Docker Support Spark can make use of a Mesos Docker containerizer by setting the property `spark.mesos.executor.docker.image` @@ -298,6 +306,20 @@ See the [configuration page](configuration.html) for information on Spark config the final overhead will be this value. + + spark.mesos.constraints + Attribute based constraints to be matched against when accepting resource offers. + + Attribute based constraints on mesos resource offers. By default, all resource offers will be accepted. Refer to Mesos Attributes & Resources for more information on attributes. +
      +
    • Scalar constraints are matched with "less than equal" semantics i.e. value in the constraint must be less than or equal to the value in the resource offer.
    • +
    • Range constraints are matched with "contains" semantics i.e. value in the constraint must be within the resource offer's value.
    • +
    • Set constraints are matched with "subset of" semantics i.e. value in the constraint must be a subset of the resource offer's value.
    • +
    • Text constraints are metched with "equality" semantics i.e. value in the constraint must be exactly equal to the resource offer's value.
    • +
    • In case there is no value present as a part of the constraint any offer with the corresponding attribute will be accepted (without value check).
    • +
    + + # Troubleshooting and Debugging diff --git a/docs/running-on-yarn.md b/docs/running-on-yarn.md index 51c133916502..de22ab557cac 100644 --- a/docs/running-on-yarn.md +++ b/docs/running-on-yarn.md @@ -7,6 +7,51 @@ Support for running on [YARN (Hadoop NextGen)](http://hadoop.apache.org/docs/stable/hadoop-yarn/hadoop-yarn-site/YARN.html) was added to Spark in version 0.6.0, and improved in subsequent releases. +# Launching Spark on YARN + +Ensure that `HADOOP_CONF_DIR` or `YARN_CONF_DIR` points to the directory which contains the (client side) configuration files for the Hadoop cluster. +These configs are used to write to HDFS and connect to the YARN ResourceManager. The +configuration contained in this directory will be distributed to the YARN cluster so that all +containers used by the application use the same configuration. If the configuration references +Java system properties or environment variables not managed by YARN, they should also be set in the +Spark application's configuration (driver, executors, and the AM when running in client mode). + +There are two deploy modes that can be used to launch Spark applications on YARN. In `yarn-cluster` mode, the Spark driver runs inside an application master process which is managed by YARN on the cluster, and the client can go away after initiating the application. In `yarn-client` mode, the driver runs in the client process, and the application master is only used for requesting resources from YARN. + +Unlike in Spark standalone and Mesos mode, in which the master's address is specified in the `--master` parameter, in YARN mode the ResourceManager's address is picked up from the Hadoop configuration. Thus, the `--master` parameter is `yarn-client` or `yarn-cluster`. +To launch a Spark application in `yarn-cluster` mode: + + `$ ./bin/spark-submit --class path.to.your.Class --master yarn-cluster [options] [app options]` + +For example: + + $ ./bin/spark-submit --class org.apache.spark.examples.SparkPi \ + --master yarn-cluster \ + --num-executors 3 \ + --driver-memory 4g \ + --executor-memory 2g \ + --executor-cores 1 \ + --queue thequeue \ + lib/spark-examples*.jar \ + 10 + +The above starts a YARN client program which starts the default Application Master. Then SparkPi will be run as a child thread of Application Master. The client will periodically poll the Application Master for status updates and display them in the console. The client will exit once your application has finished running. Refer to the "Debugging your Application" section below for how to see driver and executor logs. + +To launch a Spark application in `yarn-client` mode, do the same, but replace `yarn-cluster` with `yarn-client`. To run spark-shell: + + $ ./bin/spark-shell --master yarn-client + +## Adding Other JARs + +In `yarn-cluster` mode, the driver runs on a different machine than the client, so `SparkContext.addJar` won't work out of the box with files that are local to the client. To make files on the client available to `SparkContext.addJar`, include them with the `--jars` option in the launch command. + + $ ./bin/spark-submit --class my.main.Class \ + --master yarn-cluster \ + --jars my-other-jar.jar,my-other-other-jar.jar + my-main-jar.jar + app_arg1 app_arg2 + + # Preparations Running Spark-on-YARN requires a binary distribution of Spark which is built with YARN support. @@ -17,6 +62,38 @@ To build Spark yourself, refer to [Building Spark](building-spark.html). Most of the configs are the same for Spark on YARN as for other deployment modes. See the [configuration page](configuration.html) for more information on those. These are configs that are specific to Spark on YARN. +# Debugging your Application + +In YARN terminology, executors and application masters run inside "containers". YARN has two modes for handling container logs after an application has completed. If log aggregation is turned on (with the `yarn.log-aggregation-enable` config), container logs are copied to HDFS and deleted on the local machine. These logs can be viewed from anywhere on the cluster with the "yarn logs" command. + + yarn logs -applicationId + +will print out the contents of all log files from all containers from the given application. You can also view the container log files directly in HDFS using the HDFS shell or API. The directory where they are located can be found by looking at your YARN configs (`yarn.nodemanager.remote-app-log-dir` and `yarn.nodemanager.remote-app-log-dir-suffix`). + +When log aggregation isn't turned on, logs are retained locally on each machine under `YARN_APP_LOGS_DIR`, which is usually configured to `/tmp/logs` or `$HADOOP_HOME/logs/userlogs` depending on the Hadoop version and installation. Viewing logs for a container requires going to the host that contains them and looking in this directory. Subdirectories organize log files by application ID and container ID. + +To review per-container launch environment, increase `yarn.nodemanager.delete.debug-delay-sec` to a +large value (e.g. 36000), and then access the application cache through `yarn.nodemanager.local-dirs` +on the nodes on which containers are launched. This directory contains the launch script, JARs, and +all environment variables used for launching each container. This process is useful for debugging +classpath problems in particular. (Note that enabling this requires admin privileges on cluster +settings and a restart of all node managers. Thus, this is not applicable to hosted clusters). + +To use a custom log4j configuration for the application master or executors, there are two options: + +- upload a custom `log4j.properties` using `spark-submit`, by adding it to the `--files` list of files + to be uploaded with the application. +- add `-Dlog4j.configuration=` to `spark.driver.extraJavaOptions` + (for the driver) or `spark.executor.extraJavaOptions` (for executors). Note that if using a file, + the `file:` protocol should be explicitly provided, and the file needs to exist locally on all + the nodes. + +Note that for the first option, both executors and the application master will share the same +log4j configuration, which may cause issues when they run on the same node (e.g. trying to write +to the same log file). + +If you need a reference to the proper location to put log files in the YARN so that YARN can properly display and aggregate them, use `spark.yarn.app.container.log.dir` in your log4j.properties. For example, `log4j.appender.file_appender.File=${spark.yarn.app.container.log.dir}/spark.log`. For streaming application, configuring `RollingFileAppender` and setting file location to YARN's log directory will avoid disk overflow caused by large log file, and logs can be accessed using YARN's log utility. + #### Spark Properties @@ -50,8 +127,8 @@ Most of the configs are the same for Spark on YARN as for other deployment modes @@ -71,9 +148,22 @@ Most of the configs are the same for Spark on YARN as for other deployment modes - + + + + + + @@ -176,8 +266,8 @@ Most of the configs are the same for Spark on YARN as for other deployment modes @@ -193,7 +283,7 @@ Most of the configs are the same for Spark on YARN as for other deployment modes @@ -229,85 +319,50 @@ Most of the configs are the same for Spark on YARN as for other deployment modes running against earlier versions, this property will be ignored. + + + + + + + + + + + + + + + + + + + +
    spark.yarn.am.waitTime 100s - In yarn-cluster mode, time for the application master to wait for the - SparkContext to be initialized. In yarn-client mode, time for the application master to wait + In `yarn-cluster` mode, time for the application master to wait for the + SparkContext to be initialized. In `yarn-client` mode, time for the application master to wait for the driver to connect to it.
    spark.yarn.scheduler.heartbeat.interval-ms50003000 The interval in ms in which the Spark application master heartbeats into the YARN ResourceManager. + The value is capped at half the value of YARN's configuration for the expiry interval + (yarn.am.liveness-monitor.expiry-interval-ms). +
    spark.yarn.scheduler.initial-allocation.interval200ms + The initial interval in which the Spark application master eagerly heartbeats to the YARN ResourceManager + when there are pending container allocation requests. It should be no larger than + spark.yarn.scheduler.heartbeat.interval-ms. The allocation interval will doubled on + successive eager heartbeats if pending containers still exist, until + spark.yarn.scheduler.heartbeat.interval-ms is reached.
    Add the environment variable specified by EnvironmentVariableName to the Application Master process launched on YARN. The user can specify multiple of - these and to set multiple environment variables. In yarn-cluster mode this controls - the environment of the SPARK driver and in yarn-client mode it only controls + these and to set multiple environment variables. In `yarn-cluster` mode this controls + the environment of the SPARK driver and in `yarn-client` mode it only controls the environment of the executor launcher.
    (none) A string of extra JVM options to pass to the YARN Application Master in client mode. - In cluster mode, use spark.driver.extraJavaOptions instead. + In cluster mode, use `spark.driver.extraJavaOptions` instead.
    spark.yarn.keytab(none) + The full path to the file that contains the keytab for the principal specified above. + This keytab will be copied to the node running the Application Master via the Secure Distributed Cache, + for renewing the login tickets and the delegation tokens periodically. +
    spark.yarn.principal(none) + Principal to be used to login to KDC, while running on secure HDFS. +
    spark.yarn.config.gatewayPath(none) + A path that is valid on the gateway host (the host where a Spark application is started) but may + differ for paths for the same resource in other nodes in the cluster. Coupled with + spark.yarn.config.replacementPath, this is used to support clusters with + heterogeneous configurations, so that Spark can correctly launch remote processes. +

    + The replacement path normally will contain a reference to some environment variable exported by + YARN (and, thus, visible to Spark containers). +

    + For example, if the gateway node has Hadoop libraries installed on /disk1/hadoop, and + the location of the Hadoop install is exported by YARN as the HADOOP_HOME + environment variable, setting this value to /disk1/hadoop and the replacement path to + $HADOOP_HOME will make sure that paths used to launch remote processes properly + reference the local YARN configuration. +

    spark.yarn.config.replacementPath(none) + See spark.yarn.config.gatewayPath. +
    -# Launching Spark on YARN - -Ensure that `HADOOP_CONF_DIR` or `YARN_CONF_DIR` points to the directory which contains the (client side) configuration files for the Hadoop cluster. -These configs are used to write to the dfs and connect to the YARN ResourceManager. The -configuration contained in this directory will be distributed to the YARN cluster so that all -containers used by the application use the same configuration. If the configuration references -Java system properties or environment variables not managed by YARN, they should also be set in the -Spark application's configuration (driver, executors, and the AM when running in client mode). - -There are two deploy modes that can be used to launch Spark applications on YARN. In yarn-cluster mode, the Spark driver runs inside an application master process which is managed by YARN on the cluster, and the client can go away after initiating the application. In yarn-client mode, the driver runs in the client process, and the application master is only used for requesting resources from YARN. - -Unlike in Spark standalone and Mesos mode, in which the master's address is specified in the "master" parameter, in YARN mode the ResourceManager's address is picked up from the Hadoop configuration. Thus, the master parameter is simply "yarn-client" or "yarn-cluster". - -To launch a Spark application in yarn-cluster mode: - - ./bin/spark-submit --class path.to.your.Class --master yarn-cluster [options] [app options] - -For example: - - $ ./bin/spark-submit --class org.apache.spark.examples.SparkPi \ - --master yarn-cluster \ - --num-executors 3 \ - --driver-memory 4g \ - --executor-memory 2g \ - --executor-cores 1 \ - --queue thequeue \ - lib/spark-examples*.jar \ - 10 - -The above starts a YARN client program which starts the default Application Master. Then SparkPi will be run as a child thread of Application Master. The client will periodically poll the Application Master for status updates and display them in the console. The client will exit once your application has finished running. Refer to the "Debugging your Application" section below for how to see driver and executor logs. - -To launch a Spark application in yarn-client mode, do the same, but replace "yarn-cluster" with "yarn-client". To run spark-shell: - - $ ./bin/spark-shell --master yarn-client - -## Adding Other JARs - -In yarn-cluster mode, the driver runs on a different machine than the client, so `SparkContext.addJar` won't work out of the box with files that are local to the client. To make files on the client available to `SparkContext.addJar`, include them with the `--jars` option in the launch command. - - $ ./bin/spark-submit --class my.main.Class \ - --master yarn-cluster \ - --jars my-other-jar.jar,my-other-other-jar.jar - my-main-jar.jar - app_arg1 app_arg2 - -# Debugging your Application - -In YARN terminology, executors and application masters run inside "containers". YARN has two modes for handling container logs after an application has completed. If log aggregation is turned on (with the `yarn.log-aggregation-enable` config), container logs are copied to HDFS and deleted on the local machine. These logs can be viewed from anywhere on the cluster with the "yarn logs" command. - - yarn logs -applicationId - -will print out the contents of all log files from all containers from the given application. You can also view the container log files directly in HDFS using the HDFS shell or API. The directory where they are located can be found by looking at your YARN configs (`yarn.nodemanager.remote-app-log-dir` and `yarn.nodemanager.remote-app-log-dir-suffix`). - -When log aggregation isn't turned on, logs are retained locally on each machine under `YARN_APP_LOGS_DIR`, which is usually configured to `/tmp/logs` or `$HADOOP_HOME/logs/userlogs` depending on the Hadoop version and installation. Viewing logs for a container requires going to the host that contains them and looking in this directory. Subdirectories organize log files by application ID and container ID. - -To review per-container launch environment, increase `yarn.nodemanager.delete.debug-delay-sec` to a -large value (e.g. 36000), and then access the application cache through `yarn.nodemanager.local-dirs` -on the nodes on which containers are launched. This directory contains the launch script, JARs, and -all environment variables used for launching each container. This process is useful for debugging -classpath problems in particular. (Note that enabling this requires admin privileges on cluster -settings and a restart of all node managers. Thus, this is not applicable to hosted clusters). - -To use a custom log4j configuration for the application master or executors, there are two options: - -- upload a custom `log4j.properties` using `spark-submit`, by adding it to the `--files` list of files - to be uploaded with the application. -- add `-Dlog4j.configuration=` to `spark.driver.extraJavaOptions` - (for the driver) or `spark.executor.extraJavaOptions` (for executors). Note that if using a file, - the `file:` protocol should be explicitly provided, and the file needs to exist locally on all - the nodes. - -Note that for the first option, both executors and the application master will share the same -log4j configuration, which may cause issues when they run on the same node (e.g. trying to write -to the same log file). - -If you need a reference to the proper location to put log files in the YARN so that YARN can properly display and aggregate them, use `spark.yarn.app.container.log.dir` in your log4j.properties. For example, `log4j.appender.file_appender.File=${spark.yarn.app.container.log.dir}/spark.log`. For streaming application, configuring `RollingFileAppender` and setting file location to YARN's log directory will avoid disk overflow caused by large log file, and logs can be accessed using YARN's log utility. - # Important notes - Whether core requests are honored in scheduling decisions depends on which scheduler is in use and how it is configured. diff --git a/docs/spark-standalone.md b/docs/spark-standalone.md index 0eed9adacf12..4f71fbc086cd 100644 --- a/docs/spark-standalone.md +++ b/docs/spark-standalone.md @@ -24,7 +24,7 @@ the master's web UI, which is [http://localhost:8080](http://localhost:8080) by Similarly, you can start one or more workers and connect them to the master via: - ./sbin/start-slave.sh + ./sbin/start-slave.sh Once you have started a worker, look at the master's web UI ([http://localhost:8080](http://localhost:8080) by default). You should see the new node listed there, along with its number of CPUs and memory (minus one gigabyte left for the OS). @@ -77,7 +77,7 @@ Note, the master machine accesses each of the worker machines via ssh. By defaul If you do not have a password-less setup, you can set the environment variable SPARK_SSH_FOREGROUND and serially provide a password for each worker. -Once you've set up this file, you can launch or stop your cluster with the following shell scripts, based on Hadoop's deploy scripts, and available in `SPARK_HOME/bin`: +Once you've set up this file, you can launch or stop your cluster with the following shell scripts, based on Hadoop's deploy scripts, and available in `SPARK_HOME/sbin`: - `sbin/start-master.sh` - Starts a master instance on the machine the script is executed on. - `sbin/start-slaves.sh` - Starts a slave instance on each machine specified in the `conf/slaves` file. diff --git a/docs/sparkr.md b/docs/sparkr.md new file mode 100644 index 000000000000..095ea4308cfe --- /dev/null +++ b/docs/sparkr.md @@ -0,0 +1,232 @@ +--- +layout: global +displayTitle: SparkR (R on Spark) +title: SparkR (R on Spark) +--- + +* This will become a table of contents (this text will be scraped). +{:toc} + +# Overview +SparkR is an R package that provides a light-weight frontend to use Apache Spark from R. +In Spark {{site.SPARK_VERSION}}, SparkR provides a distributed data frame implementation that +supports operations like selection, filtering, aggregation etc. (similar to R data frames, +[dplyr](https://github.com/hadley/dplyr)) but on large datasets. + +# SparkR DataFrames + +A DataFrame is a distributed collection of data organized into named columns. It is conceptually +equivalent to a table in a relational database or a data frame in R, but with richer +optimizations under the hood. DataFrames can be constructed from a wide array of sources such as: +structured data files, tables in Hive, external databases, or existing local R data frames. + +All of the examples on this page use sample data included in R or the Spark distribution and can be run using the `./bin/sparkR` shell. + +## Starting Up: SparkContext, SQLContext + +
    +The entry point into SparkR is the `SparkContext` which connects your R program to a Spark cluster. +You can create a `SparkContext` using `sparkR.init` and pass in options such as the application name +, any spark packages depended on, etc. Further, to work with DataFrames we will need a `SQLContext`, +which can be created from the SparkContext. If you are working from the SparkR shell, the +`SQLContext` and `SparkContext` should already be created for you. + +{% highlight r %} +sc <- sparkR.init() +sqlContext <- sparkRSQL.init(sc) +{% endhighlight %} + +
    + +## Creating DataFrames +With a `SQLContext`, applications can create `DataFrame`s from a local R data frame, from a [Hive table](sql-programming-guide.html#hive-tables), or from other [data sources](sql-programming-guide.html#data-sources). + +### From local data frames +The simplest way to create a data frame is to convert a local R data frame into a SparkR DataFrame. Specifically we can use `createDataFrame` and pass in the local R data frame to create a SparkR DataFrame. As an example, the following creates a `DataFrame` based using the `faithful` dataset from R. + +
    +{% highlight r %} +df <- createDataFrame(sqlContext, faithful) + +# Displays the content of the DataFrame to stdout +head(df) +## eruptions waiting +##1 3.600 79 +##2 1.800 54 +##3 3.333 74 + +{% endhighlight %} +
    + +### From Data Sources + +SparkR supports operating on a variety of data sources through the `DataFrame` interface. This section describes the general methods for loading and saving data using Data Sources. You can check the Spark SQL programming guide for more [specific options](sql-programming-guide.html#manually-specifying-options) that are available for the built-in data sources. + +The general method for creating DataFrames from data sources is `read.df`. This method takes in the `SQLContext`, the path for the file to load and the type of data source. SparkR supports reading JSON and Parquet files natively and through [Spark Packages](http://spark-packages.org/) you can find data source connectors for popular file formats like [CSV](http://spark-packages.org/package/databricks/spark-csv) and [Avro](http://spark-packages.org/package/databricks/spark-avro). These packages can either be added by +specifying `--packages` with `spark-submit` or `sparkR` commands, or if creating context through `init` +you can specify the packages with the `packages` argument. + +
    +{% highlight r %} +sc <- sparkR.init(packages="com.databricks:spark-csv_2.11:1.0.3") +sqlContext <- sparkRSQL.init(sc) +{% endhighlight %} +
    + +We can see how to use data sources using an example JSON input file. Note that the file that is used here is _not_ a typical JSON file. Each line in the file must contain a separate, self-contained valid JSON object. As a consequence, a regular multi-line JSON file will most often fail. + +
    + +{% highlight r %} +people <- read.df(sqlContext, "./examples/src/main/resources/people.json", "json") +head(people) +## age name +##1 NA Michael +##2 30 Andy +##3 19 Justin + +# SparkR automatically infers the schema from the JSON file +printSchema(people) +# root +# |-- age: integer (nullable = true) +# |-- name: string (nullable = true) + +{% endhighlight %} +
    + +The data sources API can also be used to save out DataFrames into multiple file formats. For example we can save the DataFrame from the previous example +to a Parquet file using `write.df` + +
    +{% highlight r %} +write.df(people, path="people.parquet", source="parquet", mode="overwrite") +{% endhighlight %} +
    + +### From Hive tables + +You can also create SparkR DataFrames from Hive tables. To do this we will need to create a HiveContext which can access tables in the Hive MetaStore. Note that Spark should have been built with [Hive support](building-spark.html#building-with-hive-and-jdbc-support) and more details on the difference between SQLContext and HiveContext can be found in the [SQL programming guide](sql-programming-guide.html#starting-point-sqlcontext). + +
    +{% highlight r %} +# sc is an existing SparkContext. +hiveContext <- sparkRHive.init(sc) + +sql(hiveContext, "CREATE TABLE IF NOT EXISTS src (key INT, value STRING)") +sql(hiveContext, "LOAD DATA LOCAL INPATH 'examples/src/main/resources/kv1.txt' INTO TABLE src") + +# Queries can be expressed in HiveQL. +results <- hiveContext.sql("FROM src SELECT key, value") + +# results is now a DataFrame +head(results) +## key value +## 1 238 val_238 +## 2 86 val_86 +## 3 311 val_311 + +{% endhighlight %} +
    + +## DataFrame Operations + +SparkR DataFrames support a number of functions to do structured data processing. +Here we include some basic examples and a complete list can be found in the [API](api/R/index.html) docs: + +### Selecting rows, columns + +
    +{% highlight r %} +# Create the DataFrame +df <- createDataFrame(sqlContext, faithful) + +# Get basic information about the DataFrame +df +## DataFrame[eruptions:double, waiting:double] + +# Select only the "eruptions" column +head(select(df, df$eruptions)) +## eruptions +##1 3.600 +##2 1.800 +##3 3.333 + +# You can also pass in column name as strings +head(select(df, "eruptions")) + +# Filter the DataFrame to only retain rows with wait times shorter than 50 mins +head(filter(df, df$waiting < 50)) +## eruptions waiting +##1 1.750 47 +##2 1.750 47 +##3 1.867 48 + +{% endhighlight %} + +
    + +### Grouping, Aggregation + +SparkR data frames support a number of commonly used functions to aggregate data after grouping. For example we can compute a histogram of the `waiting` time in the `faithful` dataset as shown below + +
    +{% highlight r %} + +# We use the `n` operator to count the number of times each waiting time appears +head(summarize(groupBy(df, df$waiting), count = n(df$waiting))) +## waiting count +##1 81 13 +##2 60 6 +##3 68 1 + +# We can also sort the output from the aggregation to get the most common waiting times +waiting_counts <- summarize(groupBy(df, df$waiting), count = n(df$waiting)) +head(arrange(waiting_counts, desc(waiting_counts$count))) + +## waiting count +##1 78 15 +##2 83 14 +##3 81 13 + +{% endhighlight %} +
    + +### Operating on Columns + +SparkR also provides a number of functions that can directly applied to columns for data processing and during aggregation. The example below shows the use of basic arithmetic functions. + +
    +{% highlight r %} + +# Convert waiting time from hours to seconds. +# Note that we can assign this to a new column in the same DataFrame +df$waiting_secs <- df$waiting * 60 +head(df) +## eruptions waiting waiting_secs +##1 3.600 79 4740 +##2 1.800 54 3240 +##3 3.333 74 4440 + +{% endhighlight %} +
    + +## Running SQL Queries from SparkR +A SparkR DataFrame can also be registered as a temporary table in Spark SQL and registering a DataFrame as a table allows you to run SQL queries over its data. +The `sql` function enables applications to run SQL queries programmatically and returns the result as a `DataFrame`. + +
    +{% highlight r %} +# Load a JSON file +people <- read.df(sqlContext, "./examples/src/main/resources/people.json", "json") + +# Register this DataFrame as a table. +registerTempTable(people, "people") + +# SQL statements can be run by using the sql method +teenagers <- sql(sqlContext, "SELECT name FROM people WHERE age >= 13 AND age <= 19") +head(teenagers) +## name +##1 Justin + +{% endhighlight %} +
    diff --git a/docs/sql-programming-guide.md b/docs/sql-programming-guide.md index 78b8e8ad515a..88c96a9a095b 100644 --- a/docs/sql-programming-guide.md +++ b/docs/sql-programming-guide.md @@ -11,17 +11,18 @@ title: Spark SQL and DataFrames Spark SQL is a Spark module for structured data processing. It provides a programming abstraction called DataFrames and can also act as distributed SQL query engine. +For how to enable Hive support, please refer to the [Hive Tables](#hive-tables) section. # DataFrames A DataFrame is a distributed collection of data organized into named columns. It is conceptually equivalent to a table in a relational database or a data frame in R/Python, but with richer optimizations under the hood. DataFrames can be constructed from a wide array of sources such as: structured data files, tables in Hive, external databases, or existing RDDs. -The DataFrame API is available in [Scala](api/scala/index.html#org.apache.spark.sql.DataFrame), [Java](api/java/index.html?org/apache/spark/sql/DataFrame.html), and [Python](api/python/pyspark.sql.html#pyspark.sql.DataFrame). +The DataFrame API is available in [Scala](api/scala/index.html#org.apache.spark.sql.DataFrame), [Java](api/java/index.html?org/apache/spark/sql/DataFrame.html), [Python](api/python/pyspark.sql.html#pyspark.sql.DataFrame), and [R](api/R/index.html). -All of the examples on this page use sample data included in the Spark distribution and can be run in the `spark-shell` or the `pyspark` shell. +All of the examples on this page use sample data included in the Spark distribution and can be run in the `spark-shell`, `pyspark` shell, or `sparkR` shell. -## Starting Point: `SQLContext` +## Starting Point: SQLContext
    @@ -64,6 +65,17 @@ from pyspark.sql import SQLContext sqlContext = SQLContext(sc) {% endhighlight %} +
    + +
    + +The entry point into all relational functionality in Spark is the +`SQLContext` class, or one of its decedents. To create a basic `SQLContext`, all you need is a SparkContext. + +{% highlight r %} +sqlContext <- sparkRSQL.init(sc) +{% endhighlight %} +
    @@ -97,7 +109,7 @@ As an example, the following creates a `DataFrame` based on the content of a JSO val sc: SparkContext // An existing SparkContext. val sqlContext = new org.apache.spark.sql.SQLContext(sc) -val df = sqlContext.jsonFile("examples/src/main/resources/people.json") +val df = sqlContext.read.json("examples/src/main/resources/people.json") // Displays the content of the DataFrame to stdout df.show() @@ -110,7 +122,7 @@ df.show() JavaSparkContext sc = ...; // An existing JavaSparkContext. SQLContext sqlContext = new org.apache.spark.sql.SQLContext(sc); -DataFrame df = sqlContext.jsonFile("examples/src/main/resources/people.json"); +DataFrame df = sqlContext.read().json("examples/src/main/resources/people.json"); // Displays the content of the DataFrame to stdout df.show(); @@ -123,13 +135,26 @@ df.show(); from pyspark.sql import SQLContext sqlContext = SQLContext(sc) -df = sqlContext.jsonFile("examples/src/main/resources/people.json") +df = sqlContext.read.json("examples/src/main/resources/people.json") # Displays the content of the DataFrame to stdout df.show() {% endhighlight %} + +
    +{% highlight r %} +sqlContext <- SQLContext(sc) + +df <- jsonFile(sqlContext, "examples/src/main/resources/people.json") + +# Displays the content of the DataFrame to stdout +showDF(df) +{% endhighlight %} + +
    + @@ -146,7 +171,7 @@ val sc: SparkContext // An existing SparkContext. val sqlContext = new org.apache.spark.sql.SQLContext(sc) // Create the DataFrame -val df = sqlContext.jsonFile("examples/src/main/resources/people.json") +val df = sqlContext.read.json("examples/src/main/resources/people.json") // Show the content of the DataFrame df.show() @@ -196,7 +221,7 @@ JavaSparkContext sc // An existing SparkContext. SQLContext sqlContext = new org.apache.spark.sql.SQLContext(sc) // Create the DataFrame -DataFrame df = sqlContext.jsonFile("examples/src/main/resources/people.json"); +DataFrame df = sqlContext.read().json("examples/src/main/resources/people.json"); // Show the content of the DataFrame df.show(); @@ -252,7 +277,7 @@ from pyspark.sql import SQLContext sqlContext = SQLContext(sc) # Create the DataFrame -df = sqlContext.jsonFile("examples/src/main/resources/people.json") +df = sqlContext.read.json("examples/src/main/resources/people.json") # Show the content of the DataFrame df.show() @@ -296,6 +321,57 @@ df.groupBy("age").count().show() {% endhighlight %} + +
    +{% highlight r %} +sqlContext <- sparkRSQL.init(sc) + +# Create the DataFrame +df <- jsonFile(sqlContext, "examples/src/main/resources/people.json") + +# Show the content of the DataFrame +showDF(df) +## age name +## null Michael +## 30 Andy +## 19 Justin + +# Print the schema in a tree format +printSchema(df) +## root +## |-- age: long (nullable = true) +## |-- name: string (nullable = true) + +# Select only the "name" column +showDF(select(df, "name")) +## name +## Michael +## Andy +## Justin + +# Select everybody, but increment the age by 1 +showDF(select(df, df$name, df$age + 1)) +## name (age + 1) +## Michael null +## Andy 31 +## Justin 20 + +# Select people older than 21 +showDF(where(df, df$age > 21)) +## age name +## 30 Andy + +# Count people by age +showDF(count(groupBy(df, "age"))) +## age count +## null 1 +## 19 1 +## 30 1 + +{% endhighlight %} + +
    + @@ -325,6 +401,14 @@ sqlContext = SQLContext(sc) df = sqlContext.sql("SELECT * FROM table") {% endhighlight %} + +
    +{% highlight r %} +sqlContext <- sparkRSQL.init(sc) +df <- sql(sqlContext, "SELECT * FROM table") +{% endhighlight %} +
    + @@ -693,8 +777,8 @@ In the simplest form, the default data source (`parquet` unless otherwise config
    {% highlight scala %} -val df = sqlContext.load("examples/src/main/resources/users.parquet") -df.select("name", "favorite_color").save("namesAndFavColors.parquet") +val df = sqlContext.read.load("examples/src/main/resources/users.parquet") +df.select("name", "favorite_color").write.save("namesAndFavColors.parquet") {% endhighlight %}
    @@ -703,8 +787,8 @@ df.select("name", "favorite_color").save("namesAndFavColors.parquet") {% highlight java %} -DataFrame df = sqlContext.load("examples/src/main/resources/users.parquet"); -df.select("name", "favorite_color").save("namesAndFavColors.parquet"); +DataFrame df = sqlContext.read().load("examples/src/main/resources/users.parquet"); +df.select("name", "favorite_color").write().save("namesAndFavColors.parquet"); {% endhighlight %} @@ -714,11 +798,20 @@ df.select("name", "favorite_color").save("namesAndFavColors.parquet"); {% highlight python %} -df = sqlContext.load("examples/src/main/resources/users.parquet") -df.select("name", "favorite_color").save("namesAndFavColors.parquet") +df = sqlContext.read.load("examples/src/main/resources/users.parquet") +df.select("name", "favorite_color").write.save("namesAndFavColors.parquet") {% endhighlight %} + + +
    + +{% highlight r %} +df <- loadDF(sqlContext, "people.parquet") +saveDF(select(df, "name", "age"), "namesAndAges.parquet") +{% endhighlight %} +
    @@ -726,16 +819,16 @@ df.select("name", "favorite_color").save("namesAndFavColors.parquet") You can also manually specify the data source that will be used along with any extra options that you would like to pass to the data source. Data sources are specified by their fully qualified -name (i.e., `org.apache.spark.sql.parquet`), but for built-in sources you can also use the shorted -name (`json`, `parquet`, `jdbc`). DataFrames of any type can be converted into other types +name (i.e., `org.apache.spark.sql.parquet`), but for built-in sources you can also use their short +names (`json`, `parquet`, `jdbc`). DataFrames of any type can be converted into other types using this syntax.
    {% highlight scala %} -val df = sqlContext.load("examples/src/main/resources/people.json", "json") -df.select("name", "age").save("namesAndAges.parquet", "parquet") +val df = sqlContext.read.format("json").load("examples/src/main/resources/people.json") +df.select("name", "age").write.format("json").save("namesAndAges.json") {% endhighlight %}
    @@ -744,8 +837,8 @@ df.select("name", "age").save("namesAndAges.parquet", "parquet") {% highlight java %} -DataFrame df = sqlContext.load("examples/src/main/resources/people.json", "json"); -df.select("name", "age").save("namesAndAges.parquet", "parquet"); +DataFrame df = sqlContext.read().format("json").load("examples/src/main/resources/people.json"); +df.select("name", "age").write().format("parquet").save("namesAndAges.parquet"); {% endhighlight %} @@ -755,8 +848,18 @@ df.select("name", "age").save("namesAndAges.parquet", "parquet"); {% highlight python %} -df = sqlContext.load("examples/src/main/resources/people.json", "json") -df.select("name", "age").save("namesAndAges.parquet", "parquet") +df = sqlContext.read.load("examples/src/main/resources/people.json", format="json") +df.select("name", "age").write.save("namesAndAges.parquet", format="parquet") + +{% endhighlight %} + +
    +
    + +{% highlight r %} + +df <- loadDF(sqlContext, "people.json", "json") +saveDF(select(df, "name", "age"), "namesAndAges.parquet", "parquet") {% endhighlight %} @@ -804,7 +907,7 @@ new data. Ignore mode means that when saving a DataFrame to a data source, if data already exists, the save operation is expected to not save the contents of the DataFrame and to not - change the existing data. This is similar to a `CREATE TABLE IF NOT EXISTS` in SQL. + change the existing data. This is similar to a CREATE TABLE IF NOT EXISTS in SQL. @@ -844,11 +947,11 @@ import sqlContext.implicits._ val people: RDD[Person] = ... // An RDD of case class objects, from the previous example. // The RDD is implicitly converted to a DataFrame by implicits, allowing it to be stored using Parquet. -people.saveAsParquetFile("people.parquet") +people.write.parquet("people.parquet") // Read in the parquet file created above. Parquet files are self-describing so the schema is preserved. // The result of loading a Parquet file is also a DataFrame. -val parquetFile = sqlContext.parquetFile("people.parquet") +val parquetFile = sqlContext.read.parquet("people.parquet") //Parquet files can also be registered as tables and then used in SQL statements. parquetFile.registerTempTable("parquetFile") @@ -866,13 +969,13 @@ teenagers.map(t => "Name: " + t(0)).collect().foreach(println) DataFrame schemaPeople = ... // The DataFrame from the previous example. // DataFrames can be saved as Parquet files, maintaining the schema information. -schemaPeople.saveAsParquetFile("people.parquet"); +schemaPeople.write().parquet("people.parquet"); // Read in the Parquet file created above. Parquet files are self-describing so the schema is preserved. // The result of loading a parquet file is also a DataFrame. -DataFrame parquetFile = sqlContext.parquetFile("people.parquet"); +DataFrame parquetFile = sqlContext.read().parquet("people.parquet"); -//Parquet files can also be registered as tables and then used in SQL statements. +// Parquet files can also be registered as tables and then used in SQL statements. parquetFile.registerTempTable("parquetFile"); DataFrame teenagers = sqlContext.sql("SELECT name FROM parquetFile WHERE age >= 13 AND age <= 19"); List teenagerNames = teenagers.javaRDD().map(new Function() { @@ -892,11 +995,11 @@ List teenagerNames = teenagers.javaRDD().map(new Function() schemaPeople # The DataFrame from the previous example. # DataFrames can be saved as Parquet files, maintaining the schema information. -schemaPeople.saveAsParquetFile("people.parquet") +schemaPeople.write.parquet("people.parquet") # Read in the Parquet file created above. Parquet files are self-describing so the schema is preserved. # The result of loading a parquet file is also a DataFrame. -parquetFile = sqlContext.parquetFile("people.parquet") +parquetFile = sqlContext.read.parquet("people.parquet") # Parquet files can also be registered as tables and then used in SQL statements. parquetFile.registerTempTable("parquetFile"); @@ -908,6 +1011,40 @@ for teenName in teenNames.collect():
    +
    + +{% highlight r %} +# sqlContext from the previous example is used in this example. + +schemaPeople # The DataFrame from the previous example. + +# DataFrames can be saved as Parquet files, maintaining the schema information. +saveAsParquetFile(schemaPeople, "people.parquet") + +# Read in the Parquet file created above. Parquet files are self-describing so the schema is preserved. +# The result of loading a parquet file is also a DataFrame. +parquetFile <- parquetFile(sqlContext, "people.parquet") + +# Parquet files can also be registered as tables and then used in SQL statements. +registerTempTable(parquetFile, "parquetFile"); +teenagers <- sql(sqlContext, "SELECT name FROM parquetFile WHERE age >= 13 AND age <= 19") +teenNames <- map(teenagers, function(p) { paste("Name:", p$name)}) +for (teenName in collect(teenNames)) { + cat(teenName, "\n") +} +{% endhighlight %} + +
    + +
    + +{% highlight python %} +# sqlContext is an existing HiveContext +sqlContext.sql("REFRESH TABLE my_table") +{% endhighlight %} + +
    +
    {% highlight sql %} @@ -926,12 +1063,12 @@ SELECT * FROM parquetTable
    -### Partition discovery +### Partition Discovery Table partitioning is a common optimization approach used in systems like Hive. In a partitioned table, data are usually stored in different directories, with partitioning column values encoded in the path of each partition directory. The Parquet data source is now able to discover and infer -partitioning information automatically. For exmaple, we can store all our previously used +partitioning information automatically. For example, we can store all our previously used population data into a partitioned table using the following directory structure, with two extra columns, `gender` and `country` as partitioning columns: @@ -959,9 +1096,9 @@ path {% endhighlight %} -By passing `path/to/table` to either `SQLContext.parquetFile` or `SQLContext.load`, Spark SQL will -automatically extract the partitioning information from the paths. Now the schema of the returned -DataFrame becomes: +By passing `path/to/table` to either `SQLContext.read.parquet` or `SQLContext.read.load`, Spark SQL +will automatically extract the partitioning information from the paths. +Now the schema of the returned DataFrame becomes: {% highlight text %} @@ -974,9 +1111,13 @@ root {% endhighlight %} Notice that the data types of the partitioning columns are automatically inferred. Currently, -numeric data types and string type are supported. +numeric data types and string type are supported. Sometimes users may not want to automatically +infer the data types of the partitioning columns. For these use cases, the automatic type inference +can be configured by `spark.sql.sources.partitionColumnTypeInference.enabled`, which is default to +`true`. When type inference is disabled, string type will be used for the partitioning columns. -### Schema merging + +### Schema Merging Like ProtocolBuffer, Avro, and Thrift, Parquet also supports schema evolution. Users can start with a simple schema, and gradually add more columns to the schema as needed. In this way, users may end @@ -993,20 +1134,20 @@ source is now able to automatically detect this case and merge schemas of all th import sqlContext.implicits._ // Create a simple DataFrame, stored into a partition directory -val df1 = sparkContext.makeRDD(1 to 5).map(i => (i, i * 2)).toDF("single", "double") -df1.saveAsParquetFile("data/test_table/key=1") +val df1 = sc.makeRDD(1 to 5).map(i => (i, i * 2)).toDF("single", "double") +df1.write.parquet("data/test_table/key=1") // Create another DataFrame in a new partition directory, // adding a new column and dropping an existing column -val df2 = sparkContext.makeRDD(6 to 10).map(i => (i, i * 3)).toDF("single", "triple") -df2.saveAsParquetFile("data/test_table/key=2") +val df2 = sc.makeRDD(6 to 10).map(i => (i, i * 3)).toDF("single", "triple") +df2.write.parquet("data/test_table/key=2") // Read the partitioned table -val df3 = sqlContext.parquetFile("data/test_table") +val df3 = sqlContext.read.parquet("data/test_table") df3.printSchema() // The final schema consists of all 3 columns in the Parquet files together -// with the partiioning column appeared in the partition directory paths. +// with the partitioning column appeared in the partition directory paths. // root // |-- single: int (nullable = true) // |-- double: int (nullable = true) @@ -1033,11 +1174,38 @@ df2 = sqlContext.createDataFrame(sc.parallelize(range(6, 11)) df2.save("data/test_table/key=2", "parquet") # Read the partitioned table -df3 = sqlContext.parquetFile("data/test_table") +df3 = sqlContext.load("data/test_table", "parquet") df3.printSchema() # The final schema consists of all 3 columns in the Parquet files together -# with the partiioning column appeared in the partition directory paths. +# with the partitioning column appeared in the partition directory paths. +# root +# |-- single: int (nullable = true) +# |-- double: int (nullable = true) +# |-- triple: int (nullable = true) +# |-- key : int (nullable = true) +{% endhighlight %} + + + +
    + +{% highlight r %} +# sqlContext from the previous example is used in this example. + +# Create a simple DataFrame, stored into a partition directory +saveDF(df1, "data/test_table/key=1", "parquet", "overwrite") + +# Create another DataFrame in a new partition directory, +# adding a new column and dropping an existing column +saveDF(df2, "data/test_table/key=2", "parquet", "overwrite") + +# Read the partitioned table +df3 <- loadDF(sqlContext, "data/test_table", "parquet") +printSchema(df3) + +# The final schema consists of all 3 columns in the Parquet files together +# with the partitioning column appeared in the partition directory paths. # root # |-- single: int (nullable = true) # |-- double: int (nullable = true) @@ -1049,6 +1217,79 @@ df3.printSchema()
    +### Hive metastore Parquet table conversion + +When reading from and writing to Hive metastore Parquet tables, Spark SQL will try to use its own +Parquet support instead of Hive SerDe for better performance. This behavior is controlled by the +`spark.sql.hive.convertMetastoreParquet` configuration, and is turned on by default. + +#### Hive/Parquet Schema Reconciliation + +There are two key differences between Hive and Parquet from the perspective of table schema +processing. + +1. Hive is case insensitive, while Parquet is not +1. Hive considers all columns nullable, while nullability in Parquet is significant + +Due to this reason, we must reconcile Hive metastore schema with Parquet schema when converting a +Hive metastore Parquet table to a Spark SQL Parquet table. The reconciliation rules are: + +1. Fields that have the same name in both schema must have the same data type regardless of + nullability. The reconciled field should have the data type of the Parquet side, so that + nullability is respected. + +1. The reconciled schema contains exactly those fields defined in Hive metastore schema. + + - Any fields that only appear in the Parquet schema are dropped in the reconciled schema. + - Any fileds that only appear in the Hive metastore schema are added as nullable field in the + reconciled schema. + +#### Metadata Refreshing + +Spark SQL caches Parquet metadata for better performance. When Hive metastore Parquet table +conversion is enabled, metadata of those converted tables are also cached. If these tables are +updated by Hive or other external tools, you need to refresh them manually to ensure consistent +metadata. + +
    + +
    + +{% highlight scala %} +// sqlContext is an existing HiveContext +sqlContext.refreshTable("my_table") +{% endhighlight %} + +
    + +
    + +{% highlight java %} +// sqlContext is an existing HiveContext +sqlContext.refreshTable("my_table") +{% endhighlight %} + +
    + +
    + +{% highlight python %} +# sqlContext is an existing HiveContext +sqlContext.refreshTable("my_table") +{% endhighlight %} + +
    + +
    + +{% highlight sql %} +REFRESH TABLE my_table; +{% endhighlight %} + +
    + +
    + ### Configuration Configuration of Parquet can be done using the `setConf` method on `SQLContext` or by running @@ -1061,7 +1302,7 @@ Configuration of Parquet can be done using the `setConf` method on `SQLContext` false Some other Parquet-producing systems, in particular Impala and older versions of Spark SQL, do - not differentiate between binary data and strings when writing out the Parquet schema. This + not differentiate between binary data and strings when writing out the Parquet schema. This flag tells Spark SQL to interpret binary data as a string to provide compatibility with these systems. @@ -1078,7 +1319,7 @@ Configuration of Parquet can be done using the `setConf` method on `SQLContext` spark.sql.parquet.cacheMetadata true - Turns on caching of Parquet schema metadata. Can speed up querying of static data. + Turns on caching of Parquet schema metadata. Can speed up querying of static data. @@ -1094,7 +1335,7 @@ Configuration of Parquet can be done using the `setConf` method on `SQLContext` false Turn on Parquet filter pushdown optimization. This feature is turned off by default because of a known - bug in Paruet 1.6.0rc3 (PARQUET-136). + bug in Parquet 1.6.0rc3 (PARQUET-136). However, if your table doesn't contain any nullable string or binary columns, it's still safe to turn this feature on. @@ -1107,6 +1348,34 @@ Configuration of Parquet can be done using the `setConf` method on `SQLContext` support. + + spark.sql.parquet.output.committer.class + org.apache.parquet.hadoop.
    ParquetOutputCommitter
    + +

    + The output committer class used by Parquet. The specified class needs to be a subclass of + org.apache.hadoop.
    mapreduce.OutputCommitter
    . Typically, it's also a + subclass of org.apache.parquet.hadoop.ParquetOutputCommitter. +

    +

    + Note: +

      +
    • + This option must be set via Hadoop Configuration rather than Spark + SQLConf. +
    • +
    • + This option overrides spark.sql.sources.
      outputCommitterClass
      . +
    • +
    +

    +

    + Spark SQL comes with a builtin + org.apache.spark.sql.
    parquet.DirectParquetOutputCommitter
    , which can be more + efficient then the default Parquet output committer when writing data to S3. +

    + + ## JSON Datasets @@ -1114,12 +1383,10 @@ Configuration of Parquet can be done using the `setConf` method on `SQLContext`
    Spark SQL can automatically infer the schema of a JSON dataset and load it as a DataFrame. -This conversion can be done using one of two methods in a `SQLContext`: +This conversion can be done using `SQLContext.read.json()` on either an RDD of String, +or a JSON file. -* `jsonFile` - loads data from a directory of JSON files where each line of the files is a JSON object. -* `jsonRDD` - loads data from an existing RDD where each element of the RDD is a string containing a JSON object. - -Note that the file that is offered as _jsonFile_ is not a typical JSON file. Each +Note that the file that is offered as _a json file_ is not a typical JSON file. Each line must contain a separate, self-contained valid JSON object. As a consequence, a regular multi-line JSON file will most often fail. @@ -1130,8 +1397,7 @@ val sqlContext = new org.apache.spark.sql.SQLContext(sc) // A JSON dataset is pointed to by path. // The path can be either a single text file or a directory storing text files. val path = "examples/src/main/resources/people.json" -// Create a DataFrame from the file(s) pointed to by path -val people = sqlContext.jsonFile(path) +val people = sqlContext.read.json(path) // The inferred schema can be visualized using the printSchema() method. people.printSchema() @@ -1149,19 +1415,17 @@ val teenagers = sqlContext.sql("SELECT name FROM people WHERE age >= 13 AND age // an RDD[String] storing one JSON object per string. val anotherPeopleRDD = sc.parallelize( """{"name":"Yin","address":{"city":"Columbus","state":"Ohio"}}""" :: Nil) -val anotherPeople = sqlContext.jsonRDD(anotherPeopleRDD) +val anotherPeople = sqlContext.read.json(anotherPeopleRDD) {% endhighlight %}
    Spark SQL can automatically infer the schema of a JSON dataset and load it as a DataFrame. -This conversion can be done using one of two methods in a `SQLContext` : - -* `jsonFile` - loads data from a directory of JSON files where each line of the files is a JSON object. -* `jsonRDD` - loads data from an existing RDD where each element of the RDD is a string containing a JSON object. +This conversion can be done using `SQLContext.read().json()` on either an RDD of String, +or a JSON file. -Note that the file that is offered as _jsonFile_ is not a typical JSON file. Each +Note that the file that is offered as _a json file_ is not a typical JSON file. Each line must contain a separate, self-contained valid JSON object. As a consequence, a regular multi-line JSON file will most often fail. @@ -1171,9 +1435,7 @@ SQLContext sqlContext = new org.apache.spark.sql.SQLContext(sc); // A JSON dataset is pointed to by path. // The path can be either a single text file or a directory storing text files. -String path = "examples/src/main/resources/people.json"; -// Create a DataFrame from the file(s) pointed to by path -DataFrame people = sqlContext.jsonFile(path); +DataFrame people = sqlContext.read().json("examples/src/main/resources/people.json"); // The inferred schema can be visualized using the printSchema() method. people.printSchema(); @@ -1192,18 +1454,15 @@ DataFrame teenagers = sqlContext.sql("SELECT name FROM people WHERE age >= 13 AN List jsonData = Arrays.asList( "{\"name\":\"Yin\",\"address\":{\"city\":\"Columbus\",\"state\":\"Ohio\"}}"); JavaRDD anotherPeopleRDD = sc.parallelize(jsonData); -DataFrame anotherPeople = sqlContext.jsonRDD(anotherPeopleRDD); +DataFrame anotherPeople = sqlContext.read().json(anotherPeopleRDD); {% endhighlight %}
    Spark SQL can automatically infer the schema of a JSON dataset and load it as a DataFrame. -This conversion can be done using one of two methods in a `SQLContext`: - -* `jsonFile` - loads data from a directory of JSON files where each line of the files is a JSON object. -* `jsonRDD` - loads data from an existing RDD where each element of the RDD is a string containing a JSON object. +This conversion can be done using `SQLContext.read.json` on a JSON file. -Note that the file that is offered as _jsonFile_ is not a typical JSON file. Each +Note that the file that is offered as _a json file_ is not a typical JSON file. Each line must contain a separate, self-contained valid JSON object. As a consequence, a regular multi-line JSON file will most often fail. @@ -1214,9 +1473,7 @@ sqlContext = SQLContext(sc) # A JSON dataset is pointed to by path. # The path can be either a single text file or a directory storing text files. -path = "examples/src/main/resources/people.json" -# Create a DataFrame from the file(s) pointed to by path -people = sqlContext.jsonFile(path) +people = sqlContext.read.json("examples/src/main/resources/people.json") # The inferred schema can be visualized using the printSchema() method. people.printSchema() @@ -1238,6 +1495,39 @@ anotherPeople = sqlContext.jsonRDD(anotherPeopleRDD) {% endhighlight %}
    +
    +Spark SQL can automatically infer the schema of a JSON dataset and load it as a DataFrame. using +the `jsonFile` function, which loads data from a directory of JSON files where each line of the +files is a JSON object. + +Note that the file that is offered as _a json file_ is not a typical JSON file. Each +line must contain a separate, self-contained valid JSON object. As a consequence, +a regular multi-line JSON file will most often fail. + +{% highlight r %} +# sc is an existing SparkContext. +sqlContext <- sparkRSQL.init(sc) + +# A JSON dataset is pointed to by path. +# The path can be either a single text file or a directory storing text files. +path <- "examples/src/main/resources/people.json" +# Create a DataFrame from the file(s) pointed to by path +people <- jsonFile(sqlContext, path) + +# The inferred schema can be visualized using the printSchema() method. +printSchema(people) +# root +# |-- age: integer (nullable = true) +# |-- name: string (nullable = true) + +# Register this DataFrame as a table. +registerTempTable(people, "people") + +# SQL statements can be run by using the sql methods provided by `sqlContext`. +teenagers <- sql(sqlContext, "SELECT name FROM people WHERE age >= 13 AND age <= 19") +{% endhighlight %} +
    +
    {% highlight sql %} @@ -1265,7 +1555,12 @@ This command builds a new assembly jar that includes Hive. Note that this Hive a on all of the worker nodes, as they will need access to the Hive serialization and deserialization libraries (SerDes) in order to access data stored in Hive. -Configuration of Hive is done by placing your `hive-site.xml` file in `conf/`. +Configuration of Hive is done by placing your `hive-site.xml` file in `conf/`. Please note when running +the query on a YARN cluster (`yarn-cluster` mode), the `datanucleus` jars under the `lib_managed/jars` directory +and `hive-site.xml` under `conf/` directory need to be available on the driver and all executors launched by the +YARN cluster. The convenient way to do this is adding them through the `--jars` option and `--file` option of the +`spark-submit` command. +
    @@ -1294,12 +1589,12 @@ sqlContext.sql("FROM src SELECT key, value").collect().foreach(println) When working with Hive one must construct a `HiveContext`, which inherits from `SQLContext`, and adds support for finding tables in the MetaStore and writing queries using HiveQL. In addition to -the `sql` method a `HiveContext` also provides an `hql` methods, which allows queries to be +the `sql` method a `HiveContext` also provides an `hql` method, which allows queries to be expressed in HiveQL. {% highlight java %} // sc is an existing JavaSparkContext. -HiveContext sqlContext = new org.apache.spark.sql.hive.HiveContext(sc); +HiveContext sqlContext = new org.apache.spark.sql.hive.HiveContext(sc.sc); sqlContext.sql("CREATE TABLE IF NOT EXISTS src (key INT, value STRING)"); sqlContext.sql("LOAD DATA LOCAL INPATH 'examples/src/main/resources/kv1.txt' INTO TABLE src"); @@ -1314,10 +1609,7 @@ Row[] results = sqlContext.sql("FROM src SELECT key, value").collect();
    When working with Hive one must construct a `HiveContext`, which inherits from `SQLContext`, and -adds support for finding tables in the MetaStore and writing queries using HiveQL. In addition to -the `sql` method a `HiveContext` also provides an `hql` methods, which allows queries to be -expressed in HiveQL. - +adds support for finding tables in the MetaStore and writing queries using HiveQL. {% highlight python %} # sc is an existing SparkContext. from pyspark.sql import HiveContext @@ -1331,9 +1623,91 @@ results = sqlContext.sql("FROM src SELECT key, value").collect() {% endhighlight %} +
    + +
    + +When working with Hive one must construct a `HiveContext`, which inherits from `SQLContext`, and +adds support for finding tables in the MetaStore and writing queries using HiveQL. +{% highlight r %} +# sc is an existing SparkContext. +sqlContext <- sparkRHive.init(sc) + +sql(sqlContext, "CREATE TABLE IF NOT EXISTS src (key INT, value STRING)") +sql(sqlContext, "LOAD DATA LOCAL INPATH 'examples/src/main/resources/kv1.txt' INTO TABLE src") + +# Queries can be expressed in HiveQL. +results = sqlContext.sql("FROM src SELECT key, value").collect() + +{% endhighlight %} +
    +### Interacting with Different Versions of Hive Metastore + +One of the most important pieces of Spark SQL's Hive support is interaction with Hive metastore, +which enables Spark SQL to access metadata of Hive tables. Starting from Spark 1.4.0, a single binary build of Spark SQL can be used to query different versions of Hive metastores, using the configuration described below. + +Internally, Spark SQL uses two Hive clients, one for executing native Hive commands like `SET` +and `DESCRIBE`, the other dedicated for communicating with Hive metastore. The former uses Hive +jars of version 0.13.1, which are bundled with Spark 1.4.0. The latter uses Hive jars of the +version specified by users. An isolated classloader is used here to avoid dependency conflicts. + + + + + + + + + + + + + + + + + + + + + + + +
    Property NameDefaultMeaning
    spark.sql.hive.metastore.version0.13.1 + Version of the Hive metastore. Available + options are 0.12.0 and 0.13.1. Support for more versions is coming in the future. +
    spark.sql.hive.metastore.jarsbuiltin + Location of the jars that should be used to instantiate the HiveMetastoreClient. This + property can be one of three options: +
      +
    1. builtin
    2. + Use Hive 0.13.1, which is bundled with the Spark assembly jar when -Phive is + enabled. When this option is chosen, spark.sql.hive.metastore.version must be + either 0.13.1 or not defined. +
    3. maven
    4. + Use Hive jars of specified version downloaded from Maven repositories. +
    5. A classpath in the standard format for both Hive and Hadoop.
    6. +
    +
    spark.sql.hive.metastore.sharedPrefixescom.mysql.jdbc,
    org.postgresql,
    com.microsoft.sqlserver,
    oracle.jdbc
    +

    + A comma separated list of class prefixes that should be loaded using the classloader that is + shared between Spark SQL and a specific version of Hive. An example of classes that should + be shared is JDBC drivers that are needed to talk to the metastore. Other classes that need + to be shared are those that interact with classes that are already shared. For example, + custom appenders that are used by log4j. +

    +
    spark.sql.hive.metastore.barrierPrefixes(empty) +

    + A comma separated list of class prefixes that should explicitly be reloaded for each version + of Hive that Spark SQL is communicating with. For example, Hive UDFs that are declared in a + prefix that typically would be shared (i.e. org.apache.spark.*). +

    +
    + + ## JDBC To Other Databases Spark SQL also includes a data source that can read data from other databases using JDBC. This @@ -1367,7 +1741,7 @@ the Data Sources API. The following options are supported: dbtable - The JDBC table that should be read. Note that anything that is valid in a `FROM` clause of + The JDBC table that should be read. Note that anything that is valid in a FROM clause of a SQL query can be used. For example, instead of a full table you could also use a subquery in parentheses. @@ -1399,9 +1773,9 @@ the Data Sources API. The following options are supported:
    {% highlight scala %} -val jdbcDF = sqlContext.load("jdbc", Map( - "url" -> "jdbc:postgresql:dbserver", - "dbtable" -> "schema.tablename")) +val jdbcDF = sqlContext.read.format("jdbc").options( + Map("url" -> "jdbc:postgresql:dbserver", + "dbtable" -> "schema.tablename")).load() {% endhighlight %}
    @@ -1414,7 +1788,7 @@ Map options = new HashMap(); options.put("url", "jdbc:postgresql:dbserver"); options.put("dbtable", "schema.tablename"); -DataFrame jdbcDF = sqlContext.load("jdbc", options) +DataFrame jdbcDF = sqlContext.read().format("jdbc"). options(options).load(); {% endhighlight %} @@ -1424,7 +1798,17 @@ DataFrame jdbcDF = sqlContext.load("jdbc", options) {% highlight python %} -df = sqlContext.load(source="jdbc", url="jdbc:postgresql:dbserver", dbtable="schema.tablename") +df = sqlContext.read.format('jdbc').options(url = 'jdbc:postgresql:dbserver', dbtable='schema.tablename').load() + +{% endhighlight %} + +
    + +
    + +{% highlight r %} + +df <- loadDF(sqlContext, source="jdbc", url="jdbc:postgresql:dbserver", dbtable="schema.tablename") {% endhighlight %} @@ -1501,7 +1885,7 @@ that these options will be deprecated in future release as more optimizations ar Configures the maximum size in bytes for a table that will be broadcast to all worker nodes when performing a join. By setting this value to -1 broadcasting can be disabled. Note that currently statistics are only supported for Hive Metastore tables where the command - `ANALYZE TABLE <tableName> COMPUTE STATISTICS noscan` has been run. + ANALYZE TABLE <tableName> COMPUTE STATISTICS noscan has been run. @@ -1520,11 +1904,20 @@ that these options will be deprecated in future release as more optimizations ar Configures the number of partitions to use when shuffling data for joins or aggregations. + + spark.sql.planner.externalSort + false + + When true, performs sorts spilling to disk as needed otherwise sort each partition in memory. + + # Distributed SQL Engine -Spark SQL can also act as a distributed query engine using its JDBC/ODBC or command-line interface. In this mode, end-users or applications can interact with Spark SQL directly to run SQL queries, without the need to write any code. +Spark SQL can also act as a distributed query engine using its JDBC/ODBC or command-line interface. +In this mode, end-users or applications can interact with Spark SQL directly to run SQL queries, +without the need to write any code. ## Running the Thrift JDBC/ODBC server @@ -1538,7 +1931,7 @@ To start the JDBC/ODBC server, run the following in the Spark directory: This script accepts all `bin/spark-submit` command line options, plus a `--hiveconf` option to specify Hive properties. You may run `./sbin/start-thriftserver.sh --help` for a complete list of all available options. By default, the server listens on localhost:10000. You may override this -bahaviour via either environment variables, i.e.: +behaviour via either environment variables, i.e.: {% highlight bash %} export HIVE_SERVER2_THRIFT_PORT= @@ -1603,6 +1996,25 @@ options. ## Upgrading from Spark SQL 1.3 to 1.4 +#### DataFrame data reader/writer interface + +Based on user feedback, we created a new, more fluid API for reading data in (`SQLContext.read`) +and writing data out (`DataFrame.write`), +and deprecated the old APIs (e.g. `SQLContext.parquetFile`, `SQLContext.jsonFile`). + +See the API docs for `SQLContext.read` ( + Scala, + Java, + Python +) and `DataFrame.write` ( + Scala, + Java, + Python +) more information. + + +#### DataFrame.groupBy retains grouping columns + Based on user feedback, we changed the default behavior of `DataFrame.groupBy().agg()` to retain the grouping columns in the resulting `DataFrame`. To keep the behavior in 1.3, set `spark.sql.retainGroupColumns` to `false`.
    @@ -1726,7 +2138,7 @@ sqlContext.udf.register("strLen", (s: String) => s.length())
    {% highlight java %} -sqlContext.udf().register("strLen", (String s) -> { s.length(); }); +sqlContext.udf().register("strLen", (String s) -> s.length(), DataTypes.IntegerType); {% endhighlight %}
    @@ -2354,5 +2766,151 @@ from pyspark.sql.types import *
    +
    + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + +
    Data typeValue type in RAPI to access or create a data type
    ByteType + integer
    + Note: Numbers will be converted to 1-byte signed integer numbers at runtime. + Please make sure that numbers are within the range of -128 to 127. +
    + "byte" +
    ShortType + integer
    + Note: Numbers will be converted to 2-byte signed integer numbers at runtime. + Please make sure that numbers are within the range of -32768 to 32767. +
    + "short" +
    IntegerType integer + "integer" +
    LongType + integer
    + Note: Numbers will be converted to 8-byte signed integer numbers at runtime. + Please make sure that numbers are within the range of + -9223372036854775808 to 9223372036854775807. + Otherwise, please convert data to decimal.Decimal and use DecimalType. +
    + "long" +
    FloatType + numeric
    + Note: Numbers will be converted to 4-byte single-precision floating + point numbers at runtime. +
    + "float" +
    DoubleType numeric + "double" +
    DecimalType Not supported + Not supported +
    StringType character + "string" +
    BinaryType raw + "binary" +
    BooleanType logical + "bool" +
    TimestampType POSIXct + "timestamp" +
    DateType Date + "date" +
    ArrayType vector or list + list(type="array", elementType=elementType, containsNull=[containsNull])
    + Note: The default value of containsNull is True. +
    MapType environment + list(type="map", keyType=keyType, valueType=valueType, valueContainsNull=[valueContainsNull])
    + Note: The default value of valueContainsNull is True. +
    StructType named list + list(type="struct", fields=fields)
    + Note: fields is a Seq of StructFields. Also, two fields with the same + name are not allowed. +
    StructField The value type in R of the data type of this field + (For example, integer for a StructField with the data type IntegerType) + list(name=name, type=dataType, nullable=nullable) +
    + +
    +
    diff --git a/docs/streaming-custom-receivers.md b/docs/streaming-custom-receivers.md index 6a2048121f8b..a75587a92adc 100644 --- a/docs/streaming-custom-receivers.md +++ b/docs/streaming-custom-receivers.md @@ -4,7 +4,7 @@ title: Spark Streaming Custom Receivers --- Spark Streaming can receive streaming data from any arbitrary data source beyond -the one's for which it has in-built support (that is, beyond Flume, Kafka, Kinesis, files, sockets, etc.). +the ones for which it has built-in support (that is, beyond Flume, Kafka, Kinesis, files, sockets, etc.). This requires the developer to implement a *receiver* that is customized for receiving data from the concerned data source. This guide walks through the process of implementing a custom receiver and using it in a Spark Streaming application. Note that custom receivers can be implemented @@ -21,15 +21,15 @@ A custom receiver must extend this abstract class by implementing two methods - `onStop()`: Things to do to stop receiving data. Both `onStart()` and `onStop()` must not block indefinitely. Typically, `onStart()` would start the threads -that responsible for receiving the data and `onStop()` would ensure that the receiving by those threads +that are responsible for receiving the data, and `onStop()` would ensure that these threads receiving the data are stopped. The receiving threads can also use `isStopped()`, a `Receiver` method, to check whether they should stop receiving data. Once the data is received, that data can be stored inside Spark by calling `store(data)`, which is a method provided by the Receiver class. -There are number of flavours of `store()` which allow you store the received data -record-at-a-time or as whole collection of objects / serialized bytes. Note that the flavour of -`store()` used to implemented a receiver affects its reliability and fault-tolerance semantics. +There are a number of flavors of `store()` which allow one to store the received data +record-at-a-time or as whole collection of objects / serialized bytes. Note that the flavor of +`store()` used to implement a receiver affects its reliability and fault-tolerance semantics. This is discussed [later](#receiver-reliability) in more detail. Any exception in the receiving threads should be caught and handled properly to avoid silent @@ -60,7 +60,7 @@ class CustomReceiver(host: String, port: Int) def onStop() { // There is nothing much to do as the thread calling receive() - // is designed to stop by itself isStopped() returns false + // is designed to stop by itself if isStopped() returns false } /** Create a socket connection and receive data until receiver is stopped */ @@ -123,7 +123,7 @@ public class JavaCustomReceiver extends Receiver { public void onStop() { // There is nothing much to do as the thread calling receive() - // is designed to stop by itself isStopped() returns false + // is designed to stop by itself if isStopped() returns false } /** Create a socket connection and receive data until receiver is stopped */ @@ -167,7 +167,7 @@ public class JavaCustomReceiver extends Receiver { The custom receiver can be used in a Spark Streaming application by using `streamingContext.receiverStream()`. This will create -input DStream using data received by the instance of custom receiver, as shown below +an input DStream using data received by the instance of custom receiver, as shown below:
    @@ -206,22 +206,20 @@ there are two kinds of receivers based on their reliability and fault-tolerance and stored in Spark reliably (that is, replicated successfully). Usually, implementing this receiver involves careful consideration of the semantics of source acknowledgements. -1. *Unreliable Receiver* - These are receivers for unreliable sources that do not support - acknowledging. Even for reliable sources, one may implement an unreliable receiver that - do not go into the complexity of acknowledging correctly. +1. *Unreliable Receiver* - An *unreliable receiver* does *not* send acknowledgement to a source. This can be used for sources that do not support acknowledgement, or even for reliable sources when one does not want or need to go into the complexity of acknowledgement. To implement a *reliable receiver*, you have to use `store(multiple-records)` to store data. -This flavour of `store` is a blocking call which returns only after all the given records have +This flavor of `store` is a blocking call which returns only after all the given records have been stored inside Spark. If the receiver's configured storage level uses replication (enabled by default), then this call returns after replication has completed. Thus it ensures that the data is reliably stored, and the receiver can now acknowledge the -source appropriately. This ensures that no data is caused when the receiver fails in the middle +source appropriately. This ensures that no data is lost when the receiver fails in the middle of replicating data -- the buffered data will not be acknowledged and hence will be later resent by the source. An *unreliable receiver* does not have to implement any of this logic. It can simply receive records from the source and insert them one-at-a-time using `store(single-record)`. While it does -not get the reliability guarantees of `store(multiple-records)`, it has the following advantages. +not get the reliability guarantees of `store(multiple-records)`, it has the following advantages: - The system takes care of chunking that data into appropriate sized blocks (look for block interval in the [Spark Streaming Programming Guide](streaming-programming-guide.html)). diff --git a/docs/streaming-flume-integration.md b/docs/streaming-flume-integration.md index c8ab146bcae0..de0461010dae 100644 --- a/docs/streaming-flume-integration.md +++ b/docs/streaming-flume-integration.md @@ -58,6 +58,15 @@ configuring Flume agents. See the [API docs](api/java/index.html?org/apache/spark/streaming/flume/FlumeUtils.html) and the [example]({{site.SPARK_GITHUB_URL}}/tree/master/examples/src/main/java/org/apache/spark/examples/streaming/JavaFlumeEventCount.java).
    +
    + from pyspark.streaming.flume import FlumeUtils + + flumeStream = FlumeUtils.createStream(streamingContext, [chosen machine's hostname], [chosen port]) + + By default, the Python API will decode Flume event body as UTF8 encoded strings. You can specify your custom decoding function to decode the body byte arrays in Flume events to any arbitrary data type. + See the [API docs](api/python/pyspark.streaming.html#pyspark.streaming.flume.FlumeUtils) + and the [example]({{site.SPARK_GITHUB_URL}}/blob/master/examples/src/main/python/streaming/flume_wordcount.py). +
    Note that the hostname should be the same as the one used by the resource manager in the @@ -99,6 +108,12 @@ Configuring Flume on the chosen machine requires the following two steps. artifactId = scala-library version = {{site.SCALA_VERSION}} + (iii) *Commons Lang 3 JAR*: Download the Commons Lang 3 JAR. It can be found with the following artifact detail (or, [direct link](http://search.maven.org/remotecontent?filepath=org/apache/commons/commons-lang3/3.3.2/commons-lang3-3.3.2.jar)). + + groupId = org.apache.commons + artifactId = commons-lang3 + version = 3.3.2 + 2. **Configuration file**: On that machine, configure Flume agent to send data to an Avro sink by having the following in the configuration file. agent.sinks = spark @@ -129,6 +144,15 @@ configuring Flume agents. JavaReceiverInputDStreamflumeStream = FlumeUtils.createPollingStream(streamingContext, [sink machine hostname], [sink port]); +
    + from pyspark.streaming.flume import FlumeUtils + + addresses = [([sink machine hostname 1], [sink port 1]), ([sink machine hostname 2], [sink port 2])] + flumeStream = FlumeUtils.createPollingStream(streamingContext, addresses) + + By default, the Python API will decode Flume event body as UTF8 encoded strings. You can specify your custom decoding function to decode the body byte arrays in Flume events to any arbitrary data type. + See the [API docs](api/python/pyspark.streaming.html#pyspark.streaming.flume.FlumeUtils). +
    See the Scala example [FlumePollingEventCount]({{site.SPARK_GITHUB_URL}}/tree/master/examples/src/main/scala/org/apache/spark/examples/streaming/FlumePollingEventCount.scala). diff --git a/docs/streaming-kafka-integration.md b/docs/streaming-kafka-integration.md index 64714f0b799f..775d508d4879 100644 --- a/docs/streaming-kafka-integration.md +++ b/docs/streaming-kafka-integration.md @@ -2,12 +2,12 @@ layout: global title: Spark Streaming + Kafka Integration Guide --- -[Apache Kafka](http://kafka.apache.org/) is publish-subscribe messaging rethought as a distributed, partitioned, replicated commit log service. Here we explain how to configure Spark Streaming to receive data from Kafka. There are two approaches to this - the old approach using Receivers and Kafka's high-level API, and a new experimental approach (introduced in Spark 1.3) without using Receivers. They have different programming models, performance characteristics, and semantics guarantees, so read on for more details. +[Apache Kafka](http://kafka.apache.org/) is publish-subscribe messaging rethought as a distributed, partitioned, replicated commit log service. Here we explain how to configure Spark Streaming to receive data from Kafka. There are two approaches to this - the old approach using Receivers and Kafka's high-level API, and a new experimental approach (introduced in Spark 1.3) without using Receivers. They have different programming models, performance characteristics, and semantics guarantees, so read on for more details. ## Approach 1: Receiver-based Approach This approach uses a Receiver to receive the data. The Received is implemented using the Kafka high-level consumer API. As with all receivers, the data received from Kafka through a Receiver is stored in Spark executors, and then jobs launched by Spark Streaming processes the data. -However, under default configuration, this approach can lose data under failures (see [receiver reliability](streaming-programming-guide.html#receiver-reliability). To ensure zero-data loss, you have to additionally enable Write Ahead Logs in Spark Streaming. To ensure zero data loss, enable the Write Ahead Logs (introduced in Spark 1.2). This synchronously saves all the received Kafka data into write ahead logs on a distributed file system (e.g HDFS), so that all the data can be recovered on failure. See [Deploying section](streaming-programming-guide.html#deploying-applications) in the streaming programming guide for more details on Write Ahead Logs. +However, under default configuration, this approach can lose data under failures (see [receiver reliability](streaming-programming-guide.html#receiver-reliability). To ensure zero-data loss, you have to additionally enable Write Ahead Logs in Spark Streaming (introduced in Spark 1.2). This synchronously saves all the received Kafka data into write ahead logs on a distributed file system (e.g HDFS), so that all the data can be recovered on failure. See [Deploying section](streaming-programming-guide.html#deploying-applications) in the streaming programming guide for more details on Write Ahead Logs. Next, we discuss how to use this approach in your streaming application. @@ -29,7 +29,7 @@ Next, we discuss how to use this approach in your streaming application. [ZK quorum], [consumer group id], [per-topic number of Kafka partitions to consume]) You can also specify the key and value classes and their corresponding decoder classes using variations of `createStream`. See the [API docs](api/scala/index.html#org.apache.spark.streaming.kafka.KafkaUtils$) - and the [example]({{site.SPARK_GITHUB_URL}}/blob/master/examples/scala-2.10/src/main/scala/org/apache/spark/examples/streaming/KafkaWordCount.scala). + and the [example]({{site.SPARK_GITHUB_URL}}/blob/master/examples/src/main/scala/org/apache/spark/examples/streaming/KafkaWordCount.scala).
    import org.apache.spark.streaming.kafka.*; @@ -39,7 +39,7 @@ Next, we discuss how to use this approach in your streaming application. [ZK quorum], [consumer group id], [per-topic number of Kafka partitions to consume]); You can also specify the key and value classes and their corresponding decoder classes using variations of `createStream`. See the [API docs](api/java/index.html?org/apache/spark/streaming/kafka/KafkaUtils.html) - and the [example]({{site.SPARK_GITHUB_URL}}/blob/master/examples/scala-2.10/src/main/java/org/apache/spark/examples/streaming/JavaKafkaWordCount.java). + and the [example]({{site.SPARK_GITHUB_URL}}/blob/master/examples/src/main/java/org/apache/spark/examples/streaming/JavaKafkaWordCount.java).
    @@ -74,15 +74,15 @@ Next, we discuss how to use this approach in your streaming application. [Maven repository](http://search.maven.org/#search|ga|1|a%3A%22spark-streaming-kafka-assembly_2.10%22%20AND%20v%3A%22{{site.SPARK_VERSION_SHORT}}%22) and add it to `spark-submit` with `--jars`. ## Approach 2: Direct Approach (No Receivers) -This is a new receiver-less "direct" approach has been introduced in Spark 1.3 to ensure stronger end-to-end guarantees. Instead of using receivers to receive data, this approach periodically queries Kafka for the latest offsets in each topic+partition, and accordingly defines the offset ranges to process in each batch. When the jobs to process the data are launched, Kafka's simple consumer API is used to read the defined ranges of offsets from Kafka (similar to read files from a file system). Note that this is an experimental feature in Spark 1.3 and is only available in the Scala and Java API. +This new receiver-less "direct" approach has been introduced in Spark 1.3 to ensure stronger end-to-end guarantees. Instead of using receivers to receive data, this approach periodically queries Kafka for the latest offsets in each topic+partition, and accordingly defines the offset ranges to process in each batch. When the jobs to process the data are launched, Kafka's simple consumer API is used to read the defined ranges of offsets from Kafka (similar to read files from a file system). Note that this is an experimental feature introduced in Spark 1.3 for the Scala and Java API. Spark 1.4 added a Python API, but it is not yet at full feature parity. -This approach has the following advantages over the received-based approach (i.e. Approach 1). +This approach has the following advantages over the receiver-based approach (i.e. Approach 1). -- *Simplified Parallelism:* No need to create multiple input Kafka streams and union-ing them. With `directStream`, Spark Streaming will create as many RDD partitions as there is Kafka partitions to consume, which will all read data from Kafka in parallel. So there is one-to-one mapping between Kafka and RDD partitions, which is easier to understand and tune. +- *Simplified Parallelism:* No need to create multiple input Kafka streams and union them. With `directStream`, Spark Streaming will create as many RDD partitions as there are Kafka partitions to consume, which will all read data from Kafka in parallel. So there is a one-to-one mapping between Kafka and RDD partitions, which is easier to understand and tune. -- *Efficiency:* Achieving zero-data loss in the first approach required the data to be stored in a Write Ahead Log, which further replicated the data. This is actually inefficient as the data effectively gets replicated twice - once by Kafka, and a second time by the Write Ahead Log. This second approach eliminate the problem as there is no receiver, and hence no need for Write Ahead Logs. +- *Efficiency:* Achieving zero-data loss in the first approach required the data to be stored in a Write Ahead Log, which further replicated the data. This is actually inefficient as the data effectively gets replicated twice - once by Kafka, and a second time by the Write Ahead Log. This second approach eliminates the problem as there is no receiver, and hence no need for Write Ahead Logs. As long as you have sufficient Kafka retention, messages can be recovered from Kafka. -- *Exactly-once semantics:* The first approach uses Kafka's high level API to store consumed offsets in Zookeeper. This is traditionally the way to consume data from Kafka. While this approach (in combination with write ahead logs) can ensure zero data loss (i.e. at-least once semantics), there is a small chance some records may get consumed twice under some failures. This occurs because of inconsistencies between data reliably received by Spark Streaming and offsets tracked by Zookeeper. Hence, in this second approach, we use simple Kafka API that does not use Zookeeper and offsets tracked only by Spark Streaming within its checkpoints. This eliminates inconsistencies between Spark Streaming and Zookeeper/Kafka, and so each record is received by Spark Streaming effectively exactly once despite failures. +- *Exactly-once semantics:* The first approach uses Kafka's high level API to store consumed offsets in Zookeeper. This is traditionally the way to consume data from Kafka. While this approach (in combination with write ahead logs) can ensure zero data loss (i.e. at-least once semantics), there is a small chance some records may get consumed twice under some failures. This occurs because of inconsistencies between data reliably received by Spark Streaming and offsets tracked by Zookeeper. Hence, in this second approach, we use simple Kafka API that does not use Zookeeper. Offsets are tracked by Spark Streaming within its checkpoints. This eliminates inconsistencies between Spark Streaming and Zookeeper/Kafka, and so each record is received by Spark Streaming effectively exactly once despite failures. In order to achieve exactly-once semantics for output of your results, your output operation that saves the data to an external data store must be either idempotent, or an atomic transaction that saves results and offsets (see [Semanitcs of output operations](streaming-programming-guide.html#semantics-of-output-operations) in the main programming guide for further information). Note that one disadvantage of this approach is that it does not update offsets in Zookeeper, hence Zookeeper-based Kafka monitoring tools will not show progress. However, you can access the offsets processed by this approach in each batch and update Zookeeper yourself (see below). @@ -105,7 +105,7 @@ Next, we discuss how to use this approach in your streaming application. streamingContext, [map of Kafka parameters], [set of topics to consume]) See the [API docs](api/scala/index.html#org.apache.spark.streaming.kafka.KafkaUtils$) - and the [example]({{site.SPARK_GITHUB_URL}}/blob/master/examples/scala-2.10/src/main/scala/org/apache/spark/examples/streaming/DirectKafkaWordCount.scala). + and the [example]({{site.SPARK_GITHUB_URL}}/blob/master/examples/src/main/scala/org/apache/spark/examples/streaming/DirectKafkaWordCount.scala).
    import org.apache.spark.streaming.kafka.*; @@ -116,8 +116,15 @@ Next, we discuss how to use this approach in your streaming application. [map of Kafka parameters], [set of topics to consume]); See the [API docs](api/java/index.html?org/apache/spark/streaming/kafka/KafkaUtils.html) - and the [example]({{site.SPARK_GITHUB_URL}}/blob/master/examples/scala-2.10/src/main/java/org/apache/spark/examples/streaming/JavaDirectKafkaWordCount.java). + and the [example]({{site.SPARK_GITHUB_URL}}/blob/master/examples/src/main/java/org/apache/spark/examples/streaming/JavaDirectKafkaWordCount.java). +
    +
    + from pyspark.streaming.kafka import KafkaUtils + directKafkaStream = KafkaUtils.createDirectStream(ssc, [topic], {"metadata.broker.list": brokers}) + + By default, the Python API will decode Kafka data as UTF8 encoded strings. You can specify your custom decoding function to decode the byte arrays in Kafka records to any arbitrary data type. See the [API docs](api/python/pyspark.streaming.html#pyspark.streaming.kafka.KafkaUtils) + and the [example]({{site.SPARK_GITHUB_URL}}/blob/master/examples/src/main/python/streaming/direct_kafka_wordcount.py).
    @@ -128,29 +135,60 @@ Next, we discuss how to use this approach in your streaming application.
    - directKafkaStream.foreachRDD { rdd => - val offsetRanges = rdd.asInstanceOf[HasOffsetRanges] - // offsetRanges.length = # of Kafka partitions being consumed - ... + // Hold a reference to the current offset ranges, so it can be used downstream + var offsetRanges = Array[OffsetRange]() + + directKafkaStream.transform { rdd => + offsetRanges = rdd.asInstanceOf[HasOffsetRanges].offsetRanges + rdd + }.map { + ... + }.foreachRDD { rdd => + for (o <- offsetRanges) { + println(s"${o.topic} ${o.partition} ${o.fromOffset} ${o.untilOffset}") + } + ... }
    - directKafkaStream.foreachRDD( - new Function, Void>() { - @Override - public Void call(JavaPairRDD rdd) throws IOException { - OffsetRange[] offsetRanges = ((HasOffsetRanges)rdd).offsetRanges - // offsetRanges.length = # of Kafka partitions being consumed - ... - return null; - } + // Hold a reference to the current offset ranges, so it can be used downstream + final AtomicReference offsetRanges = new AtomicReference(); + + directKafkaStream.transformToPair( + new Function, JavaPairRDD>() { + @Override + public JavaPairRDD call(JavaPairRDD rdd) throws Exception { + OffsetRange[] offsets = ((HasOffsetRanges) rdd.rdd()).offsetRanges(); + offsetRanges.set(offsets); + return rdd; + } + } + ).map( + ... + ).foreachRDD( + new Function, Void>() { + @Override + public Void call(JavaPairRDD rdd) throws IOException { + for (OffsetRange o : offsetRanges.get()) { + System.out.println( + o.topic() + " " + o.partition() + " " + o.fromOffset() + " " + o.untilOffset() + ); + } + ... + return null; } + } );
    +
    + Not supported yet
    +
    You can use this to update Zookeeper yourself if you want Zookeeper-based Kafka monitoring tools to show progress of the streaming application. - Another thing to note is that since this approach does not use Receivers, the standard receiver-related (that is, [configurations](configuration.html) of the form `spark.streaming.receiver.*` ) will not apply to the input DStreams created by this approach (will apply to other input DStreams though). Instead, use the [configurations](configuration.html) `spark.streaming.kafka.*`. An important one is `spark.streaming.kafka.maxRatePerPartition` which is the maximum rate at which each Kafka partition will be read by this direct API. + Note that the typecast to HasOffsetRanges will only succeed if it is done in the first method called on the directKafkaStream, not later down a chain of methods. You can use transform() instead of foreachRDD() as your first method call in order to access offsets, then call further Spark methods. However, be aware that the one-to-one mapping between RDD partition and Kafka partition does not remain after any methods that shuffle or repartition, e.g. reduceByKey() or window(). + + Another thing to note is that since this approach does not use Receivers, the standard receiver-related (that is, [configurations](configuration.html) of the form `spark.streaming.receiver.*` ) will not apply to the input DStreams created by this approach (will apply to other input DStreams though). Instead, use the [configurations](configuration.html) `spark.streaming.kafka.*`. An important one is `spark.streaming.kafka.maxRatePerPartition` which is the maximum rate (in messages per second) at which each Kafka partition will be read by this direct API. -3. **Deploying:** Similar to the first approach, you can package `spark-streaming-kafka_{{site.SCALA_BINARY_VERSION}}` and its dependencies into the application JAR and the launch the application using `spark-submit`. Make sure `spark-core_{{site.SCALA_BINARY_VERSION}}` and `spark-streaming_{{site.SCALA_BINARY_VERSION}}` are marked as `provided` dependencies as those are already present in a Spark installation. \ No newline at end of file +3. **Deploying:** This is same as the first approach, for Scala, Java and Python. diff --git a/docs/streaming-kinesis-integration.md b/docs/streaming-kinesis-integration.md index 379eb513d521..aa9749afbc86 100644 --- a/docs/streaming-kinesis-integration.md +++ b/docs/streaming-kinesis-integration.md @@ -32,7 +32,8 @@ A Kinesis stream can be set up at one of the valid Kinesis endpoints with 1 or m import com.amazonaws.services.kinesis.clientlibrary.lib.worker.InitialPositionInStream val kinesisStream = KinesisUtils.createStream( - streamingContext, [Kinesis stream name], [endpoint URL], [checkpoint interval], [initial position]) + streamingContext, [Kinesis app name], [Kinesis stream name], [endpoint URL], + [region name], [initial position], [checkpoint interval], StorageLevel.MEMORY_AND_DISK_2) See the [API docs](api/scala/index.html#org.apache.spark.streaming.kinesis.KinesisUtils$) and the [example]({{site.SPARK_GITHUB_URL}}/tree/master/extras/kinesis-asl/src/main/scala/org/apache/spark/examples/streaming/KinesisWordCountASL.scala). Refer to the Running the Example section for instructions on how to run the example. @@ -44,7 +45,8 @@ A Kinesis stream can be set up at one of the valid Kinesis endpoints with 1 or m import com.amazonaws.services.kinesis.clientlibrary.lib.worker.InitialPositionInStream; JavaReceiverInputDStream kinesisStream = KinesisUtils.createStream( - streamingContext, [Kinesis stream name], [endpoint URL], [checkpoint interval], [initial position]); + streamingContext, [Kinesis app name], [Kinesis stream name], [endpoint URL], + [region name], [initial position], [checkpoint interval], StorageLevel.MEMORY_AND_DISK_2); See the [API docs](api/java/index.html?org/apache/spark/streaming/kinesis/KinesisUtils.html) and the [example]({{site.SPARK_GITHUB_URL}}/tree/master/extras/kinesis-asl/src/main/java/org/apache/spark/examples/streaming/JavaKinesisWordCountASL.java). Refer to the next subsection for instructions to run the example. @@ -54,19 +56,23 @@ A Kinesis stream can be set up at one of the valid Kinesis endpoints with 1 or m - `streamingContext`: StreamingContext containg an application name used by Kinesis to tie this Kinesis application to the Kinesis stream - - `[Kinesis stream name]`: The Kinesis stream that this streaming application receives from - - The application name used in the streaming context becomes the Kinesis application name + - `[Kineiss app name]`: The application name that will be used to checkpoint the Kinesis + sequence numbers in DynamoDB table. - The application name must be unique for a given account and region. - - The Kinesis backend automatically associates the application name to the Kinesis stream using a DynamoDB table (always in the us-east-1 region) created during Kinesis Client Library initialization. - - Changing the application name or stream name can lead to Kinesis errors in some cases. If you see errors, you may need to manually delete the DynamoDB table. + - If the table exists but has incorrect checkpoint information (for a different stream, or + old expired sequenced numbers), then there may be temporary errors. + - `[Kinesis stream name]`: The Kinesis stream that this streaming application will pull data from. - `[endpoint URL]`: Valid Kinesis endpoints URL can be found [here](http://docs.aws.amazon.com/general/latest/gr/rande.html#ak_region). + - `[region name]`: Valid Kinesis region names can be found [here](https://docs.aws.amazon.com/AWSEC2/latest/UserGuide/using-regions-availability-zones.html). + - `[checkpoint interval]`: The interval (e.g., Duration(2000) = 2 seconds) at which the Kinesis Client Library saves its position in the stream. For starters, set it to the same as the batch interval of the streaming application. - `[initial position]`: Can be either `InitialPositionInStream.TRIM_HORIZON` or `InitialPositionInStream.LATEST` (see Kinesis Checkpointing section and Amazon Kinesis API documentation for more details). + In other versions of the API, you can also specify the AWS access key and secret key directly. 3. **Deploying:** Package `spark-streaming-kinesis-asl_{{site.SCALA_BINARY_VERSION}}` and its dependencies (except `spark-core_{{site.SCALA_BINARY_VERSION}}` and `spark-streaming_{{site.SCALA_BINARY_VERSION}}` which are provided by `spark-submit`) into the application JAR. Then use `spark-submit` to launch your application (see [Deploying section](streaming-programming-guide.html#deploying-applications) in the main programming guide). @@ -122,12 +128,12 @@ To run the example,
    - bin/run-example streaming.KinesisWordCountASL [Kinesis stream name] [endpoint URL] + bin/run-example streaming.KinesisWordCountASL [Kinesis app name] [Kinesis stream name] [endpoint URL]
    - bin/run-example streaming.JavaKinesisWordCountASL [Kinesis stream name] [endpoint URL] + bin/run-example streaming.JavaKinesisWordCountASL [Kinesis app name] [Kinesis stream name] [endpoint URL]
    @@ -136,7 +142,7 @@ To run the example, - To generate random string data to put onto the Kinesis stream, in another terminal, run the associated Kinesis data producer. - bin/run-example streaming.KinesisWordCountProducerASL [Kinesis stream name] [endpoint URL] 1000 10 + bin/run-example streaming.KinesisWordProducerASL [Kinesis stream name] [endpoint URL] 1000 10 This will push 1000 lines per second of 10 random numbers per line to the Kinesis stream. This data should then be received and processed by the running example. diff --git a/docs/streaming-programming-guide.md b/docs/streaming-programming-guide.md index bd863d48d53e..e72d5580dae5 100644 --- a/docs/streaming-programming-guide.md +++ b/docs/streaming-programming-guide.md @@ -11,7 +11,7 @@ description: Spark Streaming programming guide and tutorial for Spark SPARK_VERS # Overview Spark Streaming is an extension of the core Spark API that enables scalable, high-throughput, fault-tolerant stream processing of live data streams. Data can be ingested from many sources -like Kafka, Flume, Twitter, ZeroMQ, Kinesis or TCP sockets can be processed using complex +like Kafka, Flume, Twitter, ZeroMQ, Kinesis, or TCP sockets, and can be processed using complex algorithms expressed with high-level functions like `map`, `reduce`, `join` and `window`. Finally, processed data can be pushed out to filesystems, databases, and live dashboards. In fact, you can apply Spark's @@ -52,7 +52,7 @@ different languages. **Note:** Python API for Spark Streaming has been introduced in Spark 1.2. It has all the DStream transformations and almost all the output operations available in Scala and Java interfaces. -However, it has only support for basic sources like text files and text data over sockets. +However, it only has support for basic sources like text files and text data over sockets. APIs for additional sources, like Kafka and Flume, will be available in the future. Further information about available features in the Python API are mentioned throughout this document; look out for the tag @@ -69,15 +69,15 @@ do is as follows.
    -First, we import the names of the Spark Streaming classes, and some implicit -conversions from StreamingContext into our environment, to add useful methods to +First, we import the names of the Spark Streaming classes and some implicit +conversions from StreamingContext into our environment in order to add useful methods to other classes we need (like DStream). [StreamingContext](api/scala/index.html#org.apache.spark.streaming.StreamingContext) is the -main entry point for all streaming functionality. We create a local StreamingContext with two execution threads, and batch interval of 1 second. +main entry point for all streaming functionality. We create a local StreamingContext with two execution threads, and a batch interval of 1 second. {% highlight scala %} import org.apache.spark._ import org.apache.spark.streaming._ -import org.apache.spark.streaming.StreamingContext._ // not necessary in Spark 1.3+ +import org.apache.spark.streaming.StreamingContext._ // not necessary since Spark 1.3 // Create a local StreamingContext with two working thread and batch interval of 1 second. // The master requires 2 cores to prevent from a starvation scenario. @@ -96,7 +96,7 @@ val lines = ssc.socketTextStream("localhost", 9999) This `lines` DStream represents the stream of data that will be received from the data server. Each record in this DStream is a line of text. Next, we want to split the lines by -space into words. +space characters into words. {% highlight scala %} // Split each line into words @@ -109,7 +109,7 @@ each line will be split into multiple words and the stream of words is represent `words` DStream. Next, we want to count these words. {% highlight scala %} -import org.apache.spark.streaming.StreamingContext._ // not necessary in Spark 1.3+ +import org.apache.spark.streaming.StreamingContext._ // not necessary since Spark 1.3 // Count each word in each batch val pairs = words.map(word => (word, 1)) val wordCounts = pairs.reduceByKey(_ + _) @@ -463,7 +463,7 @@ receive it there. However, for local testing and unit tests, you can pass "local in-process (detects the number of cores in the local system). Note that this internally creates a [SparkContext](api/scala/index.html#org.apache.spark.SparkContext) (starting point of all Spark functionality) which can be accessed as `ssc.sparkContext`. The batch interval must be set based on the latency requirements of your application -and available cluster resources. See the [Performance Tuning](#setting-the-right-batch-size) +and available cluster resources. See the [Performance Tuning](#setting-the-right-batch-interval) section for more details. A `StreamingContext` object can also be created from an existing `SparkContext` object. @@ -498,7 +498,7 @@ receive it there. However, for local testing and unit tests, you can pass "local in-process. Note that this internally creates a [JavaSparkContext](api/java/index.html?org/apache/spark/api/java/JavaSparkContext.html) (starting point of all Spark functionality) which can be accessed as `ssc.sparkContext`. The batch interval must be set based on the latency requirements of your application -and available cluster resources. See the [Performance Tuning](#setting-the-right-batch-size) +and available cluster resources. See the [Performance Tuning](#setting-the-right-batch-interval) section for more details. A `JavaStreamingContext` object can also be created from an existing `JavaSparkContext`. @@ -531,7 +531,7 @@ receive it there. However, for local testing and unit tests, you can pass "local in-process (detects the number of cores in the local system). The batch interval must be set based on the latency requirements of your application -and available cluster resources. See the [Performance Tuning](#setting-the-right-batch-size) +and available cluster resources. See the [Performance Tuning](#setting-the-right-batch-interval) section for more details.
    @@ -549,7 +549,7 @@ After a context is defined, you have to do the following. - Once a context has been started, no new streaming computations can be set up or added to it. - Once a context has been stopped, it cannot be restarted. - Only one StreamingContext can be active in a JVM at the same time. -- stop() on StreamingContext also stops the SparkContext. To stop only the StreamingContext, set optional parameter of `stop()` called `stopSparkContext` to false. +- stop() on StreamingContext also stops the SparkContext. To stop only the StreamingContext, set the optional parameter of `stop()` called `stopSparkContext` to false. - A SparkContext can be re-used to create multiple StreamingContexts, as long as the previous StreamingContext is stopped (without stopping the SparkContext) before the next StreamingContext is created. *** @@ -583,7 +583,7 @@ the `flatMap` operation is applied on each RDD in the `lines` DStream to generat These underlying RDD transformations are computed by the Spark engine. The DStream operations -hide most of these details and provide the developer with higher-level API for convenience. +hide most of these details and provide the developer with a higher-level API for convenience. These operations are discussed in detail in later sections. *** @@ -600,7 +600,7 @@ data from a source and stores it in Spark's memory for processing. Spark Streaming provides two categories of built-in streaming sources. - *Basic sources*: Sources directly available in the StreamingContext API. - Example: file systems, socket connections, and Akka actors. + Examples: file systems, socket connections, and Akka actors. - *Advanced sources*: Sources like Kafka, Flume, Kinesis, Twitter, etc. are available through extra utility classes. These require linking against extra dependencies as discussed in the [linking](#linking) section. @@ -610,11 +610,11 @@ We are going to discuss some of the sources present in each category later in th Note that, if you want to receive multiple streams of data in parallel in your streaming application, you can create multiple input DStreams (discussed further in the [Performance Tuning](#level-of-parallelism-in-data-receiving) section). This will -create multiple receivers which will simultaneously receive multiple data streams. But note that -Spark worker/executor as a long-running task, hence it occupies one of the cores allocated to the -Spark Streaming application. Hence, it is important to remember that Spark Streaming application +create multiple receivers which will simultaneously receive multiple data streams. But note that a +Spark worker/executor is a long-running task, hence it occupies one of the cores allocated to the +Spark Streaming application. Therefore, it is important to remember that a Spark Streaming application needs to be allocated enough cores (or threads, if running locally) to process the received data, -as well as, to run the receiver(s). +as well as to run the receiver(s). ##### Points to remember {:.no_toc} @@ -623,13 +623,13 @@ as well as, to run the receiver(s). Either of these means that only one thread will be used for running tasks locally. If you are using a input DStream based on a receiver (e.g. sockets, Kafka, Flume, etc.), then the single thread will be used to run the receiver, leaving no thread for processing the received data. Hence, when - running locally, always use "local[*n*]" as the master URL where *n* > number of receivers to run - (see [Spark Properties](configuration.html#spark-properties.html) for information on how to set + running locally, always use "local[*n*]" as the master URL, where *n* > number of receivers to run + (see [Spark Properties](configuration.html#spark-properties) for information on how to set the master). - Extending the logic to running on a cluster, the number of cores allocated to the Spark Streaming - application must be more than the number of receivers. Otherwise the system will receive data, but - not be able to process them. + application must be more than the number of receivers. Otherwise the system will receive data, but + not be able to process it. ### Basic Sources {:.no_toc} @@ -639,7 +639,7 @@ which creates a DStream from text data received over a TCP socket connection. Besides sockets, the StreamingContext API provides methods for creating DStreams from files and Akka actors as input sources. -- **File Streams:** For reading data from files on any file system compatible with the HDFS API (that is, HDFS, S3, NFS, etc.), a DStream can be created as +- **File Streams:** For reading data from files on any file system compatible with the HDFS API (that is, HDFS, S3, NFS, etc.), a DStream can be created as:
    @@ -682,14 +682,14 @@ for Java, and [StreamingContext](api/python/pyspark.streaming.html#pyspark.strea ### Advanced Sources {:.no_toc} -Python API As of Spark 1.3, -out of these sources, *only* Kafka is available in the Python API. We will add more advanced sources in the Python API in future. +Python API As of Spark {{site.SPARK_VERSION_SHORT}}, +out of these sources, *only* Kafka and Flume are available in the Python API. We will add more advanced sources in the Python API in future. This category of sources require interfacing with external non-Spark libraries, some of them with complex dependencies (e.g., Kafka and Flume). Hence, to minimize issues related to version conflicts -of dependencies, the functionality to create DStreams from these sources have been moved to separate -libraries, that can be [linked](#linking) to explicitly when necessary. For example, if you want to -create a DStream using data from Twitter's stream of tweets, you have to do the following. +of dependencies, the functionality to create DStreams from these sources has been moved to separate +libraries that can be [linked](#linking) to explicitly when necessary. For example, if you want to +create a DStream using data from Twitter's stream of tweets, you have to do the following: 1. *Linking*: Add the artifact `spark-streaming-twitter_{{site.SCALA_BINARY_VERSION}}` to the SBT/Maven project dependencies. @@ -719,11 +719,11 @@ TwitterUtils.createStream(jssc); Note that these advanced sources are not available in the Spark shell, hence applications based on these advanced sources cannot be tested in the shell. If you really want to use them in the Spark shell you will have to download the corresponding Maven artifact's JAR along with its dependencies -and it in the classpath. +and add it to the classpath. Some of these advanced sources are as follows. -- **Kafka:** Spark Streaming {{site.SPARK_VERSION_SHORT}} is compatible with Kafka 0.8.1.1. See the [Kafka Integration Guide](streaming-kafka-integration.html) for more details. +- **Kafka:** Spark Streaming {{site.SPARK_VERSION_SHORT}} is compatible with Kafka 0.8.2.1. See the [Kafka Integration Guide](streaming-kafka-integration.html) for more details. - **Flume:** Spark Streaming {{site.SPARK_VERSION_SHORT}} is compatible with Flume 1.4.0. See the [Flume Integration Guide](streaming-flume-integration.html) for more details. @@ -743,7 +743,7 @@ Some of these advanced sources are as follows. Python API This is not yet supported in Python. -Input DStreams can also be created out of custom data sources. All you have to do is implement an +Input DStreams can also be created out of custom data sources. All you have to do is implement a user-defined **receiver** (see next section to understand what that is) that can receive data from the custom sources and push it into Spark. See the [Custom Receiver Guide](streaming-custom-receivers.html) for details. @@ -753,14 +753,12 @@ Guide](streaming-custom-receivers.html) for details. There can be two kinds of data sources based on their *reliability*. Sources (like Kafka and Flume) allow the transferred data to be acknowledged. If the system receiving -data from these *reliable* sources acknowledge the received data correctly, it can be ensured -that no data gets lost due to any kind of failure. This leads to two kinds of receivers. +data from these *reliable* sources acknowledges the received data correctly, it can be ensured +that no data will be lost due to any kind of failure. This leads to two kinds of receivers: -1. *Reliable Receiver* - A *reliable receiver* correctly acknowledges a reliable - source that the data has been received and stored in Spark with replication. -1. *Unreliable Receiver* - These are receivers for sources that do not support acknowledging. Even - for reliable sources, one may implement an unreliable receiver that do not go into the complexity - of acknowledging correctly. +1. *Reliable Receiver* - A *reliable receiver* correctly sends acknowledgment to a reliable + source when the data has been received and stored in Spark with replication. +1. *Unreliable Receiver* - An *unreliable receiver* does *not* send acknowledgment to a source. This can be used for sources that do not support acknowledgment, or even for reliable sources when one does not want or need to go into the complexity of acknowledgment. The details of how to write a reliable receiver are discussed in the [Custom Receiver Guide](streaming-custom-receivers.html). @@ -828,7 +826,7 @@ Some of the common ones are as follows. cogroup(otherStream, [numTasks]) - When called on DStream of (K, V) and (K, W) pairs, return a new DStream of + When called on a DStream of (K, V) and (K, W) pairs, return a new DStream of (K, Seq[V], Seq[W]) tuples. @@ -852,13 +850,13 @@ A few of these transformations are worth discussing in more detail. The `updateStateByKey` operation allows you to maintain arbitrary state while continuously updating it with new information. To use this, you will have to do two steps. -1. Define the state - The state can be of arbitrary data type. +1. Define the state - The state can be an arbitrary data type. 1. Define the state update function - Specify with a function how to update the state using the -previous state and the new values from input stream. +previous state and the new values from an input stream. Let's illustrate this with an example. Say you want to maintain a running count of each word seen in a text data stream. Here, the running count is the state and it is an integer. We -define the update function as +define the update function as:
    @@ -947,7 +945,7 @@ operation that is not exposed in the DStream API. For example, the functionality of joining every batch in a data stream with another dataset is not directly exposed in the DStream API. However, you can easily use `transform` to do this. This enables very powerful possibilities. For example, -if you want to do real-time data cleaning by joining the input data stream with precomputed +one can do real-time data cleaning by joining the input data stream with precomputed spam information (maybe generated with Spark as well) and then filtering based on it.
    @@ -991,13 +989,14 @@ cleanedDStream = wordCounts.transform(lambda rdd: rdd.join(spamInfoRDD).filter(.
    -In fact, you can also use [machine learning](mllib-guide.html) and -[graph computation](graphx-programming-guide.html) algorithms in the `transform` method. +Note that the supplied function gets called in every batch interval. This allows you to do +time-varying RDD operations, that is, RDD operations, number of partitions, broadcast variables, +etc. can be changed between batches. #### Window Operations {:.no_toc} Spark Streaming also provides *windowed computations*, which allow you to apply -transformations over a sliding window of data. This following figure illustrates this sliding +transformations over a sliding window of data. The following figure illustrates this sliding window.

    @@ -1009,11 +1008,11 @@ window. As shown in the figure, every time the window *slides* over a source DStream, the source RDDs that fall within the window are combined and operated upon to produce the -RDDs of the windowed DStream. In this specific case, the operation is applied over last 3 time +RDDs of the windowed DStream. In this specific case, the operation is applied over the last 3 time units of data, and slides by 2 time units. This shows that any window operation needs to specify two parameters. - * window length - The duration of the window (3 in the figure) + * window length - The duration of the window (3 in the figure). * sliding interval - The interval at which the window operation is performed (2 in the figure). @@ -1021,7 +1020,7 @@ These two parameters must be multiples of the batch interval of the source DStre figure). Let's illustrate the window operations with an example. Say, you want to extend the -[earlier example](#a-quick-example) by generating word counts over last 30 seconds of data, +[earlier example](#a-quick-example) by generating word counts over the last 30 seconds of data, every 10 seconds. To do this, we have to apply the `reduceByKey` operation on the `pairs` DStream of `(word, 1)` pairs over the last 30 seconds of data. This is done using the operation `reduceByKeyAndWindow`. @@ -1096,13 +1095,13 @@ said two parameters - windowLength and slideInterval. reduceByKeyAndWindow(func, invFunc, windowLength, slideInterval, [numTasks]) - A more efficient version of the above reduceByKeyAndWindow() where the reduce + A more efficient version of the above reduceByKeyAndWindow() where the reduce value of each window is calculated incrementally using the reduce values of the previous window. - This is done by reducing the new data that enter the sliding window, and "inverse reducing" the - old data that leave the window. An example would be that of "adding" and "subtracting" counts - of keys as the window slides. However, it is applicable to only "invertible reduce functions", + This is done by reducing the new data that enters the sliding window, and "inverse reducing" the + old data that leaves the window. An example would be that of "adding" and "subtracting" counts + of keys as the window slides. However, it is applicable only to "invertible reduce functions", that is, those reduce functions which have a corresponding "inverse reduce" function (taken as - parameter invFunc. Like in reduceByKeyAndWindow, the number of reduce tasks + parameter invFunc). Like in reduceByKeyAndWindow, the number of reduce tasks is configurable through an optional argument. Note that [checkpointing](#checkpointing) must be enabled for using this operation. @@ -1224,7 +1223,7 @@ For the Python API, see [DStream](api/python/pyspark.streaming.html#pyspark.stre *** ## Output Operations on DStreams -Output operations allow DStream's data to be pushed out external systems like a database or a file systems. +Output operations allow DStream's data to be pushed out to external systems like a database or a file systems. Since the output operations actually allow the transformed data to be consumed by external systems, they trigger the actual execution of all the DStream transformations (similar to actions for RDDs). Currently, the following output operations are defined: @@ -1233,7 +1232,7 @@ Currently, the following output operations are defined: Output OperationMeaning print() - Prints first ten elements of every batch of data in a DStream on the driver node running + Prints the first ten elements of every batch of data in a DStream on the driver node running the streaming application. This is useful for development and debugging.
    Python API This is called @@ -1242,12 +1241,12 @@ Currently, the following output operations are defined: saveAsTextFiles(prefix, [suffix]) - Save this DStream's contents as a text files. The file name at each batch interval is + Save this DStream's contents as text files. The file name at each batch interval is generated based on prefix and suffix: "prefix-TIME_IN_MS[.suffix]". saveAsObjectFiles(prefix, [suffix]) - Save this DStream's contents as a SequenceFile of serialized Java objects. The file + Save this DStream's contents as SequenceFiles of serialized Java objects. The file name at each batch interval is generated based on prefix and suffix: "prefix-TIME_IN_MS[.suffix]".
    @@ -1257,7 +1256,7 @@ Currently, the following output operations are defined: saveAsHadoopFiles(prefix, [suffix]) - Save this DStream's contents as a Hadoop file. The file name at each batch interval is + Save this DStream's contents as Hadoop files. The file name at each batch interval is generated based on prefix and suffix: "prefix-TIME_IN_MS[.suffix]".
    Python API This is not available in @@ -1267,7 +1266,7 @@ Currently, the following output operations are defined: foreachRDD(func) The most generic output operator that applies a function, func, to each RDD generated from - the stream. This function should push the data in each RDD to a external system, like saving the RDD to + the stream. This function should push the data in each RDD to an external system, such as saving the RDD to files, or writing it over the network to a database. Note that the function func is executed in the driver process running the streaming application, and will usually have RDD actions in it that will force the computation of the streaming RDDs. @@ -1277,14 +1276,14 @@ Currently, the following output operations are defined: ### Design Patterns for using foreachRDD {:.no_toc} -`dstream.foreachRDD` is a powerful primitive that allows data to sent out to external systems. +`dstream.foreachRDD` is a powerful primitive that allows data to be sent out to external systems. However, it is important to understand how to use this primitive correctly and efficiently. Some of the common mistakes to avoid are as follows. Often writing data to external system requires creating a connection object (e.g. TCP connection to a remote server) and using it to send data to a remote system. For this purpose, a developer may inadvertently try creating a connection object at -the Spark driver, but try to use it in a Spark worker to save records in the RDDs. +the Spark driver, and then try to use it in a Spark worker to save records in the RDDs. For example (in Scala),

    @@ -1346,7 +1345,7 @@ dstream.foreachRDD(lambda rdd: rdd.foreach(sendRecord)) Typically, creating a connection object has time and resource overheads. Therefore, creating and destroying a connection object for each record can incur unnecessarily high overheads and can significantly reduce the overall throughput of the system. A better solution is to use -`rdd.foreachPartition` - create a single connection object and send all the records in a RDD +`rdd.foreachPartition` - create a single connection object and send all the records in a RDD partition using that connection.
    @@ -1427,26 +1426,6 @@ You can easily use [DataFrames and SQL](sql-programming-guide.html) operations o
    {% highlight scala %} -/** Lazily instantiated singleton instance of SQLContext */ -object SQLContextSingleton { - @transient private var instance: SQLContext = null - - // Instantiate SQLContext on demand - def getInstance(sparkContext: SparkContext): SQLContext = synchronized { - if (instance == null) { - instance = new SQLContext(sparkContext) - } - instance - } -} - -... - -/** Case class for converting RDD to DataFrame */ -case class Row(word: String) - -... - /** DataFrame operations inside your streaming program */ val words: DStream[String] = ... @@ -1454,11 +1433,11 @@ val words: DStream[String] = ... words.foreachRDD { rdd => // Get the singleton instance of SQLContext - val sqlContext = SQLContextSingleton.getInstance(rdd.sparkContext) + val sqlContext = SQLContext.getOrCreate(rdd.sparkContext) import sqlContext.implicits._ - // Convert RDD[String] to RDD[case class] to DataFrame - val wordsDataFrame = rdd.map(w => Row(w)).toDF() + // Convert RDD[String] to DataFrame + val wordsDataFrame = rdd.toDF("word") // Register as table wordsDataFrame.registerTempTable("words") @@ -1476,19 +1455,6 @@ See the full [source code]({{site.SPARK_GITHUB_URL}}/blob/master/examples/src/ma
    {% highlight java %} -/** Lazily instantiated singleton instance of SQLContext */ -class JavaSQLContextSingleton { - static private transient SQLContext instance = null; - static public SQLContext getInstance(SparkContext sparkContext) { - if (instance == null) { - instance = new SQLContext(sparkContext); - } - return instance; - } -} - -... - /** Java Bean class for converting RDD to DataFrame */ public class JavaRow implements java.io.Serializable { private String word; @@ -1512,7 +1478,9 @@ words.foreachRDD( new Function2, Time, Void>() { @Override public Void call(JavaRDD rdd, Time time) { - SQLContext sqlContext = JavaSQLContextSingleton.getInstance(rdd.context()); + + // Get the singleton instance of SQLContext + SQLContext sqlContext = SQLContext.getOrCreate(rdd.context()); // Convert RDD[String] to RDD[case class] to DataFrame JavaRDD rowRDD = rdd.map(new Function() { @@ -1581,7 +1549,7 @@ See the full [source code]({{site.SPARK_GITHUB_URL}}/blob/master/examples/src/ma
    -You can also run SQL queries on tables defined on streaming data from a different thread (that is, asynchronous to the running StreamingContext). Just make sure that you set the StreamingContext to remember sufficient amount of streaming data such that query can run. Otherwise the StreamingContext, which is unaware of the any asynchronous SQL queries, will delete off old streaming data before the query can complete. For example, if you want to query the last batch, but your query can take 5 minutes to run, then call `streamingContext.remember(Minutes(5))` (in Scala, or equivalent in other languages). +You can also run SQL queries on tables defined on streaming data from a different thread (that is, asynchronous to the running StreamingContext). Just make sure that you set the StreamingContext to remember a sufficient amount of streaming data such that the query can run. Otherwise the StreamingContext, which is unaware of the any asynchronous SQL queries, will delete off old streaming data before the query can complete. For example, if you want to query the last batch, but your query can take 5 minutes to run, then call `streamingContext.remember(Minutes(5))` (in Scala, or equivalent in other languages). See the [DataFrames and SQL](sql-programming-guide.html) guide to learn more about DataFrames. @@ -1594,7 +1562,7 @@ You can also easily use machine learning algorithms provided by [MLlib](mllib-gu ## Caching / Persistence Similar to RDDs, DStreams also allow developers to persist the stream's data in memory. That is, -using `persist()` method on a DStream will automatically persist every RDD of that DStream in +using the `persist()` method on a DStream will automatically persist every RDD of that DStream in memory. This is useful if the data in the DStream will be computed multiple times (e.g., multiple operations on the same data). For window-based operations like `reduceByWindow` and `reduceByKeyAndWindow` and state-based operations like `updateStateByKey`, this is implicitly true. @@ -1606,28 +1574,27 @@ default persistence level is set to replicate the data to two nodes for fault-to Note that, unlike RDDs, the default persistence level of DStreams keeps the data serialized in memory. This is further discussed in the [Performance Tuning](#memory-tuning) section. More -information on different persistence levels can be found in -[Spark Programming Guide](programming-guide.html#rdd-persistence). +information on different persistence levels can be found in the [Spark Programming Guide](programming-guide.html#rdd-persistence). *** ## Checkpointing A streaming application must operate 24/7 and hence must be resilient to failures unrelated to the application logic (e.g., system failures, JVM crashes, etc.). For this to be possible, -Spark Streaming needs to *checkpoints* enough information to a fault- +Spark Streaming needs to *checkpoint* enough information to a fault- tolerant storage system such that it can recover from failures. There are two types of data that are checkpointed. - *Metadata checkpointing* - Saving of the information defining the streaming computation to fault-tolerant storage like HDFS. This is used to recover from failure of the node running the driver of the streaming application (discussed in detail later). Metadata includes: - + *Configuration* - The configuration that were used to create the streaming application. + + *Configuration* - The configuration that was used to create the streaming application. + *DStream operations* - The set of DStream operations that define the streaming application. + *Incomplete batches* - Batches whose jobs are queued but have not completed yet. - *Data checkpointing* - Saving of the generated RDDs to reliable storage. This is necessary in some *stateful* transformations that combine data across multiple batches. In such - transformations, the generated RDDs depends on RDDs of previous batches, which causes the length - of the dependency chain to keep increasing with time. To avoid such unbounded increase in recovery + transformations, the generated RDDs depend on RDDs of previous batches, which causes the length + of the dependency chain to keep increasing with time. To avoid such unbounded increases in recovery time (proportional to dependency chain), intermediate RDDs of stateful transformations are periodically *checkpointed* to reliable storage (e.g. HDFS) to cut off the dependency chains. @@ -1641,10 +1608,10 @@ transformations are used. Checkpointing must be enabled for applications with any of the following requirements: - *Usage of stateful transformations* - If either `updateStateByKey` or `reduceByKeyAndWindow` (with - inverse function) is used in the application, then the checkpoint directory must be provided for - allowing periodic RDD checkpointing. + inverse function) is used in the application, then the checkpoint directory must be provided to + allow for periodic RDD checkpointing. - *Recovering from failures of the driver running the application* - Metadata checkpoints are used - for to recover with progress information. + to recover with progress information. Note that simple streaming applications without the aforementioned stateful transformations can be run without enabling checkpointing. The recovery from driver failures will also be partial in @@ -1659,7 +1626,7 @@ Checkpointing can be enabled by setting a directory in a fault-tolerant, reliable file system (e.g., HDFS, S3, etc.) to which the checkpoint information will be saved. This is done by using `streamingContext.checkpoint(checkpointDirectory)`. This will allow you to use the aforementioned stateful transformations. Additionally, -if you want make the application recover from driver failures, you should rewrite your +if you want to make the application recover from driver failures, you should rewrite your streaming application to have the following behavior. + When the program is being started for the first time, it will create a new StreamingContext, @@ -1780,18 +1747,17 @@ You can also explicitly create a `StreamingContext` from the checkpoint data and In addition to using `getOrCreate` one also needs to ensure that the driver process gets restarted automatically on failure. This can only be done by the deployment infrastructure that is used to run the application. This is further discussed in the -[Deployment](#deploying-applications.html) section. +[Deployment](#deploying-applications) section. Note that checkpointing of RDDs incurs the cost of saving to reliable storage. This may cause an increase in the processing time of those batches where RDDs get checkpointed. Hence, the interval of checkpointing needs to be set carefully. At small batch sizes (say 1 second), checkpointing every batch may significantly reduce operation throughput. Conversely, checkpointing too infrequently -causes the lineage and task sizes to grow which may have detrimental effects. For stateful +causes the lineage and task sizes to grow, which may have detrimental effects. For stateful transformations that require RDD checkpointing, the default interval is a multiple of the batch interval that is at least 10 seconds. It can be set by using -`dstream.checkpoint(checkpointInterval)`. Typically, a checkpoint interval of 5 - 10 times of -sliding interval of a DStream is good setting to try. +`dstream.checkpoint(checkpointInterval)`. Typically, a checkpoint interval of 5 - 10 sliding intervals of a DStream is a good setting to try. *** @@ -1864,17 +1830,17 @@ To run a Spark Streaming applications, you need to have the following. {:.no_toc} If a running Spark Streaming application needs to be upgraded with new -application code, then there are two possible mechanism. +application code, then there are two possible mechanisms. - The upgraded Spark Streaming application is started and run in parallel to the existing application. -Once the new one (receiving the same data as the old one) has been warmed up and ready +Once the new one (receiving the same data as the old one) has been warmed up and is ready for prime time, the old one be can be brought down. Note that this can be done for data sources that support sending the data to two destinations (i.e., the earlier and upgraded applications). - The existing application is shutdown gracefully (see [`StreamingContext.stop(...)`](api/scala/index.html#org.apache.spark.streaming.StreamingContext) or [`JavaStreamingContext.stop(...)`](api/java/index.html?org/apache/spark/streaming/api/java/JavaStreamingContext.html) -for graceful shutdown options) which ensure data that have been received is completely +for graceful shutdown options) which ensure data that has been received is completely processed before shutdown. Then the upgraded application can be started, which will start processing from the same point where the earlier application left off. Note that this can be done only with input sources that support source-side buffering @@ -1909,10 +1875,10 @@ The following two metrics in web UI are particularly important: to finish. If the batch processing time is consistently more than the batch interval and/or the queueing -delay keeps increasing, then it indicates the system is -not able to process the batches as fast they are being generated and falling behind. +delay keeps increasing, then it indicates that the system is +not able to process the batches as fast they are being generated and is falling behind. In that case, consider -[reducing](#reducing-the-processing-time-of-each-batch) the batch processing time. +[reducing](#reducing-the-batch-processing-times) the batch processing time. The progress of a Spark Streaming program can also be monitored using the [StreamingListener](api/scala/index.html#org.apache.spark.streaming.scheduler.StreamingListener) interface, @@ -1923,8 +1889,8 @@ and it is likely to be improved upon (i.e., more information reported) in the fu *************************************************************************************************** # Performance Tuning -Getting the best performance of a Spark Streaming application on a cluster requires a bit of -tuning. This section explains a number of the parameters and configurations that can tuned to +Getting the best performance out of a Spark Streaming application on a cluster requires a bit of +tuning. This section explains a number of the parameters and configurations that can be tuned to improve the performance of you application. At a high level, you need to consider two things: 1. Reducing the processing time of each batch of data by efficiently using cluster resources. @@ -1934,22 +1900,22 @@ improve the performance of you application. At a high level, you need to conside ## Reducing the Batch Processing Times There are a number of optimizations that can be done in Spark to minimize the processing time of -each batch. These have been discussed in detail in [Tuning Guide](tuning.html). This section +each batch. These have been discussed in detail in the [Tuning Guide](tuning.html). This section highlights some of the most important ones. ### Level of Parallelism in Data Receiving {:.no_toc} -Receiving data over the network (like Kafka, Flume, socket, etc.) requires the data to deserialized +Receiving data over the network (like Kafka, Flume, socket, etc.) requires the data to be deserialized and stored in Spark. If the data receiving becomes a bottleneck in the system, then consider parallelizing the data receiving. Note that each input DStream creates a single receiver (running on a worker machine) that receives a single stream of data. Receiving multiple data streams can therefore be achieved by creating multiple input DStreams and configuring them to receive different partitions of the data stream from the source(s). For example, a single Kafka input DStream receiving two topics of data can be split into two -Kafka input streams, each receiving only one topic. This would run two receivers on two workers, -thus allowing data to be received in parallel, and increasing overall throughput. These multiple -DStream can be unioned together to create a single DStream. Then the transformations that was -being applied on the single input DStream can applied on the unified stream. This is done as follows. +Kafka input streams, each receiving only one topic. This would run two receivers, +allowing data to be received in parallel, thus increasing overall throughput. These multiple +DStreams can be unioned together to create a single DStream. Then the transformations that were +being applied on a single input DStream can be applied on the unified stream. This is done as follows.
    @@ -1971,16 +1937,24 @@ JavaPairDStream unifiedStream = streamingContext.union(kafkaStre unifiedStream.print(); {% endhighlight %}
    +
    +{% highlight python %} +numStreams = 5 +kafkaStreams = [KafkaUtils.createStream(...) for _ in range (numStreams)] +unifiedStream = streamingContext.union(kafkaStreams) +unifiedStream.print() +{% endhighlight %} +
    Another parameter that should be considered is the receiver's blocking interval, which is determined by the [configuration parameter](configuration.html#spark-streaming) `spark.streaming.blockInterval`. For most receivers, the received data is coalesced together into blocks of data before storing inside Spark's memory. The number of blocks in each batch -determines the number of tasks that will be used to process those +determines the number of tasks that will be used to process the received data in a map-like transformation. The number of tasks per receiver per batch will be approximately (batch interval / block interval). For example, block interval of 200 ms will -create 10 tasks per 2 second batches. Too low the number of tasks (that is, less than the number +create 10 tasks per 2 second batches. If the number of tasks is too low (that is, less than the number of cores per machine), then it will be inefficient as all available cores will not be used to process the data. To increase the number of tasks for a given batch interval, reduce the block interval. However, the recommended minimum value of block interval is about 50 ms, @@ -1988,7 +1962,7 @@ below which the task launching overheads may be a problem. An alternative to receiving data with multiple input streams / receivers is to explicitly repartition the input data stream (using `inputStream.repartition()`). -This distributes the received batches of data across specified number of machines in the cluster +This distributes the received batches of data across the specified number of machines in the cluster before further processing. ### Level of Parallelism in Data Processing @@ -1996,7 +1970,7 @@ before further processing. Cluster resources can be under-utilized if the number of parallel tasks used in any stage of the computation is not high enough. For example, for distributed reduce operations like `reduceByKey` and `reduceByKeyAndWindow`, the default number of parallel tasks is controlled by -the`spark.default.parallelism` [configuration property](configuration.html#spark-properties). You +the `spark.default.parallelism` [configuration property](configuration.html#spark-properties). You can pass the level of parallelism as an argument (see [`PairDStreamFunctions`](api/scala/index.html#org.apache.spark.streaming.dstream.PairDStreamFunctions) documentation), or set the `spark.default.parallelism` @@ -2004,20 +1978,20 @@ documentation), or set the `spark.default.parallelism` ### Data Serialization {:.no_toc} -The overheads of data serialization can be reduce by tuning the serialization formats. In case of streaming, there are two types of data that are being serialized. +The overheads of data serialization can be reduced by tuning the serialization formats. In the case of streaming, there are two types of data that are being serialized. -* **Input data**: By default, the input data received through Receivers is stored in the executors' memory with [StorageLevel.MEMORY_AND_DISK_SER_2](api/scala/index.html#org.apache.spark.storage.StorageLevel$). That is, the data is serialized into bytes to reduce GC overheads, and replicated for tolerating executor failures. Also, the data is kept first in memory, and spilled over to disk only if the memory is unsufficient to hold all the input data necessary for the streaming computation. This serialization obviously has overheads -- the receiver must deserialize the received data and re-serialize it using Spark's serialization format. +* **Input data**: By default, the input data received through Receivers is stored in the executors' memory with [StorageLevel.MEMORY_AND_DISK_SER_2](api/scala/index.html#org.apache.spark.storage.StorageLevel$). That is, the data is serialized into bytes to reduce GC overheads, and replicated for tolerating executor failures. Also, the data is kept first in memory, and spilled over to disk only if the memory is insufficient to hold all of the input data necessary for the streaming computation. This serialization obviously has overheads -- the receiver must deserialize the received data and re-serialize it using Spark's serialization format. -* **Persisted RDDs generated by Streaming Operations**: RDDs generated by streaming computations may be persisted in memory. For example, window operation persist data in memory as they would be processed multiple times. However, unlike Spark, by default RDDs are persisted with [StorageLevel.MEMORY_ONLY_SER](api/scala/index.html#org.apache.spark.storage.StorageLevel$) (i.e. serialized) to minimize GC overheads. +* **Persisted RDDs generated by Streaming Operations**: RDDs generated by streaming computations may be persisted in memory. For example, window operations persist data in memory as they would be processed multiple times. However, unlike the Spark Core default of [StorageLevel.MEMORY_ONLY](api/scala/index.html#org.apache.spark.storage.StorageLevel$), persisted RDDs generated by streaming computations are persisted with [StorageLevel.MEMORY_ONLY_SER](api/scala/index.html#org.apache.spark.storage.StorageLevel$) (i.e. serialized) by default to minimize GC overheads. -In both cases, using Kryo serialization can reduce both CPU and memory overheads. See the [Spark Tuning Guide](tuning.html#data-serialization)) for more details. Consider registering custom classes, and disabling object reference tracking for Kryo (see Kryo-related configurations in the [Configuration Guide](configuration.html#compression-and-serialization)). +In both cases, using Kryo serialization can reduce both CPU and memory overheads. See the [Spark Tuning Guide](tuning.html#data-serialization) for more details. For Kryo, consider registering custom classes, and disabling object reference tracking (see Kryo-related configurations in the [Configuration Guide](configuration.html#compression-and-serialization)). -In specific cases where the amount of data that needs to be retained for the streaming application is not large, it may be feasible to persist data (both types) as deserialized objects without incurring excessive GC overheads. For example, if you are using batch intervals of few seconds and no window operations, then you can try disabling serialization in persisted data by explicitly setting the storage level accordingly. This would reduce the CPU overheads due to serialization, potentially improving performance without too much GC overheads. +In specific cases where the amount of data that needs to be retained for the streaming application is not large, it may be feasible to persist data (both types) as deserialized objects without incurring excessive GC overheads. For example, if you are using batch intervals of a few seconds and no window operations, then you can try disabling serialization in persisted data by explicitly setting the storage level accordingly. This would reduce the CPU overheads due to serialization, potentially improving performance without too much GC overheads. ### Task Launching Overheads {:.no_toc} If the number of tasks launched per second is high (say, 50 or more per second), then the overhead -of sending out tasks to the slaves maybe significant and will make it hard to achieve sub-second +of sending out tasks to the slaves may be significant and will make it hard to achieve sub-second latencies. The overhead can be reduced by the following changes: * **Task Serialization**: Using Kryo serialization for serializing tasks can reduce the task @@ -2036,7 +2010,7 @@ thus allowing sub-second batch size to be viable. For a Spark Streaming application running on a cluster to be stable, the system should be able to process data as fast as it is being received. In other words, batches of data should be processed as fast as they are being generated. Whether this is true for an application can be found by -[monitoring](#monitoring) the processing times in the streaming web UI, where the batch +[monitoring](#monitoring-applications) the processing times in the streaming web UI, where the batch processing time should be less than the batch interval. Depending on the nature of the streaming @@ -2049,35 +2023,35 @@ production can be sustained. A good approach to figure out the right batch size for your application is to test it with a conservative batch interval (say, 5-10 seconds) and a low data rate. To verify whether the system -is able to keep up with data rate, you can check the value of the end-to-end delay experienced +is able to keep up with the data rate, you can check the value of the end-to-end delay experienced by each processed batch (either look for "Total delay" in Spark driver log4j logs, or use the [StreamingListener](api/scala/index.html#org.apache.spark.streaming.scheduler.StreamingListener) interface). If the delay is maintained to be comparable to the batch size, then system is stable. Otherwise, if the delay is continuously increasing, it means that the system is unable to keep up and it therefore unstable. Once you have an idea of a stable configuration, you can try increasing the -data rate and/or reducing the batch size. Note that momentary increase in the delay due to -temporary data rate increases maybe fine as long as the delay reduces back to a low value +data rate and/or reducing the batch size. Note that a momentary increase in the delay due to +temporary data rate increases may be fine as long as the delay reduces back to a low value (i.e., less than batch size). *** ## Memory Tuning -Tuning the memory usage and GC behavior of Spark applications have been discussed in great detail +Tuning the memory usage and GC behavior of Spark applications has been discussed in great detail in the [Tuning Guide](tuning.html#memory-tuning). It is strongly recommended that you read that. In this section, we discuss a few tuning parameters specifically in the context of Spark Streaming applications. -The amount of cluster memory required by a Spark Streaming application depends heavily on the type of transformations used. For example, if you want to use a window operation on last 10 minutes of data, then your cluster should have sufficient memory to hold 10 minutes of worth of data in memory. Or if you want to use `updateStateByKey` with a large number of keys, then the necessary memory will be high. On the contrary, if you want to do a simple map-filter-store operation, then necessary memory will be low. +The amount of cluster memory required by a Spark Streaming application depends heavily on the type of transformations used. For example, if you want to use a window operation on the last 10 minutes of data, then your cluster should have sufficient memory to hold 10 minutes worth of data in memory. Or if you want to use `updateStateByKey` with a large number of keys, then the necessary memory will be high. On the contrary, if you want to do a simple map-filter-store operation, then the necessary memory will be low. -In general, since the data received through receivers are stored with StorageLevel.MEMORY_AND_DISK_SER_2, the data that does not fit in memory will spill over to the disk. This may reduce the performance of the streaming application, and hence it is advised to provide sufficient memory as required by your streaming application. Its best to try and see the memory usage on a small scale and estimate accordingly. +In general, since the data received through receivers is stored with StorageLevel.MEMORY_AND_DISK_SER_2, the data that does not fit in memory will spill over to the disk. This may reduce the performance of the streaming application, and hence it is advised to provide sufficient memory as required by your streaming application. Its best to try and see the memory usage on a small scale and estimate accordingly. -Another aspect of memory tuning is garbage collection. For a streaming application that require low latency, it is undesirable to have large pauses caused by JVM Garbage Collection. +Another aspect of memory tuning is garbage collection. For a streaming application that requires low latency, it is undesirable to have large pauses caused by JVM Garbage Collection. -There are a few parameters that can help you tune the memory usage and GC overheads. +There are a few parameters that can help you tune the memory usage and GC overheads: -* **Persistence Level of DStreams**: As mentioned earlier in the [Data Serialization](#data-serialization) section, the input data and RDDs are by default persisted as serialized bytes. This reduces both, the memory usage and GC overheads, compared to deserialized persistence. Enabling Kryo serialization further reduces serialized sizes and memory usage. Further reduction in memory usage can be achieved with compression (see the Spark configuration `spark.rdd.compress`), at the cost of CPU time. +* **Persistence Level of DStreams**: As mentioned earlier in the [Data Serialization](#data-serialization) section, the input data and RDDs are by default persisted as serialized bytes. This reduces both the memory usage and GC overheads, compared to deserialized persistence. Enabling Kryo serialization further reduces serialized sizes and memory usage. Further reduction in memory usage can be achieved with compression (see the Spark configuration `spark.rdd.compress`), at the cost of CPU time. -* **Clearing old data**: By default, all input data and persisted RDDs generated by DStream transformations are automatically cleared. Spark Streaming decides when to clear the data based on the transformations that are used. For example, if you are using window operation of 10 minutes, then Spark Streaming will keep around last 10 minutes of data, and actively throw away older data. -Data can be retained for longer duration (e.g. interactively querying older data) by setting `streamingContext.remember`. +* **Clearing old data**: By default, all input data and persisted RDDs generated by DStream transformations are automatically cleared. Spark Streaming decides when to clear the data based on the transformations that are used. For example, if you are using a window operation of 10 minutes, then Spark Streaming will keep around the last 10 minutes of data, and actively throw away older data. +Data can be retained for a longer duration (e.g. interactively querying older data) by setting `streamingContext.remember`. * **CMS Garbage Collector**: Use of the concurrent mark-and-sweep GC is strongly recommended for keeping GC-related pauses consistently low. Even though concurrent GC is known to reduce the overall processing throughput of the system, its use is still recommended to achieve more @@ -2107,18 +2081,18 @@ re-computed from the original fault-tolerant dataset using the lineage of operat 1. Assuming that all of the RDD transformations are deterministic, the data in the final transformed RDD will always be the same irrespective of failures in the Spark cluster. -Spark operates on data on fault-tolerant file systems like HDFS or S3. Hence, +Spark operates on data in fault-tolerant file systems like HDFS or S3. Hence, all of the RDDs generated from the fault-tolerant data are also fault-tolerant. However, this is not the case for Spark Streaming as the data in most cases is received over the network (except when `fileStream` is used). To achieve the same fault-tolerance properties for all of the generated RDDs, the received data is replicated among multiple Spark executors in worker nodes in the cluster (default replication factor is 2). This leads to two kinds of data in the -system that needs to recovered in the event of failures: +system that need to recovered in the event of failures: 1. *Data received and replicated* - This data survives failure of a single worker node as a copy - of it exists on one of the nodes. + of it exists on one of the other nodes. 1. *Data received but buffered for replication* - Since this is not replicated, - the only way to recover that data is to get it again from the source. + the only way to recover this data is to get it again from the source. Furthermore, there are two kinds of failures that we should be concerned about: @@ -2145,13 +2119,13 @@ In any stream processing system, broadly speaking, there are three steps in proc 1. *Receiving the data*: The data is received from sources using Receivers or otherwise. -1. *Transforming the data*: The data received data is transformed using DStream and RDD transformations. +1. *Transforming the data*: The received data is transformed using DStream and RDD transformations. 1. *Pushing out the data*: The final transformed data is pushed out to external systems like file systems, databases, dashboards, etc. -If a streaming application has to achieve end-to-end exactly-once guarantees, then each step has to provide exactly-once guarantee. That is, each record must be received exactly once, transformed exactly once, and pushed to downstream systems exactly once. Let's understand the semantics of these steps in the context of Spark Streaming. +If a streaming application has to achieve end-to-end exactly-once guarantees, then each step has to provide an exactly-once guarantee. That is, each record must be received exactly once, transformed exactly once, and pushed to downstream systems exactly once. Let's understand the semantics of these steps in the context of Spark Streaming. -1. *Receiving the data*: Different input sources provided different guarantees. This is discussed in detail in the next subsection. +1. *Receiving the data*: Different input sources provide different guarantees. This is discussed in detail in the next subsection. 1. *Transforming the data*: All data that has been received will be processed _exactly once_, thanks to the guarantees that RDDs provide. Even if there are failures, as long as the received input data is accessible, the final transformed RDDs will always have the same contents. @@ -2163,9 +2137,9 @@ Different input sources provide different guarantees, ranging from _at-least onc ### With Files {:.no_toc} -If all of the input data is already present in a fault-tolerant files system like -HDFS, Spark Streaming can always recover from any failure and process all the data. This gives -*exactly-once* semantics, that all the data will be processed exactly once no matter what fails. +If all of the input data is already present in a fault-tolerant file system like +HDFS, Spark Streaming can always recover from any failure and process all of the data. This gives +*exactly-once* semantics, meaning all of the data will be processed exactly once no matter what fails. ### With Receiver-based Sources {:.no_toc} @@ -2174,21 +2148,21 @@ scenario and the type of receiver. As we discussed [earlier](#receiver-reliability), there are two types of receivers: 1. *Reliable Receiver* - These receivers acknowledge reliable sources only after ensuring that - the received data has been replicated. If such a receiver fails, - the buffered (unreplicated) data does not get acknowledged to the source. If the receiver is - restarted, the source will resend the data, and therefore no data will be lost due to the failure. -1. *Unreliable Receiver* - Such receivers can lose data when they fail due to worker - or driver failures. + the received data has been replicated. If such a receiver fails, the source will not receive + acknowledgment for the buffered (unreplicated) data. Therefore, if the receiver is + restarted, the source will resend the data, and no data will be lost due to the failure. +1. *Unreliable Receiver* - Such receivers do *not* send acknowledgment and therefore *can* lose + data when they fail due to worker or driver failures. Depending on what type of receivers are used we achieve the following semantics. If a worker node fails, then there is no data loss with reliable receivers. With unreliable receivers, data received but not replicated can get lost. If the driver node fails, -then besides these losses, all the past data that was received and replicated in memory will be +then besides these losses, all of the past data that was received and replicated in memory will be lost. This will affect the results of the stateful transformations. To avoid this loss of past received data, Spark 1.2 introduced _write -ahead logs_ which saves the received data to fault-tolerant storage. With the [write ahead logs -enabled](#deploying-applications) and reliable receivers, there is zero data loss. In terms of semantics, it provides at-least once guarantee. +ahead logs_ which save the received data to fault-tolerant storage. With the [write ahead logs +enabled](#deploying-applications) and reliable receivers, there is zero data loss. In terms of semantics, it provides an at-least once guarantee. The following table summarizes the semantics under failures: @@ -2234,7 +2208,7 @@ The following table summarizes the semantics under failures: ### With Kafka Direct API {:.no_toc} -In Spark 1.3, we have introduced a new Kafka Direct API, which can ensure that all the Kafka data is received by Spark Streaming exactly once. Along with this, if you implement exactly-once output operation, you can achieve end-to-end exactly-once guarantees. This approach (experimental as of Spark 1.3) is further discussed in the [Kafka Integration Guide](streaming-kafka-integration.html). +In Spark 1.3, we have introduced a new Kafka Direct API, which can ensure that all the Kafka data is received by Spark Streaming exactly once. Along with this, if you implement exactly-once output operation, you can achieve end-to-end exactly-once guarantees. This approach (experimental as of Spark {{site.SPARK_VERSION_SHORT}}) is further discussed in the [Kafka Integration Guide](streaming-kafka-integration.html). ## Semantics of output operations {:.no_toc} @@ -2248,9 +2222,16 @@ additional effort may be necessary to achieve exactly-once semantics. There are - *Transactional updates*: All updates are made transactionally so that updates are made exactly once atomically. One way to do this would be the following. - - Use the batch time (available in `foreachRDD`) and the partition index of the transformed RDD to create an identifier. This identifier uniquely identifies a blob data in the streaming application. - - Update external system with this blob transactionally (that is, exactly once, atomically) using the identifier. That is, if the identifier is not already committed, commit the partition data and the identifier atomically. Else if this was already committed, skip the update. + - Use the batch time (available in `foreachRDD`) and the partition index of the RDD to create an identifier. This identifier uniquely identifies a blob data in the streaming application. + - Update external system with this blob transactionally (that is, exactly once, atomically) using the identifier. That is, if the identifier is not already committed, commit the partition data and the identifier atomically. Else, if this was already committed, skip the update. + dstream.foreachRDD { (rdd, time) => + rdd.foreachPartition { partitionIterator => + val partitionId = TaskContext.get.partitionId() + val uniqueId = generateUniqueId(time.milliseconds, partitionId) + // use this uniqueId to transactionally commit the data in partitionIterator + } + } *************************************************************************************************** *************************************************************************************************** @@ -2325,7 +2306,7 @@ package and renamed for better clarity. - Java docs * [JavaStreamingContext](api/java/index.html?org/apache/spark/streaming/api/java/JavaStreamingContext.html), [JavaDStream](api/java/index.html?org/apache/spark/streaming/api/java/JavaDStream.html) and - [PairJavaDStream](api/java/index.html?org/apache/spark/streaming/api/java/PairJavaDStream.html) + [JavaPairDStream](api/java/index.html?org/apache/spark/streaming/api/java/JavaPairDStream.html) * [KafkaUtils](api/java/index.html?org/apache/spark/streaming/kafka/KafkaUtils.html), [FlumeUtils](api/java/index.html?org/apache/spark/streaming/flume/FlumeUtils.html), [KinesisUtils](api/java/index.html?org/apache/spark/streaming/kinesis/KinesisUtils.html) diff --git a/ec2/spark_ec2.py b/ec2/spark_ec2.py index ab4a96f232c1..18ccbc0a3edd 100755 --- a/ec2/spark_ec2.py +++ b/ec2/spark_ec2.py @@ -19,8 +19,9 @@ # limitations under the License. # -from __future__ import with_statement, print_function +from __future__ import division, print_function, with_statement +import codecs import hashlib import itertools import logging @@ -47,8 +48,10 @@ else: from urllib.request import urlopen, Request from urllib.error import HTTPError + raw_input = input + xrange = range -SPARK_EC2_VERSION = "1.2.1" +SPARK_EC2_VERSION = "1.4.0" SPARK_EC2_DIR = os.path.dirname(os.path.realpath(__file__)) VALID_SPARK_VERSIONS = set([ @@ -65,6 +68,9 @@ "1.1.1", "1.2.0", "1.2.1", + "1.3.0", + "1.3.1", + "1.4.0", ]) SPARK_TACHYON_MAP = { @@ -75,6 +81,9 @@ "1.1.1": "0.5.0", "1.2.0": "0.5.0", "1.2.1": "0.5.0", + "1.3.0": "0.5.0", + "1.3.1": "0.5.0", + "1.4.0": "0.6.4", } DEFAULT_SPARK_VERSION = SPARK_EC2_VERSION @@ -82,7 +91,7 @@ # Default location to get the spark-ec2 scripts (and ami-list) from DEFAULT_SPARK_EC2_GITHUB_REPO = "https://github.com/mesos/spark-ec2" -DEFAULT_SPARK_EC2_BRANCH = "branch-1.3" +DEFAULT_SPARK_EC2_BRANCH = "branch-1.4" def setup_external_libs(libs): @@ -212,7 +221,8 @@ def parse_args(): "(default: %default).") parser.add_option( "--hadoop-major-version", default="1", - help="Major version of Hadoop (default: %default)") + help="Major version of Hadoop. Valid options are 1 (Hadoop 1.0.4), 2 (CDH 4.2.0), yarn " + + "(Hadoop 2.4.0) (default: %default)") parser.add_option( "-D", metavar="[ADDRESS:]PORT", dest="proxy_port", help="Use SSH dynamic port forwarding to create a SOCKS proxy at " + @@ -264,7 +274,8 @@ def parse_args(): help="Launch fresh slaves, but use an existing stopped master if possible") parser.add_option( "--worker-instances", type="int", default=1, - help="Number of instances per worker: variable SPARK_WORKER_INSTANCES (default: %default)") + help="Number of instances per worker: variable SPARK_WORKER_INSTANCES. Not used if YARN " + + "is used as Hadoop major version (default: %default)") parser.add_option( "--master-opts", type="string", default="", help="Extra options to give to master through SPARK_MASTER_OPTS variable " + @@ -278,6 +289,10 @@ def parse_args(): parser.add_option( "--additional-security-group", type="string", default="", help="Additional security group to place the machines in") + parser.add_option( + "--additional-tags", type="string", default="", + help="Additional tags to set on the machines; tags are comma-separated, while name and " + + "value are colon separated; ex: \"Task:MySparkProject,Env:production\"") parser.add_option( "--copy-aws-credentials", action="store_true", default=False, help="Add AWS credentials to hadoop configuration to allow Spark to access S3") @@ -291,6 +306,13 @@ def parse_args(): "--private-ips", action="store_true", default=False, help="Use private IPs for instances rather than public if VPC/subnet " + "requires that.") + parser.add_option( + "--instance-initiated-shutdown-behavior", default="stop", + choices=["stop", "terminate"], + help="Whether instances should terminate when shut down or just stop") + parser.add_option( + "--instance-profile-name", default=None, + help="IAM profile name to launch instances under") (opts, args) = parser.parse_args() if len(args) != 2: @@ -347,7 +369,7 @@ def get_validate_spark_version(version, repo): # Source: http://aws.amazon.com/amazon-linux-ami/instance-type-matrix/ -# Last Updated: 2015-05-08 +# Last Updated: 2015-06-19 # For easy maintainability, please keep this manually-inputted dictionary sorted by key. EC2_INSTANCE_TYPES = { "c1.medium": "pvm", @@ -389,6 +411,11 @@ def get_validate_spark_version(version, repo): "m3.large": "hvm", "m3.xlarge": "hvm", "m3.2xlarge": "hvm", + "m4.large": "hvm", + "m4.xlarge": "hvm", + "m4.2xlarge": "hvm", + "m4.4xlarge": "hvm", + "m4.10xlarge": "hvm", "r3.large": "hvm", "r3.xlarge": "hvm", "r3.2xlarge": "hvm", @@ -398,6 +425,7 @@ def get_validate_spark_version(version, repo): "t2.micro": "hvm", "t2.small": "hvm", "t2.medium": "hvm", + "t2.large": "hvm", } @@ -419,13 +447,14 @@ def get_spark_ami(opts): b=opts.spark_ec2_git_branch) ami_path = "%s/%s/%s" % (ami_prefix, opts.region, instance_type) + reader = codecs.getreader("ascii") try: - ami = urlopen(ami_path).read().strip() - print("Spark AMI: " + ami) + ami = reader(urlopen(ami_path)).read().strip() except: print("Could not resolve AMI at: " + ami_path, file=stderr) sys.exit(1) + print("Spark AMI: " + ami) return ami @@ -476,6 +505,8 @@ def launch_cluster(conn, opts, cluster_name): master_group.authorize('tcp', 50070, 50070, authorized_address) master_group.authorize('tcp', 60070, 60070, authorized_address) master_group.authorize('tcp', 4040, 4045, authorized_address) + # Rstudio (GUI for R) needs port 8787 for web access + master_group.authorize('tcp', 8787, 8787, authorized_address) # HDFS NFS gateway requires 111,2049,4242 for tcp & udp master_group.authorize('tcp', 111, 111, authorized_address) master_group.authorize('udp', 111, 111, authorized_address) @@ -483,6 +514,8 @@ def launch_cluster(conn, opts, cluster_name): master_group.authorize('udp', 2049, 2049, authorized_address) master_group.authorize('tcp', 4242, 4242, authorized_address) master_group.authorize('udp', 4242, 4242, authorized_address) + # RM in YARN mode uses 8088 + master_group.authorize('tcp', 8088, 8088, authorized_address) if opts.ganglia: master_group.authorize('tcp', 5080, 5080, authorized_address) if slave_group.rules == []: # Group was just now created @@ -578,7 +611,8 @@ def launch_cluster(conn, opts, cluster_name): block_device_map=block_map, subnet_id=opts.subnet_id, placement_group=opts.placement_group, - user_data=user_data_content) + user_data=user_data_content, + instance_profile_name=opts.instance_profile_name) my_req_ids += [req.id for req in slave_reqs] i += 1 @@ -623,16 +657,19 @@ def launch_cluster(conn, opts, cluster_name): for zone in zones: num_slaves_this_zone = get_partition(opts.slaves, num_zones, i) if num_slaves_this_zone > 0: - slave_res = image.run(key_name=opts.key_pair, - security_group_ids=[slave_group.id] + additional_group_ids, - instance_type=opts.instance_type, - placement=zone, - min_count=num_slaves_this_zone, - max_count=num_slaves_this_zone, - block_device_map=block_map, - subnet_id=opts.subnet_id, - placement_group=opts.placement_group, - user_data=user_data_content) + slave_res = image.run( + key_name=opts.key_pair, + security_group_ids=[slave_group.id] + additional_group_ids, + instance_type=opts.instance_type, + placement=zone, + min_count=num_slaves_this_zone, + max_count=num_slaves_this_zone, + block_device_map=block_map, + subnet_id=opts.subnet_id, + placement_group=opts.placement_group, + user_data=user_data_content, + instance_initiated_shutdown_behavior=opts.instance_initiated_shutdown_behavior, + instance_profile_name=opts.instance_profile_name) slave_nodes += slave_res.instances print("Launched {s} slave{plural_s} in {z}, regid = {r}".format( s=num_slaves_this_zone, @@ -654,32 +691,43 @@ def launch_cluster(conn, opts, cluster_name): master_type = opts.instance_type if opts.zone == 'all': opts.zone = random.choice(conn.get_all_zones()).name - master_res = image.run(key_name=opts.key_pair, - security_group_ids=[master_group.id] + additional_group_ids, - instance_type=master_type, - placement=opts.zone, - min_count=1, - max_count=1, - block_device_map=block_map, - subnet_id=opts.subnet_id, - placement_group=opts.placement_group, - user_data=user_data_content) + master_res = image.run( + key_name=opts.key_pair, + security_group_ids=[master_group.id] + additional_group_ids, + instance_type=master_type, + placement=opts.zone, + min_count=1, + max_count=1, + block_device_map=block_map, + subnet_id=opts.subnet_id, + placement_group=opts.placement_group, + user_data=user_data_content, + instance_initiated_shutdown_behavior=opts.instance_initiated_shutdown_behavior, + instance_profile_name=opts.instance_profile_name) master_nodes = master_res.instances print("Launched master in %s, regid = %s" % (zone, master_res.id)) # This wait time corresponds to SPARK-4983 print("Waiting for AWS to propagate instance metadata...") - time.sleep(5) - # Give the instances descriptive names + time.sleep(15) + + # Give the instances descriptive names and set additional tags + additional_tags = {} + if opts.additional_tags.strip(): + additional_tags = dict( + map(str.strip, tag.split(':', 1)) for tag in opts.additional_tags.split(',') + ) + for master in master_nodes: - master.add_tag( - key='Name', - value='{cn}-master-{iid}'.format(cn=cluster_name, iid=master.id)) + master.add_tags( + dict(additional_tags, Name='{cn}-master-{iid}'.format(cn=cluster_name, iid=master.id)) + ) + for slave in slave_nodes: - slave.add_tag( - key='Name', - value='{cn}-slave-{iid}'.format(cn=cluster_name, iid=slave.id)) + slave.add_tags( + dict(additional_tags, Name='{cn}-slave-{iid}'.format(cn=cluster_name, iid=slave.id)) + ) # Return all the instances return (master_nodes, slave_nodes) @@ -746,11 +794,15 @@ def setup_cluster(conn, master_nodes, slave_nodes, opts, deploy_ssh_key): 'mapreduce', 'spark-standalone', 'tachyon'] if opts.hadoop_major_version == "1": - modules = filter(lambda x: x != "mapreduce", modules) + modules = list(filter(lambda x: x != "mapreduce", modules)) if opts.ganglia: modules.append('ganglia') + # Clear SPARK_WORKER_INSTANCES if running on YARN + if opts.hadoop_major_version == "yarn": + opts.worker_instances = "" + # NOTE: We should clone the repository before running deploy_files to # prevent ec2-variables.sh from being overwritten print("Cloning spark-ec2 scripts from {r}/tree/{b} on master...".format( @@ -860,7 +912,11 @@ def wait_for_cluster_state(conn, opts, cluster_instances, cluster_state): for i in cluster_instances: i.update() - statuses = conn.get_all_instance_status(instance_ids=[i.id for i in cluster_instances]) + max_batch = 100 + statuses = [] + for j in xrange(0, len(cluster_instances), max_batch): + batch = [i.id for i in cluster_instances[j:j + max_batch]] + statuses.extend(conn.get_all_instance_status(instance_ids=batch)) if cluster_state == 'ssh-ready': if all(i.state == 'running' for i in cluster_instances) and \ @@ -889,7 +945,7 @@ def wait_for_cluster_state(conn, opts, cluster_instances, cluster_state): # Get number of local disks available for a given EC2 instance type. def get_num_disks(instance_type): # Source: http://docs.aws.amazon.com/AWSEC2/latest/UserGuide/InstanceStorage.html - # Last Updated: 2015-05-08 + # Last Updated: 2015-06-19 # For easy maintainability, please keep this manually-inputted dictionary sorted by key. disks_by_instance = { "c1.medium": 1, @@ -931,6 +987,11 @@ def get_num_disks(instance_type): "m3.large": 1, "m3.xlarge": 2, "m3.2xlarge": 2, + "m4.large": 0, + "m4.xlarge": 0, + "m4.2xlarge": 0, + "m4.4xlarge": 0, + "m4.10xlarge": 0, "r3.large": 1, "r3.xlarge": 1, "r3.2xlarge": 1, @@ -940,6 +1001,7 @@ def get_num_disks(instance_type): "t2.micro": 0, "t2.small": 0, "t2.medium": 0, + "t2.large": 0, } if instance_type in disks_by_instance: return disks_by_instance[instance_type] @@ -984,6 +1046,7 @@ def deploy_files(conn, root_dir, opts, master_nodes, slave_nodes, modules): master_addresses = [get_dns_name(i, opts.private_ips) for i in master_nodes] slave_addresses = [get_dns_name(i, opts.private_ips) for i in slave_nodes] + worker_instances_str = "%d" % opts.worker_instances if opts.worker_instances else "" template_vars = { "master_list": '\n'.join(master_addresses), "active_master": active_master, @@ -997,7 +1060,7 @@ def deploy_files(conn, root_dir, opts, master_nodes, slave_nodes, modules): "spark_version": spark_v, "tachyon_version": tachyon_v, "hadoop_major_version": opts.hadoop_major_version, - "spark_worker_instances": "%d" % opts.worker_instances, + "spark_worker_instances": worker_instances_str, "spark_master_opts": opts.master_opts } @@ -1152,7 +1215,7 @@ def get_zones(conn, opts): # Gets the number of items in a partition def get_partition(total, num_partitions, current_partitions): - num_slaves_this_zone = total / num_partitions + num_slaves_this_zone = total // num_partitions if (total % num_partitions) - current_partitions > 0: num_slaves_this_zone += 1 return num_slaves_this_zone diff --git a/examples/pom.xml b/examples/pom.xml index 5b04b4f8d6ca..e6884b09dca9 100644 --- a/examples/pom.xml +++ b/examples/pom.xml @@ -21,7 +21,7 @@ org.apache.spark spark-parent_2.10 - 1.4.0-SNAPSHOT + 1.5.0-SNAPSHOT ../pom.xml @@ -97,6 +97,11 @@ + + org.apache.spark + spark-streaming-kafka_${scala.binary.version} + ${project.version} + org.apache.hbase hbase-testing-util @@ -392,45 +397,6 @@ - - - scala-2.10 - - !scala-2.11 - - - - org.apache.spark - spark-streaming-kafka_${scala.binary.version} - ${project.version} - - - - - - org.codehaus.mojo - build-helper-maven-plugin - - - add-scala-sources - generate-sources - - add-source - - - - src/main/scala - scala-2.10/src/main/scala - scala-2.10/src/main/java - - - - - - - - diff --git a/examples/src/main/java/org/apache/spark/examples/ml/JavaDeveloperApiExample.java b/examples/src/main/java/org/apache/spark/examples/ml/JavaDeveloperApiExample.java index ec533d174ebd..9df26ffca577 100644 --- a/examples/src/main/java/org/apache/spark/examples/ml/JavaDeveloperApiExample.java +++ b/examples/src/main/java/org/apache/spark/examples/ml/JavaDeveloperApiExample.java @@ -156,6 +156,11 @@ public MyJavaLogisticRegressionModel train(DataFrame dataset) { // Create a model, and return it. return new MyJavaLogisticRegressionModel(uid(), weights).setParent(this); } + + @Override + public MyJavaLogisticRegression copy(ParamMap extra) { + return defaultCopy(extra); + } } /** diff --git a/examples/src/main/java/org/apache/spark/examples/ml/JavaSimpleParamsExample.java b/examples/src/main/java/org/apache/spark/examples/ml/JavaSimpleParamsExample.java index 29158d5c8565..dac649d1d5ae 100644 --- a/examples/src/main/java/org/apache/spark/examples/ml/JavaSimpleParamsExample.java +++ b/examples/src/main/java/org/apache/spark/examples/ml/JavaSimpleParamsExample.java @@ -97,7 +97,7 @@ public static void main(String[] args) { DataFrame test = jsql.createDataFrame(jsc.parallelize(localTest), LabeledPoint.class); // Make predictions on test documents using the Transformer.transform() method. - // LogisticRegression.transform will only use the 'features' column. + // LogisticRegressionModel.transform will only use the 'features' column. // Note that model2.transform() outputs a 'myProbability' column instead of the usual // 'probability' column since we renamed the lr.probabilityCol parameter previously. DataFrame results = model2.transform(test); diff --git a/examples/src/main/java/org/apache/spark/examples/sql/JavaSparkSQL.java b/examples/src/main/java/org/apache/spark/examples/sql/JavaSparkSQL.java index 8159ffbe2d26..afee279ec32b 100644 --- a/examples/src/main/java/org/apache/spark/examples/sql/JavaSparkSQL.java +++ b/examples/src/main/java/org/apache/spark/examples/sql/JavaSparkSQL.java @@ -94,12 +94,12 @@ public String call(Row row) { System.out.println("=== Data source: Parquet File ==="); // DataFrames can be saved as parquet files, maintaining the schema information. - schemaPeople.saveAsParquetFile("people.parquet"); + schemaPeople.write().parquet("people.parquet"); // Read in the parquet file created above. // Parquet files are self-describing so the schema is preserved. // The result of loading a parquet file is also a DataFrame. - DataFrame parquetFile = sqlContext.parquetFile("people.parquet"); + DataFrame parquetFile = sqlContext.read().parquet("people.parquet"); //Parquet files can also be registered as tables and then used in SQL statements. parquetFile.registerTempTable("parquetFile"); @@ -120,7 +120,7 @@ public String call(Row row) { // The path can be either a single text file or a directory storing text files. String path = "examples/src/main/resources/people.json"; // Create a DataFrame from the file(s) pointed by path - DataFrame peopleFromJsonFile = sqlContext.jsonFile(path); + DataFrame peopleFromJsonFile = sqlContext.read().json(path); // Because the schema of a JSON dataset is automatically inferred, to write queries, // it is better to take a look at what is the schema. @@ -151,7 +151,7 @@ public String call(Row row) { List jsonData = Arrays.asList( "{\"name\":\"Yin\",\"address\":{\"city\":\"Columbus\",\"state\":\"Ohio\"}}"); JavaRDD anotherPeopleRDD = ctx.parallelize(jsonData); - DataFrame peopleFromJsonRDD = sqlContext.jsonRDD(anotherPeopleRDD.rdd()); + DataFrame peopleFromJsonRDD = sqlContext.read().json(anotherPeopleRDD.rdd()); // Take a look at the schema of this new DataFrame. peopleFromJsonRDD.printSchema(); diff --git a/examples/scala-2.10/src/main/java/org/apache/spark/examples/streaming/JavaDirectKafkaWordCount.java b/examples/src/main/java/org/apache/spark/examples/streaming/JavaDirectKafkaWordCount.java similarity index 100% rename from examples/scala-2.10/src/main/java/org/apache/spark/examples/streaming/JavaDirectKafkaWordCount.java rename to examples/src/main/java/org/apache/spark/examples/streaming/JavaDirectKafkaWordCount.java diff --git a/examples/scala-2.10/src/main/java/org/apache/spark/examples/streaming/JavaKafkaWordCount.java b/examples/src/main/java/org/apache/spark/examples/streaming/JavaKafkaWordCount.java similarity index 100% rename from examples/scala-2.10/src/main/java/org/apache/spark/examples/streaming/JavaKafkaWordCount.java rename to examples/src/main/java/org/apache/spark/examples/streaming/JavaKafkaWordCount.java diff --git a/examples/src/main/python/hbase_inputformat.py b/examples/src/main/python/hbase_inputformat.py index 5b82a14fba41..c5ae5d043b8e 100644 --- a/examples/src/main/python/hbase_inputformat.py +++ b/examples/src/main/python/hbase_inputformat.py @@ -18,6 +18,7 @@ from __future__ import print_function import sys +import json from pyspark import SparkContext @@ -27,24 +28,24 @@ hbase(main):016:0> create 'test', 'f1' 0 row(s) in 1.0430 seconds -hbase(main):017:0> put 'test', 'row1', 'f1', 'value1' +hbase(main):017:0> put 'test', 'row1', 'f1:a', 'value1' 0 row(s) in 0.0130 seconds -hbase(main):018:0> put 'test', 'row2', 'f1', 'value2' +hbase(main):018:0> put 'test', 'row1', 'f1:b', 'value2' 0 row(s) in 0.0030 seconds -hbase(main):019:0> put 'test', 'row3', 'f1', 'value3' +hbase(main):019:0> put 'test', 'row2', 'f1', 'value3' 0 row(s) in 0.0050 seconds -hbase(main):020:0> put 'test', 'row4', 'f1', 'value4' +hbase(main):020:0> put 'test', 'row3', 'f1', 'value4' 0 row(s) in 0.0110 seconds hbase(main):021:0> scan 'test' ROW COLUMN+CELL - row1 column=f1:, timestamp=1401883411986, value=value1 - row2 column=f1:, timestamp=1401883415212, value=value2 - row3 column=f1:, timestamp=1401883417858, value=value3 - row4 column=f1:, timestamp=1401883420805, value=value4 + row1 column=f1:a, timestamp=1401883411986, value=value1 + row1 column=f1:b, timestamp=1401883415212, value=value2 + row2 column=f1:, timestamp=1401883417858, value=value3 + row3 column=f1:, timestamp=1401883420805, value=value4 4 row(s) in 0.0240 seconds """ if __name__ == "__main__": @@ -64,6 +65,8 @@ table = sys.argv[2] sc = SparkContext(appName="HBaseInputFormat") + # Other options for configuring scan behavior are available. More information available at + # https://github.com/apache/hbase/blob/master/hbase-server/src/main/java/org/apache/hadoop/hbase/mapreduce/TableInputFormat.java conf = {"hbase.zookeeper.quorum": host, "hbase.mapreduce.inputtable": table} if len(sys.argv) > 3: conf = {"hbase.zookeeper.quorum": host, "zookeeper.znode.parent": sys.argv[3], @@ -78,6 +81,8 @@ keyConverter=keyConv, valueConverter=valueConv, conf=conf) + hbase_rdd = hbase_rdd.flatMapValues(lambda v: v.split("\n")).mapValues(json.loads) + output = hbase_rdd.collect() for (k, v) in output: print((k, v)) diff --git a/examples/src/main/python/kmeans.py b/examples/src/main/python/kmeans.py index 1456c8731284..0ea7cfb7025a 100755 --- a/examples/src/main/python/kmeans.py +++ b/examples/src/main/python/kmeans.py @@ -68,7 +68,7 @@ def closestPoint(p, centers): closest = data.map( lambda p: (closestPoint(p, kPoints), (p, 1))) pointStats = closest.reduceByKey( - lambda (p1, c1), (p2, c2): (p1 + p2, c1 + c2)) + lambda p1_c1, p2_c2: (p1_c1[0] + p2_c2[0], p1_c1[1] + p2_c2[1])) newPoints = pointStats.map( lambda st: (st[0], st[1][0] / st[1][1])).collect() diff --git a/examples/src/main/python/ml/cross_validator.py b/examples/src/main/python/ml/cross_validator.py new file mode 100644 index 000000000000..f0ca97c72494 --- /dev/null +++ b/examples/src/main/python/ml/cross_validator.py @@ -0,0 +1,96 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +from __future__ import print_function + +from pyspark import SparkContext +from pyspark.ml import Pipeline +from pyspark.ml.classification import LogisticRegression +from pyspark.ml.evaluation import BinaryClassificationEvaluator +from pyspark.ml.feature import HashingTF, Tokenizer +from pyspark.ml.tuning import CrossValidator, ParamGridBuilder +from pyspark.sql import Row, SQLContext + +""" +A simple example demonstrating model selection using CrossValidator. +This example also demonstrates how Pipelines are Estimators. +Run with: + + bin/spark-submit examples/src/main/python/ml/cross_validator.py +""" + +if __name__ == "__main__": + sc = SparkContext(appName="CrossValidatorExample") + sqlContext = SQLContext(sc) + + # Prepare training documents, which are labeled. + LabeledDocument = Row("id", "text", "label") + training = sc.parallelize([(0, "a b c d e spark", 1.0), + (1, "b d", 0.0), + (2, "spark f g h", 1.0), + (3, "hadoop mapreduce", 0.0), + (4, "b spark who", 1.0), + (5, "g d a y", 0.0), + (6, "spark fly", 1.0), + (7, "was mapreduce", 0.0), + (8, "e spark program", 1.0), + (9, "a e c l", 0.0), + (10, "spark compile", 1.0), + (11, "hadoop software", 0.0) + ]) \ + .map(lambda x: LabeledDocument(*x)).toDF() + + # Configure an ML pipeline, which consists of tree stages: tokenizer, hashingTF, and lr. + tokenizer = Tokenizer(inputCol="text", outputCol="words") + hashingTF = HashingTF(inputCol=tokenizer.getOutputCol(), outputCol="features") + lr = LogisticRegression(maxIter=10) + pipeline = Pipeline(stages=[tokenizer, hashingTF, lr]) + + # We now treat the Pipeline as an Estimator, wrapping it in a CrossValidator instance. + # This will allow us to jointly choose parameters for all Pipeline stages. + # A CrossValidator requires an Estimator, a set of Estimator ParamMaps, and an Evaluator. + # We use a ParamGridBuilder to construct a grid of parameters to search over. + # With 3 values for hashingTF.numFeatures and 2 values for lr.regParam, + # this grid will have 3 x 2 = 6 parameter settings for CrossValidator to choose from. + paramGrid = ParamGridBuilder() \ + .addGrid(hashingTF.numFeatures, [10, 100, 1000]) \ + .addGrid(lr.regParam, [0.1, 0.01]) \ + .build() + + crossval = CrossValidator(estimator=pipeline, + estimatorParamMaps=paramGrid, + evaluator=BinaryClassificationEvaluator(), + numFolds=2) # use 3+ folds in practice + + # Run cross-validation, and choose the best set of parameters. + cvModel = crossval.fit(training) + + # Prepare test documents, which are unlabeled. + Document = Row("id", "text") + test = sc.parallelize([(4L, "spark i j k"), + (5L, "l m n"), + (6L, "mapreduce spark"), + (7L, "apache hadoop")]) \ + .map(lambda x: Document(*x)).toDF() + + # Make predictions on test documents. cvModel uses the best model found (lrModel). + prediction = cvModel.transform(test) + selected = prediction.select("id", "text", "probability", "prediction") + for row in selected.collect(): + print(row) + + sc.stop() diff --git a/examples/src/main/python/ml/gradient_boosted_trees.py b/examples/src/main/python/ml/gradient_boosted_trees.py new file mode 100644 index 000000000000..6446f0fe5eea --- /dev/null +++ b/examples/src/main/python/ml/gradient_boosted_trees.py @@ -0,0 +1,83 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +from __future__ import print_function + +import sys + +from pyspark import SparkContext +from pyspark.ml.classification import GBTClassifier +from pyspark.ml.feature import StringIndexer +from pyspark.ml.regression import GBTRegressor +from pyspark.mllib.evaluation import BinaryClassificationMetrics, RegressionMetrics +from pyspark.mllib.util import MLUtils +from pyspark.sql import Row, SQLContext + +""" +A simple example demonstrating a Gradient Boosted Trees Classification/Regression Pipeline. +Note: GBTClassifier only supports binary classification currently +Run with: + bin/spark-submit examples/src/main/python/ml/gradient_boosted_trees.py +""" + + +def testClassification(train, test): + # Train a GradientBoostedTrees model. + + rf = GBTClassifier(maxIter=30, maxDepth=4, labelCol="indexedLabel") + + model = rf.fit(train) + predictionAndLabels = model.transform(test).select("prediction", "indexedLabel") \ + .map(lambda x: (x.prediction, x.indexedLabel)) + + metrics = BinaryClassificationMetrics(predictionAndLabels) + print("AUC %.3f" % metrics.areaUnderROC) + + +def testRegression(train, test): + # Train a GradientBoostedTrees model. + + rf = GBTRegressor(maxIter=30, maxDepth=4, labelCol="indexedLabel") + + model = rf.fit(train) + predictionAndLabels = model.transform(test).select("prediction", "indexedLabel") \ + .map(lambda x: (x.prediction, x.indexedLabel)) + + metrics = RegressionMetrics(predictionAndLabels) + print("rmse %.3f" % metrics.rootMeanSquaredError) + print("r2 %.3f" % metrics.r2) + print("mae %.3f" % metrics.meanAbsoluteError) + + +if __name__ == "__main__": + if len(sys.argv) > 1: + print("Usage: gradient_boosted_trees", file=sys.stderr) + exit(1) + sc = SparkContext(appName="PythonGBTExample") + sqlContext = SQLContext(sc) + + # Load and parse the data file into a dataframe. + df = MLUtils.loadLibSVMFile(sc, "data/mllib/sample_libsvm_data.txt").toDF() + + # Map labels into an indexed column of labels in [0, numLabels) + stringIndexer = StringIndexer(inputCol="label", outputCol="indexedLabel") + si_model = stringIndexer.fit(df) + td = si_model.transform(df) + [train, test] = td.randomSplit([0.7, 0.3]) + testClassification(train, test) + testRegression(train, test) + sc.stop() diff --git a/examples/src/main/python/ml/logistic_regression.py b/examples/src/main/python/ml/logistic_regression.py new file mode 100644 index 000000000000..55afe1b207fe --- /dev/null +++ b/examples/src/main/python/ml/logistic_regression.py @@ -0,0 +1,67 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +from __future__ import print_function + +import sys + +from pyspark import SparkContext +from pyspark.ml.classification import LogisticRegression +from pyspark.mllib.evaluation import MulticlassMetrics +from pyspark.ml.feature import StringIndexer +from pyspark.mllib.util import MLUtils +from pyspark.sql import SQLContext + +""" +A simple example demonstrating a logistic regression with elastic net regularization Pipeline. +Run with: + bin/spark-submit examples/src/main/python/ml/logistic_regression.py +""" + +if __name__ == "__main__": + + if len(sys.argv) > 1: + print("Usage: logistic_regression", file=sys.stderr) + exit(-1) + + sc = SparkContext(appName="PythonLogisticRegressionExample") + sqlContext = SQLContext(sc) + + # Load and parse the data file into a dataframe. + df = MLUtils.loadLibSVMFile(sc, "data/mllib/sample_libsvm_data.txt").toDF() + + # Map labels into an indexed column of labels in [0, numLabels) + stringIndexer = StringIndexer(inputCol="label", outputCol="indexedLabel") + si_model = stringIndexer.fit(df) + td = si_model.transform(df) + [training, test] = td.randomSplit([0.7, 0.3]) + + lr = LogisticRegression(maxIter=100, regParam=0.3).setLabelCol("indexedLabel") + lr.setElasticNetParam(0.8) + + # Fit the model + lrModel = lr.fit(training) + + predictionAndLabels = lrModel.transform(test).select("prediction", "indexedLabel") \ + .map(lambda x: (x.prediction, x.indexedLabel)) + + metrics = MulticlassMetrics(predictionAndLabels) + print("weighted f-measure %.3f" % metrics.weightedFMeasure()) + print("precision %s" % metrics.precision()) + print("recall %s" % metrics.recall()) + + sc.stop() diff --git a/examples/src/main/python/ml/random_forest_example.py b/examples/src/main/python/ml/random_forest_example.py new file mode 100644 index 000000000000..c7730e1bfacd --- /dev/null +++ b/examples/src/main/python/ml/random_forest_example.py @@ -0,0 +1,87 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +from __future__ import print_function + +import sys + +from pyspark import SparkContext +from pyspark.ml.classification import RandomForestClassifier +from pyspark.ml.feature import StringIndexer +from pyspark.ml.regression import RandomForestRegressor +from pyspark.mllib.evaluation import MulticlassMetrics, RegressionMetrics +from pyspark.mllib.util import MLUtils +from pyspark.sql import Row, SQLContext + +""" +A simple example demonstrating a RandomForest Classification/Regression Pipeline. +Run with: + bin/spark-submit examples/src/main/python/ml/random_forest_example.py +""" + + +def testClassification(train, test): + # Train a RandomForest model. + # Setting featureSubsetStrategy="auto" lets the algorithm choose. + # Note: Use larger numTrees in practice. + + rf = RandomForestClassifier(labelCol="indexedLabel", numTrees=3, maxDepth=4) + + model = rf.fit(train) + predictionAndLabels = model.transform(test).select("prediction", "indexedLabel") \ + .map(lambda x: (x.prediction, x.indexedLabel)) + + metrics = MulticlassMetrics(predictionAndLabels) + print("weighted f-measure %.3f" % metrics.weightedFMeasure()) + print("precision %s" % metrics.precision()) + print("recall %s" % metrics.recall()) + + +def testRegression(train, test): + # Train a RandomForest model. + # Note: Use larger numTrees in practice. + + rf = RandomForestRegressor(labelCol="indexedLabel", numTrees=3, maxDepth=4) + + model = rf.fit(train) + predictionAndLabels = model.transform(test).select("prediction", "indexedLabel") \ + .map(lambda x: (x.prediction, x.indexedLabel)) + + metrics = RegressionMetrics(predictionAndLabels) + print("rmse %.3f" % metrics.rootMeanSquaredError) + print("r2 %.3f" % metrics.r2) + print("mae %.3f" % metrics.meanAbsoluteError) + + +if __name__ == "__main__": + if len(sys.argv) > 1: + print("Usage: random_forest_example", file=sys.stderr) + exit(1) + sc = SparkContext(appName="PythonRandomForestExample") + sqlContext = SQLContext(sc) + + # Load and parse the data file into a dataframe. + df = MLUtils.loadLibSVMFile(sc, "data/mllib/sample_libsvm_data.txt").toDF() + + # Map labels into an indexed column of labels in [0, numLabels) + stringIndexer = StringIndexer(inputCol="label", outputCol="indexedLabel") + si_model = stringIndexer.fit(df) + td = si_model.transform(df) + [train, test] = td.randomSplit([0.7, 0.3]) + testClassification(train, test) + testRegression(train, test) + sc.stop() diff --git a/examples/src/main/python/ml/simple_params_example.py b/examples/src/main/python/ml/simple_params_example.py new file mode 100644 index 000000000000..a9f29dab2d60 --- /dev/null +++ b/examples/src/main/python/ml/simple_params_example.py @@ -0,0 +1,98 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +from __future__ import print_function + +import pprint +import sys + +from pyspark import SparkContext +from pyspark.ml.classification import LogisticRegression +from pyspark.mllib.linalg import DenseVector +from pyspark.mllib.regression import LabeledPoint +from pyspark.sql import SQLContext + +""" +A simple example demonstrating ways to specify parameters for Estimators and Transformers. +Run with: + bin/spark-submit examples/src/main/python/ml/simple_params_example.py +""" + +if __name__ == "__main__": + if len(sys.argv) > 1: + print("Usage: simple_params_example", file=sys.stderr) + exit(1) + sc = SparkContext(appName="PythonSimpleParamsExample") + sqlContext = SQLContext(sc) + + # prepare training data. + # We create an RDD of LabeledPoints and convert them into a DataFrame. + # A LabeledPoint is an Object with two fields named label and features + # and Spark SQL identifies these fields and creates the schema appropriately. + training = sc.parallelize([ + LabeledPoint(1.0, DenseVector([0.0, 1.1, 0.1])), + LabeledPoint(0.0, DenseVector([2.0, 1.0, -1.0])), + LabeledPoint(0.0, DenseVector([2.0, 1.3, 1.0])), + LabeledPoint(1.0, DenseVector([0.0, 1.2, -0.5]))]).toDF() + + # Create a LogisticRegression instance with maxIter = 10. + # This instance is an Estimator. + lr = LogisticRegression(maxIter=10) + # Print out the parameters, documentation, and any default values. + print("LogisticRegression parameters:\n" + lr.explainParams() + "\n") + + # We may also set parameters using setter methods. + lr.setRegParam(0.01) + + # Learn a LogisticRegression model. This uses the parameters stored in lr. + model1 = lr.fit(training) + + # Since model1 is a Model (i.e., a Transformer produced by an Estimator), + # we can view the parameters it used during fit(). + # This prints the parameter (name: value) pairs, where names are unique IDs for this + # LogisticRegression instance. + print("Model 1 was fit using parameters:\n") + pprint.pprint(model1.extractParamMap()) + + # We may alternatively specify parameters using a parameter map. + # paramMap overrides all lr parameters set earlier. + paramMap = {lr.maxIter: 20, lr.threshold: 0.55, lr.probabilityCol: "myProbability"} + + # Now learn a new model using the new parameters. + model2 = lr.fit(training, paramMap) + print("Model 2 was fit using parameters:\n") + pprint.pprint(model2.extractParamMap()) + + # prepare test data. + test = sc.parallelize([ + LabeledPoint(1.0, DenseVector([-1.0, 1.5, 1.3])), + LabeledPoint(0.0, DenseVector([3.0, 2.0, -0.1])), + LabeledPoint(0.0, DenseVector([0.0, 2.2, -1.5]))]).toDF() + + # Make predictions on test data using the Transformer.transform() method. + # LogisticRegressionModel.transform will only use the 'features' column. + # Note that model2.transform() outputs a 'myProbability' column instead of the usual + # 'probability' column since we renamed the lr.probabilityCol parameter previously. + result = model2.transform(test) \ + .select("features", "label", "myProbability", "prediction") \ + .collect() + + for row in result: + print("features=%s,label=%s -> prob=%s, prediction=%s" + % (row.features, row.label, row.myProbability, row.prediction)) + + sc.stop() diff --git a/examples/src/main/python/parquet_inputformat.py b/examples/src/main/python/parquet_inputformat.py index 96ddac761d69..e1fd85b082c0 100644 --- a/examples/src/main/python/parquet_inputformat.py +++ b/examples/src/main/python/parquet_inputformat.py @@ -51,7 +51,7 @@ parquet_rdd = sc.newAPIHadoopFile( path, - 'parquet.avro.AvroParquetInputFormat', + 'org.apache.parquet.avro.AvroParquetInputFormat', 'java.lang.Void', 'org.apache.avro.generic.IndexedRecord', valueConverter='org.apache.spark.examples.pythonconverters.IndexedRecordToJavaConverter') diff --git a/examples/src/main/python/streaming/flume_wordcount.py b/examples/src/main/python/streaming/flume_wordcount.py new file mode 100644 index 000000000000..091b64d8c4af --- /dev/null +++ b/examples/src/main/python/streaming/flume_wordcount.py @@ -0,0 +1,55 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +""" + Counts words in UTF8 encoded, '\n' delimited text received from the network every second. + Usage: flume_wordcount.py + + To run this on your local machine, you need to setup Flume first, see + https://flume.apache.org/documentation.html + + and then run the example + `$ bin/spark-submit --jars external/flume-assembly/target/scala-*/\ + spark-streaming-flume-assembly-*.jar examples/src/main/python/streaming/flume_wordcount.py \ + localhost 12345 +""" +from __future__ import print_function + +import sys + +from pyspark import SparkContext +from pyspark.streaming import StreamingContext +from pyspark.streaming.flume import FlumeUtils + +if __name__ == "__main__": + if len(sys.argv) != 3: + print("Usage: flume_wordcount.py ", file=sys.stderr) + exit(-1) + + sc = SparkContext(appName="PythonStreamingFlumeWordCount") + ssc = StreamingContext(sc, 1) + + hostname, port = sys.argv[1:] + kvs = FlumeUtils.createStream(ssc, hostname, int(port)) + lines = kvs.map(lambda x: x[1]) + counts = lines.flatMap(lambda line: line.split(" ")) \ + .map(lambda word: (word, 1)) \ + .reduceByKey(lambda a, b: a+b) + counts.pprint() + + ssc.start() + ssc.awaitTermination() diff --git a/examples/src/main/python/streaming/queue_stream.py b/examples/src/main/python/streaming/queue_stream.py new file mode 100644 index 000000000000..dcd6a0fc6ff9 --- /dev/null +++ b/examples/src/main/python/streaming/queue_stream.py @@ -0,0 +1,50 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +""" + Create a queue of RDDs that will be mapped/reduced one at a time in + 1 second intervals. + + To run this example use + `$ bin/spark-submit examples/src/main/python/streaming/queue_stream.py +""" +import sys +import time + +from pyspark import SparkContext +from pyspark.streaming import StreamingContext + +if __name__ == "__main__": + + sc = SparkContext(appName="PythonStreamingQueueStream") + ssc = StreamingContext(sc, 1) + + # Create the queue through which RDDs can be pushed to + # a QueueInputDStream + rddQueue = [] + for i in xrange(5): + rddQueue += [ssc.sparkContext.parallelize([j for j in xrange(1, 1001)], 10)] + + # Create the QueueInputDStream and use it do some processing + inputStream = ssc.queueStream(rddQueue) + mappedStream = inputStream.map(lambda x: (x % 10, 1)) + reducedStream = mappedStream.reduceByKey(lambda a, b: a + b) + reducedStream.pprint() + + ssc.start() + time.sleep(6) + ssc.stop(stopSparkContext=True, stopGraceFully=True) diff --git a/examples/src/main/r/data-manipulation.R b/examples/src/main/r/data-manipulation.R new file mode 100644 index 000000000000..aa2336e300a9 --- /dev/null +++ b/examples/src/main/r/data-manipulation.R @@ -0,0 +1,107 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +# For this example, we shall use the "flights" dataset +# The dataset consists of every flight departing Houston in 2011. +# The data set is made up of 227,496 rows x 14 columns. + +# To run this example use +# ./bin/sparkR --packages com.databricks:spark-csv_2.10:1.0.3 +# examples/src/main/r/data-manipulation.R + +# Load SparkR library into your R session +library(SparkR) + +args <- commandArgs(trailing = TRUE) + +if (length(args) != 1) { + print("Usage: data-manipulation.R % + summarize(avg(flightsDF$dep_delay), avg(flightsDF$arr_delay)) -> dailyDelayDF + + # Print the computed data frame + head(dailyDelayDF) +} + +# Stop the SparkContext now +sparkR.stop() diff --git a/examples/src/main/scala/org/apache/spark/examples/CassandraCQLTest.scala b/examples/src/main/scala/org/apache/spark/examples/CassandraCQLTest.scala index 11d5c92c5952..023bb3ee2d10 100644 --- a/examples/src/main/scala/org/apache/spark/examples/CassandraCQLTest.scala +++ b/examples/src/main/scala/org/apache/spark/examples/CassandraCQLTest.scala @@ -104,8 +104,8 @@ object CassandraCQLTest { val casRdd = sc.newAPIHadoopRDD(job.getConfiguration(), classOf[CqlPagingInputFormat], - classOf[java.util.Map[String,ByteBuffer]], - classOf[java.util.Map[String,ByteBuffer]]) + classOf[java.util.Map[String, ByteBuffer]], + classOf[java.util.Map[String, ByteBuffer]]) println("Count: " + casRdd.count) val productSaleRDD = casRdd.map { @@ -118,7 +118,7 @@ object CassandraCQLTest { case (productId, saleCount) => println(productId + ":" + saleCount) } - val casoutputCF = aggregatedRDD.map { + val casoutputCF = aggregatedRDD.map { case (productId, saleCount) => { val outColFamKey = Map("prod_id" -> ByteBufferUtil.bytes(productId)) val outKey: java.util.Map[String, ByteBuffer] = outColFamKey diff --git a/examples/src/main/scala/org/apache/spark/examples/DFSReadWriteTest.scala b/examples/src/main/scala/org/apache/spark/examples/DFSReadWriteTest.scala new file mode 100644 index 000000000000..1f12034ce0f5 --- /dev/null +++ b/examples/src/main/scala/org/apache/spark/examples/DFSReadWriteTest.scala @@ -0,0 +1,138 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.examples + +import java.io.File + +import scala.io.Source._ + +import org.apache.spark.{SparkContext, SparkConf} +import org.apache.spark.SparkContext._ + +/** + * Simple test for reading and writing to a distributed + * file system. This example does the following: + * + * 1. Reads local file + * 2. Computes word count on local file + * 3. Writes local file to a DFS + * 4. Reads the file back from the DFS + * 5. Computes word count on the file using Spark + * 6. Compares the word count results + */ +object DFSReadWriteTest { + + private var localFilePath: File = new File(".") + private var dfsDirPath: String = "" + + private val NPARAMS = 2 + + private def readFile(filename: String): List[String] = { + val lineIter: Iterator[String] = fromFile(filename).getLines() + val lineList: List[String] = lineIter.toList + lineList + } + + private def printUsage(): Unit = { + val usage: String = "DFS Read-Write Test\n" + + "\n" + + "Usage: localFile dfsDir\n" + + "\n" + + "localFile - (string) local file to use in test\n" + + "dfsDir - (string) DFS directory for read/write tests\n" + + println(usage) + } + + private def parseArgs(args: Array[String]): Unit = { + if (args.length != NPARAMS) { + printUsage() + System.exit(1) + } + + var i = 0 + + localFilePath = new File(args(i)) + if (!localFilePath.exists) { + System.err.println("Given path (" + args(i) + ") does not exist.\n") + printUsage() + System.exit(1) + } + + if (!localFilePath.isFile) { + System.err.println("Given path (" + args(i) + ") is not a file.\n") + printUsage() + System.exit(1) + } + + i += 1 + dfsDirPath = args(i) + } + + def runLocalWordCount(fileContents: List[String]): Int = { + fileContents.flatMap(_.split(" ")) + .flatMap(_.split("\t")) + .filter(_.size > 0) + .groupBy(w => w) + .mapValues(_.size) + .values + .sum + } + + def main(args: Array[String]): Unit = { + parseArgs(args) + + println("Performing local word count") + val fileContents = readFile(localFilePath.toString()) + val localWordCount = runLocalWordCount(fileContents) + + println("Creating SparkConf") + val conf = new SparkConf().setAppName("DFS Read Write Test") + + println("Creating SparkContext") + val sc = new SparkContext(conf) + + println("Writing local file to DFS") + val dfsFilename = dfsDirPath + "/dfs_read_write_test" + val fileRDD = sc.parallelize(fileContents) + fileRDD.saveAsTextFile(dfsFilename) + + println("Reading file from DFS and running Word Count") + val readFileRDD = sc.textFile(dfsFilename) + + val dfsWordCount = readFileRDD + .flatMap(_.split(" ")) + .flatMap(_.split("\t")) + .filter(_.size > 0) + .map(w => (w, 1)) + .countByKey() + .values + .sum + + sc.stop() + + if (localWordCount == dfsWordCount) { + println(s"Success! Local Word Count ($localWordCount) " + + s"and DFS Word Count ($dfsWordCount) agree.") + } else { + println(s"Failure! Local Word Count ($localWordCount) " + + s"and DFS Word Count ($dfsWordCount) disagree.") + } + + } +} diff --git a/examples/src/main/scala/org/apache/spark/examples/HBaseTest.scala b/examples/src/main/scala/org/apache/spark/examples/HBaseTest.scala index 849887d23c9c..95c96111c9b1 100644 --- a/examples/src/main/scala/org/apache/spark/examples/HBaseTest.scala +++ b/examples/src/main/scala/org/apache/spark/examples/HBaseTest.scala @@ -59,5 +59,6 @@ object HBaseTest { hBaseRDD.count() sc.stop() + admin.close() } } diff --git a/examples/src/main/scala/org/apache/spark/examples/LocalLR.scala b/examples/src/main/scala/org/apache/spark/examples/LocalLR.scala index a55e0dc8d36c..c3fc74a116c0 100644 --- a/examples/src/main/scala/org/apache/spark/examples/LocalLR.scala +++ b/examples/src/main/scala/org/apache/spark/examples/LocalLR.scala @@ -39,7 +39,7 @@ object LocalLR { def generateData: Array[DataPoint] = { def generatePoint(i: Int): DataPoint = { - val y = if(i % 2 == 0) -1 else 1 + val y = if (i % 2 == 0) -1 else 1 val x = DenseVector.fill(D){rand.nextGaussian + y * R} DataPoint(x, y) } diff --git a/examples/src/main/scala/org/apache/spark/examples/LogQuery.scala b/examples/src/main/scala/org/apache/spark/examples/LogQuery.scala index 32e02eab8b03..75c82117cbad 100644 --- a/examples/src/main/scala/org/apache/spark/examples/LogQuery.scala +++ b/examples/src/main/scala/org/apache/spark/examples/LogQuery.scala @@ -22,7 +22,7 @@ import org.apache.spark.SparkContext._ /** * Executes a roll up-style query against Apache logs. - * + * * Usage: LogQuery [logFile] */ object LogQuery { diff --git a/examples/src/main/scala/org/apache/spark/examples/SparkALS.scala b/examples/src/main/scala/org/apache/spark/examples/SparkALS.scala index 6c0ac8013ce3..30c426155183 100644 --- a/examples/src/main/scala/org/apache/spark/examples/SparkALS.scala +++ b/examples/src/main/scala/org/apache/spark/examples/SparkALS.scala @@ -117,7 +117,7 @@ object SparkALS { var us = Array.fill(U)(randomVector(F)) // Iteratively update movies then users - val Rc = sc.broadcast(R) + val Rc = sc.broadcast(R) var msb = sc.broadcast(ms) var usb = sc.broadcast(us) for (iter <- 1 to ITERATIONS) { diff --git a/examples/src/main/scala/org/apache/spark/examples/SparkLR.scala b/examples/src/main/scala/org/apache/spark/examples/SparkLR.scala index 8c01a6084462..1e6b4fb0c751 100644 --- a/examples/src/main/scala/org/apache/spark/examples/SparkLR.scala +++ b/examples/src/main/scala/org/apache/spark/examples/SparkLR.scala @@ -44,7 +44,7 @@ object SparkLR { def generateData: Array[DataPoint] = { def generatePoint(i: Int): DataPoint = { - val y = if(i % 2 == 0) -1 else 1 + val y = if (i % 2 == 0) -1 else 1 val x = DenseVector.fill(D){rand.nextGaussian + y * R} DataPoint(x, y) } diff --git a/examples/src/main/scala/org/apache/spark/examples/SparkPageRank.scala b/examples/src/main/scala/org/apache/spark/examples/SparkPageRank.scala index 8d092b6506d3..bd7894f184c4 100644 --- a/examples/src/main/scala/org/apache/spark/examples/SparkPageRank.scala +++ b/examples/src/main/scala/org/apache/spark/examples/SparkPageRank.scala @@ -51,7 +51,7 @@ object SparkPageRank { showWarning() val sparkConf = new SparkConf().setAppName("PageRank") - val iters = if (args.length > 0) args(1).toInt else 10 + val iters = if (args.length > 1) args(1).toInt else 10 val ctx = new SparkContext(sparkConf) val lines = ctx.textFile(args(0), 1) val links = lines.map{ s => diff --git a/examples/src/main/scala/org/apache/spark/examples/bagel/PageRankUtils.scala b/examples/src/main/scala/org/apache/spark/examples/bagel/PageRankUtils.scala deleted file mode 100644 index ab6e63deb3c9..000000000000 --- a/examples/src/main/scala/org/apache/spark/examples/bagel/PageRankUtils.scala +++ /dev/null @@ -1,112 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.examples.bagel - -import org.apache.spark._ -import org.apache.spark.bagel._ - -class PageRankUtils extends Serializable { - def computeWithCombiner(numVertices: Long, epsilon: Double)( - self: PRVertex, messageSum: Option[Double], superstep: Int - ): (PRVertex, Array[PRMessage]) = { - val newValue = messageSum match { - case Some(msgSum) if msgSum != 0 => - 0.15 / numVertices + 0.85 * msgSum - case _ => self.value - } - - val terminate = superstep >= 10 - - val outbox: Array[PRMessage] = - if (!terminate) { - self.outEdges.map(targetId => new PRMessage(targetId, newValue / self.outEdges.size)) - } else { - Array[PRMessage]() - } - - (new PRVertex(newValue, self.outEdges, !terminate), outbox) - } - - def computeNoCombiner(numVertices: Long, epsilon: Double) - (self: PRVertex, messages: Option[Array[PRMessage]], superstep: Int) - : (PRVertex, Array[PRMessage]) = - computeWithCombiner(numVertices, epsilon)(self, messages match { - case Some(msgs) => Some(msgs.map(_.value).sum) - case None => None - }, superstep) -} - -class PRCombiner extends Combiner[PRMessage, Double] with Serializable { - def createCombiner(msg: PRMessage): Double = - msg.value - def mergeMsg(combiner: Double, msg: PRMessage): Double = - combiner + msg.value - def mergeCombiners(a: Double, b: Double): Double = - a + b -} - -class PRVertex() extends Vertex with Serializable { - var value: Double = _ - var outEdges: Array[String] = _ - var active: Boolean = _ - - def this(value: Double, outEdges: Array[String], active: Boolean = true) { - this() - this.value = value - this.outEdges = outEdges - this.active = active - } - - override def toString(): String = { - "PRVertex(value=%f, outEdges.length=%d, active=%s)" - .format(value, outEdges.length, active.toString) - } -} - -class PRMessage() extends Message[String] with Serializable { - var targetId: String = _ - var value: Double = _ - - def this(targetId: String, value: Double) { - this() - this.targetId = targetId - this.value = value - } -} - -class CustomPartitioner(partitions: Int) extends Partitioner { - def numPartitions: Int = partitions - - def getPartition(key: Any): Int = { - val hash = key match { - case k: Long => (k & 0x00000000FFFFFFFFL).toInt - case _ => key.hashCode - } - - val mod = key.hashCode % partitions - if (mod < 0) mod + partitions else mod - } - - override def equals(other: Any): Boolean = other match { - case c: CustomPartitioner => - c.numPartitions == numPartitions - case _ => false - } - - override def hashCode: Int = numPartitions -} diff --git a/examples/src/main/scala/org/apache/spark/examples/bagel/WikipediaPageRank.scala b/examples/src/main/scala/org/apache/spark/examples/bagel/WikipediaPageRank.scala deleted file mode 100644 index 859abedf2a55..000000000000 --- a/examples/src/main/scala/org/apache/spark/examples/bagel/WikipediaPageRank.scala +++ /dev/null @@ -1,106 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.examples.bagel - -import org.apache.spark._ -import org.apache.spark.SparkContext._ - -import org.apache.spark.bagel._ - -import scala.xml.{XML,NodeSeq} - -/** - * Run PageRank on XML Wikipedia dumps from http://wiki.freebase.com/wiki/WEX. Uses the "articles" - * files from there, which contains one line per wiki article in a tab-separated format - * (http://wiki.freebase.com/wiki/WEX/Documentation#articles). - */ -object WikipediaPageRank { - def main(args: Array[String]) { - if (args.length < 4) { - System.err.println( - "Usage: WikipediaPageRank ") - System.exit(-1) - } - val sparkConf = new SparkConf() - sparkConf.setAppName("WikipediaPageRank") - sparkConf.registerKryoClasses(Array(classOf[PRVertex], classOf[PRMessage])) - - val inputFile = args(0) - val threshold = args(1).toDouble - val numPartitions = args(2).toInt - val usePartitioner = args(3).toBoolean - - sparkConf.setAppName("WikipediaPageRank") - val sc = new SparkContext(sparkConf) - - // Parse the Wikipedia page data into a graph - val input = sc.textFile(inputFile) - - println("Counting vertices...") - val numVertices = input.count() - println("Done counting vertices.") - - println("Parsing input file...") - var vertices = input.map(line => { - val fields = line.split("\t") - val (title, body) = (fields(1), fields(3).replace("\\n", "\n")) - val links = - if (body == "\\N") { - NodeSeq.Empty - } else { - try { - XML.loadString(body) \\ "link" \ "target" - } catch { - case e: org.xml.sax.SAXParseException => - System.err.println("Article \"" + title + "\" has malformed XML in body:\n" + body) - NodeSeq.Empty - } - } - val outEdges = links.map(link => new String(link.text)).toArray - val id = new String(title) - (id, new PRVertex(1.0 / numVertices, outEdges)) - }) - if (usePartitioner) { - vertices = vertices.partitionBy(new HashPartitioner(sc.defaultParallelism)).cache() - } else { - vertices = vertices.cache() - } - println("Done parsing input file.") - - // Do the computation - val epsilon = 0.01 / numVertices - val messages = sc.parallelize(Array[(String, PRMessage)]()) - val utils = new PageRankUtils - val result = - Bagel.run( - sc, vertices, messages, combiner = new PRCombiner(), - numPartitions = numPartitions)( - utils.computeWithCombiner(numVertices, epsilon)) - - // Print the result - System.err.println("Articles with PageRank >= " + threshold + ":") - val top = - (result - .filter { case (id, vertex) => vertex.value >= threshold } - .map { case (id, vertex) => "%s\t%s\n".format(id, vertex.value) } - .collect().mkString) - println(top) - - sc.stop() - } -} diff --git a/examples/src/main/scala/org/apache/spark/examples/bagel/WikipediaPageRankStandalone.scala b/examples/src/main/scala/org/apache/spark/examples/bagel/WikipediaPageRankStandalone.scala deleted file mode 100644 index 576a3e371b99..000000000000 --- a/examples/src/main/scala/org/apache/spark/examples/bagel/WikipediaPageRankStandalone.scala +++ /dev/null @@ -1,232 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.examples.bagel - -import java.io.{InputStream, OutputStream, DataInputStream, DataOutputStream} -import java.nio.ByteBuffer - -import scala.collection.mutable.ArrayBuffer -import scala.xml.{XML, NodeSeq} - -import org.apache.spark._ -import org.apache.spark.serializer.{DeserializationStream, SerializationStream, SerializerInstance} -import org.apache.spark.SparkContext._ -import org.apache.spark.rdd.RDD - -import scala.reflect.ClassTag - -object WikipediaPageRankStandalone { - def main(args: Array[String]) { - if (args.length < 4) { - System.err.println("Usage: WikipediaPageRankStandalone " + - " ") - System.exit(-1) - } - val sparkConf = new SparkConf() - sparkConf.set("spark.serializer", "spark.bagel.examples.WPRSerializer") - - val inputFile = args(0) - val threshold = args(1).toDouble - val numIterations = args(2).toInt - val usePartitioner = args(3).toBoolean - - sparkConf.setAppName("WikipediaPageRankStandalone") - - val sc = new SparkContext(sparkConf) - - val input = sc.textFile(inputFile) - val partitioner = new HashPartitioner(sc.defaultParallelism) - val links = - if (usePartitioner) { - input.map(parseArticle _).partitionBy(partitioner).cache() - } else { - input.map(parseArticle _).cache() - } - val n = links.count() - val defaultRank = 1.0 / n - val a = 0.15 - - // Do the computation - val startTime = System.currentTimeMillis - val ranks = - pageRank(links, numIterations, defaultRank, a, n, partitioner, usePartitioner, - sc.defaultParallelism) - - // Print the result - System.err.println("Articles with PageRank >= " + threshold + ":") - val top = - (ranks - .filter { case (id, rank) => rank >= threshold } - .map { case (id, rank) => "%s\t%s\n".format(id, rank) } - .collect().mkString) - println(top) - - val time = (System.currentTimeMillis - startTime) / 1000.0 - println("Completed %d iterations in %f seconds: %f seconds per iteration" - .format(numIterations, time, time / numIterations)) - sc.stop() - } - - def parseArticle(line: String): (String, Array[String]) = { - val fields = line.split("\t") - val (title, body) = (fields(1), fields(3).replace("\\n", "\n")) - val id = new String(title) - val links = - if (body == "\\N") { - NodeSeq.Empty - } else { - try { - XML.loadString(body) \\ "link" \ "target" - } catch { - case e: org.xml.sax.SAXParseException => - System.err.println("Article \"" + title + "\" has malformed XML in body:\n" + body) - NodeSeq.Empty - } - } - val outEdges = links.map(link => new String(link.text)).toArray - (id, outEdges) - } - - def pageRank( - links: RDD[(String, Array[String])], - numIterations: Int, - defaultRank: Double, - a: Double, - n: Long, - partitioner: Partitioner, - usePartitioner: Boolean, - numPartitions: Int - ): RDD[(String, Double)] = { - var ranks = links.mapValues { edges => defaultRank } - for (i <- 1 to numIterations) { - val contribs = links.groupWith(ranks).flatMap { - case (id, (linksWrapperIterable, rankWrapperIterable)) => - val linksWrapper = linksWrapperIterable.iterator - val rankWrapper = rankWrapperIterable.iterator - if (linksWrapper.hasNext) { - val linksWrapperHead = linksWrapper.next - if (rankWrapper.hasNext) { - val rankWrapperHead = rankWrapper.next - linksWrapperHead.map(dest => (dest, rankWrapperHead / linksWrapperHead.size)) - } else { - linksWrapperHead.map(dest => (dest, defaultRank / linksWrapperHead.size)) - } - } else { - Array[(String, Double)]() - } - } - ranks = (contribs.combineByKey((x: Double) => x, - (x: Double, y: Double) => x + y, - (x: Double, y: Double) => x + y, - partitioner) - .mapValues(sum => a/n + (1-a)*sum)) - } - ranks - } -} - -class WPRSerializer extends org.apache.spark.serializer.Serializer { - def newInstance(): SerializerInstance = new WPRSerializerInstance() -} - -class WPRSerializerInstance extends SerializerInstance { - def serialize[T: ClassTag](t: T): ByteBuffer = { - throw new UnsupportedOperationException() - } - - def deserialize[T: ClassTag](bytes: ByteBuffer): T = { - throw new UnsupportedOperationException() - } - - def deserialize[T: ClassTag](bytes: ByteBuffer, loader: ClassLoader): T = { - throw new UnsupportedOperationException() - } - - def serializeStream(s: OutputStream): SerializationStream = { - new WPRSerializationStream(s) - } - - def deserializeStream(s: InputStream): DeserializationStream = { - new WPRDeserializationStream(s) - } -} - -class WPRSerializationStream(os: OutputStream) extends SerializationStream { - val dos = new DataOutputStream(os) - - def writeObject[T: ClassTag](t: T): SerializationStream = t match { - case (id: String, wrapper: ArrayBuffer[_]) => wrapper(0) match { - case links: Array[String] => { - dos.writeInt(0) // links - dos.writeUTF(id) - dos.writeInt(links.length) - for (link <- links) { - dos.writeUTF(link) - } - this - } - case rank: Double => { - dos.writeInt(1) // rank - dos.writeUTF(id) - dos.writeDouble(rank) - this - } - } - case (id: String, rank: Double) => { - dos.writeInt(2) // rank without wrapper - dos.writeUTF(id) - dos.writeDouble(rank) - this - } - } - - def flush() { dos.flush() } - def close() { dos.close() } -} - -class WPRDeserializationStream(is: InputStream) extends DeserializationStream { - val dis = new DataInputStream(is) - - def readObject[T: ClassTag](): T = { - val typeId = dis.readInt() - typeId match { - case 0 => { - val id = dis.readUTF() - val numLinks = dis.readInt() - val links = new Array[String](numLinks) - for (i <- 0 until numLinks) { - val link = dis.readUTF() - links(i) = link - } - (id, ArrayBuffer(links)).asInstanceOf[T] - } - case 1 => { - val id = dis.readUTF() - val rank = dis.readDouble() - (id, ArrayBuffer(rank)).asInstanceOf[T] - } - case 2 => { - val id = dis.readUTF() - val rank = dis.readDouble() - (id, rank).asInstanceOf[T] - } - } - } - - def close() { dis.close() } -} diff --git a/examples/src/main/scala/org/apache/spark/examples/ml/DeveloperApiExample.scala b/examples/src/main/scala/org/apache/spark/examples/ml/DeveloperApiExample.scala index 3ee456edbe01..7b8cc21ed898 100644 --- a/examples/src/main/scala/org/apache/spark/examples/ml/DeveloperApiExample.scala +++ b/examples/src/main/scala/org/apache/spark/examples/ml/DeveloperApiExample.scala @@ -130,6 +130,8 @@ private class MyLogisticRegression(override val uid: String) // Create a model, and return it. new MyLogisticRegressionModel(uid, weights).setParent(this) } + + override def copy(extra: ParamMap): MyLogisticRegression = defaultCopy(extra) } /** diff --git a/examples/src/main/scala/org/apache/spark/examples/ml/LinearRegressionExample.scala b/examples/src/main/scala/org/apache/spark/examples/ml/LinearRegressionExample.scala new file mode 100644 index 000000000000..b54466fd48bc --- /dev/null +++ b/examples/src/main/scala/org/apache/spark/examples/ml/LinearRegressionExample.scala @@ -0,0 +1,142 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.examples.ml + +import scala.collection.mutable +import scala.language.reflectiveCalls + +import scopt.OptionParser + +import org.apache.spark.{SparkConf, SparkContext} +import org.apache.spark.examples.mllib.AbstractParams +import org.apache.spark.ml.{Pipeline, PipelineStage} +import org.apache.spark.ml.regression.{LinearRegression, LinearRegressionModel} +import org.apache.spark.sql.DataFrame + +/** + * An example runner for linear regression with elastic-net (mixing L1/L2) regularization. + * Run with + * {{{ + * bin/run-example ml.LinearRegressionExample [options] + * }}} + * A synthetic dataset can be found at `data/mllib/sample_linear_regression_data.txt` which can be + * trained by + * {{{ + * bin/run-example ml.LinearRegressionExample --regParam 0.15 --elasticNetParam 1.0 \ + * data/mllib/sample_linear_regression_data.txt + * }}} + * If you use it as a template to create your own app, please use `spark-submit` to submit your app. + */ +object LinearRegressionExample { + + case class Params( + input: String = null, + testInput: String = "", + dataFormat: String = "libsvm", + regParam: Double = 0.0, + elasticNetParam: Double = 0.0, + maxIter: Int = 100, + tol: Double = 1E-6, + fracTest: Double = 0.2) extends AbstractParams[Params] + + def main(args: Array[String]) { + val defaultParams = Params() + + val parser = new OptionParser[Params]("LinearRegressionExample") { + head("LinearRegressionExample: an example Linear Regression with Elastic-Net app.") + opt[Double]("regParam") + .text(s"regularization parameter, default: ${defaultParams.regParam}") + .action((x, c) => c.copy(regParam = x)) + opt[Double]("elasticNetParam") + .text(s"ElasticNet mixing parameter. For alpha = 0, the penalty is an L2 penalty. " + + s"For alpha = 1, it is an L1 penalty. For 0 < alpha < 1, the penalty is a combination of " + + s"L1 and L2, default: ${defaultParams.elasticNetParam}") + .action((x, c) => c.copy(elasticNetParam = x)) + opt[Int]("maxIter") + .text(s"maximum number of iterations, default: ${defaultParams.maxIter}") + .action((x, c) => c.copy(maxIter = x)) + opt[Double]("tol") + .text(s"the convergence tolerance of iterations, Smaller value will lead " + + s"to higher accuracy with the cost of more iterations, default: ${defaultParams.tol}") + .action((x, c) => c.copy(tol = x)) + opt[Double]("fracTest") + .text(s"fraction of data to hold out for testing. If given option testInput, " + + s"this option is ignored. default: ${defaultParams.fracTest}") + .action((x, c) => c.copy(fracTest = x)) + opt[String]("testInput") + .text(s"input path to test dataset. If given, option fracTest is ignored." + + s" default: ${defaultParams.testInput}") + .action((x, c) => c.copy(testInput = x)) + opt[String]("dataFormat") + .text("data format: libsvm (default), dense (deprecated in Spark v1.1)") + .action((x, c) => c.copy(dataFormat = x)) + arg[String]("") + .text("input path to labeled examples") + .required() + .action((x, c) => c.copy(input = x)) + checkConfig { params => + if (params.fracTest < 0 || params.fracTest >= 1) { + failure(s"fracTest ${params.fracTest} value incorrect; should be in [0,1).") + } else { + success + } + } + } + + parser.parse(args, defaultParams).map { params => + run(params) + }.getOrElse { + sys.exit(1) + } + } + + def run(params: Params) { + val conf = new SparkConf().setAppName(s"LinearRegressionExample with $params") + val sc = new SparkContext(conf) + + println(s"LinearRegressionExample with parameters:\n$params") + + // Load training and test data and cache it. + val (training: DataFrame, test: DataFrame) = DecisionTreeExample.loadDatasets(sc, params.input, + params.dataFormat, params.testInput, "regression", params.fracTest) + + val lir = new LinearRegression() + .setFeaturesCol("features") + .setLabelCol("label") + .setRegParam(params.regParam) + .setElasticNetParam(params.elasticNetParam) + .setMaxIter(params.maxIter) + .setTol(params.tol) + + // Train the model + val startTime = System.nanoTime() + val lirModel = lir.fit(training) + val elapsedTime = (System.nanoTime() - startTime) / 1e9 + println(s"Training time: $elapsedTime seconds") + + // Print the weights and intercept for linear regression. + println(s"Weights: ${lirModel.weights} Intercept: ${lirModel.intercept}") + + println("Training data results:") + DecisionTreeExample.evaluateRegressionModel(lirModel, training, "label") + println("Test data results:") + DecisionTreeExample.evaluateRegressionModel(lirModel, test, "label") + + sc.stop() + } +} diff --git a/examples/src/main/scala/org/apache/spark/examples/ml/LogisticRegressionExample.scala b/examples/src/main/scala/org/apache/spark/examples/ml/LogisticRegressionExample.scala new file mode 100644 index 000000000000..3cf193f353fb --- /dev/null +++ b/examples/src/main/scala/org/apache/spark/examples/ml/LogisticRegressionExample.scala @@ -0,0 +1,159 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.examples.ml + +import scala.collection.mutable +import scala.language.reflectiveCalls + +import scopt.OptionParser + +import org.apache.spark.{SparkConf, SparkContext} +import org.apache.spark.examples.mllib.AbstractParams +import org.apache.spark.ml.{Pipeline, PipelineStage} +import org.apache.spark.ml.classification.{LogisticRegression, LogisticRegressionModel} +import org.apache.spark.ml.feature.StringIndexer +import org.apache.spark.sql.DataFrame + +/** + * An example runner for logistic regression with elastic-net (mixing L1/L2) regularization. + * Run with + * {{{ + * bin/run-example ml.LogisticRegressionExample [options] + * }}} + * A synthetic dataset can be found at `data/mllib/sample_libsvm_data.txt` which can be + * trained by + * {{{ + * bin/run-example ml.LogisticRegressionExample --regParam 0.3 --elasticNetParam 0.8 \ + * data/mllib/sample_libsvm_data.txt + * }}} + * If you use it as a template to create your own app, please use `spark-submit` to submit your app. + */ +object LogisticRegressionExample { + + case class Params( + input: String = null, + testInput: String = "", + dataFormat: String = "libsvm", + regParam: Double = 0.0, + elasticNetParam: Double = 0.0, + maxIter: Int = 100, + fitIntercept: Boolean = true, + tol: Double = 1E-6, + fracTest: Double = 0.2) extends AbstractParams[Params] + + def main(args: Array[String]) { + val defaultParams = Params() + + val parser = new OptionParser[Params]("LogisticRegressionExample") { + head("LogisticRegressionExample: an example Logistic Regression with Elastic-Net app.") + opt[Double]("regParam") + .text(s"regularization parameter, default: ${defaultParams.regParam}") + .action((x, c) => c.copy(regParam = x)) + opt[Double]("elasticNetParam") + .text(s"ElasticNet mixing parameter. For alpha = 0, the penalty is an L2 penalty. " + + s"For alpha = 1, it is an L1 penalty. For 0 < alpha < 1, the penalty is a combination of " + + s"L1 and L2, default: ${defaultParams.elasticNetParam}") + .action((x, c) => c.copy(elasticNetParam = x)) + opt[Int]("maxIter") + .text(s"maximum number of iterations, default: ${defaultParams.maxIter}") + .action((x, c) => c.copy(maxIter = x)) + opt[Boolean]("fitIntercept") + .text(s"whether to fit an intercept term, default: ${defaultParams.fitIntercept}") + .action((x, c) => c.copy(fitIntercept = x)) + opt[Double]("tol") + .text(s"the convergence tolerance of iterations, Smaller value will lead " + + s"to higher accuracy with the cost of more iterations, default: ${defaultParams.tol}") + .action((x, c) => c.copy(tol = x)) + opt[Double]("fracTest") + .text(s"fraction of data to hold out for testing. If given option testInput, " + + s"this option is ignored. default: ${defaultParams.fracTest}") + .action((x, c) => c.copy(fracTest = x)) + opt[String]("testInput") + .text(s"input path to test dataset. If given, option fracTest is ignored." + + s" default: ${defaultParams.testInput}") + .action((x, c) => c.copy(testInput = x)) + opt[String]("dataFormat") + .text("data format: libsvm (default), dense (deprecated in Spark v1.1)") + .action((x, c) => c.copy(dataFormat = x)) + arg[String]("") + .text("input path to labeled examples") + .required() + .action((x, c) => c.copy(input = x)) + checkConfig { params => + if (params.fracTest < 0 || params.fracTest >= 1) { + failure(s"fracTest ${params.fracTest} value incorrect; should be in [0,1).") + } else { + success + } + } + } + + parser.parse(args, defaultParams).map { params => + run(params) + }.getOrElse { + sys.exit(1) + } + } + + def run(params: Params) { + val conf = new SparkConf().setAppName(s"LogisticRegressionExample with $params") + val sc = new SparkContext(conf) + + println(s"LogisticRegressionExample with parameters:\n$params") + + // Load training and test data and cache it. + val (training: DataFrame, test: DataFrame) = DecisionTreeExample.loadDatasets(sc, params.input, + params.dataFormat, params.testInput, "classification", params.fracTest) + + // Set up Pipeline + val stages = new mutable.ArrayBuffer[PipelineStage]() + + val labelIndexer = new StringIndexer() + .setInputCol("labelString") + .setOutputCol("indexedLabel") + stages += labelIndexer + + val lor = new LogisticRegression() + .setFeaturesCol("features") + .setLabelCol("indexedLabel") + .setRegParam(params.regParam) + .setElasticNetParam(params.elasticNetParam) + .setMaxIter(params.maxIter) + .setTol(params.tol) + + stages += lor + val pipeline = new Pipeline().setStages(stages.toArray) + + // Fit the Pipeline + val startTime = System.nanoTime() + val pipelineModel = pipeline.fit(training) + val elapsedTime = (System.nanoTime() - startTime) / 1e9 + println(s"Training time: $elapsedTime seconds") + + val lorModel = pipelineModel.stages.last.asInstanceOf[LogisticRegressionModel] + // Print the weights and intercept for logistic regression. + println(s"Weights: ${lorModel.weights} Intercept: ${lorModel.intercept}") + + println("Training data results:") + DecisionTreeExample.evaluateClassificationModel(pipelineModel, training, "indexedLabel") + println("Test data results:") + DecisionTreeExample.evaluateClassificationModel(pipelineModel, test, "indexedLabel") + + sc.stop() + } +} diff --git a/examples/src/main/scala/org/apache/spark/examples/ml/OneVsRestExample.scala b/examples/src/main/scala/org/apache/spark/examples/ml/OneVsRestExample.scala index b99d0a124601..6927eb8f275c 100644 --- a/examples/src/main/scala/org/apache/spark/examples/ml/OneVsRestExample.scala +++ b/examples/src/main/scala/org/apache/spark/examples/ml/OneVsRestExample.scala @@ -73,7 +73,7 @@ object OneVsRestExample { .action((x, c) => c.copy(fracTest = x)) opt[String]("testInput") .text("input path to test dataset. If given, option fracTest is ignored") - .action((x,c) => c.copy(testInput = Some(x))) + .action((x, c) => c.copy(testInput = Some(x))) opt[Int]("maxIter") .text(s"maximum number of iterations for Logistic Regression." + s" default: ${defaultParams.maxIter}") @@ -88,10 +88,10 @@ object OneVsRestExample { .action((x, c) => c.copy(fitIntercept = x)) opt[Double]("regParam") .text(s"the regularization parameter for Logistic Regression.") - .action((x,c) => c.copy(regParam = Some(x))) + .action((x, c) => c.copy(regParam = Some(x))) opt[Double]("elasticNetParam") .text(s"the ElasticNet mixing parameter for Logistic Regression.") - .action((x,c) => c.copy(elasticNetParam = Some(x))) + .action((x, c) => c.copy(elasticNetParam = Some(x))) checkConfig { params => if (params.fracTest < 0 || params.fracTest >= 1) { failure(s"fracTest ${params.fracTest} value incorrect; should be in [0,1).") diff --git a/examples/src/main/scala/org/apache/spark/examples/ml/SimpleParamsExample.scala b/examples/src/main/scala/org/apache/spark/examples/ml/SimpleParamsExample.scala index e8a991f50e33..a0561e2573fc 100644 --- a/examples/src/main/scala/org/apache/spark/examples/ml/SimpleParamsExample.scala +++ b/examples/src/main/scala/org/apache/spark/examples/ml/SimpleParamsExample.scala @@ -87,7 +87,7 @@ object SimpleParamsExample { LabeledPoint(1.0, Vectors.dense(0.0, 2.2, -1.5)))) // Make predictions on test data using the Transformer.transform() method. - // LogisticRegression.transform will only use the 'features' column. + // LogisticRegressionModel.transform will only use the 'features' column. // Note that model2.transform() outputs a 'myProbability' column instead of the usual // 'probability' column since we renamed the lr.probabilityCol parameter previously. model2.transform(test.toDF()) diff --git a/examples/src/main/scala/org/apache/spark/examples/mllib/DatasetExample.scala b/examples/src/main/scala/org/apache/spark/examples/mllib/DatasetExample.scala index e943d6c889fa..520893b26d59 100644 --- a/examples/src/main/scala/org/apache/spark/examples/mllib/DatasetExample.scala +++ b/examples/src/main/scala/org/apache/spark/examples/mllib/DatasetExample.scala @@ -103,10 +103,10 @@ object DatasetExample { tmpDir.deleteOnExit() val outputDir = new File(tmpDir, "dataset").toString println(s"Saving to $outputDir as Parquet file.") - df.saveAsParquetFile(outputDir) + df.write.parquet(outputDir) println(s"Loading Parquet file with UDT from $outputDir.") - val newDataset = sqlContext.parquetFile(outputDir) + val newDataset = sqlContext.read.parquet(outputDir) println(s"Schema from Parquet: ${newDataset.schema.prettyJson}") val newFeatures = newDataset.select("features").map { case Row(v: Vector) => v } diff --git a/examples/src/main/scala/org/apache/spark/examples/mllib/DecisionTreeRunner.scala b/examples/src/main/scala/org/apache/spark/examples/mllib/DecisionTreeRunner.scala index b0613632c994..3381941673db 100644 --- a/examples/src/main/scala/org/apache/spark/examples/mllib/DecisionTreeRunner.scala +++ b/examples/src/main/scala/org/apache/spark/examples/mllib/DecisionTreeRunner.scala @@ -22,7 +22,6 @@ import scala.language.reflectiveCalls import scopt.OptionParser import org.apache.spark.{SparkConf, SparkContext} -import org.apache.spark.SparkContext._ import org.apache.spark.mllib.evaluation.MulticlassMetrics import org.apache.spark.mllib.linalg.Vector import org.apache.spark.mllib.regression.LabeledPoint @@ -354,7 +353,11 @@ object DecisionTreeRunner { /** * Calculates the mean squared error for regression. + * + * This is just for demo purpose. In general, don't copy this code because it is NOT efficient + * due to the use of structural types, which leads to one reflection call per record. */ + // scalastyle:off structural.type private[mllib] def meanSquaredError( model: { def predict(features: Vector): Double }, data: RDD[LabeledPoint]): Double = { @@ -363,4 +366,5 @@ object DecisionTreeRunner { err * err }.mean() } + // scalastyle:on structural.type } diff --git a/examples/src/main/scala/org/apache/spark/examples/mllib/DenseGaussianMixture.scala b/examples/src/main/scala/org/apache/spark/examples/mllib/DenseGaussianMixture.scala index df76b45e5081..f8c71ccabc43 100644 --- a/examples/src/main/scala/org/apache/spark/examples/mllib/DenseGaussianMixture.scala +++ b/examples/src/main/scala/org/apache/spark/examples/mllib/DenseGaussianMixture.scala @@ -40,23 +40,23 @@ object DenseGaussianMixture { private def run(inputFile: String, k: Int, convergenceTol: Double, maxIterations: Int) { val conf = new SparkConf().setAppName("Gaussian Mixture Model EM example") - val ctx = new SparkContext(conf) - + val ctx = new SparkContext(conf) + val data = ctx.textFile(inputFile).map { line => Vectors.dense(line.trim.split(' ').map(_.toDouble)) }.cache() - + val clusters = new GaussianMixture() .setK(k) .setConvergenceTol(convergenceTol) .setMaxIterations(maxIterations) .run(data) - + for (i <- 0 until clusters.k) { - println("weight=%f\nmu=%s\nsigma=\n%s\n" format + println("weight=%f\nmu=%s\nsigma=\n%s\n" format (clusters.weights(i), clusters.gaussians(i).mu, clusters.gaussians(i).sigma)) } - + println("Cluster labels (first <= 100):") val clusterLabels = clusters.predict(data) clusterLabels.take(100).foreach { x => diff --git a/examples/src/main/scala/org/apache/spark/examples/pythonconverters/AvroConverters.scala b/examples/src/main/scala/org/apache/spark/examples/pythonconverters/AvroConverters.scala index a11890d6f2b1..3ebb112fc069 100644 --- a/examples/src/main/scala/org/apache/spark/examples/pythonconverters/AvroConverters.scala +++ b/examples/src/main/scala/org/apache/spark/examples/pythonconverters/AvroConverters.scala @@ -36,22 +36,21 @@ object AvroConversionUtil extends Serializable { return null } schema.getType match { - case UNION => unpackUnion(obj, schema) - case ARRAY => unpackArray(obj, schema) - case FIXED => unpackFixed(obj, schema) - case MAP => unpackMap(obj, schema) - case BYTES => unpackBytes(obj) - case RECORD => unpackRecord(obj) - case STRING => obj.toString - case ENUM => obj.toString - case NULL => obj + case UNION => unpackUnion(obj, schema) + case ARRAY => unpackArray(obj, schema) + case FIXED => unpackFixed(obj, schema) + case MAP => unpackMap(obj, schema) + case BYTES => unpackBytes(obj) + case RECORD => unpackRecord(obj) + case STRING => obj.toString + case ENUM => obj.toString + case NULL => obj case BOOLEAN => obj - case DOUBLE => obj - case FLOAT => obj - case INT => obj - case LONG => obj - case other => throw new SparkException( - s"Unknown Avro schema type ${other.getName}") + case DOUBLE => obj + case FLOAT => obj + case INT => obj + case LONG => obj + case other => throw new SparkException(s"Unknown Avro schema type ${other.getName}") } } diff --git a/examples/src/main/scala/org/apache/spark/examples/pythonconverters/HBaseConverters.scala b/examples/src/main/scala/org/apache/spark/examples/pythonconverters/HBaseConverters.scala index 273bee0a8b30..90d48a64106c 100644 --- a/examples/src/main/scala/org/apache/spark/examples/pythonconverters/HBaseConverters.scala +++ b/examples/src/main/scala/org/apache/spark/examples/pythonconverters/HBaseConverters.scala @@ -18,20 +18,34 @@ package org.apache.spark.examples.pythonconverters import scala.collection.JavaConversions._ +import scala.util.parsing.json.JSONObject import org.apache.spark.api.python.Converter import org.apache.hadoop.hbase.client.{Put, Result} import org.apache.hadoop.hbase.io.ImmutableBytesWritable import org.apache.hadoop.hbase.util.Bytes +import org.apache.hadoop.hbase.KeyValue.Type +import org.apache.hadoop.hbase.CellUtil /** - * Implementation of [[org.apache.spark.api.python.Converter]] that converts an - * HBase Result to a String + * Implementation of [[org.apache.spark.api.python.Converter]] that converts all + * the records in an HBase Result to a String */ class HBaseResultToStringConverter extends Converter[Any, String] { override def convert(obj: Any): String = { + import collection.JavaConverters._ val result = obj.asInstanceOf[Result] - Bytes.toStringBinary(result.value()) + val output = result.listCells.asScala.map(cell => + Map( + "row" -> Bytes.toStringBinary(CellUtil.cloneRow(cell)), + "columnFamily" -> Bytes.toStringBinary(CellUtil.cloneFamily(cell)), + "qualifier" -> Bytes.toStringBinary(CellUtil.cloneQualifier(cell)), + "timestamp" -> cell.getTimestamp.toString, + "type" -> Type.codeToType(cell.getTypeByte).toString, + "value" -> Bytes.toStringBinary(CellUtil.cloneValue(cell)) + ) + ) + output.map(JSONObject(_).toString()).mkString("\n") } } diff --git a/examples/src/main/scala/org/apache/spark/examples/sql/RDDRelation.scala b/examples/src/main/scala/org/apache/spark/examples/sql/RDDRelation.scala index 6331d1c0060f..b11e32047dc3 100644 --- a/examples/src/main/scala/org/apache/spark/examples/sql/RDDRelation.scala +++ b/examples/src/main/scala/org/apache/spark/examples/sql/RDDRelation.scala @@ -58,10 +58,10 @@ object RDDRelation { df.where($"key" === 1).orderBy($"value".asc).select($"key").collect().foreach(println) // Write out an RDD as a parquet file. - df.saveAsParquetFile("pair.parquet") + df.write.parquet("pair.parquet") // Read in parquet file. Parquet files are self-describing so the schmema is preserved. - val parquetFile = sqlContext.parquetFile("pair.parquet") + val parquetFile = sqlContext.read.parquet("pair.parquet") // Queries can be run using the DSL on parequet files just like the original RDD. parquetFile.where($"key" === 1).select($"value".as("a")).collect().foreach(println) diff --git a/examples/src/main/scala/org/apache/spark/examples/streaming/ActorWordCount.scala b/examples/src/main/scala/org/apache/spark/examples/streaming/ActorWordCount.scala index 92867b44be13..016de4c63d1d 100644 --- a/examples/src/main/scala/org/apache/spark/examples/streaming/ActorWordCount.scala +++ b/examples/src/main/scala/org/apache/spark/examples/streaming/ActorWordCount.scala @@ -104,10 +104,8 @@ extends Actor with ActorHelper { object FeederActor { def main(args: Array[String]) { - if(args.length < 2){ - System.err.println( - "Usage: FeederActor \n" - ) + if (args.length < 2){ + System.err.println("Usage: FeederActor \n") System.exit(1) } val Seq(host, port) = args.toSeq diff --git a/examples/scala-2.10/src/main/scala/org/apache/spark/examples/streaming/DirectKafkaWordCount.scala b/examples/src/main/scala/org/apache/spark/examples/streaming/DirectKafkaWordCount.scala similarity index 97% rename from examples/scala-2.10/src/main/scala/org/apache/spark/examples/streaming/DirectKafkaWordCount.scala rename to examples/src/main/scala/org/apache/spark/examples/streaming/DirectKafkaWordCount.scala index 11a8cf09533c..fbe394de4a17 100644 --- a/examples/scala-2.10/src/main/scala/org/apache/spark/examples/streaming/DirectKafkaWordCount.scala +++ b/examples/src/main/scala/org/apache/spark/examples/streaming/DirectKafkaWordCount.scala @@ -51,7 +51,7 @@ object DirectKafkaWordCount { // Create context with 2 second batch interval val sparkConf = new SparkConf().setAppName("DirectKafkaWordCount") - val ssc = new StreamingContext(sparkConf, Seconds(2)) + val ssc = new StreamingContext(sparkConf, Seconds(2)) // Create direct kafka stream with brokers and topics val topicsSet = topics.split(",").toSet diff --git a/examples/scala-2.10/src/main/scala/org/apache/spark/examples/streaming/KafkaWordCount.scala b/examples/src/main/scala/org/apache/spark/examples/streaming/KafkaWordCount.scala similarity index 95% rename from examples/scala-2.10/src/main/scala/org/apache/spark/examples/streaming/KafkaWordCount.scala rename to examples/src/main/scala/org/apache/spark/examples/streaming/KafkaWordCount.scala index f407367a54f6..60416ee34354 100644 --- a/examples/scala-2.10/src/main/scala/org/apache/spark/examples/streaming/KafkaWordCount.scala +++ b/examples/src/main/scala/org/apache/spark/examples/streaming/KafkaWordCount.scala @@ -49,10 +49,10 @@ object KafkaWordCount { val Array(zkQuorum, group, topics, numThreads) = args val sparkConf = new SparkConf().setAppName("KafkaWordCount") - val ssc = new StreamingContext(sparkConf, Seconds(2)) + val ssc = new StreamingContext(sparkConf, Seconds(2)) ssc.checkpoint("checkpoint") - val topicMap = topics.split(",").map((_,numThreads.toInt)).toMap + val topicMap = topics.split(",").map((_, numThreads.toInt)).toMap val lines = KafkaUtils.createStream(ssc, zkQuorum, group, topicMap).map(_._2) val words = lines.flatMap(_.split(" ")) val wordCounts = words.map(x => (x, 1L)) @@ -96,7 +96,7 @@ object KafkaWordCountProducer { producer.send(message) } - Thread.sleep(100) + Thread.sleep(1000) } } diff --git a/examples/src/main/scala/org/apache/spark/examples/streaming/MQTTWordCount.scala b/examples/src/main/scala/org/apache/spark/examples/streaming/MQTTWordCount.scala index 85b9a54b40ba..813c8554f519 100644 --- a/examples/src/main/scala/org/apache/spark/examples/streaming/MQTTWordCount.scala +++ b/examples/src/main/scala/org/apache/spark/examples/streaming/MQTTWordCount.scala @@ -40,7 +40,7 @@ object MQTTPublisher { StreamingExamples.setStreamingLogLevels() val Seq(brokerUrl, topic) = args.toSeq - + var client: MqttClient = null try { @@ -49,7 +49,7 @@ object MQTTPublisher { client.connect() - val msgtopic = client.getTopic(topic) + val msgtopic = client.getTopic(topic) val msgContent = "hello mqtt demo for spark streaming" val message = new MqttMessage(msgContent.getBytes("utf-8")) @@ -59,10 +59,10 @@ object MQTTPublisher { println(s"Published data. topic: ${msgtopic.getName()}; Message: $message") } catch { case e: MqttException if e.getReasonCode == MqttException.REASON_CODE_MAX_INFLIGHT => - Thread.sleep(10) + Thread.sleep(10) println("Queue is full, wait for to consume data from the message queue") - } - } + } + } } catch { case e: MqttException => println("Exception Caught: " + e) } finally { @@ -107,7 +107,7 @@ object MQTTWordCount { val lines = MQTTUtils.createStream(ssc, brokerUrl, topic, StorageLevel.MEMORY_ONLY_SER_2) val words = lines.flatMap(x => x.split(" ")) val wordCounts = words.map(x => (x, 1)).reduceByKey(_ + _) - + wordCounts.print() ssc.start() ssc.awaitTermination() diff --git a/examples/src/main/scala/org/apache/spark/examples/streaming/clickstream/PageViewGenerator.scala b/examples/src/main/scala/org/apache/spark/examples/streaming/clickstream/PageViewGenerator.scala index 54d996b8ac99..889f052c7026 100644 --- a/examples/src/main/scala/org/apache/spark/examples/streaming/clickstream/PageViewGenerator.scala +++ b/examples/src/main/scala/org/apache/spark/examples/streaming/clickstream/PageViewGenerator.scala @@ -57,8 +57,7 @@ object PageViewGenerator { 404 -> .05) val userZipCode = Map(94709 -> .5, 94117 -> .5) - val userID = Map((1 to 100).map(_ -> .01):_*) - + val userID = Map((1 to 100).map(_ -> .01) : _*) def pickFromDistribution[T](inputMap : Map[T, Double]) : T = { val rand = new Random().nextDouble() diff --git a/external/flume-assembly/pom.xml b/external/flume-assembly/pom.xml new file mode 100644 index 000000000000..8565cd83edfa --- /dev/null +++ b/external/flume-assembly/pom.xml @@ -0,0 +1,135 @@ + + + + + 4.0.0 + + org.apache.spark + spark-parent_2.10 + 1.5.0-SNAPSHOT + ../../pom.xml + + + org.apache.spark + spark-streaming-flume-assembly_2.10 + jar + Spark Project External Flume Assembly + http://spark.apache.org/ + + + streaming-flume-assembly + + + + + org.apache.spark + spark-streaming-flume_${scala.binary.version} + ${project.version} + + + org.apache.spark + spark-streaming_${scala.binary.version} + ${project.version} + provided + + + org.apache.avro + avro + ${avro.version} + + + org.apache.avro + avro-ipc + ${avro.version} + + + io.netty + netty + + + org.mortbay.jetty + jetty + + + org.mortbay.jetty + jetty-util + + + org.mortbay.jetty + servlet-api + + + org.apache.velocity + velocity + + + + + + + target/scala-${scala.binary.version}/classes + target/scala-${scala.binary.version}/test-classes + + + org.apache.maven.plugins + maven-shade-plugin + + false + ${project.build.directory}/scala-${scala.binary.version}/spark-streaming-flume-assembly-${project.version}.jar + + + *:* + + + + + *:* + + META-INF/*.SF + META-INF/*.DSA + META-INF/*.RSA + + + + + + + package + + shade + + + + + + reference.conf + + + log4j.properties + + + + + + + + + + + + diff --git a/external/flume-sink/pom.xml b/external/flume-sink/pom.xml index 1f3e619d97a2..0664cfb2021e 100644 --- a/external/flume-sink/pom.xml +++ b/external/flume-sink/pom.xml @@ -21,7 +21,7 @@ org.apache.spark spark-parent_2.10 - 1.4.0-SNAPSHOT + 1.5.0-SNAPSHOT ../../pom.xml @@ -35,22 +35,49 @@ http://spark.apache.org/ - - org.apache.commons - commons-lang3 - org.apache.flume flume-ng-sdk + + + + com.google.guava + guava + + + + org.apache.thrift + libthrift + + org.apache.flume flume-ng-core + + + com.google.guava + guava + + + org.apache.thrift + libthrift + + org.scala-lang scala-library + + + com.google.guava + guava + test + + + + diff --git a/external/flume-sink/src/main/scala/org/apache/spark/streaming/flume/sink/SparkAvroCallbackHandler.scala b/external/flume-sink/src/main/scala/org/apache/spark/streaming/flume/sink/SparkAvroCallbackHandler.scala index fd01807fc3ac..719fca0938b3 100644 --- a/external/flume-sink/src/main/scala/org/apache/spark/streaming/flume/sink/SparkAvroCallbackHandler.scala +++ b/external/flume-sink/src/main/scala/org/apache/spark/streaming/flume/sink/SparkAvroCallbackHandler.scala @@ -16,14 +16,13 @@ */ package org.apache.spark.streaming.flume.sink +import java.util.UUID import java.util.concurrent.{CountDownLatch, Executors} import java.util.concurrent.atomic.AtomicLong import scala.collection.mutable -import com.google.common.util.concurrent.ThreadFactoryBuilder import org.apache.flume.Channel -import org.apache.commons.lang3.RandomStringUtils /** * Class that implements the SparkFlumeProtocol, that is used by the Avro Netty Server to process @@ -45,8 +44,7 @@ import org.apache.commons.lang3.RandomStringUtils private[flume] class SparkAvroCallbackHandler(val threads: Int, val channel: Channel, val transactionTimeout: Int, val backOffInterval: Int) extends SparkFlumeProtocol with Logging { val transactionExecutorOpt = Option(Executors.newFixedThreadPool(threads, - new ThreadFactoryBuilder().setDaemon(true) - .setNameFormat("Spark Sink Processor Thread - %d").build())) + new SparkSinkThreadFactory("Spark Sink Processor Thread - %d"))) // Protected by `sequenceNumberToProcessor` private val sequenceNumberToProcessor = mutable.HashMap[CharSequence, TransactionProcessor]() // This sink will not persist sequence numbers and reuses them if it gets restarted. @@ -55,7 +53,7 @@ private[flume] class SparkAvroCallbackHandler(val threads: Int, val channel: Cha // Since the new txn may not have the same sequence number we must guard against accidentally // committing a new transaction. To reduce the probability of that happening a random string is // prepended to the sequence number. Does not change for life of sink - private val seqBase = RandomStringUtils.randomAlphanumeric(8) + private val seqBase = UUID.randomUUID().toString.substring(0, 8) private val seqCounter = new AtomicLong(0) // Protected by `sequenceNumberToProcessor` diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ExpressionOptimizationSuite.scala b/external/flume-sink/src/main/scala/org/apache/spark/streaming/flume/sink/SparkSinkThreadFactory.scala similarity index 51% rename from sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ExpressionOptimizationSuite.scala rename to external/flume-sink/src/main/scala/org/apache/spark/streaming/flume/sink/SparkSinkThreadFactory.scala index a4a3a66b8b22..845fc8debda7 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ExpressionOptimizationSuite.scala +++ b/external/flume-sink/src/main/scala/org/apache/spark/streaming/flume/sink/SparkSinkThreadFactory.scala @@ -14,23 +14,22 @@ * See the License for the specific language governing permissions and * limitations under the License. */ +package org.apache.spark.streaming.flume.sink -package org.apache.spark.sql.catalyst.optimizer - -import org.apache.spark.sql.catalyst.expressions._ -import org.apache.spark.sql.catalyst.plans.logical._ +import java.util.concurrent.ThreadFactory +import java.util.concurrent.atomic.AtomicLong /** - * Overrides our expression evaluation tests and reruns them after optimization has occured. This - * is to ensure that constant folding and other optimizations do not break anything. + * Thread factory that generates daemon threads with a specified name format. */ -class ExpressionOptimizationSuite extends ExpressionEvaluationSuite { - override def checkEvaluation( - expression: Expression, - expected: Any, - inputRow: Row = EmptyRow): Unit = { - val plan = Project(Alias(expression, s"Optimized($expression)")() :: Nil, OneRowRelation) - val optimizedPlan = DefaultOptimizer.execute(plan) - super.checkEvaluation(optimizedPlan.expressions.head, expected, inputRow) +private[sink] class SparkSinkThreadFactory(nameFormat: String) extends ThreadFactory { + + private val threadId = new AtomicLong() + + override def newThread(r: Runnable): Thread = { + val t = new Thread(r, nameFormat.format(threadId.incrementAndGet())) + t.setDaemon(true) + t } + } diff --git a/external/flume-sink/src/main/scala/org/apache/spark/streaming/flume/sink/TransactionProcessor.scala b/external/flume-sink/src/main/scala/org/apache/spark/streaming/flume/sink/TransactionProcessor.scala index ea45b14294df..7ad43b1d7b0a 100644 --- a/external/flume-sink/src/main/scala/org/apache/spark/streaming/flume/sink/TransactionProcessor.scala +++ b/external/flume-sink/src/main/scala/org/apache/spark/streaming/flume/sink/TransactionProcessor.scala @@ -143,7 +143,7 @@ private class TransactionProcessor(val channel: Channel, val seqNum: String, eventBatch.setErrorMsg(msg) } else { // At this point, the events are available, so fill them into the event batch - eventBatch = new EventBatch("",seqNum, events) + eventBatch = new EventBatch("", seqNum, events) } }) } catch { diff --git a/external/flume-sink/src/test/scala/org/apache/spark/streaming/flume/sink/SparkSinkSuite.scala b/external/flume-sink/src/test/scala/org/apache/spark/streaming/flume/sink/SparkSinkSuite.scala index 650b2fbe1c14..fa43629d4977 100644 --- a/external/flume-sink/src/test/scala/org/apache/spark/streaming/flume/sink/SparkSinkSuite.scala +++ b/external/flume-sink/src/test/scala/org/apache/spark/streaming/flume/sink/SparkSinkSuite.scala @@ -24,16 +24,24 @@ import scala.collection.JavaConversions._ import scala.concurrent.{ExecutionContext, Future} import scala.util.{Failure, Success} -import com.google.common.util.concurrent.ThreadFactoryBuilder import org.apache.avro.ipc.NettyTransceiver import org.apache.avro.ipc.specific.SpecificRequestor import org.apache.flume.Context import org.apache.flume.channel.MemoryChannel import org.apache.flume.event.EventBuilder import org.jboss.netty.channel.socket.nio.NioClientSocketChannelFactory + +// Due to MNG-1378, there is not a way to include test dependencies transitively. +// We cannot include Spark core tests as a dependency here because it depends on +// Spark core main, which has too many dependencies to require here manually. +// For this reason, we continue to use FunSuite and ignore the scalastyle checks +// that fail if this is detected. +//scalastyle:off import org.scalatest.FunSuite class SparkSinkSuite extends FunSuite { +//scalastyle:on + val eventsPerBatch = 1000 val channelCapacity = 5000 @@ -185,9 +193,8 @@ class SparkSinkSuite extends FunSuite { count: Int): Seq[(NettyTransceiver, SparkFlumeProtocol.Callback)] = { (1 to count).map(_ => { - lazy val channelFactoryExecutor = - Executors.newCachedThreadPool(new ThreadFactoryBuilder().setDaemon(true). - setNameFormat("Flume Receiver Channel Thread - %d").build()) + lazy val channelFactoryExecutor = Executors.newCachedThreadPool( + new SparkSinkThreadFactory("Flume Receiver Channel Thread - %d")) lazy val channelFactory = new NioClientSocketChannelFactory(channelFactoryExecutor, channelFactoryExecutor) val transceiver = new NettyTransceiver(address, channelFactory) diff --git a/external/flume/pom.xml b/external/flume/pom.xml index 8df7edbdcad3..14f7daaf417e 100644 --- a/external/flume/pom.xml +++ b/external/flume/pom.xml @@ -21,7 +21,7 @@ org.apache.spark spark-parent_2.10 - 1.4.0-SNAPSHOT + 1.5.0-SNAPSHOT ../../pom.xml @@ -41,6 +41,13 @@ ${project.version} provided + + org.apache.spark + spark-core_${scala.binary.version} + ${project.version} + test-jar + test + org.apache.spark spark-streaming-flume-sink_${scala.binary.version} diff --git a/external/flume/src/main/scala/org/apache/spark/streaming/flume/EventTransformer.scala b/external/flume/src/main/scala/org/apache/spark/streaming/flume/EventTransformer.scala index dc629df4f4ac..65c49c131518 100644 --- a/external/flume/src/main/scala/org/apache/spark/streaming/flume/EventTransformer.scala +++ b/external/flume/src/main/scala/org/apache/spark/streaming/flume/EventTransformer.scala @@ -60,7 +60,7 @@ private[streaming] object EventTransformer extends Logging { out.write(body) val numHeaders = headers.size() out.writeInt(numHeaders) - for ((k,v) <- headers) { + for ((k, v) <- headers) { val keyBuff = Utils.serialize(k.toString) out.writeInt(keyBuff.length) out.write(keyBuff) diff --git a/external/flume/src/main/scala/org/apache/spark/streaming/flume/FlumeInputDStream.scala b/external/flume/src/main/scala/org/apache/spark/streaming/flume/FlumeInputDStream.scala index 60e2994431b3..1e32a365a1ee 100644 --- a/external/flume/src/main/scala/org/apache/spark/streaming/flume/FlumeInputDStream.scala +++ b/external/flume/src/main/scala/org/apache/spark/streaming/flume/FlumeInputDStream.scala @@ -152,9 +152,9 @@ class FlumeReceiver( val channelFactory = new NioServerSocketChannelFactory(Executors.newCachedThreadPool(), Executors.newCachedThreadPool()) val channelPipelineFactory = new CompressionChannelPipelineFactory() - + new NettyServer( - responder, + responder, new InetSocketAddress(host, port), channelFactory, channelPipelineFactory, @@ -188,12 +188,12 @@ class FlumeReceiver( override def preferredLocation: Option[String] = Option(host) - /** A Netty Pipeline factory that will decompress incoming data from + /** A Netty Pipeline factory that will decompress incoming data from * and the Netty client and compress data going back to the client. * * The compression on the return is required because Flume requires - * a successful response to indicate it can remove the event/batch - * from the configured channel + * a successful response to indicate it can remove the event/batch + * from the configured channel */ private[streaming] class CompressionChannelPipelineFactory extends ChannelPipelineFactory { diff --git a/external/flume/src/main/scala/org/apache/spark/streaming/flume/FlumePollingInputDStream.scala b/external/flume/src/main/scala/org/apache/spark/streaming/flume/FlumePollingInputDStream.scala index 92fa5b41be89..583e7dca317a 100644 --- a/external/flume/src/main/scala/org/apache/spark/streaming/flume/FlumePollingInputDStream.scala +++ b/external/flume/src/main/scala/org/apache/spark/streaming/flume/FlumePollingInputDStream.scala @@ -110,7 +110,7 @@ private[streaming] class FlumePollingReceiver( } /** - * A wrapper around the transceiver and the Avro IPC API. + * A wrapper around the transceiver and the Avro IPC API. * @param transceiver The transceiver to use for communication with Flume * @param client The client that the callbacks are received on. */ diff --git a/external/flume/src/main/scala/org/apache/spark/streaming/flume/FlumeTestUtils.scala b/external/flume/src/main/scala/org/apache/spark/streaming/flume/FlumeTestUtils.scala new file mode 100644 index 000000000000..9d9c3b189415 --- /dev/null +++ b/external/flume/src/main/scala/org/apache/spark/streaming/flume/FlumeTestUtils.scala @@ -0,0 +1,116 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.streaming.flume + +import java.net.{InetSocketAddress, ServerSocket} +import java.nio.ByteBuffer +import java.util.{List => JList} + +import scala.collection.JavaConversions._ + +import com.google.common.base.Charsets.UTF_8 +import org.apache.avro.ipc.NettyTransceiver +import org.apache.avro.ipc.specific.SpecificRequestor +import org.apache.commons.lang3.RandomUtils +import org.apache.flume.source.avro +import org.apache.flume.source.avro.{AvroSourceProtocol, AvroFlumeEvent} +import org.jboss.netty.channel.ChannelPipeline +import org.jboss.netty.channel.socket.SocketChannel +import org.jboss.netty.channel.socket.nio.NioClientSocketChannelFactory +import org.jboss.netty.handler.codec.compression.{ZlibDecoder, ZlibEncoder} + +import org.apache.spark.util.Utils +import org.apache.spark.SparkConf + +/** + * Share codes for Scala and Python unit tests + */ +private[flume] class FlumeTestUtils { + + private var transceiver: NettyTransceiver = null + + private val testPort: Int = findFreePort() + + def getTestPort(): Int = testPort + + /** Find a free port */ + private def findFreePort(): Int = { + val candidatePort = RandomUtils.nextInt(1024, 65536) + Utils.startServiceOnPort(candidatePort, (trialPort: Int) => { + val socket = new ServerSocket(trialPort) + socket.close() + (null, trialPort) + }, new SparkConf())._2 + } + + /** Send data to the flume receiver */ + def writeInput(input: JList[String], enableCompression: Boolean): Unit = { + val testAddress = new InetSocketAddress("localhost", testPort) + + val inputEvents = input.map { item => + val event = new AvroFlumeEvent + event.setBody(ByteBuffer.wrap(item.getBytes(UTF_8))) + event.setHeaders(Map[CharSequence, CharSequence]("test" -> "header")) + event + } + + // if last attempted transceiver had succeeded, close it + close() + + // Create transceiver + transceiver = { + if (enableCompression) { + new NettyTransceiver(testAddress, new CompressionChannelFactory(6)) + } else { + new NettyTransceiver(testAddress) + } + } + + // Create Avro client with the transceiver + val client = SpecificRequestor.getClient(classOf[AvroSourceProtocol], transceiver) + if (client == null) { + throw new AssertionError("Cannot create client") + } + + // Send data + val status = client.appendBatch(inputEvents.toList) + if (status != avro.Status.OK) { + throw new AssertionError("Sent events unsuccessfully") + } + } + + def close(): Unit = { + if (transceiver != null) { + transceiver.close() + transceiver = null + } + } + + /** Class to create socket channel with compression */ + private class CompressionChannelFactory(compressionLevel: Int) + extends NioClientSocketChannelFactory { + + override def newChannel(pipeline: ChannelPipeline): SocketChannel = { + val encoder = new ZlibEncoder(compressionLevel) + pipeline.addFirst("deflater", encoder) + pipeline.addFirst("inflater", new ZlibDecoder()) + super.newChannel(pipeline) + } + } + +} diff --git a/external/flume/src/main/scala/org/apache/spark/streaming/flume/FlumeUtils.scala b/external/flume/src/main/scala/org/apache/spark/streaming/flume/FlumeUtils.scala index 44dec45c227c..095bfb0c73a9 100644 --- a/external/flume/src/main/scala/org/apache/spark/streaming/flume/FlumeUtils.scala +++ b/external/flume/src/main/scala/org/apache/spark/streaming/flume/FlumeUtils.scala @@ -18,10 +18,16 @@ package org.apache.spark.streaming.flume import java.net.InetSocketAddress +import java.io.{DataOutputStream, ByteArrayOutputStream} +import java.util.{List => JList, Map => JMap} +import scala.collection.JavaConversions._ + +import org.apache.spark.api.java.function.PairFunction +import org.apache.spark.api.python.PythonRDD import org.apache.spark.storage.StorageLevel import org.apache.spark.streaming.StreamingContext -import org.apache.spark.streaming.api.java.{JavaReceiverInputDStream, JavaStreamingContext} +import org.apache.spark.streaming.api.java.{JavaPairDStream, JavaReceiverInputDStream, JavaStreamingContext} import org.apache.spark.streaming.dstream.ReceiverInputDStream @@ -236,3 +242,71 @@ object FlumeUtils { createPollingStream(jssc.ssc, addresses, storageLevel, maxBatchSize, parallelism) } } + +/** + * This is a helper class that wraps the methods in FlumeUtils into more Python-friendly class and + * function so that it can be easily instantiated and called from Python's FlumeUtils. + */ +private class FlumeUtilsPythonHelper { + + def createStream( + jssc: JavaStreamingContext, + hostname: String, + port: Int, + storageLevel: StorageLevel, + enableDecompression: Boolean + ): JavaPairDStream[Array[Byte], Array[Byte]] = { + val dstream = FlumeUtils.createStream(jssc, hostname, port, storageLevel, enableDecompression) + FlumeUtilsPythonHelper.toByteArrayPairDStream(dstream) + } + + def createPollingStream( + jssc: JavaStreamingContext, + hosts: JList[String], + ports: JList[Int], + storageLevel: StorageLevel, + maxBatchSize: Int, + parallelism: Int + ): JavaPairDStream[Array[Byte], Array[Byte]] = { + assert(hosts.length == ports.length) + val addresses = hosts.zip(ports).map { + case (host, port) => new InetSocketAddress(host, port) + } + val dstream = FlumeUtils.createPollingStream( + jssc.ssc, addresses, storageLevel, maxBatchSize, parallelism) + FlumeUtilsPythonHelper.toByteArrayPairDStream(dstream) + } + +} + +private object FlumeUtilsPythonHelper { + + private def stringMapToByteArray(map: JMap[CharSequence, CharSequence]): Array[Byte] = { + val byteStream = new ByteArrayOutputStream() + val output = new DataOutputStream(byteStream) + try { + output.writeInt(map.size) + map.foreach { kv => + PythonRDD.writeUTF(kv._1.toString, output) + PythonRDD.writeUTF(kv._2.toString, output) + } + byteStream.toByteArray + } + finally { + output.close() + } + } + + private def toByteArrayPairDStream(dstream: JavaReceiverInputDStream[SparkFlumeEvent]): + JavaPairDStream[Array[Byte], Array[Byte]] = { + dstream.mapToPair(new PairFunction[SparkFlumeEvent, Array[Byte], Array[Byte]] { + override def call(sparkEvent: SparkFlumeEvent): (Array[Byte], Array[Byte]) = { + val event = sparkEvent.event + val byteBuffer = event.getBody + val body = new Array[Byte](byteBuffer.remaining()) + byteBuffer.get(body) + (stringMapToByteArray(event.getHeaders), body) + } + }) + } +} diff --git a/external/flume/src/main/scala/org/apache/spark/streaming/flume/PollingFlumeTestUtils.scala b/external/flume/src/main/scala/org/apache/spark/streaming/flume/PollingFlumeTestUtils.scala new file mode 100644 index 000000000000..91d63d49dbec --- /dev/null +++ b/external/flume/src/main/scala/org/apache/spark/streaming/flume/PollingFlumeTestUtils.scala @@ -0,0 +1,209 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.streaming.flume + +import java.util.concurrent._ +import java.util.{List => JList, Map => JMap} + +import scala.collection.JavaConversions._ +import scala.collection.mutable.ArrayBuffer + +import com.google.common.base.Charsets.UTF_8 +import org.apache.flume.event.EventBuilder +import org.apache.flume.Context +import org.apache.flume.channel.MemoryChannel +import org.apache.flume.conf.Configurables + +import org.apache.spark.streaming.flume.sink.{SparkSinkConfig, SparkSink} + +/** + * Share codes for Scala and Python unit tests + */ +private[flume] class PollingFlumeTestUtils { + + private val batchCount = 5 + val eventsPerBatch = 100 + private val totalEventsPerChannel = batchCount * eventsPerBatch + private val channelCapacity = 5000 + + def getTotalEvents: Int = totalEventsPerChannel * channels.size + + private val channels = new ArrayBuffer[MemoryChannel] + private val sinks = new ArrayBuffer[SparkSink] + + /** + * Start a sink and return the port of this sink + */ + def startSingleSink(): Int = { + channels.clear() + sinks.clear() + + // Start the channel and sink. + val context = new Context() + context.put("capacity", channelCapacity.toString) + context.put("transactionCapacity", "1000") + context.put("keep-alive", "0") + val channel = new MemoryChannel() + Configurables.configure(channel, context) + + val sink = new SparkSink() + context.put(SparkSinkConfig.CONF_HOSTNAME, "localhost") + context.put(SparkSinkConfig.CONF_PORT, String.valueOf(0)) + Configurables.configure(sink, context) + sink.setChannel(channel) + sink.start() + + channels += (channel) + sinks += sink + + sink.getPort() + } + + /** + * Start 2 sinks and return the ports + */ + def startMultipleSinks(): JList[Int] = { + channels.clear() + sinks.clear() + + // Start the channel and sink. + val context = new Context() + context.put("capacity", channelCapacity.toString) + context.put("transactionCapacity", "1000") + context.put("keep-alive", "0") + val channel = new MemoryChannel() + Configurables.configure(channel, context) + + val channel2 = new MemoryChannel() + Configurables.configure(channel2, context) + + val sink = new SparkSink() + context.put(SparkSinkConfig.CONF_HOSTNAME, "localhost") + context.put(SparkSinkConfig.CONF_PORT, String.valueOf(0)) + Configurables.configure(sink, context) + sink.setChannel(channel) + sink.start() + + val sink2 = new SparkSink() + context.put(SparkSinkConfig.CONF_HOSTNAME, "localhost") + context.put(SparkSinkConfig.CONF_PORT, String.valueOf(0)) + Configurables.configure(sink2, context) + sink2.setChannel(channel2) + sink2.start() + + sinks += sink + sinks += sink2 + channels += channel + channels += channel2 + + sinks.map(_.getPort()) + } + + /** + * Send data and wait until all data has been received + */ + def sendDatAndEnsureAllDataHasBeenReceived(): Unit = { + val executor = Executors.newCachedThreadPool() + val executorCompletion = new ExecutorCompletionService[Void](executor) + + val latch = new CountDownLatch(batchCount * channels.size) + sinks.foreach(_.countdownWhenBatchReceived(latch)) + + channels.foreach(channel => { + executorCompletion.submit(new TxnSubmitter(channel)) + }) + + for (i <- 0 until channels.size) { + executorCompletion.take() + } + + latch.await(15, TimeUnit.SECONDS) // Ensure all data has been received. + } + + /** + * A Python-friendly method to assert the output + */ + def assertOutput( + outputHeaders: JList[JMap[String, String]], outputBodies: JList[String]): Unit = { + require(outputHeaders.size == outputBodies.size) + val eventSize = outputHeaders.size + if (eventSize != totalEventsPerChannel * channels.size) { + throw new AssertionError( + s"Expected ${totalEventsPerChannel * channels.size} events, but was $eventSize") + } + var counter = 0 + for (k <- 0 until channels.size; i <- 0 until totalEventsPerChannel) { + val eventBodyToVerify = s"${channels(k).getName}-$i" + val eventHeaderToVerify: JMap[String, String] = Map[String, String](s"test-$i" -> "header") + var found = false + var j = 0 + while (j < eventSize && !found) { + if (eventBodyToVerify == outputBodies.get(j) && + eventHeaderToVerify == outputHeaders.get(j)) { + found = true + counter += 1 + } + j += 1 + } + } + if (counter != totalEventsPerChannel * channels.size) { + throw new AssertionError( + s"111 Expected ${totalEventsPerChannel * channels.size} events, but was $counter") + } + } + + def assertChannelsAreEmpty(): Unit = { + channels.foreach(assertChannelIsEmpty) + } + + private def assertChannelIsEmpty(channel: MemoryChannel): Unit = { + val queueRemaining = channel.getClass.getDeclaredField("queueRemaining") + queueRemaining.setAccessible(true) + val m = queueRemaining.get(channel).getClass.getDeclaredMethod("availablePermits") + if (m.invoke(queueRemaining.get(channel)).asInstanceOf[Int] != 5000) { + throw new AssertionError(s"Channel ${channel.getName} is not empty") + } + } + + def close(): Unit = { + sinks.foreach(_.stop()) + sinks.clear() + channels.foreach(_.stop()) + channels.clear() + } + + private class TxnSubmitter(channel: MemoryChannel) extends Callable[Void] { + override def call(): Void = { + var t = 0 + for (i <- 0 until batchCount) { + val tx = channel.getTransaction + tx.begin() + for (j <- 0 until eventsPerBatch) { + channel.put(EventBuilder.withBody(s"${channel.getName}-$t".getBytes(UTF_8), + Map[String, String](s"test-$t" -> "header"))) + t += 1 + } + tx.commit() + tx.close() + Thread.sleep(500) // Allow some time for the events to reach + } + null + } + } + +} diff --git a/external/flume/src/test/scala/org/apache/spark/streaming/flume/FlumePollingStreamSuite.scala b/external/flume/src/test/scala/org/apache/spark/streaming/flume/FlumePollingStreamSuite.scala index 93afe50c2134..d5f9a0aa38f9 100644 --- a/external/flume/src/test/scala/org/apache/spark/streaming/flume/FlumePollingStreamSuite.scala +++ b/external/flume/src/test/scala/org/apache/spark/streaming/flume/FlumePollingStreamSuite.scala @@ -18,47 +18,33 @@ package org.apache.spark.streaming.flume import java.net.InetSocketAddress -import java.util.concurrent._ import scala.collection.JavaConversions._ import scala.collection.mutable.{SynchronizedBuffer, ArrayBuffer} import scala.concurrent.duration._ import scala.language.postfixOps -import org.apache.flume.Context -import org.apache.flume.channel.MemoryChannel -import org.apache.flume.conf.Configurables -import org.apache.flume.event.EventBuilder +import com.google.common.base.Charsets.UTF_8 +import org.scalatest.BeforeAndAfter import org.scalatest.concurrent.Eventually._ -import org.scalatest.{BeforeAndAfter, FunSuite} - -import org.apache.spark.{SparkConf, Logging} +import org.apache.spark.{Logging, SparkConf, SparkFunSuite} import org.apache.spark.storage.StorageLevel import org.apache.spark.streaming.dstream.ReceiverInputDStream import org.apache.spark.streaming.{Seconds, TestOutputStream, StreamingContext} -import org.apache.spark.streaming.flume.sink._ import org.apache.spark.util.{ManualClock, Utils} -class FlumePollingStreamSuite extends FunSuite with BeforeAndAfter with Logging { +class FlumePollingStreamSuite extends SparkFunSuite with BeforeAndAfter with Logging { - val batchCount = 5 - val eventsPerBatch = 100 - val totalEventsPerChannel = batchCount * eventsPerBatch - val channelCapacity = 5000 val maxAttempts = 5 val batchDuration = Seconds(1) val conf = new SparkConf() .setMaster("local[2]") .setAppName(this.getClass.getSimpleName) + .set("spark.streaming.clock", "org.apache.spark.util.ManualClock") - def beforeFunction() { - logInfo("Using manual clock") - conf.set("spark.streaming.clock", "org.apache.spark.util.ManualClock") - } - - before(beforeFunction()) + val utils = new PollingFlumeTestUtils test("flume polling test") { testMultipleTimes(testFlumePolling) @@ -89,146 +75,55 @@ class FlumePollingStreamSuite extends FunSuite with BeforeAndAfter with Logging } private def testFlumePolling(): Unit = { - // Start the channel and sink. - val context = new Context() - context.put("capacity", channelCapacity.toString) - context.put("transactionCapacity", "1000") - context.put("keep-alive", "0") - val channel = new MemoryChannel() - Configurables.configure(channel, context) - - val sink = new SparkSink() - context.put(SparkSinkConfig.CONF_HOSTNAME, "localhost") - context.put(SparkSinkConfig.CONF_PORT, String.valueOf(0)) - Configurables.configure(sink, context) - sink.setChannel(channel) - sink.start() - - writeAndVerify(Seq(sink), Seq(channel)) - assertChannelIsEmpty(channel) - sink.stop() - channel.stop() + try { + val port = utils.startSingleSink() + + writeAndVerify(Seq(port)) + utils.assertChannelsAreEmpty() + } finally { + utils.close() + } } private def testFlumePollingMultipleHost(): Unit = { - // Start the channel and sink. - val context = new Context() - context.put("capacity", channelCapacity.toString) - context.put("transactionCapacity", "1000") - context.put("keep-alive", "0") - val channel = new MemoryChannel() - Configurables.configure(channel, context) - - val channel2 = new MemoryChannel() - Configurables.configure(channel2, context) - - val sink = new SparkSink() - context.put(SparkSinkConfig.CONF_HOSTNAME, "localhost") - context.put(SparkSinkConfig.CONF_PORT, String.valueOf(0)) - Configurables.configure(sink, context) - sink.setChannel(channel) - sink.start() - - val sink2 = new SparkSink() - context.put(SparkSinkConfig.CONF_HOSTNAME, "localhost") - context.put(SparkSinkConfig.CONF_PORT, String.valueOf(0)) - Configurables.configure(sink2, context) - sink2.setChannel(channel2) - sink2.start() try { - writeAndVerify(Seq(sink, sink2), Seq(channel, channel2)) - assertChannelIsEmpty(channel) - assertChannelIsEmpty(channel2) + val ports = utils.startMultipleSinks() + writeAndVerify(ports) + utils.assertChannelsAreEmpty() } finally { - sink.stop() - sink2.stop() - channel.stop() - channel2.stop() + utils.close() } } - def writeAndVerify(sinks: Seq[SparkSink], channels: Seq[MemoryChannel]) { + def writeAndVerify(sinkPorts: Seq[Int]): Unit = { // Set up the streaming context and input streams val ssc = new StreamingContext(conf, batchDuration) - val addresses = sinks.map(sink => new InetSocketAddress("localhost", sink.getPort())) + val addresses = sinkPorts.map(port => new InetSocketAddress("localhost", port)) val flumeStream: ReceiverInputDStream[SparkFlumeEvent] = FlumeUtils.createPollingStream(ssc, addresses, StorageLevel.MEMORY_AND_DISK, - eventsPerBatch, 5) + utils.eventsPerBatch, 5) val outputBuffer = new ArrayBuffer[Seq[SparkFlumeEvent]] with SynchronizedBuffer[Seq[SparkFlumeEvent]] val outputStream = new TestOutputStream(flumeStream, outputBuffer) outputStream.register() ssc.start() - val clock = ssc.scheduler.clock.asInstanceOf[ManualClock] - val executor = Executors.newCachedThreadPool() - val executorCompletion = new ExecutorCompletionService[Void](executor) - - val latch = new CountDownLatch(batchCount * channels.size) - sinks.foreach(_.countdownWhenBatchReceived(latch)) - - channels.foreach(channel => { - executorCompletion.submit(new TxnSubmitter(channel, clock)) - }) - - for (i <- 0 until channels.size) { - executorCompletion.take() - } - - latch.await(15, TimeUnit.SECONDS) // Ensure all data has been received. - clock.advance(batchDuration.milliseconds) - - // The eventually is required to ensure that all data in the batch has been processed. - eventually(timeout(10 seconds), interval(100 milliseconds)) { - val flattenedBuffer = outputBuffer.flatten - assert(flattenedBuffer.size === totalEventsPerChannel * channels.size) - var counter = 0 - for (k <- 0 until channels.size; i <- 0 until totalEventsPerChannel) { - val eventToVerify = EventBuilder.withBody((channels(k).getName + " - " + - String.valueOf(i)).getBytes("utf-8"), - Map[String, String]("test-" + i.toString -> "header")) - var found = false - var j = 0 - while (j < flattenedBuffer.size && !found) { - val strToCompare = new String(flattenedBuffer(j).event.getBody.array(), "utf-8") - if (new String(eventToVerify.getBody, "utf-8") == strToCompare && - eventToVerify.getHeaders.get("test-" + i.toString) - .equals(flattenedBuffer(j).event.getHeaders.get("test-" + i.toString))) { - found = true - counter += 1 - } - j += 1 - } - } - assert(counter === totalEventsPerChannel * channels.size) - } - ssc.stop() - } - - def assertChannelIsEmpty(channel: MemoryChannel): Unit = { - val queueRemaining = channel.getClass.getDeclaredField("queueRemaining") - queueRemaining.setAccessible(true) - val m = queueRemaining.get(channel).getClass.getDeclaredMethod("availablePermits") - assert(m.invoke(queueRemaining.get(channel)).asInstanceOf[Int] === 5000) - } - - private class TxnSubmitter(channel: MemoryChannel, clock: ManualClock) extends Callable[Void] { - override def call(): Void = { - var t = 0 - for (i <- 0 until batchCount) { - val tx = channel.getTransaction - tx.begin() - for (j <- 0 until eventsPerBatch) { - channel.put(EventBuilder.withBody((channel.getName + " - " + String.valueOf(t)).getBytes( - "utf-8"), - Map[String, String]("test-" + t.toString -> "header"))) - t += 1 - } - tx.commit() - tx.close() - Thread.sleep(500) // Allow some time for the events to reach + try { + utils.sendDatAndEnsureAllDataHasBeenReceived() + val clock = ssc.scheduler.clock.asInstanceOf[ManualClock] + clock.advance(batchDuration.milliseconds) + + // The eventually is required to ensure that all data in the batch has been processed. + eventually(timeout(10 seconds), interval(100 milliseconds)) { + val flattenOutputBuffer = outputBuffer.flatten + val headers = flattenOutputBuffer.map(_.event.getHeaders.map { + case kv => (kv._1.toString, kv._2.toString) + }).map(mapAsJavaMap) + val bodies = flattenOutputBuffer.map(e => new String(e.event.getBody.array(), UTF_8)) + utils.assertOutput(headers, bodies) } - null + } finally { + ssc.stop() } } diff --git a/external/flume/src/test/scala/org/apache/spark/streaming/flume/FlumeStreamSuite.scala b/external/flume/src/test/scala/org/apache/spark/streaming/flume/FlumeStreamSuite.scala index 39e6754c81db..5bc4cdf65306 100644 --- a/external/flume/src/test/scala/org/apache/spark/streaming/flume/FlumeStreamSuite.scala +++ b/external/flume/src/test/scala/org/apache/spark/streaming/flume/FlumeStreamSuite.scala @@ -17,46 +17,26 @@ package org.apache.spark.streaming.flume -import java.net.{InetSocketAddress, ServerSocket} -import java.nio.ByteBuffer - import scala.collection.JavaConversions._ import scala.collection.mutable.{ArrayBuffer, SynchronizedBuffer} import scala.concurrent.duration._ import scala.language.postfixOps import com.google.common.base.Charsets -import org.apache.avro.ipc.NettyTransceiver -import org.apache.avro.ipc.specific.SpecificRequestor -import org.apache.commons.lang3.RandomUtils -import org.apache.flume.source.avro -import org.apache.flume.source.avro.{AvroFlumeEvent, AvroSourceProtocol} import org.jboss.netty.channel.ChannelPipeline import org.jboss.netty.channel.socket.SocketChannel import org.jboss.netty.channel.socket.nio.NioClientSocketChannelFactory import org.jboss.netty.handler.codec.compression._ -import org.scalatest.{BeforeAndAfter, FunSuite, Matchers} +import org.scalatest.{BeforeAndAfter, Matchers} import org.scalatest.concurrent.Eventually._ -import org.apache.spark.{Logging, SparkConf} +import org.apache.spark.{Logging, SparkConf, SparkFunSuite} import org.apache.spark.storage.StorageLevel import org.apache.spark.streaming.{Milliseconds, StreamingContext, TestOutputStream} -import org.apache.spark.util.Utils -class FlumeStreamSuite extends FunSuite with BeforeAndAfter with Matchers with Logging { +class FlumeStreamSuite extends SparkFunSuite with BeforeAndAfter with Matchers with Logging { val conf = new SparkConf().setMaster("local[4]").setAppName("FlumeStreamSuite") - var ssc: StreamingContext = null - var transceiver: NettyTransceiver = null - - after { - if (ssc != null) { - ssc.stop() - } - if (transceiver != null) { - transceiver.close() - } - } test("flume input stream") { testFlumeStream(testCompression = false) @@ -69,19 +49,29 @@ class FlumeStreamSuite extends FunSuite with BeforeAndAfter with Matchers with L /** Run test on flume stream */ private def testFlumeStream(testCompression: Boolean): Unit = { val input = (1 to 100).map { _.toString } - val testPort = findFreePort() - val outputBuffer = startContext(testPort, testCompression) - writeAndVerify(input, testPort, outputBuffer, testCompression) - } + val utils = new FlumeTestUtils + try { + val outputBuffer = startContext(utils.getTestPort(), testCompression) - /** Find a free port */ - private def findFreePort(): Int = { - val candidatePort = RandomUtils.nextInt(1024, 65536) - Utils.startServiceOnPort(candidatePort, (trialPort: Int) => { - val socket = new ServerSocket(trialPort) - socket.close() - (null, trialPort) - }, conf)._2 + eventually(timeout(10 seconds), interval(100 milliseconds)) { + utils.writeInput(input, testCompression) + } + + eventually(timeout(10 seconds), interval(100 milliseconds)) { + val outputEvents = outputBuffer.flatten.map { _.event } + outputEvents.foreach { + event => + event.getHeaders.get("test") should be("header") + } + val output = outputEvents.map(event => new String(event.getBody.array(), Charsets.UTF_8)) + output should be (input) + } + } finally { + if (ssc != null) { + ssc.stop() + } + utils.close() + } } /** Setup and start the streaming context */ @@ -98,58 +88,6 @@ class FlumeStreamSuite extends FunSuite with BeforeAndAfter with Matchers with L outputBuffer } - /** Send data to the flume receiver and verify whether the data was received */ - private def writeAndVerify( - input: Seq[String], - testPort: Int, - outputBuffer: ArrayBuffer[Seq[SparkFlumeEvent]], - enableCompression: Boolean - ) { - val testAddress = new InetSocketAddress("localhost", testPort) - - val inputEvents = input.map { item => - val event = new AvroFlumeEvent - event.setBody(ByteBuffer.wrap(item.getBytes(Charsets.UTF_8))) - event.setHeaders(Map[CharSequence, CharSequence]("test" -> "header")) - event - } - - eventually(timeout(10 seconds), interval(100 milliseconds)) { - // if last attempted transceiver had succeeded, close it - if (transceiver != null) { - transceiver.close() - transceiver = null - } - - // Create transceiver - transceiver = { - if (enableCompression) { - new NettyTransceiver(testAddress, new CompressionChannelFactory(6)) - } else { - new NettyTransceiver(testAddress) - } - } - - // Create Avro client with the transceiver - val client = SpecificRequestor.getClient(classOf[AvroSourceProtocol], transceiver) - client should not be null - - // Send data - val status = client.appendBatch(inputEvents.toList) - status should be (avro.Status.OK) - } - - eventually(timeout(10 seconds), interval(100 milliseconds)) { - val outputEvents = outputBuffer.flatten.map { _.event } - outputEvents.foreach { - event => - event.getHeaders.get("test") should be("header") - } - val output = outputEvents.map(event => new String(event.getBody.array(), Charsets.UTF_8)) - output should be (input) - } - } - /** Class to create socket channel with compression */ private class CompressionChannelFactory(compressionLevel: Int) extends NioClientSocketChannelFactory { diff --git a/external/kafka-assembly/pom.xml b/external/kafka-assembly/pom.xml index 0b79f47647f6..8059c443827e 100644 --- a/external/kafka-assembly/pom.xml +++ b/external/kafka-assembly/pom.xml @@ -21,7 +21,7 @@ org.apache.spark spark-parent_2.10 - 1.4.0-SNAPSHOT + 1.5.0-SNAPSHOT ../../pom.xml diff --git a/external/kafka/pom.xml b/external/kafka/pom.xml index 243ce6eaca65..ded863bd985e 100644 --- a/external/kafka/pom.xml +++ b/external/kafka/pom.xml @@ -21,7 +21,7 @@ org.apache.spark spark-parent_2.10 - 1.4.0-SNAPSHOT + 1.5.0-SNAPSHOT ../../pom.xml @@ -41,6 +41,13 @@ ${project.version} provided + + org.apache.spark + spark-core_${scala.binary.version} + ${project.version} + test-jar + test + org.apache.kafka kafka_${scala.binary.version} diff --git a/external/kafka/src/main/scala/org/apache/spark/streaming/kafka/DirectKafkaInputDStream.scala b/external/kafka/src/main/scala/org/apache/spark/streaming/kafka/DirectKafkaInputDStream.scala index 6715aede7928..876456c96477 100644 --- a/external/kafka/src/main/scala/org/apache/spark/streaming/kafka/DirectKafkaInputDStream.scala +++ b/external/kafka/src/main/scala/org/apache/spark/streaming/kafka/DirectKafkaInputDStream.scala @@ -65,6 +65,9 @@ class DirectKafkaInputDStream[ val maxRetries = context.sparkContext.getConf.getInt( "spark.streaming.kafka.maxRetries", 1) + // Keep this consistent with how other streams are named (e.g. "Flume polling stream [2]") + private[streaming] override def name: String = s"Kafka direct stream [$id]" + protected[streaming] override val checkpointData = new DirectKafkaInputDStreamCheckpointData @@ -117,8 +120,7 @@ class DirectKafkaInputDStream[ context.sparkContext, kafkaParams, currentOffsets, untilOffsets, messageHandler) // Report the record number of this batch interval to InputInfoTracker. - val numRecords = rdd.offsetRanges.map(r => r.untilOffset - r.fromOffset).sum - val inputInfo = InputInfo(id, numRecords) + val inputInfo = InputInfo(id, rdd.count) ssc.scheduler.inputInfoTracker.reportInfo(validTime, inputInfo) currentOffsets = untilOffsets.map(kv => kv._1 -> kv._2.offset) @@ -150,10 +152,7 @@ class DirectKafkaInputDStream[ override def restore() { // this is assuming that the topics don't change during execution, which is true currently val topics = fromOffsets.keySet - val leaders = kc.findLeaders(topics).fold( - errs => throw new SparkException(errs.mkString("\n")), - ok => ok - ) + val leaders = KafkaCluster.checkErrors(kc.findLeaders(topics)) batchForTime.toSeq.sortBy(_._1)(Time.ordering).foreach { case (t, b) => logInfo(s"Restoring KafkaRDD for time $t ${b.mkString("[", ", ", "]")}") diff --git a/external/kafka/src/main/scala/org/apache/spark/streaming/kafka/KafkaCluster.scala b/external/kafka/src/main/scala/org/apache/spark/streaming/kafka/KafkaCluster.scala index 6cf254a7b69c..3e6b937af57b 100644 --- a/external/kafka/src/main/scala/org/apache/spark/streaming/kafka/KafkaCluster.scala +++ b/external/kafka/src/main/scala/org/apache/spark/streaming/kafka/KafkaCluster.scala @@ -113,7 +113,7 @@ class KafkaCluster(val kafkaParams: Map[String, String]) extends Serializable { r.flatMap { tm: TopicMetadata => tm.partitionsMetadata.map { pm: PartitionMetadata => TopicAndPartition(tm.topic, pm.partitionId) - } + } } } } @@ -360,6 +360,14 @@ private[spark] object KafkaCluster { type Err = ArrayBuffer[Throwable] + /** If the result is right, return it, otherwise throw SparkException */ + def checkErrors[T](result: Either[Err, T]): T = { + result.fold( + errs => throw new SparkException(errs.mkString("\n")), + ok => ok + ) + } + private[spark] case class LeaderOffset(host: String, port: Int, offset: Long) diff --git a/external/kafka/src/main/scala/org/apache/spark/streaming/kafka/KafkaInputDStream.scala b/external/kafka/src/main/scala/org/apache/spark/streaming/kafka/KafkaInputDStream.scala index cca0fac0234e..04b2dc10d39e 100644 --- a/external/kafka/src/main/scala/org/apache/spark/streaming/kafka/KafkaInputDStream.scala +++ b/external/kafka/src/main/scala/org/apache/spark/streaming/kafka/KafkaInputDStream.scala @@ -135,7 +135,7 @@ class KafkaReceiver[ store((msgAndMetadata.key, msgAndMetadata.message)) } } catch { - case e: Throwable => logError("Error handling message; exiting", e) + case e: Throwable => reportError("Error handling message; exiting", e) } } } diff --git a/external/kafka/src/main/scala/org/apache/spark/streaming/kafka/KafkaRDD.scala b/external/kafka/src/main/scala/org/apache/spark/streaming/kafka/KafkaRDD.scala index a1b4a12e5d6a..c5cd2154772a 100644 --- a/external/kafka/src/main/scala/org/apache/spark/streaming/kafka/KafkaRDD.scala +++ b/external/kafka/src/main/scala/org/apache/spark/streaming/kafka/KafkaRDD.scala @@ -17,9 +17,11 @@ package org.apache.spark.streaming.kafka +import scala.collection.mutable.ArrayBuffer import scala.reflect.{classTag, ClassTag} import org.apache.spark.{Logging, Partition, SparkContext, SparkException, TaskContext} +import org.apache.spark.partial.{PartialResult, BoundedDouble} import org.apache.spark.rdd.RDD import org.apache.spark.util.NextIterator @@ -60,6 +62,48 @@ class KafkaRDD[ }.toArray } + override def count(): Long = offsetRanges.map(_.count).sum + + override def countApprox( + timeout: Long, + confidence: Double = 0.95 + ): PartialResult[BoundedDouble] = { + val c = count + new PartialResult(new BoundedDouble(c, 1.0, c, c), true) + } + + override def isEmpty(): Boolean = count == 0L + + override def take(num: Int): Array[R] = { + val nonEmptyPartitions = this.partitions + .map(_.asInstanceOf[KafkaRDDPartition]) + .filter(_.count > 0) + + if (num < 1 || nonEmptyPartitions.size < 1) { + return new Array[R](0) + } + + // Determine in advance how many messages need to be taken from each partition + val parts = nonEmptyPartitions.foldLeft(Map[Int, Int]()) { (result, part) => + val remain = num - result.values.sum + if (remain > 0) { + val taken = Math.min(remain, part.count) + result + (part.index -> taken.toInt) + } else { + result + } + } + + val buf = new ArrayBuffer[R] + val res = context.runJob( + this, + (tc: TaskContext, it: Iterator[R]) => it.take(parts(tc.partitionId)).toArray, + parts.keys.toArray, + allowLocal = true) + res.foreach(buf ++= _) + buf.toArray + } + override def getPreferredLocations(thePart: Partition): Seq[String] = { val part = thePart.asInstanceOf[KafkaRDDPartition] // TODO is additional hostname resolution necessary here diff --git a/external/kafka/src/main/scala/org/apache/spark/streaming/kafka/KafkaRDDPartition.scala b/external/kafka/src/main/scala/org/apache/spark/streaming/kafka/KafkaRDDPartition.scala index a842a6f17766..a660d2a00c35 100644 --- a/external/kafka/src/main/scala/org/apache/spark/streaming/kafka/KafkaRDDPartition.scala +++ b/external/kafka/src/main/scala/org/apache/spark/streaming/kafka/KafkaRDDPartition.scala @@ -35,4 +35,7 @@ class KafkaRDDPartition( val untilOffset: Long, val host: String, val port: Int -) extends Partition +) extends Partition { + /** Number of messages this partition refers to */ + def count(): Long = untilOffset - fromOffset +} diff --git a/external/kafka/src/main/scala/org/apache/spark/streaming/kafka/KafkaTestUtils.scala b/external/kafka/src/main/scala/org/apache/spark/streaming/kafka/KafkaTestUtils.scala index 6dc4e9517d5a..b608b7595272 100644 --- a/external/kafka/src/main/scala/org/apache/spark/streaming/kafka/KafkaTestUtils.scala +++ b/external/kafka/src/main/scala/org/apache/spark/streaming/kafka/KafkaTestUtils.scala @@ -195,6 +195,8 @@ private class KafkaTestUtils extends Logging { val props = new Properties() props.put("metadata.broker.list", brokerAddress) props.put("serializer.class", classOf[StringEncoder].getName) + // wait for all in-sync replicas to ack sends + props.put("request.required.acks", "-1") props } @@ -229,21 +231,6 @@ private class KafkaTestUtils extends Logging { tryAgain(1) } - /** Wait until the leader offset for the given topic/partition equals the specified offset */ - def waitUntilLeaderOffset( - topic: String, - partition: Int, - offset: Long): Unit = { - eventually(Time(10000), Time(100)) { - val kc = new KafkaCluster(Map("metadata.broker.list" -> brokerAddress)) - val tp = TopicAndPartition(topic, partition) - val llo = kc.getLatestLeaderOffsets(Set(tp)).right.get.apply(tp).offset - assert( - llo == offset, - s"$topic $partition $offset not reached after timeout") - } - } - private def waitUntilMetadataIsPropagated(topic: String, partition: Int): Unit = { def isPropagated = server.apis.metadataCache.getPartitionInfo(topic, partition) match { case Some(partitionState) => diff --git a/external/kafka/src/main/scala/org/apache/spark/streaming/kafka/KafkaUtils.scala b/external/kafka/src/main/scala/org/apache/spark/streaming/kafka/KafkaUtils.scala index d7cf500577c2..0e33362d34ac 100644 --- a/external/kafka/src/main/scala/org/apache/spark/streaming/kafka/KafkaUtils.scala +++ b/external/kafka/src/main/scala/org/apache/spark/streaming/kafka/KafkaUtils.scala @@ -158,15 +158,31 @@ object KafkaUtils { /** get leaders for the given offset ranges, or throw an exception */ private def leadersForRanges( - kafkaParams: Map[String, String], + kc: KafkaCluster, offsetRanges: Array[OffsetRange]): Map[TopicAndPartition, (String, Int)] = { - val kc = new KafkaCluster(kafkaParams) val topics = offsetRanges.map(o => TopicAndPartition(o.topic, o.partition)).toSet - val leaders = kc.findLeaders(topics).fold( - errs => throw new SparkException(errs.mkString("\n")), - ok => ok - ) - leaders + val leaders = kc.findLeaders(topics) + KafkaCluster.checkErrors(leaders) + } + + /** Make sure offsets are available in kafka, or throw an exception */ + private def checkOffsets( + kc: KafkaCluster, + offsetRanges: Array[OffsetRange]): Unit = { + val topics = offsetRanges.map(_.topicAndPartition).toSet + val result = for { + low <- kc.getEarliestLeaderOffsets(topics).right + high <- kc.getLatestLeaderOffsets(topics).right + } yield { + offsetRanges.filterNot { o => + low(o.topicAndPartition).offset <= o.fromOffset && + o.untilOffset <= high(o.topicAndPartition).offset + } + } + val badRanges = KafkaCluster.checkErrors(result) + if (!badRanges.isEmpty) { + throw new SparkException("Offsets not available on leader: " + badRanges.mkString(",")) + } } /** @@ -189,9 +205,11 @@ object KafkaUtils { sc: SparkContext, kafkaParams: Map[String, String], offsetRanges: Array[OffsetRange] - ): RDD[(K, V)] = { + ): RDD[(K, V)] = sc.withScope { val messageHandler = (mmd: MessageAndMetadata[K, V]) => (mmd.key, mmd.message) - val leaders = leadersForRanges(kafkaParams, offsetRanges) + val kc = new KafkaCluster(kafkaParams) + val leaders = leadersForRanges(kc, offsetRanges) + checkOffsets(kc, offsetRanges) new KafkaRDD[K, V, KD, VD, (K, V)](sc, kafkaParams, offsetRanges, leaders, messageHandler) } @@ -224,16 +242,19 @@ object KafkaUtils { offsetRanges: Array[OffsetRange], leaders: Map[TopicAndPartition, Broker], messageHandler: MessageAndMetadata[K, V] => R - ): RDD[R] = { + ): RDD[R] = sc.withScope { + val kc = new KafkaCluster(kafkaParams) val leaderMap = if (leaders.isEmpty) { - leadersForRanges(kafkaParams, offsetRanges) + leadersForRanges(kc, offsetRanges) } else { // This could be avoided by refactoring KafkaRDD.leaders and KafkaCluster to use Broker leaders.map { case (tp: TopicAndPartition, Broker(host, port)) => (tp, (host, port)) }.toMap } - new KafkaRDD[K, V, KD, VD, R](sc, kafkaParams, offsetRanges, leaderMap, messageHandler) + val cleanedHandler = sc.clean(messageHandler) + checkOffsets(kc, offsetRanges) + new KafkaRDD[K, V, KD, VD, R](sc, kafkaParams, offsetRanges, leaderMap, cleanedHandler) } /** @@ -256,7 +277,7 @@ object KafkaUtils { valueDecoderClass: Class[VD], kafkaParams: JMap[String, String], offsetRanges: Array[OffsetRange] - ): JavaPairRDD[K, V] = { + ): JavaPairRDD[K, V] = jsc.sc.withScope { implicit val keyCmt: ClassTag[K] = ClassTag(keyClass) implicit val valueCmt: ClassTag[V] = ClassTag(valueClass) implicit val keyDecoderCmt: ClassTag[KD] = ClassTag(keyDecoderClass) @@ -294,7 +315,7 @@ object KafkaUtils { offsetRanges: Array[OffsetRange], leaders: JMap[TopicAndPartition, Broker], messageHandler: JFunction[MessageAndMetadata[K, V], R] - ): JavaRDD[R] = { + ): JavaRDD[R] = jsc.sc.withScope { implicit val keyCmt: ClassTag[K] = ClassTag(keyClass) implicit val valueCmt: ClassTag[V] = ClassTag(valueClass) implicit val keyDecoderCmt: ClassTag[KD] = ClassTag(keyDecoderClass) @@ -314,7 +335,7 @@ object KafkaUtils { * Points to note: * - No receivers: This stream does not use any receiver. It directly queries Kafka * - Offsets: This does not use Zookeeper to store offsets. The consumed offsets are tracked - * by the stream itself. For interoperability with Kafka monitoring tools that depend on + * by the stream itself. For interoperability with Kafka monitoring tools that depend on * Zookeeper, you have to update Kafka/Zookeeper yourself from the streaming application. * You can access the offsets used in each batch from the generated RDDs (see * [[org.apache.spark.streaming.kafka.HasOffsetRanges]]). @@ -348,8 +369,9 @@ object KafkaUtils { fromOffsets: Map[TopicAndPartition, Long], messageHandler: MessageAndMetadata[K, V] => R ): InputDStream[R] = { + val cleanedHandler = ssc.sc.clean(messageHandler) new DirectKafkaInputDStream[K, V, KD, VD, R]( - ssc, kafkaParams, fromOffsets, messageHandler) + ssc, kafkaParams, fromOffsets, cleanedHandler) } /** @@ -361,7 +383,7 @@ object KafkaUtils { * Points to note: * - No receivers: This stream does not use any receiver. It directly queries Kafka * - Offsets: This does not use Zookeeper to store offsets. The consumed offsets are tracked - * by the stream itself. For interoperability with Kafka monitoring tools that depend on + * by the stream itself. For interoperability with Kafka monitoring tools that depend on * Zookeeper, you have to update Kafka/Zookeeper yourself from the streaming application. * You can access the offsets used in each batch from the generated RDDs (see * [[org.apache.spark.streaming.kafka.HasOffsetRanges]]). @@ -397,7 +419,7 @@ object KafkaUtils { val kc = new KafkaCluster(kafkaParams) val reset = kafkaParams.get("auto.offset.reset").map(_.toLowerCase) - (for { + val result = for { topicPartitions <- kc.getPartitions(topics).right leaderOffsets <- (if (reset == Some("smallest")) { kc.getEarliestLeaderOffsets(topicPartitions) @@ -410,10 +432,8 @@ object KafkaUtils { } new DirectKafkaInputDStream[K, V, KD, VD, (K, V)]( ssc, kafkaParams, fromOffsets, messageHandler) - }).fold( - errs => throw new SparkException(errs.mkString("\n")), - ok => ok - ) + } + KafkaCluster.checkErrors(result) } /** @@ -425,7 +445,7 @@ object KafkaUtils { * Points to note: * - No receivers: This stream does not use any receiver. It directly queries Kafka * - Offsets: This does not use Zookeeper to store offsets. The consumed offsets are tracked - * by the stream itself. For interoperability with Kafka monitoring tools that depend on + * by the stream itself. For interoperability with Kafka monitoring tools that depend on * Zookeeper, you have to update Kafka/Zookeeper yourself from the streaming application. * You can access the offsets used in each batch from the generated RDDs (see * [[org.apache.spark.streaming.kafka.HasOffsetRanges]]). @@ -469,11 +489,12 @@ object KafkaUtils { implicit val keyDecoderCmt: ClassTag[KD] = ClassTag(keyDecoderClass) implicit val valueDecoderCmt: ClassTag[VD] = ClassTag(valueDecoderClass) implicit val recordCmt: ClassTag[R] = ClassTag(recordClass) + val cleanedHandler = jssc.sparkContext.clean(messageHandler.call _) createDirectStream[K, V, KD, VD, R]( jssc.ssc, Map(kafkaParams.toSeq: _*), Map(fromOffsets.mapValues { _.longValue() }.toSeq: _*), - messageHandler.call _ + cleanedHandler ) } @@ -486,7 +507,7 @@ object KafkaUtils { * Points to note: * - No receivers: This stream does not use any receiver. It directly queries Kafka * - Offsets: This does not use Zookeeper to store offsets. The consumed offsets are tracked - * by the stream itself. For interoperability with Kafka monitoring tools that depend on + * by the stream itself. For interoperability with Kafka monitoring tools that depend on * Zookeeper, you have to update Kafka/Zookeeper yourself from the streaming application. * You can access the offsets used in each batch from the generated RDDs (see * [[org.apache.spark.streaming.kafka.HasOffsetRanges]]). diff --git a/external/kafka/src/main/scala/org/apache/spark/streaming/kafka/OffsetRange.scala b/external/kafka/src/main/scala/org/apache/spark/streaming/kafka/OffsetRange.scala index 9c3dfeb8f592..267504266630 100644 --- a/external/kafka/src/main/scala/org/apache/spark/streaming/kafka/OffsetRange.scala +++ b/external/kafka/src/main/scala/org/apache/spark/streaming/kafka/OffsetRange.scala @@ -55,6 +55,12 @@ final class OffsetRange private( val untilOffset: Long) extends Serializable { import OffsetRange.OffsetRangeTuple + /** Kafka TopicAndPartition object, for convenience */ + def topicAndPartition(): TopicAndPartition = TopicAndPartition(topic, partition) + + /** Number of messages this OffsetRange refers to */ + def count(): Long = untilOffset - fromOffset + override def equals(obj: Any): Boolean = obj match { case that: OffsetRange => this.topic == that.topic && diff --git a/external/kafka/src/main/scala/org/apache/spark/streaming/kafka/ReliableKafkaReceiver.scala b/external/kafka/src/main/scala/org/apache/spark/streaming/kafka/ReliableKafkaReceiver.scala index ea87e960379f..75f0dfc22b9d 100644 --- a/external/kafka/src/main/scala/org/apache/spark/streaming/kafka/ReliableKafkaReceiver.scala +++ b/external/kafka/src/main/scala/org/apache/spark/streaming/kafka/ReliableKafkaReceiver.scala @@ -267,7 +267,7 @@ class ReliableKafkaReceiver[ } } catch { case e: Exception => - logError("Error handling message", e) + reportError("Error handling message", e) } } } diff --git a/external/kafka/src/test/java/org/apache/spark/streaming/kafka/JavaDirectKafkaStreamSuite.java b/external/kafka/src/test/java/org/apache/spark/streaming/kafka/JavaDirectKafkaStreamSuite.java index 4c1d6a03eb2b..02cd24a35906 100644 --- a/external/kafka/src/test/java/org/apache/spark/streaming/kafka/JavaDirectKafkaStreamSuite.java +++ b/external/kafka/src/test/java/org/apache/spark/streaming/kafka/JavaDirectKafkaStreamSuite.java @@ -18,9 +18,8 @@ package org.apache.spark.streaming.kafka; import java.io.Serializable; -import java.util.HashMap; -import java.util.HashSet; -import java.util.Arrays; +import java.util.*; +import java.util.concurrent.atomic.AtomicReference; import scala.Tuple2; @@ -34,6 +33,7 @@ import org.apache.spark.SparkConf; import org.apache.spark.api.java.JavaRDD; +import org.apache.spark.api.java.JavaPairRDD; import org.apache.spark.api.java.function.Function; import org.apache.spark.streaming.Durations; import org.apache.spark.streaming.api.java.JavaDStream; @@ -67,8 +67,10 @@ public void tearDown() { @Test public void testKafkaStream() throws InterruptedException { - String topic1 = "topic1"; - String topic2 = "topic2"; + final String topic1 = "topic1"; + final String topic2 = "topic2"; + // hold a reference to the current offset ranges, so it can be used downstream + final AtomicReference offsetRanges = new AtomicReference(); String[] topic1data = createTopicAndSendData(topic1); String[] topic2data = createTopicAndSendData(topic2); @@ -89,6 +91,17 @@ public void testKafkaStream() throws InterruptedException { StringDecoder.class, kafkaParams, topicToSet(topic1) + ).transformToPair( + // Make sure you can get offset ranges from the rdd + new Function, JavaPairRDD>() { + @Override + public JavaPairRDD call(JavaPairRDD rdd) throws Exception { + OffsetRange[] offsets = ((HasOffsetRanges) rdd.rdd()).offsetRanges(); + offsetRanges.set(offsets); + Assert.assertEquals(offsets[0].topic(), topic1); + return rdd; + } + } ).map( new Function, String>() { @Override @@ -116,12 +129,17 @@ public String call(MessageAndMetadata msgAndMd) throws Exception ); JavaDStream unifiedStream = stream1.union(stream2); - final HashSet result = new HashSet(); + final Set result = Collections.synchronizedSet(new HashSet()); unifiedStream.foreachRDD( new Function, Void>() { @Override public Void call(JavaRDD rdd) throws Exception { result.addAll(rdd.collect()); + for (OffsetRange o : offsetRanges.get()) { + System.out.println( + o.topic() + " " + o.partition() + " " + o.fromOffset() + " " + o.untilOffset() + ); + } return null; } } diff --git a/external/kafka/src/test/java/org/apache/spark/streaming/kafka/JavaKafkaRDDSuite.java b/external/kafka/src/test/java/org/apache/spark/streaming/kafka/JavaKafkaRDDSuite.java index 5cf379635354..a9dc6e50613c 100644 --- a/external/kafka/src/test/java/org/apache/spark/streaming/kafka/JavaKafkaRDDSuite.java +++ b/external/kafka/src/test/java/org/apache/spark/streaming/kafka/JavaKafkaRDDSuite.java @@ -72,9 +72,6 @@ public void testKafkaRDD() throws InterruptedException { HashMap kafkaParams = new HashMap(); kafkaParams.put("metadata.broker.list", kafkaTestUtils.brokerAddress()); - kafkaTestUtils.waitUntilLeaderOffset(topic1, 0, topic1data.length); - kafkaTestUtils.waitUntilLeaderOffset(topic2, 0, topic2data.length); - OffsetRange[] offsetRanges = { OffsetRange.create(topic1, 0, 0, 1), OffsetRange.create(topic2, 0, 0, 1) diff --git a/external/kafka/src/test/java/org/apache/spark/streaming/kafka/JavaKafkaStreamSuite.java b/external/kafka/src/test/java/org/apache/spark/streaming/kafka/JavaKafkaStreamSuite.java index 540f4ceabab4..e4c659215b76 100644 --- a/external/kafka/src/test/java/org/apache/spark/streaming/kafka/JavaKafkaStreamSuite.java +++ b/external/kafka/src/test/java/org/apache/spark/streaming/kafka/JavaKafkaStreamSuite.java @@ -18,9 +18,7 @@ package org.apache.spark.streaming.kafka; import java.io.Serializable; -import java.util.HashMap; -import java.util.List; -import java.util.Random; +import java.util.*; import scala.Tuple2; @@ -94,7 +92,7 @@ public void testKafkaStream() throws InterruptedException { topics, StorageLevel.MEMORY_ONLY_SER()); - final HashMap result = new HashMap(); + final Map result = Collections.synchronizedMap(new HashMap()); JavaDStream words = stream.map( new Function, String>() { diff --git a/external/kafka/src/test/scala/org/apache/spark/streaming/kafka/DirectKafkaStreamSuite.scala b/external/kafka/src/test/scala/org/apache/spark/streaming/kafka/DirectKafkaStreamSuite.scala index b6d314dfc778..8e1715f6dbb9 100644 --- a/external/kafka/src/test/scala/org/apache/spark/streaming/kafka/DirectKafkaStreamSuite.scala +++ b/external/kafka/src/test/scala/org/apache/spark/streaming/kafka/DirectKafkaStreamSuite.scala @@ -28,10 +28,10 @@ import scala.language.postfixOps import kafka.common.TopicAndPartition import kafka.message.MessageAndMetadata import kafka.serializer.StringDecoder -import org.scalatest.{BeforeAndAfter, BeforeAndAfterAll, FunSuite} +import org.scalatest.{BeforeAndAfter, BeforeAndAfterAll} import org.scalatest.concurrent.Eventually -import org.apache.spark.{Logging, SparkConf, SparkContext} +import org.apache.spark.{Logging, SparkConf, SparkContext, SparkFunSuite} import org.apache.spark.rdd.RDD import org.apache.spark.streaming.{Milliseconds, StreamingContext, Time} import org.apache.spark.streaming.dstream.DStream @@ -39,7 +39,7 @@ import org.apache.spark.streaming.scheduler._ import org.apache.spark.util.Utils class DirectKafkaStreamSuite - extends FunSuite + extends SparkFunSuite with BeforeAndAfter with BeforeAndAfterAll with Eventually @@ -99,15 +99,24 @@ class DirectKafkaStreamSuite ssc, kafkaParams, topics) } - val allReceived = new ArrayBuffer[(String, String)] + val allReceived = + new ArrayBuffer[(String, String)] with mutable.SynchronizedBuffer[(String, String)] - stream.foreachRDD { rdd => - // Get the offset ranges in the RDD - val offsets = rdd.asInstanceOf[HasOffsetRanges].offsetRanges + // hold a reference to the current offset ranges, so it can be used downstream + var offsetRanges = Array[OffsetRange]() + + stream.transform { rdd => + // Get the offset ranges in the RDD + offsetRanges = rdd.asInstanceOf[HasOffsetRanges].offsetRanges + rdd + }.foreachRDD { rdd => + for (o <- offsetRanges) { + println(s"${o.topic} ${o.partition} ${o.fromOffset} ${o.untilOffset}") + } val collected = rdd.mapPartitionsWithIndex { (i, iter) => // For each partition, get size of the range in the partition, // and the number of items in the partition - val off = offsets(i) + val off = offsetRanges(i) val all = iter.toSeq val partSize = all.size val rangeSize = off.untilOffset - off.fromOffset @@ -162,7 +171,7 @@ class DirectKafkaStreamSuite "Start offset not from latest" ) - val collectedData = new mutable.ArrayBuffer[String]() + val collectedData = new mutable.ArrayBuffer[String]() with mutable.SynchronizedBuffer[String] stream.map { _._2 }.foreachRDD { rdd => collectedData ++= rdd.collect() } ssc.start() val newData = Map("b" -> 10) @@ -208,7 +217,7 @@ class DirectKafkaStreamSuite "Start offset not from latest" ) - val collectedData = new mutable.ArrayBuffer[String]() + val collectedData = new mutable.ArrayBuffer[String]() with mutable.SynchronizedBuffer[String] stream.foreachRDD { rdd => collectedData ++= rdd.collect() } ssc.start() val newData = Map("b" -> 10) @@ -324,7 +333,8 @@ class DirectKafkaStreamSuite ssc, kafkaParams, Set(topic)) } - val allReceived = new ArrayBuffer[(String, String)] + val allReceived = + new ArrayBuffer[(String, String)] with mutable.SynchronizedBuffer[(String, String)] stream.foreachRDD { rdd => allReceived ++= rdd.collect() } ssc.start() @@ -350,8 +360,8 @@ class DirectKafkaStreamSuite } object DirectKafkaStreamSuite { - val collectedData = new mutable.ArrayBuffer[String]() - var total = -1L + val collectedData = new mutable.ArrayBuffer[String]() with mutable.SynchronizedBuffer[String] + @volatile var total = -1L class InputInfoCollector extends StreamingListener { val numRecordsSubmitted = new AtomicLong(0L) diff --git a/external/kafka/src/test/scala/org/apache/spark/streaming/kafka/KafkaClusterSuite.scala b/external/kafka/src/test/scala/org/apache/spark/streaming/kafka/KafkaClusterSuite.scala index 7fb841b79cb6..d66830cbacde 100644 --- a/external/kafka/src/test/scala/org/apache/spark/streaming/kafka/KafkaClusterSuite.scala +++ b/external/kafka/src/test/scala/org/apache/spark/streaming/kafka/KafkaClusterSuite.scala @@ -20,9 +20,11 @@ package org.apache.spark.streaming.kafka import scala.util.Random import kafka.common.TopicAndPartition -import org.scalatest.{BeforeAndAfterAll, FunSuite} +import org.scalatest.BeforeAndAfterAll -class KafkaClusterSuite extends FunSuite with BeforeAndAfterAll { +import org.apache.spark.SparkFunSuite + +class KafkaClusterSuite extends SparkFunSuite with BeforeAndAfterAll { private val topic = "kcsuitetopic" + Random.nextInt(10000) private val topicAndPartition = TopicAndPartition(topic, 0) private var kc: KafkaCluster = null diff --git a/external/kafka/src/test/scala/org/apache/spark/streaming/kafka/KafkaRDDSuite.scala b/external/kafka/src/test/scala/org/apache/spark/streaming/kafka/KafkaRDDSuite.scala index 39c3fb448ff5..f52a738afd65 100644 --- a/external/kafka/src/test/scala/org/apache/spark/streaming/kafka/KafkaRDDSuite.scala +++ b/external/kafka/src/test/scala/org/apache/spark/streaming/kafka/KafkaRDDSuite.scala @@ -22,11 +22,11 @@ import scala.util.Random import kafka.serializer.StringDecoder import kafka.common.TopicAndPartition import kafka.message.MessageAndMetadata -import org.scalatest.{BeforeAndAfterAll, FunSuite} +import org.scalatest.BeforeAndAfterAll import org.apache.spark._ -class KafkaRDDSuite extends FunSuite with BeforeAndAfterAll { +class KafkaRDDSuite extends SparkFunSuite with BeforeAndAfterAll { private var kafkaTestUtils: KafkaTestUtils = _ @@ -55,21 +55,39 @@ class KafkaRDDSuite extends FunSuite with BeforeAndAfterAll { test("basic usage") { val topic = s"topicbasic-${Random.nextInt}" kafkaTestUtils.createTopic(topic) - val messages = Set("the", "quick", "brown", "fox") - kafkaTestUtils.sendMessages(topic, messages.toArray) + val messages = Array("the", "quick", "brown", "fox") + kafkaTestUtils.sendMessages(topic, messages) val kafkaParams = Map("metadata.broker.list" -> kafkaTestUtils.brokerAddress, "group.id" -> s"test-consumer-${Random.nextInt}") - kafkaTestUtils.waitUntilLeaderOffset(topic, 0, messages.size) - val offsetRanges = Array(OffsetRange(topic, 0, 0, messages.size)) - val rdd = KafkaUtils.createRDD[String, String, StringDecoder, StringDecoder]( + val rdd = KafkaUtils.createRDD[String, String, StringDecoder, StringDecoder]( sc, kafkaParams, offsetRanges) val received = rdd.map(_._2).collect.toSet - assert(received === messages) + assert(received === messages.toSet) + + // size-related method optimizations return sane results + assert(rdd.count === messages.size) + assert(rdd.countApprox(0).getFinalValue.mean === messages.size) + assert(!rdd.isEmpty) + assert(rdd.take(1).size === 1) + assert(rdd.take(1).head._2 === messages.head) + assert(rdd.take(messages.size + 10).size === messages.size) + + val emptyRdd = KafkaUtils.createRDD[String, String, StringDecoder, StringDecoder]( + sc, kafkaParams, Array(OffsetRange(topic, 0, 0, 0))) + + assert(emptyRdd.isEmpty) + + // invalid offset ranges throw exceptions + val badRanges = Array(OffsetRange(topic, 0, 0, messages.size + 1)) + intercept[SparkException] { + KafkaUtils.createRDD[String, String, StringDecoder, StringDecoder]( + sc, kafkaParams, badRanges) + } } test("iterator boundary conditions") { @@ -86,7 +104,6 @@ class KafkaRDDSuite extends FunSuite with BeforeAndAfterAll { // this is the "lots of messages" case kafkaTestUtils.sendMessages(topic, sent) val sentCount = sent.values.sum - kafkaTestUtils.waitUntilLeaderOffset(topic, 0, sentCount) // rdd defined from leaders after sending messages, should get the number sent val rdd = getRdd(kc, Set(topic)) @@ -113,7 +130,6 @@ class KafkaRDDSuite extends FunSuite with BeforeAndAfterAll { val sentOnlyOne = Map("d" -> 1) kafkaTestUtils.sendMessages(topic, sentOnlyOne) - kafkaTestUtils.waitUntilLeaderOffset(topic, 0, sentCount + 1) assert(rdd2.isDefined) assert(rdd2.get.count === 0, "got messages when there shouldn't be any") diff --git a/external/kafka/src/test/scala/org/apache/spark/streaming/kafka/KafkaStreamSuite.scala b/external/kafka/src/test/scala/org/apache/spark/streaming/kafka/KafkaStreamSuite.scala index 24699dfc33ad..797b07f80d8e 100644 --- a/external/kafka/src/test/scala/org/apache/spark/streaming/kafka/KafkaStreamSuite.scala +++ b/external/kafka/src/test/scala/org/apache/spark/streaming/kafka/KafkaStreamSuite.scala @@ -23,14 +23,14 @@ import scala.language.postfixOps import scala.util.Random import kafka.serializer.StringDecoder -import org.scalatest.{BeforeAndAfterAll, FunSuite} +import org.scalatest.BeforeAndAfterAll import org.scalatest.concurrent.Eventually -import org.apache.spark.SparkConf +import org.apache.spark.{SparkConf, SparkFunSuite} import org.apache.spark.storage.StorageLevel import org.apache.spark.streaming.{Milliseconds, StreamingContext} -class KafkaStreamSuite extends FunSuite with Eventually with BeforeAndAfterAll { +class KafkaStreamSuite extends SparkFunSuite with Eventually with BeforeAndAfterAll { private var ssc: StreamingContext = _ private var kafkaTestUtils: KafkaTestUtils = _ @@ -65,7 +65,7 @@ class KafkaStreamSuite extends FunSuite with Eventually with BeforeAndAfterAll { val stream = KafkaUtils.createStream[String, String, StringDecoder, StringDecoder]( ssc, kafkaParams, Map(topic -> 1), StorageLevel.MEMORY_ONLY) - val result = new mutable.HashMap[String, Long]() + val result = new mutable.HashMap[String, Long]() with mutable.SynchronizedMap[String, Long] stream.map(_._2).countByValue().foreachRDD { r => val ret = r.collect() ret.toMap.foreach { kv => @@ -77,10 +77,7 @@ class KafkaStreamSuite extends FunSuite with Eventually with BeforeAndAfterAll { ssc.start() eventually(timeout(10000 milliseconds), interval(100 milliseconds)) { - assert(sent.size === result.size) - sent.keys.foreach { k => - assert(sent(k) === result(k).toInt) - } + assert(sent === result) } } } diff --git a/external/kafka/src/test/scala/org/apache/spark/streaming/kafka/ReliableKafkaStreamSuite.scala b/external/kafka/src/test/scala/org/apache/spark/streaming/kafka/ReliableKafkaStreamSuite.scala index 38548dd73b82..80e2df62de3f 100644 --- a/external/kafka/src/test/scala/org/apache/spark/streaming/kafka/ReliableKafkaStreamSuite.scala +++ b/external/kafka/src/test/scala/org/apache/spark/streaming/kafka/ReliableKafkaStreamSuite.scala @@ -26,15 +26,15 @@ import scala.util.Random import kafka.serializer.StringDecoder import kafka.utils.{ZKGroupTopicDirs, ZkUtils} -import org.scalatest.{BeforeAndAfter, BeforeAndAfterAll, FunSuite} +import org.scalatest.{BeforeAndAfter, BeforeAndAfterAll} import org.scalatest.concurrent.Eventually -import org.apache.spark.SparkConf +import org.apache.spark.{SparkConf, SparkFunSuite} import org.apache.spark.storage.StorageLevel import org.apache.spark.streaming.{Milliseconds, StreamingContext} import org.apache.spark.util.Utils -class ReliableKafkaStreamSuite extends FunSuite +class ReliableKafkaStreamSuite extends SparkFunSuite with BeforeAndAfterAll with BeforeAndAfter with Eventually { private val sparkConf = new SparkConf() diff --git a/external/mqtt/pom.xml b/external/mqtt/pom.xml index 98f95a9a64fa..0e41e5781784 100644 --- a/external/mqtt/pom.xml +++ b/external/mqtt/pom.xml @@ -21,7 +21,7 @@ org.apache.spark spark-parent_2.10 - 1.4.0-SNAPSHOT + 1.5.0-SNAPSHOT ../../pom.xml @@ -41,6 +41,13 @@ ${project.version} provided + + org.apache.spark + spark-core_${scala.binary.version} + ${project.version} + test-jar + test + org.eclipse.paho org.eclipse.paho.client.mqttv3 diff --git a/external/mqtt/src/main/scala/org/apache/spark/streaming/mqtt/MQTTInputDStream.scala b/external/mqtt/src/main/scala/org/apache/spark/streaming/mqtt/MQTTInputDStream.scala index 3c0ef94cb0fa..7c2f18cb35bd 100644 --- a/external/mqtt/src/main/scala/org/apache/spark/streaming/mqtt/MQTTInputDStream.scala +++ b/external/mqtt/src/main/scala/org/apache/spark/streaming/mqtt/MQTTInputDStream.scala @@ -17,25 +17,12 @@ package org.apache.spark.streaming.mqtt -import java.io.IOException -import java.util.concurrent.Executors -import java.util.Properties - -import scala.collection.JavaConversions._ -import scala.collection.Map -import scala.collection.mutable.HashMap -import scala.reflect.ClassTag - import org.eclipse.paho.client.mqttv3.IMqttDeliveryToken import org.eclipse.paho.client.mqttv3.MqttCallback import org.eclipse.paho.client.mqttv3.MqttClient -import org.eclipse.paho.client.mqttv3.MqttClientPersistence -import org.eclipse.paho.client.mqttv3.MqttException import org.eclipse.paho.client.mqttv3.MqttMessage -import org.eclipse.paho.client.mqttv3.MqttTopic import org.eclipse.paho.client.mqttv3.persist.MemoryPersistence -import org.apache.spark.Logging import org.apache.spark.storage.StorageLevel import org.apache.spark.streaming.StreamingContext import org.apache.spark.streaming.dstream._ @@ -57,6 +44,8 @@ class MQTTInputDStream( storageLevel: StorageLevel ) extends ReceiverInputDStream[String](ssc_) { + private[streaming] override def name: String = s"MQTT stream [$id]" + def getReceiver(): Receiver[String] = { new MQTTReceiver(brokerUrl, topic, storageLevel) } @@ -86,7 +75,7 @@ class MQTTReceiver( // Handles Mqtt message override def messageArrived(topic: String, message: MqttMessage) { - store(new String(message.getPayload(),"utf-8")) + store(new String(message.getPayload(), "utf-8")) } override def deliveryComplete(token: IMqttDeliveryToken) { diff --git a/external/mqtt/src/test/scala/org/apache/spark/streaming/mqtt/MQTTStreamSuite.scala b/external/mqtt/src/test/scala/org/apache/spark/streaming/mqtt/MQTTStreamSuite.scala index a19a72c58a70..c4bf5aa7869b 100644 --- a/external/mqtt/src/test/scala/org/apache/spark/streaming/mqtt/MQTTStreamSuite.scala +++ b/external/mqtt/src/test/scala/org/apache/spark/streaming/mqtt/MQTTStreamSuite.scala @@ -29,7 +29,7 @@ import org.apache.commons.lang3.RandomUtils import org.eclipse.paho.client.mqttv3._ import org.eclipse.paho.client.mqttv3.persist.MqttDefaultFilePersistence -import org.scalatest.{BeforeAndAfter, FunSuite} +import org.scalatest.BeforeAndAfter import org.scalatest.concurrent.Eventually import org.apache.spark.streaming.{Milliseconds, StreamingContext} @@ -37,10 +37,10 @@ import org.apache.spark.storage.StorageLevel import org.apache.spark.streaming.dstream.ReceiverInputDStream import org.apache.spark.streaming.scheduler.StreamingListener import org.apache.spark.streaming.scheduler.StreamingListenerReceiverStarted -import org.apache.spark.SparkConf +import org.apache.spark.{SparkConf, SparkFunSuite} import org.apache.spark.util.Utils -class MQTTStreamSuite extends FunSuite with Eventually with BeforeAndAfter { +class MQTTStreamSuite extends SparkFunSuite with Eventually with BeforeAndAfter { private val batchDuration = Milliseconds(500) private val master = "local[2]" diff --git a/external/twitter/pom.xml b/external/twitter/pom.xml index 8b6a8959ac4c..178ae8de13b5 100644 --- a/external/twitter/pom.xml +++ b/external/twitter/pom.xml @@ -21,7 +21,7 @@ org.apache.spark spark-parent_2.10 - 1.4.0-SNAPSHOT + 1.5.0-SNAPSHOT ../../pom.xml @@ -41,6 +41,13 @@ ${project.version} provided + + org.apache.spark + spark-core_${scala.binary.version} + ${project.version} + test-jar + test + org.twitter4j twitter4j-stream diff --git a/external/twitter/src/test/scala/org/apache/spark/streaming/twitter/TwitterStreamSuite.scala b/external/twitter/src/test/scala/org/apache/spark/streaming/twitter/TwitterStreamSuite.scala index 9ee57d7581d8..d9acb568879f 100644 --- a/external/twitter/src/test/scala/org/apache/spark/streaming/twitter/TwitterStreamSuite.scala +++ b/external/twitter/src/test/scala/org/apache/spark/streaming/twitter/TwitterStreamSuite.scala @@ -18,16 +18,16 @@ package org.apache.spark.streaming.twitter -import org.scalatest.{BeforeAndAfter, FunSuite} +import org.scalatest.BeforeAndAfter import twitter4j.Status import twitter4j.auth.{NullAuthorization, Authorization} -import org.apache.spark.Logging +import org.apache.spark.{Logging, SparkFunSuite} import org.apache.spark.streaming.{Seconds, StreamingContext} import org.apache.spark.storage.StorageLevel import org.apache.spark.streaming.dstream.ReceiverInputDStream -class TwitterStreamSuite extends FunSuite with BeforeAndAfter with Logging { +class TwitterStreamSuite extends SparkFunSuite with BeforeAndAfter with Logging { val batchDuration = Seconds(1) diff --git a/external/zeromq/pom.xml b/external/zeromq/pom.xml index a50d378b3433..37bfd10d4366 100644 --- a/external/zeromq/pom.xml +++ b/external/zeromq/pom.xml @@ -21,7 +21,7 @@ org.apache.spark spark-parent_2.10 - 1.4.0-SNAPSHOT + 1.5.0-SNAPSHOT ../../pom.xml @@ -41,6 +41,13 @@ ${project.version} provided + + org.apache.spark + spark-core_${scala.binary.version} + ${project.version} + test-jar + test + ${akka.group} akka-zeromq_${scala.binary.version} diff --git a/external/zeromq/src/test/scala/org/apache/spark/streaming/zeromq/ZeroMQStreamSuite.scala b/external/zeromq/src/test/scala/org/apache/spark/streaming/zeromq/ZeroMQStreamSuite.scala index a7566e733d89..35d2e62c6848 100644 --- a/external/zeromq/src/test/scala/org/apache/spark/streaming/zeromq/ZeroMQStreamSuite.scala +++ b/external/zeromq/src/test/scala/org/apache/spark/streaming/zeromq/ZeroMQStreamSuite.scala @@ -20,13 +20,13 @@ package org.apache.spark.streaming.zeromq import akka.actor.SupervisorStrategy import akka.util.ByteString import akka.zeromq.Subscribe -import org.scalatest.FunSuite +import org.apache.spark.SparkFunSuite import org.apache.spark.storage.StorageLevel import org.apache.spark.streaming.{Seconds, StreamingContext} import org.apache.spark.streaming.dstream.ReceiverInputDStream -class ZeroMQStreamSuite extends FunSuite { +class ZeroMQStreamSuite extends SparkFunSuite { val batchDuration = Seconds(1) diff --git a/extras/java8-tests/pom.xml b/extras/java8-tests/pom.xml index 4351a8a12fe2..3636a9037d43 100644 --- a/extras/java8-tests/pom.xml +++ b/extras/java8-tests/pom.xml @@ -20,7 +20,7 @@ org.apache.spark spark-parent_2.10 - 1.4.0-SNAPSHOT + 1.5.0-SNAPSHOT ../../pom.xml @@ -39,6 +39,13 @@ spark-core_${scala.binary.version} ${project.version} + + org.apache.spark + spark-core_${scala.binary.version} + ${project.version} + test-jar + test + org.apache.spark spark-streaming_${scala.binary.version} @@ -49,6 +56,7 @@ spark-streaming_${scala.binary.version} ${project.version} test-jar + test junit diff --git a/extras/kinesis-asl/pom.xml b/extras/kinesis-asl/pom.xml index 25847a1b33d9..5289073eb457 100644 --- a/extras/kinesis-asl/pom.xml +++ b/extras/kinesis-asl/pom.xml @@ -20,7 +20,7 @@ org.apache.spark spark-parent_2.10 - 1.4.0-SNAPSHOT + 1.5.0-SNAPSHOT ../../pom.xml @@ -59,7 +59,7 @@ org.mockito - mockito-all + mockito-core test diff --git a/extras/kinesis-asl/src/main/java/org/apache/spark/examples/streaming/JavaKinesisWordCountASL.java b/extras/kinesis-asl/src/main/java/org/apache/spark/examples/streaming/JavaKinesisWordCountASL.java index b0bff27a61c1..06e0ff28afd9 100644 --- a/extras/kinesis-asl/src/main/java/org/apache/spark/examples/streaming/JavaKinesisWordCountASL.java +++ b/extras/kinesis-asl/src/main/java/org/apache/spark/examples/streaming/JavaKinesisWordCountASL.java @@ -20,6 +20,7 @@ import java.util.List; import java.util.regex.Pattern; +import com.amazonaws.regions.RegionUtils; import org.apache.log4j.Logger; import org.apache.spark.SparkConf; import org.apache.spark.api.java.function.FlatMapFunction; @@ -40,140 +41,146 @@ import com.google.common.collect.Lists; /** - * Java-friendly Kinesis Spark Streaming WordCount example + * Consumes messages from a Amazon Kinesis streams and does wordcount. * - * See http://spark.apache.org/docs/latest/streaming-kinesis.html for more details - * on the Kinesis Spark Streaming integration. + * This example spins up 1 Kinesis Receiver per shard for the given stream. + * It then starts pulling from the last checkpointed sequence number of the given stream. * - * This example spins up 1 Kinesis Worker (Spark Streaming Receiver) per shard - * for the given stream. - * It then starts pulling from the last checkpointed sequence number of the given - * and . + * Usage: JavaKinesisWordCountASL [app-name] [stream-name] [endpoint-url] [region-name] + * [app-name] is the name of the consumer app, used to track the read data in DynamoDB + * [stream-name] name of the Kinesis stream (ie. mySparkStream) + * [endpoint-url] endpoint of the Kinesis service + * (e.g. https://kinesis.us-east-1.amazonaws.com) * - * Valid endpoint urls: http://docs.aws.amazon.com/general/latest/gr/rande.html#ak_region - * - * This code uses the DefaultAWSCredentialsProviderChain and searches for credentials - * in the following order of precedence: - * Environment Variables - AWS_ACCESS_KEY_ID and AWS_SECRET_KEY - * Java System Properties - aws.accessKeyId and aws.secretKey - * Credential profiles file - default location (~/.aws/credentials) shared by all AWS SDKs - * Instance profile credentials - delivered through the Amazon EC2 metadata service - * - * Usage: JavaKinesisWordCountASL - * is the name of the Kinesis stream (ie. mySparkStream) - * is the endpoint of the Kinesis service - * (ie. https://kinesis.us-east-1.amazonaws.com) * * Example: - * $ export AWS_ACCESS_KEY_ID= + * # export AWS keys if necessary + * $ export AWS_ACCESS_KEY_ID=[your-access-key] * $ export AWS_SECRET_KEY= - * $ $SPARK_HOME/bin/run-example \ - * org.apache.spark.examples.streaming.JavaKinesisWordCountASL mySparkStream \ - * https://kinesis.us-east-1.amazonaws.com * - * Note that number of workers/threads should be 1 more than the number of receivers. - * This leaves one thread available for actually processing the data. + * # run the example + * $ SPARK_HOME/bin/run-example streaming.JavaKinesisWordCountASL myAppName mySparkStream \ + * https://kinesis.us-east-1.amazonaws.com + * + * There is a companion helper class called KinesisWordProducerASL which puts dummy data + * onto the Kinesis stream. * - * There is a companion helper class called KinesisWordCountProducerASL which puts dummy data - * onto the Kinesis stream. - * Usage instructions for KinesisWordCountProducerASL are provided in the class definition. + * This code uses the DefaultAWSCredentialsProviderChain to find credentials + * in the following order: + * Environment Variables - AWS_ACCESS_KEY_ID and AWS_SECRET_KEY + * Java System Properties - aws.accessKeyId and aws.secretKey + * Credential profiles file - default location (~/.aws/credentials) shared by all AWS SDKs + * Instance profile credentials - delivered through the Amazon EC2 metadata service + * For more information, see + * http://docs.aws.amazon.com/AWSSdkDocsJava/latest/DeveloperGuide/credentials.html + * + * See http://spark.apache.org/docs/latest/streaming-kinesis-integration.html for more details on + * the Kinesis Spark Streaming integration. */ public final class JavaKinesisWordCountASL { // needs to be public for access from run-example - private static final Pattern WORD_SEPARATOR = Pattern.compile(" "); - private static final Logger logger = Logger.getLogger(JavaKinesisWordCountASL.class); - - /* Make the constructor private to enforce singleton */ - private JavaKinesisWordCountASL() { + private static final Pattern WORD_SEPARATOR = Pattern.compile(" "); + private static final Logger logger = Logger.getLogger(JavaKinesisWordCountASL.class); + + public static void main(String[] args) { + // Check that all required args were passed in. + if (args.length != 3) { + System.err.println( + "Usage: JavaKinesisWordCountASL \n\n" + + " is the name of the app, used to track the read data in DynamoDB\n" + + " is the name of the Kinesis stream\n" + + " is the endpoint of the Kinesis service\n" + + " (e.g. https://kinesis.us-east-1.amazonaws.com)\n" + + "Generate data for the Kinesis stream using the example KinesisWordProducerASL.\n" + + "See http://spark.apache.org/docs/latest/streaming-kinesis-integration.html for more\n" + + "details.\n" + ); + System.exit(1); } - public static void main(String[] args) { - /* Check that all required args were passed in. */ - if (args.length < 2) { - System.err.println( - "Usage: JavaKinesisWordCountASL \n" + - " is the name of the Kinesis stream\n" + - " is the endpoint of the Kinesis service\n" + - " (e.g. https://kinesis.us-east-1.amazonaws.com)\n"); - System.exit(1); - } - - StreamingExamples.setStreamingLogLevels(); - - /* Populate the appropriate variables from the given args */ - String streamName = args[0]; - String endpointUrl = args[1]; - /* Set the batch interval to a fixed 2000 millis (2 seconds) */ - Duration batchInterval = new Duration(2000); - - /* Create a Kinesis client in order to determine the number of shards for the given stream */ - AmazonKinesisClient kinesisClient = new AmazonKinesisClient( - new DefaultAWSCredentialsProviderChain()); - kinesisClient.setEndpoint(endpointUrl); - - /* Determine the number of shards from the stream */ - int numShards = kinesisClient.describeStream(streamName) - .getStreamDescription().getShards().size(); - - /* In this example, we're going to create 1 Kinesis Worker/Receiver/DStream for each shard */ - int numStreams = numShards; - - /* Setup the Spark config. */ - SparkConf sparkConfig = new SparkConf().setAppName("KinesisWordCount"); - - /* Kinesis checkpoint interval. Same as batchInterval for this example. */ - Duration checkpointInterval = batchInterval; + // Set default log4j logging level to WARN to hide Spark logs + StreamingExamples.setStreamingLogLevels(); + + // Populate the appropriate variables from the given args + String kinesisAppName = args[0]; + String streamName = args[1]; + String endpointUrl = args[2]; + + // Create a Kinesis client in order to determine the number of shards for the given stream + AmazonKinesisClient kinesisClient = + new AmazonKinesisClient(new DefaultAWSCredentialsProviderChain()); + kinesisClient.setEndpoint(endpointUrl); + int numShards = + kinesisClient.describeStream(streamName).getStreamDescription().getShards().size(); + + + // In this example, we're going to create 1 Kinesis Receiver/input DStream for each shard. + // This is not a necessity; if there are less receivers/DStreams than the number of shards, + // then the shards will be automatically distributed among the receivers and each receiver + // will receive data from multiple shards. + int numStreams = numShards; + + // Spark Streaming batch interval + Duration batchInterval = new Duration(2000); + + // Kinesis checkpoint interval. Same as batchInterval for this example. + Duration kinesisCheckpointInterval = batchInterval; + + // Get the region name from the endpoint URL to save Kinesis Client Library metadata in + // DynamoDB of the same region as the Kinesis stream + String regionName = RegionUtils.getRegionByEndpoint(endpointUrl).getName(); + + // Setup the Spark config and StreamingContext + SparkConf sparkConfig = new SparkConf().setAppName("JavaKinesisWordCountASL"); + JavaStreamingContext jssc = new JavaStreamingContext(sparkConfig, batchInterval); + + // Create the Kinesis DStreams + List> streamsList = new ArrayList>(numStreams); + for (int i = 0; i < numStreams; i++) { + streamsList.add( + KinesisUtils.createStream(jssc, kinesisAppName, streamName, endpointUrl, regionName, + InitialPositionInStream.LATEST, kinesisCheckpointInterval, StorageLevel.MEMORY_AND_DISK_2()) + ); + } - /* Setup the StreamingContext */ - JavaStreamingContext jssc = new JavaStreamingContext(sparkConfig, batchInterval); + // Union all the streams if there is more than 1 stream + JavaDStream unionStreams; + if (streamsList.size() > 1) { + unionStreams = jssc.union(streamsList.get(0), streamsList.subList(1, streamsList.size())); + } else { + // Otherwise, just use the 1 stream + unionStreams = streamsList.get(0); + } - /* Create the same number of Kinesis DStreams/Receivers as Kinesis stream's shards */ - List> streamsList = new ArrayList>(numStreams); - for (int i = 0; i < numStreams; i++) { - streamsList.add( - KinesisUtils.createStream(jssc, streamName, endpointUrl, checkpointInterval, - InitialPositionInStream.LATEST, StorageLevel.MEMORY_AND_DISK_2()) - ); + // Convert each line of Array[Byte] to String, and split into words + JavaDStream words = unionStreams.flatMap(new FlatMapFunction() { + @Override + public Iterable call(byte[] line) { + return Lists.newArrayList(WORD_SEPARATOR.split(new String(line))); + } + }); + + // Map each word to a (word, 1) tuple so we can reduce by key to count the words + JavaPairDStream wordCounts = words.mapToPair( + new PairFunction() { + @Override + public Tuple2 call(String s) { + return new Tuple2(s, 1); + } } - - /* Union all the streams if there is more than 1 stream */ - JavaDStream unionStreams; - if (streamsList.size() > 1) { - unionStreams = jssc.union(streamsList.get(0), streamsList.subList(1, streamsList.size())); - } else { - /* Otherwise, just use the 1 stream */ - unionStreams = streamsList.get(0); + ).reduceByKey( + new Function2() { + @Override + public Integer call(Integer i1, Integer i2) { + return i1 + i2; + } } + ); - /* - * Split each line of the union'd DStreams into multiple words using flatMap to produce the collection. - * Convert lines of byte[] to multiple Strings by first converting to String, then splitting on WORD_SEPARATOR. - */ - JavaDStream words = unionStreams.flatMap(new FlatMapFunction() { - @Override - public Iterable call(byte[] line) { - return Lists.newArrayList(WORD_SEPARATOR.split(new String(line))); - } - }); - - /* Map each word to a (word, 1) tuple, then reduce/aggregate by word. */ - JavaPairDStream wordCounts = words.mapToPair( - new PairFunction() { - @Override - public Tuple2 call(String s) { - return new Tuple2(s, 1); - } - }).reduceByKey(new Function2() { - @Override - public Integer call(Integer i1, Integer i2) { - return i1 + i2; - } - }); - - /* Print the first 10 wordCounts */ - wordCounts.print(); - - /* Start the streaming context and await termination */ - jssc.start(); - jssc.awaitTermination(); - } + // Print the first 10 wordCounts + wordCounts.print(); + + // Start the streaming context and await termination + jssc.start(); + jssc.awaitTermination(); + } } diff --git a/extras/kinesis-asl/src/main/scala/org/apache/spark/examples/streaming/KinesisWordCountASL.scala b/extras/kinesis-asl/src/main/scala/org/apache/spark/examples/streaming/KinesisWordCountASL.scala index 32da0858d1a1..be8b62d3cc6b 100644 --- a/extras/kinesis-asl/src/main/scala/org/apache/spark/examples/streaming/KinesisWordCountASL.scala +++ b/extras/kinesis-asl/src/main/scala/org/apache/spark/examples/streaming/KinesisWordCountASL.scala @@ -18,223 +18,249 @@ package org.apache.spark.examples.streaming import java.nio.ByteBuffer + import scala.util.Random -import org.apache.spark.Logging -import org.apache.spark.SparkConf -import org.apache.spark.storage.StorageLevel -import org.apache.spark.streaming.Milliseconds -import org.apache.spark.streaming.StreamingContext -import org.apache.spark.streaming.StreamingContext.toPairDStreamFunctions -import org.apache.spark.streaming.kinesis.KinesisUtils -import com.amazonaws.auth.DefaultAWSCredentialsProviderChain + +import com.amazonaws.auth.{DefaultAWSCredentialsProviderChain, BasicAWSCredentials} +import com.amazonaws.regions.RegionUtils import com.amazonaws.services.kinesis.AmazonKinesisClient import com.amazonaws.services.kinesis.clientlibrary.lib.worker.InitialPositionInStream import com.amazonaws.services.kinesis.model.PutRecordRequest -import org.apache.log4j.Logger -import org.apache.log4j.Level +import org.apache.log4j.{Level, Logger} + +import org.apache.spark.{Logging, SparkConf} +import org.apache.spark.storage.StorageLevel +import org.apache.spark.streaming.{Milliseconds, StreamingContext} +import org.apache.spark.streaming.dstream.DStream.toPairDStreamFunctions +import org.apache.spark.streaming.kinesis.KinesisUtils + /** - * Kinesis Spark Streaming WordCount example. + * Consumes messages from a Amazon Kinesis streams and does wordcount. * - * See http://spark.apache.org/docs/latest/streaming-kinesis.html for more details on - * the Kinesis Spark Streaming integration. + * This example spins up 1 Kinesis Receiver per shard for the given stream. + * It then starts pulling from the last checkpointed sequence number of the given stream. * - * This example spins up 1 Kinesis Worker (Spark Streaming Receiver) per shard - * for the given stream. - * It then starts pulling from the last checkpointed sequence number of the given - * and . + * Usage: KinesisWordCountASL + * is the name of the consumer app, used to track the read data in DynamoDB + * name of the Kinesis stream (ie. mySparkStream) + * endpoint of the Kinesis service + * (e.g. https://kinesis.us-east-1.amazonaws.com) * - * Valid endpoint urls: http://docs.aws.amazon.com/general/latest/gr/rande.html#ak_region - * - * This code uses the DefaultAWSCredentialsProviderChain and searches for credentials - * in the following order of precedence: - * Environment Variables - AWS_ACCESS_KEY_ID and AWS_SECRET_KEY - * Java System Properties - aws.accessKeyId and aws.secretKey - * Credential profiles file - default location (~/.aws/credentials) shared by all AWS SDKs - * Instance profile credentials - delivered through the Amazon EC2 metadata service - * - * Usage: KinesisWordCountASL - * is the name of the Kinesis stream (ie. mySparkStream) - * is the endpoint of the Kinesis service - * (ie. https://kinesis.us-east-1.amazonaws.com) * * Example: - * $ export AWS_ACCESS_KEY_ID= - * $ export AWS_SECRET_KEY= - * $ $SPARK_HOME/bin/run-example \ - * org.apache.spark.examples.streaming.KinesisWordCountASL mySparkStream \ - * https://kinesis.us-east-1.amazonaws.com + * # export AWS keys if necessary + * $ export AWS_ACCESS_KEY_ID= + * $ export AWS_SECRET_KEY= + * + * # run the example + * $ SPARK_HOME/bin/run-example streaming.KinesisWordCountASL myAppName mySparkStream \ + * https://kinesis.us-east-1.amazonaws.com * - * - * Note that number of workers/threads should be 1 more than the number of receivers. - * This leaves one thread available for actually processing the data. + * There is a companion helper class called KinesisWordProducerASL which puts dummy data + * onto the Kinesis stream. * - * There is a companion helper class below called KinesisWordCountProducerASL which puts - * dummy data onto the Kinesis stream. - * Usage instructions for KinesisWordCountProducerASL are provided in that class definition. + * This code uses the DefaultAWSCredentialsProviderChain to find credentials + * in the following order: + * Environment Variables - AWS_ACCESS_KEY_ID and AWS_SECRET_KEY + * Java System Properties - aws.accessKeyId and aws.secretKey + * Credential profiles file - default location (~/.aws/credentials) shared by all AWS SDKs + * Instance profile credentials - delivered through the Amazon EC2 metadata service + * For more information, see + * http://docs.aws.amazon.com/AWSSdkDocsJava/latest/DeveloperGuide/credentials.html + * + * See http://spark.apache.org/docs/latest/streaming-kinesis-integration.html for more details on + * the Kinesis Spark Streaming integration. */ -private object KinesisWordCountASL extends Logging { +object KinesisWordCountASL extends Logging { def main(args: Array[String]) { - /* Check that all required args were passed in. */ - if (args.length < 2) { + // Check that all required args were passed in. + if (args.length != 3) { System.err.println( """ - |Usage: KinesisWordCount + |Usage: KinesisWordCountASL + | + | is the name of the consumer app, used to track the read data in DynamoDB | is the name of the Kinesis stream | is the endpoint of the Kinesis service | (e.g. https://kinesis.us-east-1.amazonaws.com) + | + |Generate input data for Kinesis stream using the example KinesisWordProducerASL. + |See http://spark.apache.org/docs/latest/streaming-kinesis-integration.html for more + |details. """.stripMargin) System.exit(1) } StreamingExamples.setStreamingLogLevels() - /* Populate the appropriate variables from the given args */ - val Array(streamName, endpointUrl) = args + // Populate the appropriate variables from the given args + val Array(appName, streamName, endpointUrl) = args - /* Determine the number of shards from the stream */ - val kinesisClient = new AmazonKinesisClient(new DefaultAWSCredentialsProviderChain()) + + // Determine the number of shards from the stream using the low-level Kinesis Client + // from the AWS Java SDK. + val credentials = new DefaultAWSCredentialsProviderChain().getCredentials() + require(credentials != null, + "No AWS credentials found. Please specify credentials using one of the methods specified " + + "in http://docs.aws.amazon.com/AWSSdkDocsJava/latest/DeveloperGuide/credentials.html") + val kinesisClient = new AmazonKinesisClient(credentials) kinesisClient.setEndpoint(endpointUrl) - val numShards = kinesisClient.describeStream(streamName).getStreamDescription().getShards() - .size() + val numShards = kinesisClient.describeStream(streamName).getStreamDescription().getShards().size + - /* In this example, we're going to create 1 Kinesis Worker/Receiver/DStream for each shard. */ + // In this example, we're going to create 1 Kinesis Receiver/input DStream for each shard. + // This is not a necessity; if there are less receivers/DStreams than the number of shards, + // then the shards will be automatically distributed among the receivers and each receiver + // will receive data from multiple shards. val numStreams = numShards - /* Setup the and SparkConfig and StreamingContext */ - /* Spark Streaming batch interval */ + // Spark Streaming batch interval val batchInterval = Milliseconds(2000) - val sparkConfig = new SparkConf().setAppName("KinesisWordCount") - val ssc = new StreamingContext(sparkConfig, batchInterval) - /* Kinesis checkpoint interval. Same as batchInterval for this example. */ + // Kinesis checkpoint interval is the interval at which the DynamoDB is updated with information + // on sequence number of records that have been received. Same as batchInterval for this + // example. val kinesisCheckpointInterval = batchInterval - /* Create the same number of Kinesis DStreams/Receivers as Kinesis stream's shards */ + // Get the region name from the endpoint URL to save Kinesis Client Library metadata in + // DynamoDB of the same region as the Kinesis stream + val regionName = RegionUtils.getRegionByEndpoint(endpointUrl).getName() + + // Setup the SparkConfig and StreamingContext + val sparkConfig = new SparkConf().setAppName("KinesisWordCountASL") + val ssc = new StreamingContext(sparkConfig, batchInterval) + + // Create the Kinesis DStreams val kinesisStreams = (0 until numStreams).map { i => - KinesisUtils.createStream(ssc, streamName, endpointUrl, kinesisCheckpointInterval, - InitialPositionInStream.LATEST, StorageLevel.MEMORY_AND_DISK_2) + KinesisUtils.createStream(ssc, appName, streamName, endpointUrl, regionName, + InitialPositionInStream.LATEST, kinesisCheckpointInterval, StorageLevel.MEMORY_AND_DISK_2) } - /* Union all the streams */ + // Union all the streams val unionStreams = ssc.union(kinesisStreams) - /* Convert each line of Array[Byte] to String, split into words, and count them */ - val words = unionStreams.flatMap(byteArray => new String(byteArray) - .split(" ")) + // Convert each line of Array[Byte] to String, and split into words + val words = unionStreams.flatMap(byteArray => new String(byteArray).split(" ")) - /* Map each word to a (word, 1) tuple so we can reduce/aggregate by key. */ + // Map each word to a (word, 1) tuple so we can reduce by key to count the words val wordCounts = words.map(word => (word, 1)).reduceByKey(_ + _) - /* Print the first 10 wordCounts */ + // Print the first 10 wordCounts wordCounts.print() - /* Start the streaming context and await termination */ + // Start the streaming context and await termination ssc.start() ssc.awaitTermination() } } /** - * Usage: KinesisWordCountProducerASL - * + * Usage: KinesisWordProducerASL \ + * + * * is the name of the Kinesis stream (ie. mySparkStream) - * is the endpoint of the Kinesis service + * is the endpoint of the Kinesis service * (ie. https://kinesis.us-east-1.amazonaws.com) * is the rate of records per second to put onto the stream * is the rate of records per second to put onto the stream * * Example: - * $ export AWS_ACCESS_KEY_ID= - * $ export AWS_SECRET_KEY= - * $ $SPARK_HOME/bin/run-example \ - * org.apache.spark.examples.streaming.KinesisWordCountProducerASL mySparkStream \ - * https://kinesis.us-east-1.amazonaws.com 10 5 + * $ SPARK_HOME/bin/run-example streaming.KinesisWordProducerASL mySparkStream \ + * https://kinesis.us-east-1.amazonaws.com us-east-1 10 5 */ -private object KinesisWordCountProducerASL { +object KinesisWordProducerASL { def main(args: Array[String]) { - if (args.length < 4) { - System.err.println("Usage: KinesisWordCountProducerASL " + - " ") + if (args.length != 4) { + System.err.println( + """ + |Usage: KinesisWordProducerASL + + | + | is the name of the Kinesis stream + | is the endpoint of the Kinesis service + | (e.g. https://kinesis.us-east-1.amazonaws.com) + | is the rate of records per second to put onto the stream + | is the rate of records per second to put onto the stream + | + """.stripMargin) + System.exit(1) } + // Set default log4j logging level to WARN to hide Spark logs StreamingExamples.setStreamingLogLevels() - /* Populate the appropriate variables from the given args */ + // Populate the appropriate variables from the given args val Array(stream, endpoint, recordsPerSecond, wordsPerRecord) = args - /* Generate the records and return the totals */ - val totals = generate(stream, endpoint, recordsPerSecond.toInt, wordsPerRecord.toInt) + // Generate the records and return the totals + val totals = generate(stream, endpoint, recordsPerSecond.toInt, + wordsPerRecord.toInt) - /* Print the array of (index, total) tuples */ - println("Totals") - totals.foreach(total => println(total.toString())) + // Print the array of (word, total) tuples + println("Totals for the words sent") + totals.foreach(println(_)) } def generate(stream: String, endpoint: String, recordsPerSecond: Int, - wordsPerRecord: Int): Seq[(Int, Int)] = { + wordsPerRecord: Int): Seq[(String, Int)] = { - val MaxRandomInts = 10 + val randomWords = List("spark", "you", "are", "my", "father") + val totals = scala.collection.mutable.Map[String, Int]() - /* Create the Kinesis client */ + // Create the low-level Kinesis Client from the AWS Java SDK. val kinesisClient = new AmazonKinesisClient(new DefaultAWSCredentialsProviderChain()) kinesisClient.setEndpoint(endpoint) println(s"Putting records onto stream $stream and endpoint $endpoint at a rate of" + - s" $recordsPerSecond records per second and $wordsPerRecord words per record"); - - val totals = new Array[Int](MaxRandomInts) - /* Put String records onto the stream per the given recordPerSec and wordsPerRecord */ - for (i <- 1 to 5) { - - /* Generate recordsPerSec records to put onto the stream */ - val records = (1 to recordsPerSecond.toInt).map { recordNum => - /* - * Randomly generate each wordsPerRec words between 0 (inclusive) - * and MAX_RANDOM_INTS (exclusive) - */ + s" $recordsPerSecond records per second and $wordsPerRecord words per record") + + // Iterate and put records onto the stream per the given recordPerSec and wordsPerRecord + for (i <- 1 to 10) { + // Generate recordsPerSec records to put onto the stream + val records = (1 to recordsPerSecond.toInt).foreach { recordNum => + // Randomly generate wordsPerRecord number of words val data = (1 to wordsPerRecord.toInt).map(x => { - /* Generate the random int */ - val randomInt = Random.nextInt(MaxRandomInts) + // Get a random index to a word + val randomWordIdx = Random.nextInt(randomWords.size) + val randomWord = randomWords(randomWordIdx) - /* Keep track of the totals */ - totals(randomInt) += 1 + // Increment total count to compare to server counts later + totals(randomWord) = totals.getOrElse(randomWord, 0) + 1 - randomInt.toString() + randomWord }).mkString(" ") - /* Create a partitionKey based on recordNum */ + // Create a partitionKey based on recordNum val partitionKey = s"partitionKey-$recordNum" - /* Create a PutRecordRequest with an Array[Byte] version of the data */ + // Create a PutRecordRequest with an Array[Byte] version of the data val putRecordRequest = new PutRecordRequest().withStreamName(stream) .withPartitionKey(partitionKey) - .withData(ByteBuffer.wrap(data.getBytes())); + .withData(ByteBuffer.wrap(data.getBytes())) - /* Put the record onto the stream and capture the PutRecordResult */ - val putRecordResult = kinesisClient.putRecord(putRecordRequest); + // Put the record onto the stream and capture the PutRecordResult + val putRecordResult = kinesisClient.putRecord(putRecordRequest) } - /* Sleep for a second */ + // Sleep for a second Thread.sleep(1000) println("Sent " + recordsPerSecond + " records") } - - /* Convert the totals to (index, total) tuple */ - (0 to (MaxRandomInts - 1)).zip(totals) + // Convert the totals to (index, total) tuple + totals.toSeq.sortBy(_._1) } } -/** - * Utility functions for Spark Streaming examples. +/** + * Utility functions for Spark Streaming examples. * This has been lifted from the examples/ project to remove the circular dependency. */ private[streaming] object StreamingExamples extends Logging { - - /** Set reasonable logging levels for streaming if the user has not configured log4j. */ + // Set reasonable logging levels for streaming if the user has not configured log4j. def setStreamingLogLevels() { val log4jInitialized = Logger.getRootLogger.getAllAppenders.hasMoreElements if (!log4jInitialized) { diff --git a/extras/kinesis-asl/src/main/scala/org/apache/spark/streaming/kinesis/KinesisCheckpointState.scala b/extras/kinesis-asl/src/main/scala/org/apache/spark/streaming/kinesis/KinesisCheckpointState.scala index 588e86a1887e..83a453755951 100644 --- a/extras/kinesis-asl/src/main/scala/org/apache/spark/streaming/kinesis/KinesisCheckpointState.scala +++ b/extras/kinesis-asl/src/main/scala/org/apache/spark/streaming/kinesis/KinesisCheckpointState.scala @@ -23,20 +23,20 @@ import org.apache.spark.util.{Clock, ManualClock, SystemClock} /** * This is a helper class for managing checkpoint clocks. * - * @param checkpointInterval + * @param checkpointInterval * @param currentClock. Default to current SystemClock if none is passed in (mocking purposes) */ private[kinesis] class KinesisCheckpointState( - checkpointInterval: Duration, + checkpointInterval: Duration, currentClock: Clock = new SystemClock()) extends Logging { - + /* Initialize the checkpoint clock using the given currentClock + checkpointInterval millis */ val checkpointClock = new ManualClock() checkpointClock.setTime(currentClock.getTimeMillis() + checkpointInterval.milliseconds) /** - * Check if it's time to checkpoint based on the current time and the derived time + * Check if it's time to checkpoint based on the current time and the derived time * for the next checkpoint * * @return true if it's time to checkpoint @@ -48,7 +48,7 @@ private[kinesis] class KinesisCheckpointState( /** * Advance the checkpoint clock by the checkpoint interval. */ - def advanceCheckpoint() = { + def advanceCheckpoint(): Unit = { checkpointClock.advance(checkpointInterval.milliseconds) } } diff --git a/extras/kinesis-asl/src/main/scala/org/apache/spark/streaming/kinesis/KinesisReceiver.scala b/extras/kinesis-asl/src/main/scala/org/apache/spark/streaming/kinesis/KinesisReceiver.scala index a7fe4476cacb..1a8a4cecc114 100644 --- a/extras/kinesis-asl/src/main/scala/org/apache/spark/streaming/kinesis/KinesisReceiver.scala +++ b/extras/kinesis-asl/src/main/scala/org/apache/spark/streaming/kinesis/KinesisReceiver.scala @@ -16,39 +16,45 @@ */ package org.apache.spark.streaming.kinesis -import java.net.InetAddress import java.util.UUID +import scala.util.control.NonFatal + +import com.amazonaws.auth.{AWSCredentials, AWSCredentialsProvider, BasicAWSCredentials, DefaultAWSCredentialsProviderChain} +import com.amazonaws.services.kinesis.clientlibrary.interfaces.{IRecordProcessor, IRecordProcessorFactory} +import com.amazonaws.services.kinesis.clientlibrary.lib.worker.{InitialPositionInStream, KinesisClientLibConfiguration, Worker} + import org.apache.spark.Logging import org.apache.spark.storage.StorageLevel import org.apache.spark.streaming.Duration import org.apache.spark.streaming.receiver.Receiver import org.apache.spark.util.Utils -import com.amazonaws.auth.AWSCredentialsProvider -import com.amazonaws.auth.DefaultAWSCredentialsProviderChain -import com.amazonaws.services.kinesis.clientlibrary.interfaces.IRecordProcessor -import com.amazonaws.services.kinesis.clientlibrary.interfaces.IRecordProcessorFactory -import com.amazonaws.services.kinesis.clientlibrary.lib.worker.InitialPositionInStream -import com.amazonaws.services.kinesis.clientlibrary.lib.worker.KinesisClientLibConfiguration -import com.amazonaws.services.kinesis.clientlibrary.lib.worker.Worker + +private[kinesis] +case class SerializableAWSCredentials(accessKeyId: String, secretKey: String) + extends AWSCredentials { + override def getAWSAccessKeyId: String = accessKeyId + override def getAWSSecretKey: String = secretKey +} /** * Custom AWS Kinesis-specific implementation of Spark Streaming's Receiver. * This implementation relies on the Kinesis Client Library (KCL) Worker as described here: * https://github.com/awslabs/amazon-kinesis-client - * This is a custom receiver used with StreamingContext.receiverStream(Receiver) - * as described here: - * http://spark.apache.org/docs/latest/streaming-custom-receivers.html - * Instances of this class will get shipped to the Spark Streaming Workers - * to run within a Spark Executor. + * This is a custom receiver used with StreamingContext.receiverStream(Receiver) as described here: + * http://spark.apache.org/docs/latest/streaming-custom-receivers.html + * Instances of this class will get shipped to the Spark Streaming Workers to run within a + * Spark Executor. * * @param appName Kinesis application name. Kinesis Apps are mapped to Kinesis Streams * by the Kinesis Client Library. If you change the App name or Stream name, - * the KCL will throw errors. This usually requires deleting the backing + * the KCL will throw errors. This usually requires deleting the backing * DynamoDB table with the same name this Kinesis application. * @param streamName Kinesis stream name * @param endpointUrl Url of Kinesis service (e.g., https://kinesis.us-east-1.amazonaws.com) + * @param regionName Region name used by the Kinesis Client Library for + * DynamoDB (lease coordination and checkpointing) and CloudWatch (metrics) * @param checkpointInterval Checkpoint interval for Kinesis checkpointing. * See the Kinesis Spark Streaming documentation for more * details on the different types of checkpoints. @@ -59,92 +65,121 @@ import com.amazonaws.services.kinesis.clientlibrary.lib.worker.Worker * (InitialPositionInStream.TRIM_HORIZON) or * the tip of the stream (InitialPositionInStream.LATEST). * @param storageLevel Storage level to use for storing the received objects - * - * @return ReceiverInputDStream[Array[Byte]] + * @param awsCredentialsOption Optional AWS credentials, used when user directly specifies + * the credentials */ private[kinesis] class KinesisReceiver( appName: String, streamName: String, endpointUrl: String, - checkpointInterval: Duration, + regionName: String, initialPositionInStream: InitialPositionInStream, - storageLevel: StorageLevel) - extends Receiver[Array[Byte]](storageLevel) with Logging { receiver => - - /* - * The following vars are built in the onStart() method which executes in the Spark Worker after - * this code is serialized and shipped remotely. - */ + checkpointInterval: Duration, + storageLevel: StorageLevel, + awsCredentialsOption: Option[SerializableAWSCredentials] + ) extends Receiver[Array[Byte]](storageLevel) with Logging { receiver => /* - * workerId should be based on the ip address of the actual Spark Worker where this code runs - * (not the Driver's ip address.) + * ================================================================================= + * The following vars are initialize in the onStart() method which executes in the + * Spark worker after this Receiver is serialized and shipped to the worker. + * ================================================================================= */ - var workerId: String = null - /* - * This impl uses the DefaultAWSCredentialsProviderChain and searches for credentials - * in the following order of precedence: - * Environment Variables - AWS_ACCESS_KEY_ID and AWS_SECRET_KEY - * Java System Properties - aws.accessKeyId and aws.secretKey - * Credential profiles file at the default location (~/.aws/credentials) shared by all - * AWS SDKs and the AWS CLI - * Instance profile credentials delivered through the Amazon EC2 metadata service + /** + * workerId is used by the KCL should be based on the ip address of the actual Spark Worker + * where this code runs (not the driver's IP address.) */ - var credentialsProvider: AWSCredentialsProvider = null - - /* KCL config instance. */ - var kinesisClientLibConfiguration: KinesisClientLibConfiguration = null + private var workerId: String = null - /* - * RecordProcessorFactory creates impls of IRecordProcessor. - * IRecordProcessor adapts the KCL to our Spark KinesisReceiver via the - * IRecordProcessor.processRecords() method. - * We're using our custom KinesisRecordProcessor in this case. + /** + * Worker is the core client abstraction from the Kinesis Client Library (KCL). + * A worker can process more than one shards from the given stream. + * Each shard is assigned its own IRecordProcessor and the worker run multiple such + * processors. */ - var recordProcessorFactory: IRecordProcessorFactory = null + private var worker: Worker = null - /* - * Create a Kinesis Worker. - * This is the core client abstraction from the Kinesis Client Library (KCL). - * We pass the RecordProcessorFactory from above as well as the KCL config instance. - * A Kinesis Worker can process 1..* shards from the given stream - each with its - * own RecordProcessor. - */ - var worker: Worker = null + /** Thread running the worker */ + private var workerThread: Thread = null /** - * This is called when the KinesisReceiver starts and must be non-blocking. - * The KCL creates and manages the receiving/processing thread pool through the Worker.run() - * method. + * This is called when the KinesisReceiver starts and must be non-blocking. + * The KCL creates and manages the receiving/processing thread pool through Worker.run(). */ override def onStart() { workerId = Utils.localHostName() + ":" + UUID.randomUUID() - credentialsProvider = new DefaultAWSCredentialsProviderChain() - kinesisClientLibConfiguration = new KinesisClientLibConfiguration(appName, streamName, - credentialsProvider, workerId).withKinesisEndpoint(endpointUrl) - .withInitialPositionInStream(initialPositionInStream).withTaskBackoffTimeMillis(500) - recordProcessorFactory = new IRecordProcessorFactory { + + // KCL config instance + val awsCredProvider = resolveAWSCredentialsProvider() + val kinesisClientLibConfiguration = + new KinesisClientLibConfiguration(appName, streamName, awsCredProvider, workerId) + .withKinesisEndpoint(endpointUrl) + .withInitialPositionInStream(initialPositionInStream) + .withTaskBackoffTimeMillis(500) + .withRegionName(regionName) + + /* + * RecordProcessorFactory creates impls of IRecordProcessor. + * IRecordProcessor adapts the KCL to our Spark KinesisReceiver via the + * IRecordProcessor.processRecords() method. + * We're using our custom KinesisRecordProcessor in this case. + */ + val recordProcessorFactory = new IRecordProcessorFactory { override def createProcessor: IRecordProcessor = new KinesisRecordProcessor(receiver, workerId, new KinesisCheckpointState(checkpointInterval)) } + worker = new Worker(recordProcessorFactory, kinesisClientLibConfiguration) - worker.run() + workerThread = new Thread() { + override def run(): Unit = { + try { + worker.run() + } catch { + case NonFatal(e) => + restart("Error running the KCL worker in Receiver", e) + } + } + } + workerThread.setName(s"Kinesis Receiver ${streamId}") + workerThread.setDaemon(true) + workerThread.start() logInfo(s"Started receiver with workerId $workerId") } /** - * This is called when the KinesisReceiver stops. - * The KCL worker.shutdown() method stops the receiving/processing threads. - * The KCL will do its best to drain and checkpoint any in-flight records upon shutdown. + * This is called when the KinesisReceiver stops. + * The KCL worker.shutdown() method stops the receiving/processing threads. + * The KCL will do its best to drain and checkpoint any in-flight records upon shutdown. */ override def onStop() { - worker.shutdown() - logInfo(s"Shut down receiver with workerId $workerId") + if (workerThread != null) { + if (worker != null) { + worker.shutdown() + worker = null + } + workerThread.join() + workerThread = null + logInfo(s"Stopped receiver for workerId $workerId") + } workerId = null - credentialsProvider = null - kinesisClientLibConfiguration = null - recordProcessorFactory = null - worker = null + } + + /** + * If AWS credential is provided, return a AWSCredentialProvider returning that credential. + * Otherwise, return the DefaultAWSCredentialsProviderChain. + */ + private def resolveAWSCredentialsProvider(): AWSCredentialsProvider = { + awsCredentialsOption match { + case Some(awsCredentials) => + logInfo("Using provided AWS credentials") + new AWSCredentialsProvider { + override def getCredentials: AWSCredentials = awsCredentials + override def refresh(): Unit = { } + } + case None => + logInfo("Using DefaultAWSCredentialsProviderChain") + new DefaultAWSCredentialsProviderChain() + } } } diff --git a/extras/kinesis-asl/src/main/scala/org/apache/spark/streaming/kinesis/KinesisRecordProcessor.scala b/extras/kinesis-asl/src/main/scala/org/apache/spark/streaming/kinesis/KinesisRecordProcessor.scala index af8cd875b454..fe9e3a0c793e 100644 --- a/extras/kinesis-asl/src/main/scala/org/apache/spark/streaming/kinesis/KinesisRecordProcessor.scala +++ b/extras/kinesis-asl/src/main/scala/org/apache/spark/streaming/kinesis/KinesisRecordProcessor.scala @@ -35,7 +35,10 @@ import com.amazonaws.services.kinesis.model.Record /** * Kinesis-specific implementation of the Kinesis Client Library (KCL) IRecordProcessor. * This implementation operates on the Array[Byte] from the KinesisReceiver. - * The Kinesis Worker creates an instance of this KinesisRecordProcessor upon startup. + * The Kinesis Worker creates an instance of this KinesisRecordProcessor for each + * shard in the Kinesis stream upon startup. This is normally done in separate threads, + * but the KCLs within the KinesisReceivers will balance themselves out if you create + * multiple Receivers. * * @param receiver Kinesis receiver * @param workerId for logging purposes @@ -47,8 +50,8 @@ private[kinesis] class KinesisRecordProcessor( workerId: String, checkpointState: KinesisCheckpointState) extends IRecordProcessor with Logging { - /* shardId to be populated during initialize() */ - var shardId: String = _ + // shardId to be populated during initialize() + private var shardId: String = _ /** * The Kinesis Client Library calls this method during IRecordProcessor initialization. @@ -56,8 +59,8 @@ private[kinesis] class KinesisRecordProcessor( * @param shardId assigned by the KCL to this particular RecordProcessor. */ override def initialize(shardId: String) { - logInfo(s"Initialize: Initializing workerId $workerId with shardId $shardId") this.shardId = shardId + logInfo(s"Initialized workerId $workerId with shardId $shardId") } /** @@ -66,29 +69,34 @@ private[kinesis] class KinesisRecordProcessor( * and Spark Streaming's Receiver.store(). * * @param batch list of records from the Kinesis stream shard - * @param checkpointer used to update Kinesis when this batch has been processed/stored + * @param checkpointer used to update Kinesis when this batch has been processed/stored * in the DStream */ override def processRecords(batch: List[Record], checkpointer: IRecordProcessorCheckpointer) { if (!receiver.isStopped()) { try { /* - * Note: If we try to store the raw ByteBuffer from record.getData(), the Spark Streaming - * Receiver.store(ByteBuffer) attempts to deserialize the ByteBuffer using the - * internally-configured Spark serializer (kryo, etc). - * This is not desirable, so we instead store a raw Array[Byte] and decouple - * ourselves from Spark's internal serialization strategy. - */ + * Notes: + * 1) If we try to store the raw ByteBuffer from record.getData(), the Spark Streaming + * Receiver.store(ByteBuffer) attempts to deserialize the ByteBuffer using the + * internally-configured Spark serializer (kryo, etc). + * 2) This is not desirable, so we instead store a raw Array[Byte] and decouple + * ourselves from Spark's internal serialization strategy. + * 3) For performance, the BlockGenerator is asynchronously queuing elements within its + * memory before creating blocks. This prevents the small block scenario, but requires + * that you register callbacks to know when a block has been generated and stored + * (WAL is sufficient for storage) before can checkpoint back to the source. + */ batch.foreach(record => receiver.store(record.getData().array())) - + logDebug(s"Stored: Worker $workerId stored ${batch.size} records for shardId $shardId") /* - * Checkpoint the sequence number of the last record successfully processed/stored + * Checkpoint the sequence number of the last record successfully processed/stored * in the batch. * In this implementation, we're checkpointing after the given checkpointIntervalMillis. - * Note that this logic requires that processRecords() be called AND that it's time to - * checkpoint. I point this out because there is no background thread running the + * Note that this logic requires that processRecords() be called AND that it's time to + * checkpoint. I point this out because there is no background thread running the * checkpointer. Checkpointing is tested and trigger only when a new batch comes in. * If the worker is shutdown cleanly, checkpoint will happen (see shutdown() below). * However, if the worker dies unexpectedly, a checkpoint may not happen. @@ -116,22 +124,22 @@ private[kinesis] class KinesisRecordProcessor( logError(s"Exception: WorkerId $workerId encountered and exception while storing " + " or checkpointing a batch for workerId $workerId and shardId $shardId.", e) - /* Rethrow the exception to the Kinesis Worker that is managing this RecordProcessor.*/ + /* Rethrow the exception to the Kinesis Worker that is managing this RecordProcessor. */ throw e } } } else { /* RecordProcessor has been stopped. */ - logInfo(s"Stopped: The Spark KinesisReceiver has stopped for workerId $workerId" + + logInfo(s"Stopped: The Spark KinesisReceiver has stopped for workerId $workerId" + s" and shardId $shardId. No more records will be processed.") } } /** * Kinesis Client Library is shutting down this Worker for 1 of 2 reasons: - * 1) the stream is resharding by splitting or merging adjacent shards + * 1) the stream is resharding by splitting or merging adjacent shards * (ShutdownReason.TERMINATE) - * 2) the failed or latent Worker has stopped sending heartbeats for whatever reason + * 2) the failed or latent Worker has stopped sending heartbeats for whatever reason * (ShutdownReason.ZOMBIE) * * @param checkpointer used to perform a Kinesis checkpoint for ShutdownReason.TERMINATE @@ -145,7 +153,7 @@ private[kinesis] class KinesisRecordProcessor( * Checkpoint to indicate that all records from the shard have been drained and processed. * It's now OK to read from the new shards that resulted from a resharding event. */ - case ShutdownReason.TERMINATE => + case ShutdownReason.TERMINATE => KinesisRecordProcessor.retryRandom(checkpointer.checkpoint(), 4, 100) /* @@ -190,7 +198,7 @@ private[kinesis] object KinesisRecordProcessor extends Logging { logError(s"Retryable Exception: Random backOffMillis=${backOffMillis}", e) retryRandom(expression, numRetriesLeft - 1, maxBackOffMillis) } - /* Throw: Shutdown has been requested by the Kinesis Client Library.*/ + /* Throw: Shutdown has been requested by the Kinesis Client Library. */ case _: ShutdownException => { logError(s"ShutdownException: Caught shutdown exception, skipping checkpoint.", e) throw e diff --git a/extras/kinesis-asl/src/main/scala/org/apache/spark/streaming/kinesis/KinesisUtils.scala b/extras/kinesis-asl/src/main/scala/org/apache/spark/streaming/kinesis/KinesisUtils.scala index 96f4399accd3..e5acab50181e 100644 --- a/extras/kinesis-asl/src/main/scala/org/apache/spark/streaming/kinesis/KinesisUtils.scala +++ b/extras/kinesis-asl/src/main/scala/org/apache/spark/streaming/kinesis/KinesisUtils.scala @@ -16,29 +16,78 @@ */ package org.apache.spark.streaming.kinesis -import org.apache.spark.annotation.Experimental +import com.amazonaws.regions.RegionUtils +import com.amazonaws.services.kinesis.clientlibrary.lib.worker.InitialPositionInStream + import org.apache.spark.storage.StorageLevel -import org.apache.spark.streaming.Duration -import org.apache.spark.streaming.StreamingContext -import org.apache.spark.streaming.api.java.JavaReceiverInputDStream -import org.apache.spark.streaming.api.java.JavaStreamingContext +import org.apache.spark.streaming.api.java.{JavaReceiverInputDStream, JavaStreamingContext} import org.apache.spark.streaming.dstream.ReceiverInputDStream - -import com.amazonaws.services.kinesis.clientlibrary.lib.worker.InitialPositionInStream +import org.apache.spark.streaming.{Duration, StreamingContext} -/** - * Helper class to create Amazon Kinesis Input Stream - * :: Experimental :: - */ -@Experimental object KinesisUtils { /** - * Create an InputDStream that pulls messages from a Kinesis stream. - * :: Experimental :: - * @param ssc StreamingContext object + * Create an input stream that pulls messages from a Kinesis stream. + * This uses the Kinesis Client Library (KCL) to pull messages from Kinesis. + * + * Note: The AWS credentials will be discovered using the DefaultAWSCredentialsProviderChain + * on the workers. See AWS documentation to understand how DefaultAWSCredentialsProviderChain + * gets the AWS credentials. + * + * @param ssc StreamingContext object + * @param kinesisAppName Kinesis application name used by the Kinesis Client Library + * (KCL) to update DynamoDB + * @param streamName Kinesis stream name + * @param endpointUrl Url of Kinesis service (e.g., https://kinesis.us-east-1.amazonaws.com) + * @param regionName Name of region used by the Kinesis Client Library (KCL) to update + * DynamoDB (lease coordination and checkpointing) and CloudWatch (metrics) + * @param initialPositionInStream In the absence of Kinesis checkpoint info, this is the + * worker's initial starting position in the stream. + * The values are either the beginning of the stream + * per Kinesis' limit of 24 hours + * (InitialPositionInStream.TRIM_HORIZON) or + * the tip of the stream (InitialPositionInStream.LATEST). + * @param checkpointInterval Checkpoint interval for Kinesis checkpointing. + * See the Kinesis Spark Streaming documentation for more + * details on the different types of checkpoints. + * @param storageLevel Storage level to use for storing the received objects. + * StorageLevel.MEMORY_AND_DISK_2 is recommended. + */ + def createStream( + ssc: StreamingContext, + kinesisAppName: String, + streamName: String, + endpointUrl: String, + regionName: String, + initialPositionInStream: InitialPositionInStream, + checkpointInterval: Duration, + storageLevel: StorageLevel + ): ReceiverInputDStream[Array[Byte]] = { + // Setting scope to override receiver stream's scope of "receiver stream" + ssc.withNamedScope("kinesis stream") { + ssc.receiverStream( + new KinesisReceiver(kinesisAppName, streamName, endpointUrl, validateRegion(regionName), + initialPositionInStream, checkpointInterval, storageLevel, None)) + } + } + + /** + * Create an input stream that pulls messages from a Kinesis stream. + * This uses the Kinesis Client Library (KCL) to pull messages from Kinesis. + * + * Note: + * The given AWS credentials will get saved in DStream checkpoints if checkpointing + * is enabled. Make sure that your checkpoint directory is secure. + * + * @param ssc StreamingContext object + * @param kinesisAppName Kinesis application name used by the Kinesis Client Library + * (KCL) to update DynamoDB * @param streamName Kinesis stream name * @param endpointUrl Url of Kinesis service (e.g., https://kinesis.us-east-1.amazonaws.com) + * @param regionName Name of region used by the Kinesis Client Library (KCL) to update + * DynamoDB (lease coordination and checkpointing) and CloudWatch (metrics) + * @param awsAccessKeyId AWS AccessKeyId (if null, will use DefaultAWSCredentialsProviderChain) + * @param awsSecretKey AWS SecretKey (if null, will use DefaultAWSCredentialsProviderChain) * @param checkpointInterval Checkpoint interval for Kinesis checkpointing. * See the Kinesis Spark Streaming documentation for more * details on the different types of checkpoints. @@ -48,28 +97,84 @@ object KinesisUtils { * per Kinesis' limit of 24 hours * (InitialPositionInStream.TRIM_HORIZON) or * the tip of the stream (InitialPositionInStream.LATEST). - * @param storageLevel Storage level to use for storing the received objects + * @param storageLevel Storage level to use for storing the received objects. + * StorageLevel.MEMORY_AND_DISK_2 is recommended. + */ + def createStream( + ssc: StreamingContext, + kinesisAppName: String, + streamName: String, + endpointUrl: String, + regionName: String, + initialPositionInStream: InitialPositionInStream, + checkpointInterval: Duration, + storageLevel: StorageLevel, + awsAccessKeyId: String, + awsSecretKey: String + ): ReceiverInputDStream[Array[Byte]] = { + ssc.receiverStream( + new KinesisReceiver(kinesisAppName, streamName, endpointUrl, validateRegion(regionName), + initialPositionInStream, checkpointInterval, storageLevel, + Some(SerializableAWSCredentials(awsAccessKeyId, awsSecretKey)))) + } + + /** + * Create an input stream that pulls messages from a Kinesis stream. + * This uses the Kinesis Client Library (KCL) to pull messages from Kinesis. * - * @return ReceiverInputDStream[Array[Byte]] + * Note: + * - The AWS credentials will be discovered using the DefaultAWSCredentialsProviderChain + * on the workers. See AWS documentation to understand how DefaultAWSCredentialsProviderChain + * gets AWS credentials. + * - The region of the `endpointUrl` will be used for DynamoDB and CloudWatch. + * - The Kinesis application name used by the Kinesis Client Library (KCL) will be the app name in + * [[org.apache.spark.SparkConf]]. + * + * @param ssc Java StreamingContext object + * @param streamName Kinesis stream name + * @param endpointUrl Endpoint url of Kinesis service + * (e.g., https://kinesis.us-east-1.amazonaws.com) + * @param checkpointInterval Checkpoint interval for Kinesis checkpointing. + * See the Kinesis Spark Streaming documentation for more + * details on the different types of checkpoints. + * @param initialPositionInStream In the absence of Kinesis checkpoint info, this is the + * worker's initial starting position in the stream. + * The values are either the beginning of the stream + * per Kinesis' limit of 24 hours + * (InitialPositionInStream.TRIM_HORIZON) or + * the tip of the stream (InitialPositionInStream.LATEST). + * @param storageLevel Storage level to use for storing the received objects + * StorageLevel.MEMORY_AND_DISK_2 is recommended. */ - @Experimental + @deprecated("use other forms of createStream", "1.4.0") def createStream( ssc: StreamingContext, streamName: String, endpointUrl: String, checkpointInterval: Duration, initialPositionInStream: InitialPositionInStream, - storageLevel: StorageLevel): ReceiverInputDStream[Array[Byte]] = { - ssc.receiverStream(new KinesisReceiver(ssc.sc.appName, streamName, endpointUrl, - checkpointInterval, initialPositionInStream, storageLevel)) + storageLevel: StorageLevel + ): ReceiverInputDStream[Array[Byte]] = { + ssc.receiverStream( + new KinesisReceiver(ssc.sc.appName, streamName, endpointUrl, getRegionByEndpoint(endpointUrl), + initialPositionInStream, checkpointInterval, storageLevel, None)) } /** - * Create a Java-friendly InputDStream that pulls messages from a Kinesis stream. - * :: Experimental :: + * Create an input stream that pulls messages from a Kinesis stream. + * This uses the Kinesis Client Library (KCL) to pull messages from Kinesis. + * + * Note: The AWS credentials will be discovered using the DefaultAWSCredentialsProviderChain + * on the workers. See AWS documentation to understand how DefaultAWSCredentialsProviderChain + * gets the AWS credentials. + * * @param jssc Java StreamingContext object + * @param kinesisAppName Kinesis application name used by the Kinesis Client Library + * (KCL) to update DynamoDB * @param streamName Kinesis stream name * @param endpointUrl Url of Kinesis service (e.g., https://kinesis.us-east-1.amazonaws.com) + * @param regionName Name of region used by the Kinesis Client Library (KCL) to update + * DynamoDB (lease coordination and checkpointing) and CloudWatch (metrics) * @param checkpointInterval Checkpoint interval for Kinesis checkpointing. * See the Kinesis Spark Streaming documentation for more * details on the different types of checkpoints. @@ -79,19 +184,116 @@ object KinesisUtils { * per Kinesis' limit of 24 hours * (InitialPositionInStream.TRIM_HORIZON) or * the tip of the stream (InitialPositionInStream.LATEST). - * @param storageLevel Storage level to use for storing the received objects + * @param storageLevel Storage level to use for storing the received objects. + * StorageLevel.MEMORY_AND_DISK_2 is recommended. + */ + def createStream( + jssc: JavaStreamingContext, + kinesisAppName: String, + streamName: String, + endpointUrl: String, + regionName: String, + initialPositionInStream: InitialPositionInStream, + checkpointInterval: Duration, + storageLevel: StorageLevel + ): JavaReceiverInputDStream[Array[Byte]] = { + createStream(jssc.ssc, kinesisAppName, streamName, endpointUrl, regionName, + initialPositionInStream, checkpointInterval, storageLevel) + } + + /** + * Create an input stream that pulls messages from a Kinesis stream. + * This uses the Kinesis Client Library (KCL) to pull messages from Kinesis. * - * @return JavaReceiverInputDStream[Array[Byte]] + * Note: + * The given AWS credentials will get saved in DStream checkpoints if checkpointing + * is enabled. Make sure that your checkpoint directory is secure. + * + * @param jssc Java StreamingContext object + * @param kinesisAppName Kinesis application name used by the Kinesis Client Library + * (KCL) to update DynamoDB + * @param streamName Kinesis stream name + * @param endpointUrl Url of Kinesis service (e.g., https://kinesis.us-east-1.amazonaws.com) + * @param regionName Name of region used by the Kinesis Client Library (KCL) to update + * DynamoDB (lease coordination and checkpointing) and CloudWatch (metrics) + * @param awsAccessKeyId AWS AccessKeyId (if null, will use DefaultAWSCredentialsProviderChain) + * @param awsSecretKey AWS SecretKey (if null, will use DefaultAWSCredentialsProviderChain) + * @param checkpointInterval Checkpoint interval for Kinesis checkpointing. + * See the Kinesis Spark Streaming documentation for more + * details on the different types of checkpoints. + * @param initialPositionInStream In the absence of Kinesis checkpoint info, this is the + * worker's initial starting position in the stream. + * The values are either the beginning of the stream + * per Kinesis' limit of 24 hours + * (InitialPositionInStream.TRIM_HORIZON) or + * the tip of the stream (InitialPositionInStream.LATEST). + * @param storageLevel Storage level to use for storing the received objects. + * StorageLevel.MEMORY_AND_DISK_2 is recommended. */ - @Experimental def createStream( - jssc: JavaStreamingContext, - streamName: String, - endpointUrl: String, + jssc: JavaStreamingContext, + kinesisAppName: String, + streamName: String, + endpointUrl: String, + regionName: String, + initialPositionInStream: InitialPositionInStream, + checkpointInterval: Duration, + storageLevel: StorageLevel, + awsAccessKeyId: String, + awsSecretKey: String + ): JavaReceiverInputDStream[Array[Byte]] = { + createStream(jssc.ssc, kinesisAppName, streamName, endpointUrl, regionName, + initialPositionInStream, checkpointInterval, storageLevel, awsAccessKeyId, awsSecretKey) + } + + /** + * Create an input stream that pulls messages from a Kinesis stream. + * This uses the Kinesis Client Library (KCL) to pull messages from Kinesis. + * + * Note: + * - The AWS credentials will be discovered using the DefaultAWSCredentialsProviderChain + * on the workers. See AWS documentation to understand how DefaultAWSCredentialsProviderChain + * gets AWS credentials. + * - The region of the `endpointUrl` will be used for DynamoDB and CloudWatch. + * - The Kinesis application name used by the Kinesis Client Library (KCL) will be the app name in + * [[org.apache.spark.SparkConf]]. + * + * @param jssc Java StreamingContext object + * @param streamName Kinesis stream name + * @param endpointUrl Endpoint url of Kinesis service + * (e.g., https://kinesis.us-east-1.amazonaws.com) + * @param checkpointInterval Checkpoint interval for Kinesis checkpointing. + * See the Kinesis Spark Streaming documentation for more + * details on the different types of checkpoints. + * @param initialPositionInStream In the absence of Kinesis checkpoint info, this is the + * worker's initial starting position in the stream. + * The values are either the beginning of the stream + * per Kinesis' limit of 24 hours + * (InitialPositionInStream.TRIM_HORIZON) or + * the tip of the stream (InitialPositionInStream.LATEST). + * @param storageLevel Storage level to use for storing the received objects + * StorageLevel.MEMORY_AND_DISK_2 is recommended. + */ + @deprecated("use other forms of createStream", "1.4.0") + def createStream( + jssc: JavaStreamingContext, + streamName: String, + endpointUrl: String, checkpointInterval: Duration, initialPositionInStream: InitialPositionInStream, - storageLevel: StorageLevel): JavaReceiverInputDStream[Array[Byte]] = { - jssc.receiverStream(new KinesisReceiver(jssc.ssc.sc.appName, streamName, - endpointUrl, checkpointInterval, initialPositionInStream, storageLevel)) + storageLevel: StorageLevel + ): JavaReceiverInputDStream[Array[Byte]] = { + createStream( + jssc.ssc, streamName, endpointUrl, checkpointInterval, initialPositionInStream, storageLevel) + } + + private def getRegionByEndpoint(endpointUrl: String): String = { + RegionUtils.getRegionByEndpoint(endpointUrl).getName() + } + + private def validateRegion(regionName: String): String = { + Option(RegionUtils.getRegion(regionName)).map { _.getName }.getOrElse { + throw new IllegalArgumentException(s"Region name '$regionName' is not valid") + } } } diff --git a/extras/kinesis-asl/src/test/scala/org/apache/spark/streaming/kinesis/KinesisReceiverSuite.scala b/extras/kinesis-asl/src/test/scala/org/apache/spark/streaming/kinesis/KinesisReceiverSuite.scala index 255fe6581960..6c262624833c 100644 --- a/extras/kinesis-asl/src/test/scala/org/apache/spark/streaming/kinesis/KinesisReceiverSuite.scala +++ b/extras/kinesis-asl/src/test/scala/org/apache/spark/streaming/kinesis/KinesisReceiverSuite.scala @@ -20,32 +20,29 @@ import java.nio.ByteBuffer import scala.collection.JavaConversions.seqAsJavaList -import org.apache.spark.storage.StorageLevel -import org.apache.spark.streaming.Milliseconds -import org.apache.spark.streaming.Seconds -import org.apache.spark.streaming.StreamingContext -import org.apache.spark.streaming.TestSuiteBase -import org.apache.spark.util.{ManualClock, Clock} - -import org.mockito.Mockito._ -import org.scalatest.BeforeAndAfter -import org.scalatest.Matchers -import org.scalatest.mock.MockitoSugar - -import com.amazonaws.services.kinesis.clientlibrary.exceptions.InvalidStateException -import com.amazonaws.services.kinesis.clientlibrary.exceptions.KinesisClientLibDependencyException -import com.amazonaws.services.kinesis.clientlibrary.exceptions.ShutdownException -import com.amazonaws.services.kinesis.clientlibrary.exceptions.ThrottlingException +import com.amazonaws.services.kinesis.clientlibrary.exceptions.{InvalidStateException, KinesisClientLibDependencyException, ShutdownException, ThrottlingException} import com.amazonaws.services.kinesis.clientlibrary.interfaces.IRecordProcessorCheckpointer import com.amazonaws.services.kinesis.clientlibrary.lib.worker.InitialPositionInStream import com.amazonaws.services.kinesis.clientlibrary.types.ShutdownReason import com.amazonaws.services.kinesis.model.Record +import org.mockito.Mockito._ +// scalastyle:off +// To avoid introducing a dependency on Spark core tests, simply use scalatest's FunSuite +// here instead of our own SparkFunSuite. Introducing the dependency has caused problems +// in the past (SPARK-8781) that are complicated by bugs in the maven shade plugin (MSHADE-148). +import org.scalatest.{BeforeAndAfter, FunSuite, Matchers} +import org.scalatest.mock.MockitoSugar + +import org.apache.spark.storage.StorageLevel +import org.apache.spark.streaming.{Milliseconds, Seconds, StreamingContext} +import org.apache.spark.util.{Clock, ManualClock, Utils} /** * Suite of Kinesis streaming receiver tests focusing mostly on the KinesisRecordProcessor */ -class KinesisReceiverSuite extends TestSuiteBase with Matchers with BeforeAndAfter - with MockitoSugar { +class KinesisReceiverSuite extends FunSuite with Matchers with BeforeAndAfter + with MockitoSugar { +// scalastyle:on val app = "TestKinesisReceiver" val stream = "mySparkStream" @@ -65,7 +62,7 @@ class KinesisReceiverSuite extends TestSuiteBase with Matchers with BeforeAndAft var checkpointStateMock: KinesisCheckpointState = _ var currentClockMock: Clock = _ - override def beforeFunction() = { + before { receiverMock = mock[KinesisReceiver] checkpointerMock = mock[IRecordProcessorCheckpointer] checkpointClockMock = mock[ManualClock] @@ -73,23 +70,35 @@ class KinesisReceiverSuite extends TestSuiteBase with Matchers with BeforeAndAft currentClockMock = mock[Clock] } - override def afterFunction(): Unit = { - super.afterFunction() + after { // Since this suite was originally written using EasyMock, add this to preserve the old // mocking semantics (see SPARK-5735 for more details) verifyNoMoreInteractions(receiverMock, checkpointerMock, checkpointClockMock, checkpointStateMock, currentClockMock) } - test("kinesis utils api") { - val ssc = new StreamingContext(master, framework, batchDuration) + test("KinesisUtils API") { + val ssc = new StreamingContext("local[2]", getClass.getSimpleName, Seconds(1)) // Tests the API, does not actually test data receiving - val kinesisStream = KinesisUtils.createStream(ssc, "mySparkStream", + val kinesisStream1 = KinesisUtils.createStream(ssc, "mySparkStream", "https://kinesis.us-west-2.amazonaws.com", Seconds(2), - InitialPositionInStream.LATEST, StorageLevel.MEMORY_AND_DISK_2); + InitialPositionInStream.LATEST, StorageLevel.MEMORY_AND_DISK_2) + val kinesisStream2 = KinesisUtils.createStream(ssc, "myAppNam", "mySparkStream", + "https://kinesis.us-west-2.amazonaws.com", "us-west-2", + InitialPositionInStream.LATEST, Seconds(2), StorageLevel.MEMORY_AND_DISK_2) + val kinesisStream3 = KinesisUtils.createStream(ssc, "myAppNam", "mySparkStream", + "https://kinesis.us-west-2.amazonaws.com", "us-west-2", + InitialPositionInStream.LATEST, Seconds(2), StorageLevel.MEMORY_AND_DISK_2, + "awsAccessKey", "awsSecretKey") + ssc.stop() } + test("check serializability of SerializableAWSCredentials") { + Utils.deserialize[SerializableAWSCredentials]( + Utils.serialize(new SerializableAWSCredentials("x", "y"))) + } + test("process records including store and checkpoint") { when(receiverMock.isStopped()).thenReturn(false) when(checkpointStateMock.shouldCheckpoint()).thenReturn(true) diff --git a/extras/spark-ganglia-lgpl/pom.xml b/extras/spark-ganglia-lgpl/pom.xml index e14bbae4a9b6..478d0019a25f 100644 --- a/extras/spark-ganglia-lgpl/pom.xml +++ b/extras/spark-ganglia-lgpl/pom.xml @@ -20,7 +20,7 @@ org.apache.spark spark-parent_2.10 - 1.4.0-SNAPSHOT + 1.5.0-SNAPSHOT ../../pom.xml diff --git a/graphx/pom.xml b/graphx/pom.xml index d38a3aa8256b..853dea9a7795 100644 --- a/graphx/pom.xml +++ b/graphx/pom.xml @@ -21,7 +21,7 @@ org.apache.spark spark-parent_2.10 - 1.4.0-SNAPSHOT + 1.5.0-SNAPSHOT ../pom.xml @@ -40,6 +40,13 @@ spark-core_${scala.binary.version} ${project.version} + + org.apache.spark + spark-core_${scala.binary.version} + ${project.version} + test-jar + test + com.google.guava guava diff --git a/graphx/src/main/scala/org/apache/spark/graphx/EdgeDirection.scala b/graphx/src/main/scala/org/apache/spark/graphx/EdgeDirection.scala index 058c8c8aa1b2..ce1054ed92ba 100644 --- a/graphx/src/main/scala/org/apache/spark/graphx/EdgeDirection.scala +++ b/graphx/src/main/scala/org/apache/spark/graphx/EdgeDirection.scala @@ -26,8 +26,8 @@ class EdgeDirection private (private val name: String) extends Serializable { * out becomes in and both and either remain the same. */ def reverse: EdgeDirection = this match { - case EdgeDirection.In => EdgeDirection.Out - case EdgeDirection.Out => EdgeDirection.In + case EdgeDirection.In => EdgeDirection.Out + case EdgeDirection.Out => EdgeDirection.In case EdgeDirection.Either => EdgeDirection.Either case EdgeDirection.Both => EdgeDirection.Both } diff --git a/graphx/src/main/scala/org/apache/spark/graphx/EdgeRDD.scala b/graphx/src/main/scala/org/apache/spark/graphx/EdgeRDD.scala index cc70b396a8dd..4611a3ace219 100644 --- a/graphx/src/main/scala/org/apache/spark/graphx/EdgeRDD.scala +++ b/graphx/src/main/scala/org/apache/spark/graphx/EdgeRDD.scala @@ -41,14 +41,16 @@ abstract class EdgeRDD[ED]( @transient sc: SparkContext, @transient deps: Seq[Dependency[_]]) extends RDD[Edge[ED]](sc, deps) { + // scalastyle:off structural.type private[graphx] def partitionsRDD: RDD[(PartitionID, EdgePartition[ED, VD])] forSome { type VD } + // scalastyle:on structural.type override protected def getPartitions: Array[Partition] = partitionsRDD.partitions override def compute(part: Partition, context: TaskContext): Iterator[Edge[ED]] = { val p = firstParent[(PartitionID, EdgePartition[ED, _])].iterator(part, context) if (p.hasNext) { - p.next._2.iterator.map(_.copy()) + p.next()._2.iterator.map(_.copy()) } else { Iterator.empty } diff --git a/graphx/src/main/scala/org/apache/spark/graphx/EdgeTriplet.scala b/graphx/src/main/scala/org/apache/spark/graphx/EdgeTriplet.scala index c8790cac3d8a..65f82429d202 100644 --- a/graphx/src/main/scala/org/apache/spark/graphx/EdgeTriplet.scala +++ b/graphx/src/main/scala/org/apache/spark/graphx/EdgeTriplet.scala @@ -37,7 +37,7 @@ class EdgeTriplet[VD, ED] extends Edge[ED] { /** * Set the edge properties of this triplet. */ - protected[spark] def set(other: Edge[ED]): EdgeTriplet[VD,ED] = { + protected[spark] def set(other: Edge[ED]): EdgeTriplet[VD, ED] = { srcId = other.srcId dstId = other.dstId attr = other.attr diff --git a/graphx/src/main/scala/org/apache/spark/graphx/Graph.scala b/graphx/src/main/scala/org/apache/spark/graphx/Graph.scala index 36dc7b0f86c8..db73a8abc573 100644 --- a/graphx/src/main/scala/org/apache/spark/graphx/Graph.scala +++ b/graphx/src/main/scala/org/apache/spark/graphx/Graph.scala @@ -316,7 +316,7 @@ abstract class Graph[VD: ClassTag, ED: ClassTag] protected () extends Serializab * satisfy the predicates */ def subgraph( - epred: EdgeTriplet[VD,ED] => Boolean = (x => true), + epred: EdgeTriplet[VD, ED] => Boolean = (x => true), vpred: (VertexId, VD) => Boolean = ((v, d) => true)) : Graph[VD, ED] diff --git a/graphx/src/main/scala/org/apache/spark/graphx/GraphOps.scala b/graphx/src/main/scala/org/apache/spark/graphx/GraphOps.scala index 7edd627b2091..9451ff1e5c0e 100644 --- a/graphx/src/main/scala/org/apache/spark/graphx/GraphOps.scala +++ b/graphx/src/main/scala/org/apache/spark/graphx/GraphOps.scala @@ -124,18 +124,18 @@ class GraphOps[VD: ClassTag, ED: ClassTag](graph: Graph[VD, ED]) extends Seriali def collectNeighbors(edgeDirection: EdgeDirection): VertexRDD[Array[(VertexId, VD)]] = { val nbrs = edgeDirection match { case EdgeDirection.Either => - graph.aggregateMessages[Array[(VertexId,VD)]]( + graph.aggregateMessages[Array[(VertexId, VD)]]( ctx => { ctx.sendToSrc(Array((ctx.dstId, ctx.dstAttr))) ctx.sendToDst(Array((ctx.srcId, ctx.srcAttr))) }, (a, b) => a ++ b, TripletFields.All) case EdgeDirection.In => - graph.aggregateMessages[Array[(VertexId,VD)]]( + graph.aggregateMessages[Array[(VertexId, VD)]]( ctx => ctx.sendToDst(Array((ctx.srcId, ctx.srcAttr))), (a, b) => a ++ b, TripletFields.Src) case EdgeDirection.Out => - graph.aggregateMessages[Array[(VertexId,VD)]]( + graph.aggregateMessages[Array[(VertexId, VD)]]( ctx => ctx.sendToSrc(Array((ctx.dstId, ctx.dstAttr))), (a, b) => a ++ b, TripletFields.Dst) case EdgeDirection.Both => @@ -253,7 +253,7 @@ class GraphOps[VD: ClassTag, ED: ClassTag](graph: Graph[VD, ED]) extends Seriali def filter[VD2: ClassTag, ED2: ClassTag]( preprocess: Graph[VD, ED] => Graph[VD2, ED2], epred: (EdgeTriplet[VD2, ED2]) => Boolean = (x: EdgeTriplet[VD2, ED2]) => true, - vpred: (VertexId, VD2) => Boolean = (v:VertexId, d:VD2) => true): Graph[VD, ED] = { + vpred: (VertexId, VD2) => Boolean = (v: VertexId, d: VD2) => true): Graph[VD, ED] = { graph.mask(preprocess(graph).subgraph(epred, vpred)) } @@ -356,7 +356,7 @@ class GraphOps[VD: ClassTag, ED: ClassTag](graph: Graph[VD, ED]) extends Seriali maxIterations: Int = Int.MaxValue, activeDirection: EdgeDirection = EdgeDirection.Either)( vprog: (VertexId, VD, A) => VD, - sendMsg: EdgeTriplet[VD, ED] => Iterator[(VertexId,A)], + sendMsg: EdgeTriplet[VD, ED] => Iterator[(VertexId, A)], mergeMsg: (A, A) => A) : Graph[VD, ED] = { Pregel(graph, initialMsg, maxIterations, activeDirection)(vprog, sendMsg, mergeMsg) diff --git a/graphx/src/main/scala/org/apache/spark/graphx/Pregel.scala b/graphx/src/main/scala/org/apache/spark/graphx/Pregel.scala index 01b013ff716f..cfcf7244eaed 100644 --- a/graphx/src/main/scala/org/apache/spark/graphx/Pregel.scala +++ b/graphx/src/main/scala/org/apache/spark/graphx/Pregel.scala @@ -147,10 +147,10 @@ object Pregel extends Logging { logInfo("Pregel finished iteration " + i) // Unpersist the RDDs hidden by newly-materialized RDDs - oldMessages.unpersist(blocking=false) - newVerts.unpersist(blocking=false) - prevG.unpersistVertices(blocking=false) - prevG.edges.unpersist(blocking=false) + oldMessages.unpersist(blocking = false) + newVerts.unpersist(blocking = false) + prevG.unpersistVertices(blocking = false) + prevG.edges.unpersist(blocking = false) // count the iteration i += 1 } diff --git a/graphx/src/main/scala/org/apache/spark/graphx/impl/EdgePartition.scala b/graphx/src/main/scala/org/apache/spark/graphx/impl/EdgePartition.scala index c56157080925..ab021a252eb8 100644 --- a/graphx/src/main/scala/org/apache/spark/graphx/impl/EdgePartition.scala +++ b/graphx/src/main/scala/org/apache/spark/graphx/impl/EdgePartition.scala @@ -156,8 +156,8 @@ class EdgePartition[ val size = data.size var i = 0 while (i < size) { - edge.srcId = srcIds(i) - edge.dstId = dstIds(i) + edge.srcId = srcIds(i) + edge.dstId = dstIds(i) edge.attr = data(i) newData(i) = f(edge) i += 1 diff --git a/graphx/src/main/scala/org/apache/spark/graphx/lib/PageRank.scala b/graphx/src/main/scala/org/apache/spark/graphx/lib/PageRank.scala index bc974b2f04e7..8c0a461e99fa 100644 --- a/graphx/src/main/scala/org/apache/spark/graphx/lib/PageRank.scala +++ b/graphx/src/main/scala/org/apache/spark/graphx/lib/PageRank.scala @@ -116,7 +116,7 @@ object PageRank extends Logging { val personalized = srcId isDefined val src: VertexId = srcId.getOrElse(-1L) - def delta(u: VertexId, v: VertexId):Double = { if (u == v) 1.0 else 0.0 } + def delta(u: VertexId, v: VertexId): Double = { if (u == v) 1.0 else 0.0 } var iteration = 0 var prevRankGraph: Graph[Double, Double] = null @@ -133,13 +133,13 @@ object PageRank extends Logging { // edge partitions. prevRankGraph = rankGraph val rPrb = if (personalized) { - (src: VertexId ,id: VertexId) => resetProb * delta(src,id) + (src: VertexId , id: VertexId) => resetProb * delta(src, id) } else { (src: VertexId, id: VertexId) => resetProb } rankGraph = rankGraph.joinVertices(rankUpdates) { - (id, oldRank, msgSum) => rPrb(src,id) + (1.0 - resetProb) * msgSum + (id, oldRank, msgSum) => rPrb(src, id) + (1.0 - resetProb) * msgSum }.cache() rankGraph.edges.foreachPartition(x => {}) // also materializes rankGraph.vertices @@ -243,7 +243,7 @@ object PageRank extends Logging { // Execute a dynamic version of Pregel. val vp = if (personalized) { - (id: VertexId, attr: (Double, Double),msgSum: Double) => + (id: VertexId, attr: (Double, Double), msgSum: Double) => personalizedVertexProgram(id, attr, msgSum) } else { (id: VertexId, attr: (Double, Double), msgSum: Double) => diff --git a/graphx/src/main/scala/org/apache/spark/graphx/lib/SVDPlusPlus.scala b/graphx/src/main/scala/org/apache/spark/graphx/lib/SVDPlusPlus.scala index 3b0e1628d86b..9cb24ed080e1 100644 --- a/graphx/src/main/scala/org/apache/spark/graphx/lib/SVDPlusPlus.scala +++ b/graphx/src/main/scala/org/apache/spark/graphx/lib/SVDPlusPlus.scala @@ -210,7 +210,7 @@ object SVDPlusPlus { /** * Forces materialization of a Graph by count()ing its RDDs. */ - private def materialize(g: Graph[_,_]): Unit = { + private def materialize(g: Graph[_, _]): Unit = { g.vertices.count() g.edges.count() } diff --git a/graphx/src/main/scala/org/apache/spark/graphx/lib/TriangleCount.scala b/graphx/src/main/scala/org/apache/spark/graphx/lib/TriangleCount.scala index daf162085e3e..a5d598053f9c 100644 --- a/graphx/src/main/scala/org/apache/spark/graphx/lib/TriangleCount.scala +++ b/graphx/src/main/scala/org/apache/spark/graphx/lib/TriangleCount.scala @@ -38,7 +38,7 @@ import org.apache.spark.graphx._ */ object TriangleCount { - def run[VD: ClassTag, ED: ClassTag](graph: Graph[VD,ED]): Graph[Int, ED] = { + def run[VD: ClassTag, ED: ClassTag](graph: Graph[VD, ED]): Graph[Int, ED] = { // Remove redundant edges val g = graph.groupEdges((a, b) => a).cache() @@ -49,7 +49,7 @@ object TriangleCount { var i = 0 while (i < nbrs.size) { // prevent self cycle - if(nbrs(i) != vid) { + if (nbrs(i) != vid) { set.add(nbrs(i)) } i += 1 diff --git a/graphx/src/main/scala/org/apache/spark/graphx/util/GraphGenerators.scala b/graphx/src/main/scala/org/apache/spark/graphx/util/GraphGenerators.scala index 2d6a825b6172..9591c4e9b8f4 100644 --- a/graphx/src/main/scala/org/apache/spark/graphx/util/GraphGenerators.scala +++ b/graphx/src/main/scala/org/apache/spark/graphx/util/GraphGenerators.scala @@ -243,14 +243,15 @@ object GraphGenerators { * @return A graph containing vertices with the row and column ids * as their attributes and edge values as 1.0. */ - def gridGraph(sc: SparkContext, rows: Int, cols: Int): Graph[(Int,Int), Double] = { + def gridGraph(sc: SparkContext, rows: Int, cols: Int): Graph[(Int, Int), Double] = { // Convert row column address into vertex ids (row major order) def sub2ind(r: Int, c: Int): VertexId = r * cols + c - val vertices: RDD[(VertexId, (Int,Int))] = - sc.parallelize(0 until rows).flatMap( r => (0 until cols).map( c => (sub2ind(r,c), (r,c)) ) ) + val vertices: RDD[(VertexId, (Int, Int))] = sc.parallelize(0 until rows).flatMap { r => + (0 until cols).map( c => (sub2ind(r, c), (r, c)) ) + } val edges: RDD[Edge[Double]] = - vertices.flatMap{ case (vid, (r,c)) => + vertices.flatMap{ case (vid, (r, c)) => (if (r + 1 < rows) { Seq( (sub2ind(r, c), sub2ind(r + 1, c))) } else { Seq.empty }) ++ (if (c + 1 < cols) { Seq( (sub2ind(r, c), sub2ind(r, c + 1))) } else { Seq.empty }) }.map{ case (src, dst) => Edge(src, dst, 1.0) } diff --git a/graphx/src/test/scala/org/apache/spark/graphx/EdgeRDDSuite.scala b/graphx/src/test/scala/org/apache/spark/graphx/EdgeRDDSuite.scala index eb1dbe52c2fd..f1ecc9e2219d 100644 --- a/graphx/src/test/scala/org/apache/spark/graphx/EdgeRDDSuite.scala +++ b/graphx/src/test/scala/org/apache/spark/graphx/EdgeRDDSuite.scala @@ -17,11 +17,10 @@ package org.apache.spark.graphx -import org.scalatest.FunSuite - +import org.apache.spark.SparkFunSuite import org.apache.spark.storage.StorageLevel -class EdgeRDDSuite extends FunSuite with LocalSparkContext { +class EdgeRDDSuite extends SparkFunSuite with LocalSparkContext { test("cache, getStorageLevel") { // test to see if getStorageLevel returns correct value after caching diff --git a/graphx/src/test/scala/org/apache/spark/graphx/EdgeSuite.scala b/graphx/src/test/scala/org/apache/spark/graphx/EdgeSuite.scala index 5a2c73b41427..094a63472eaa 100644 --- a/graphx/src/test/scala/org/apache/spark/graphx/EdgeSuite.scala +++ b/graphx/src/test/scala/org/apache/spark/graphx/EdgeSuite.scala @@ -17,21 +17,21 @@ package org.apache.spark.graphx -import org.scalatest.FunSuite +import org.apache.spark.SparkFunSuite -class EdgeSuite extends FunSuite { +class EdgeSuite extends SparkFunSuite { test ("compare") { // decending order val testEdges: Array[Edge[Int]] = Array( - Edge(0x7FEDCBA987654321L, -0x7FEDCBA987654321L, 1), - Edge(0x2345L, 0x1234L, 1), - Edge(0x1234L, 0x5678L, 1), - Edge(0x1234L, 0x2345L, 1), + Edge(0x7FEDCBA987654321L, -0x7FEDCBA987654321L, 1), + Edge(0x2345L, 0x1234L, 1), + Edge(0x1234L, 0x5678L, 1), + Edge(0x1234L, 0x2345L, 1), Edge(-0x7FEDCBA987654321L, 0x7FEDCBA987654321L, 1) ) // to ascending order val sortedEdges = testEdges.sorted(Edge.lexicographicOrdering[Int]) - + for (i <- 0 until testEdges.length) { assert(sortedEdges(i) == testEdges(testEdges.length - i - 1)) } diff --git a/graphx/src/test/scala/org/apache/spark/graphx/GraphOpsSuite.scala b/graphx/src/test/scala/org/apache/spark/graphx/GraphOpsSuite.scala index 9bc8007ce49c..57a8b95dd12e 100644 --- a/graphx/src/test/scala/org/apache/spark/graphx/GraphOpsSuite.scala +++ b/graphx/src/test/scala/org/apache/spark/graphx/GraphOpsSuite.scala @@ -17,13 +17,12 @@ package org.apache.spark.graphx -import org.apache.spark.SparkContext +import org.apache.spark.{SparkContext, SparkFunSuite} import org.apache.spark.graphx.Graph._ import org.apache.spark.graphx.impl.EdgePartition import org.apache.spark.rdd._ -import org.scalatest.FunSuite -class GraphOpsSuite extends FunSuite with LocalSparkContext { +class GraphOpsSuite extends SparkFunSuite with LocalSparkContext { test("joinVertices") { withSpark { sc => @@ -59,7 +58,7 @@ class GraphOpsSuite extends FunSuite with LocalSparkContext { test ("filter") { withSpark { sc => val n = 5 - val vertices = sc.parallelize((0 to n).map(x => (x:VertexId, x))) + val vertices = sc.parallelize((0 to n).map(x => (x: VertexId, x))) val edges = sc.parallelize((1 to n).map(x => Edge(0, x, x))) val graph: Graph[Int, Int] = Graph(vertices, edges).cache() val filteredGraph = graph.filter( @@ -67,11 +66,11 @@ class GraphOpsSuite extends FunSuite with LocalSparkContext { val degrees: VertexRDD[Int] = graph.outDegrees graph.outerJoinVertices(degrees) {(vid, data, deg) => deg.getOrElse(0)} }, - vpred = (vid: VertexId, deg:Int) => deg > 0 + vpred = (vid: VertexId, deg: Int) => deg > 0 ).cache() val v = filteredGraph.vertices.collect().toSet - assert(v === Set((0,0))) + assert(v === Set((0, 0))) // the map is necessary because of object-reuse in the edge iterator val e = filteredGraph.edges.map(e => Edge(e.srcId, e.dstId, e.attr)).collect().toSet diff --git a/graphx/src/test/scala/org/apache/spark/graphx/GraphSuite.scala b/graphx/src/test/scala/org/apache/spark/graphx/GraphSuite.scala index a570e4ed75fc..1f5e27d5508b 100644 --- a/graphx/src/test/scala/org/apache/spark/graphx/GraphSuite.scala +++ b/graphx/src/test/scala/org/apache/spark/graphx/GraphSuite.scala @@ -17,16 +17,14 @@ package org.apache.spark.graphx -import org.scalatest.FunSuite - -import org.apache.spark.SparkContext +import org.apache.spark.{SparkContext, SparkFunSuite} import org.apache.spark.graphx.Graph._ import org.apache.spark.graphx.PartitionStrategy._ import org.apache.spark.rdd._ import org.apache.spark.storage.StorageLevel import org.apache.spark.util.Utils -class GraphSuite extends FunSuite with LocalSparkContext { +class GraphSuite extends SparkFunSuite with LocalSparkContext { def starGraph(sc: SparkContext, n: Int): Graph[String, Int] = { Graph.fromEdgeTuples(sc.parallelize((1 to n).map(x => (0: VertexId, x: VertexId)), 3), "v") @@ -248,7 +246,7 @@ class GraphSuite extends FunSuite with LocalSparkContext { test("mask") { withSpark { sc => val n = 5 - val vertices = sc.parallelize((0 to n).map(x => (x:VertexId, x))) + val vertices = sc.parallelize((0 to n).map(x => (x: VertexId, x))) val edges = sc.parallelize((1 to n).map(x => Edge(0, x, x))) val graph: Graph[Int, Int] = Graph(vertices, edges).cache() @@ -260,11 +258,11 @@ class GraphSuite extends FunSuite with LocalSparkContext { val projectedGraph = graph.mask(subgraph) val v = projectedGraph.vertices.collect().toSet - assert(v === Set((0,0), (1,1), (2,2), (4,4), (5,5))) + assert(v === Set((0, 0), (1, 1), (2, 2), (4, 4), (5, 5))) // the map is necessary because of object-reuse in the edge iterator val e = projectedGraph.edges.map(e => Edge(e.srcId, e.dstId, e.attr)).collect().toSet - assert(e === Set(Edge(0,1,1), Edge(0,2,2), Edge(0,5,5))) + assert(e === Set(Edge(0, 1, 1), Edge(0, 2, 2), Edge(0, 5, 5))) } } diff --git a/graphx/src/test/scala/org/apache/spark/graphx/PregelSuite.scala b/graphx/src/test/scala/org/apache/spark/graphx/PregelSuite.scala index 490b94429ea1..8afa2d403b53 100644 --- a/graphx/src/test/scala/org/apache/spark/graphx/PregelSuite.scala +++ b/graphx/src/test/scala/org/apache/spark/graphx/PregelSuite.scala @@ -17,12 +17,10 @@ package org.apache.spark.graphx -import org.scalatest.FunSuite - -import org.apache.spark.SparkContext +import org.apache.spark.{SparkContext, SparkFunSuite} import org.apache.spark.rdd._ -class PregelSuite extends FunSuite with LocalSparkContext { +class PregelSuite extends SparkFunSuite with LocalSparkContext { test("1 iteration") { withSpark { sc => diff --git a/graphx/src/test/scala/org/apache/spark/graphx/VertexRDDSuite.scala b/graphx/src/test/scala/org/apache/spark/graphx/VertexRDDSuite.scala index d0a7198d691d..f1aa685a79c9 100644 --- a/graphx/src/test/scala/org/apache/spark/graphx/VertexRDDSuite.scala +++ b/graphx/src/test/scala/org/apache/spark/graphx/VertexRDDSuite.scala @@ -17,13 +17,11 @@ package org.apache.spark.graphx -import org.scalatest.FunSuite - -import org.apache.spark.{HashPartitioner, SparkContext} +import org.apache.spark.{HashPartitioner, SparkContext, SparkFunSuite} import org.apache.spark.rdd.RDD import org.apache.spark.storage.StorageLevel -class VertexRDDSuite extends FunSuite with LocalSparkContext { +class VertexRDDSuite extends SparkFunSuite with LocalSparkContext { private def vertices(sc: SparkContext, n: Int) = { VertexRDD(sc.parallelize((0 to n).map(x => (x.toLong, x)), 5)) diff --git a/graphx/src/test/scala/org/apache/spark/graphx/impl/EdgePartitionSuite.scala b/graphx/src/test/scala/org/apache/spark/graphx/impl/EdgePartitionSuite.scala index 515f3a9cd02e..7435647c6d9e 100644 --- a/graphx/src/test/scala/org/apache/spark/graphx/impl/EdgePartitionSuite.scala +++ b/graphx/src/test/scala/org/apache/spark/graphx/impl/EdgePartitionSuite.scala @@ -20,15 +20,13 @@ package org.apache.spark.graphx.impl import scala.reflect.ClassTag import scala.util.Random -import org.scalatest.FunSuite - -import org.apache.spark.SparkConf +import org.apache.spark.{SparkConf, SparkFunSuite} import org.apache.spark.serializer.JavaSerializer import org.apache.spark.serializer.KryoSerializer import org.apache.spark.graphx._ -class EdgePartitionSuite extends FunSuite { +class EdgePartitionSuite extends SparkFunSuite { def makeEdgePartition[A: ClassTag](xs: Iterable[(Int, Int, A)]): EdgePartition[A, Int] = { val builder = new EdgePartitionBuilder[A, Int] diff --git a/graphx/src/test/scala/org/apache/spark/graphx/impl/VertexPartitionSuite.scala b/graphx/src/test/scala/org/apache/spark/graphx/impl/VertexPartitionSuite.scala index fe8304c1cdc3..1203f8959f50 100644 --- a/graphx/src/test/scala/org/apache/spark/graphx/impl/VertexPartitionSuite.scala +++ b/graphx/src/test/scala/org/apache/spark/graphx/impl/VertexPartitionSuite.scala @@ -17,15 +17,13 @@ package org.apache.spark.graphx.impl -import org.scalatest.FunSuite - -import org.apache.spark.SparkConf +import org.apache.spark.{SparkConf, SparkFunSuite} import org.apache.spark.serializer.JavaSerializer import org.apache.spark.serializer.KryoSerializer import org.apache.spark.graphx._ -class VertexPartitionSuite extends FunSuite { +class VertexPartitionSuite extends SparkFunSuite { test("isDefined, filter") { val vp = VertexPartition(Iterator((0L, 1), (1L, 1))).filter { (vid, attr) => vid == 0 } diff --git a/graphx/src/test/scala/org/apache/spark/graphx/lib/ConnectedComponentsSuite.scala b/graphx/src/test/scala/org/apache/spark/graphx/lib/ConnectedComponentsSuite.scala index 4cc30a96408f..c965a6eb8df1 100644 --- a/graphx/src/test/scala/org/apache/spark/graphx/lib/ConnectedComponentsSuite.scala +++ b/graphx/src/test/scala/org/apache/spark/graphx/lib/ConnectedComponentsSuite.scala @@ -17,16 +17,14 @@ package org.apache.spark.graphx.lib -import org.scalatest.FunSuite - -import org.apache.spark.SparkContext +import org.apache.spark.{SparkContext, SparkFunSuite} import org.apache.spark.SparkContext._ import org.apache.spark.graphx._ import org.apache.spark.graphx.util.GraphGenerators import org.apache.spark.rdd._ -class ConnectedComponentsSuite extends FunSuite with LocalSparkContext { +class ConnectedComponentsSuite extends SparkFunSuite with LocalSparkContext { test("Grid Connected Components") { withSpark { sc => @@ -52,13 +50,16 @@ class ConnectedComponentsSuite extends FunSuite with LocalSparkContext { withSpark { sc => val chain1 = (0 until 9).map(x => (x, x + 1)) val chain2 = (10 until 20).map(x => (x, x + 1)) - val rawEdges = sc.parallelize(chain1 ++ chain2, 3).map { case (s,d) => (s.toLong, d.toLong) } + val rawEdges = sc.parallelize(chain1 ++ chain2, 3).map { case (s, d) => (s.toLong, d.toLong) } val twoChains = Graph.fromEdgeTuples(rawEdges, 1.0) val ccGraph = twoChains.connectedComponents() val vertices = ccGraph.vertices.collect() for ( (id, cc) <- vertices ) { - if(id < 10) { assert(cc === 0) } - else { assert(cc === 10) } + if (id < 10) { + assert(cc === 0) + } else { + assert(cc === 10) + } } val ccMap = vertices.toMap for (id <- 0 until 20) { @@ -75,7 +76,7 @@ class ConnectedComponentsSuite extends FunSuite with LocalSparkContext { withSpark { sc => val chain1 = (0 until 9).map(x => (x, x + 1)) val chain2 = (10 until 20).map(x => (x, x + 1)) - val rawEdges = sc.parallelize(chain1 ++ chain2, 3).map { case (s,d) => (s.toLong, d.toLong) } + val rawEdges = sc.parallelize(chain1 ++ chain2, 3).map { case (s, d) => (s.toLong, d.toLong) } val twoChains = Graph.fromEdgeTuples(rawEdges, true).reverse val ccGraph = twoChains.connectedComponents() val vertices = ccGraph.vertices.collect() @@ -106,9 +107,9 @@ class ConnectedComponentsSuite extends FunSuite with LocalSparkContext { (4L, ("peter", "student")))) // Create an RDD for edges val relationships: RDD[Edge[String]] = - sc.parallelize(Array(Edge(3L, 7L, "collab"), Edge(5L, 3L, "advisor"), + sc.parallelize(Array(Edge(3L, 7L, "collab"), Edge(5L, 3L, "advisor"), Edge(2L, 5L, "colleague"), Edge(5L, 7L, "pi"), - Edge(4L, 0L, "student"), Edge(5L, 0L, "colleague"))) + Edge(4L, 0L, "student"), Edge(5L, 0L, "colleague"))) // Edges are: // 2 ---> 5 ---> 3 // | \ diff --git a/graphx/src/test/scala/org/apache/spark/graphx/lib/LabelPropagationSuite.scala b/graphx/src/test/scala/org/apache/spark/graphx/lib/LabelPropagationSuite.scala index 61fd0c460556..808877f0590f 100644 --- a/graphx/src/test/scala/org/apache/spark/graphx/lib/LabelPropagationSuite.scala +++ b/graphx/src/test/scala/org/apache/spark/graphx/lib/LabelPropagationSuite.scala @@ -17,11 +17,10 @@ package org.apache.spark.graphx.lib -import org.scalatest.FunSuite - +import org.apache.spark.SparkFunSuite import org.apache.spark.graphx._ -class LabelPropagationSuite extends FunSuite with LocalSparkContext { +class LabelPropagationSuite extends SparkFunSuite with LocalSparkContext { test("Label Propagation") { withSpark { sc => // Construct a graph with two cliques connected by a single edge diff --git a/graphx/src/test/scala/org/apache/spark/graphx/lib/PageRankSuite.scala b/graphx/src/test/scala/org/apache/spark/graphx/lib/PageRankSuite.scala index 3f3c9dfd7b3d..45f1e3011035 100644 --- a/graphx/src/test/scala/org/apache/spark/graphx/lib/PageRankSuite.scala +++ b/graphx/src/test/scala/org/apache/spark/graphx/lib/PageRankSuite.scala @@ -17,8 +17,7 @@ package org.apache.spark.graphx.lib -import org.scalatest.FunSuite - +import org.apache.spark.SparkFunSuite import org.apache.spark.graphx._ import org.apache.spark.graphx.util.GraphGenerators @@ -31,14 +30,14 @@ object GridPageRank { def sub2ind(r: Int, c: Int): Int = r * nCols + c // Make the grid graph for (r <- 0 until nRows; c <- 0 until nCols) { - val ind = sub2ind(r,c) + val ind = sub2ind(r, c) if (r + 1 < nRows) { outDegree(ind) += 1 - inNbrs(sub2ind(r + 1,c)) += ind + inNbrs(sub2ind(r + 1, c)) += ind } if (c + 1 < nCols) { outDegree(ind) += 1 - inNbrs(sub2ind(r,c + 1)) += ind + inNbrs(sub2ind(r, c + 1)) += ind } } // compute the pagerank @@ -57,7 +56,7 @@ object GridPageRank { } -class PageRankSuite extends FunSuite with LocalSparkContext { +class PageRankSuite extends SparkFunSuite with LocalSparkContext { def compareRanks(a: VertexRDD[Double], b: VertexRDD[Double]): Double = { a.leftJoin(b) { case (id, a, bOpt) => (a - bOpt.getOrElse(0.0)) * (a - bOpt.getOrElse(0.0)) } @@ -99,8 +98,8 @@ class PageRankSuite extends FunSuite with LocalSparkContext { val resetProb = 0.15 val errorTol = 1.0e-5 - val staticRanks1 = starGraph.staticPersonalizedPageRank(0,numIter = 1, resetProb).vertices - val staticRanks2 = starGraph.staticPersonalizedPageRank(0,numIter = 2, resetProb) + val staticRanks1 = starGraph.staticPersonalizedPageRank(0, numIter = 1, resetProb).vertices + val staticRanks2 = starGraph.staticPersonalizedPageRank(0, numIter = 2, resetProb) .vertices.cache() // Static PageRank should only take 2 iterations to converge @@ -117,7 +116,7 @@ class PageRankSuite extends FunSuite with LocalSparkContext { } assert(staticErrors.sum === 0) - val dynamicRanks = starGraph.personalizedPageRank(0,0, resetProb).vertices.cache() + val dynamicRanks = starGraph.personalizedPageRank(0, 0, resetProb).vertices.cache() assert(compareRanks(staticRanks2, dynamicRanks) < errorTol) } } // end of test Star PageRank @@ -162,7 +161,7 @@ class PageRankSuite extends FunSuite with LocalSparkContext { test("Chain PersonalizedPageRank") { withSpark { sc => val chain1 = (0 until 9).map(x => (x, x + 1) ) - val rawEdges = sc.parallelize(chain1, 1).map { case (s,d) => (s.toLong, d.toLong) } + val rawEdges = sc.parallelize(chain1, 1).map { case (s, d) => (s.toLong, d.toLong) } val chain = Graph.fromEdgeTuples(rawEdges, 1.0).cache() val resetProb = 0.15 val tol = 0.0001 diff --git a/graphx/src/test/scala/org/apache/spark/graphx/lib/SVDPlusPlusSuite.scala b/graphx/src/test/scala/org/apache/spark/graphx/lib/SVDPlusPlusSuite.scala index 7bd6b7f3c4ab..2991438f5e57 100644 --- a/graphx/src/test/scala/org/apache/spark/graphx/lib/SVDPlusPlusSuite.scala +++ b/graphx/src/test/scala/org/apache/spark/graphx/lib/SVDPlusPlusSuite.scala @@ -17,12 +17,11 @@ package org.apache.spark.graphx.lib -import org.scalatest.FunSuite - +import org.apache.spark.SparkFunSuite import org.apache.spark.graphx._ -class SVDPlusPlusSuite extends FunSuite with LocalSparkContext { +class SVDPlusPlusSuite extends SparkFunSuite with LocalSparkContext { test("Test SVD++ with mean square error on training set") { withSpark { sc => diff --git a/graphx/src/test/scala/org/apache/spark/graphx/lib/ShortestPathsSuite.scala b/graphx/src/test/scala/org/apache/spark/graphx/lib/ShortestPathsSuite.scala index f2c38e79c452..d7eaa70ce640 100644 --- a/graphx/src/test/scala/org/apache/spark/graphx/lib/ShortestPathsSuite.scala +++ b/graphx/src/test/scala/org/apache/spark/graphx/lib/ShortestPathsSuite.scala @@ -17,16 +17,14 @@ package org.apache.spark.graphx.lib -import org.scalatest.FunSuite - -import org.apache.spark.SparkContext +import org.apache.spark.{SparkContext, SparkFunSuite} import org.apache.spark.SparkContext._ import org.apache.spark.graphx._ import org.apache.spark.graphx.lib._ import org.apache.spark.graphx.util.GraphGenerators import org.apache.spark.rdd._ -class ShortestPathsSuite extends FunSuite with LocalSparkContext { +class ShortestPathsSuite extends SparkFunSuite with LocalSparkContext { test("Shortest Path Computations") { withSpark { sc => diff --git a/graphx/src/test/scala/org/apache/spark/graphx/lib/StronglyConnectedComponentsSuite.scala b/graphx/src/test/scala/org/apache/spark/graphx/lib/StronglyConnectedComponentsSuite.scala index 1f658c371ffc..d6b03208180d 100644 --- a/graphx/src/test/scala/org/apache/spark/graphx/lib/StronglyConnectedComponentsSuite.scala +++ b/graphx/src/test/scala/org/apache/spark/graphx/lib/StronglyConnectedComponentsSuite.scala @@ -17,16 +17,14 @@ package org.apache.spark.graphx.lib -import org.scalatest.FunSuite - -import org.apache.spark.SparkContext +import org.apache.spark.{SparkContext, SparkFunSuite} import org.apache.spark.SparkContext._ import org.apache.spark.graphx._ import org.apache.spark.graphx.util.GraphGenerators import org.apache.spark.rdd._ -class StronglyConnectedComponentsSuite extends FunSuite with LocalSparkContext { +class StronglyConnectedComponentsSuite extends SparkFunSuite with LocalSparkContext { test("Island Strongly Connected Components") { withSpark { sc => diff --git a/graphx/src/test/scala/org/apache/spark/graphx/lib/TriangleCountSuite.scala b/graphx/src/test/scala/org/apache/spark/graphx/lib/TriangleCountSuite.scala index 293c7f3ba4c2..c47552cf3a3b 100644 --- a/graphx/src/test/scala/org/apache/spark/graphx/lib/TriangleCountSuite.scala +++ b/graphx/src/test/scala/org/apache/spark/graphx/lib/TriangleCountSuite.scala @@ -17,13 +17,12 @@ package org.apache.spark.graphx.lib -import org.scalatest.FunSuite - +import org.apache.spark.SparkFunSuite import org.apache.spark.graphx._ import org.apache.spark.graphx.PartitionStrategy.RandomVertexCut -class TriangleCountSuite extends FunSuite with LocalSparkContext { +class TriangleCountSuite extends SparkFunSuite with LocalSparkContext { test("Count a single triangle") { withSpark { sc => @@ -58,7 +57,7 @@ class TriangleCountSuite extends FunSuite with LocalSparkContext { val triangles = Array(0L -> 1L, 1L -> 2L, 2L -> 0L) ++ Array(0L -> -1L, -1L -> -2L, -2L -> 0L) - val revTriangles = triangles.map { case (a,b) => (b,a) } + val revTriangles = triangles.map { case (a, b) => (b, a) } val rawEdges = sc.parallelize(triangles ++ revTriangles, 2) val graph = Graph.fromEdgeTuples(rawEdges, true).cache() val triangleCount = graph.triangleCount() diff --git a/graphx/src/test/scala/org/apache/spark/graphx/util/BytecodeUtilsSuite.scala b/graphx/src/test/scala/org/apache/spark/graphx/util/BytecodeUtilsSuite.scala index f3b3738db0da..186d0cc2a977 100644 --- a/graphx/src/test/scala/org/apache/spark/graphx/util/BytecodeUtilsSuite.scala +++ b/graphx/src/test/scala/org/apache/spark/graphx/util/BytecodeUtilsSuite.scala @@ -17,10 +17,10 @@ package org.apache.spark.graphx.util -import org.scalatest.FunSuite +import org.apache.spark.SparkFunSuite -class BytecodeUtilsSuite extends FunSuite { +class BytecodeUtilsSuite extends SparkFunSuite { import BytecodeUtilsSuite.TestClass diff --git a/graphx/src/test/scala/org/apache/spark/graphx/util/GraphGeneratorsSuite.scala b/graphx/src/test/scala/org/apache/spark/graphx/util/GraphGeneratorsSuite.scala index 8d9c8ddccbb3..32e0c841c699 100644 --- a/graphx/src/test/scala/org/apache/spark/graphx/util/GraphGeneratorsSuite.scala +++ b/graphx/src/test/scala/org/apache/spark/graphx/util/GraphGeneratorsSuite.scala @@ -17,11 +17,10 @@ package org.apache.spark.graphx.util -import org.scalatest.FunSuite - +import org.apache.spark.SparkFunSuite import org.apache.spark.graphx.LocalSparkContext -class GraphGeneratorsSuite extends FunSuite with LocalSparkContext { +class GraphGeneratorsSuite extends SparkFunSuite with LocalSparkContext { test("GraphGenerators.generateRandomEdges") { val src = 5 diff --git a/launcher/pom.xml b/launcher/pom.xml index ebfa7685eaa1..2fd768d8119c 100644 --- a/launcher/pom.xml +++ b/launcher/pom.xml @@ -22,14 +22,14 @@ org.apache.spark spark-parent_2.10 - 1.4.0-SNAPSHOT + 1.5.0-SNAPSHOT ../pom.xml org.apache.spark spark-launcher_2.10 jar - Spark Launcher Project + Spark Project Launcher http://spark.apache.org/ launcher @@ -49,7 +49,7 @@ org.mockito - mockito-all + mockito-core test @@ -68,12 +68,6 @@ org.apache.hadoop hadoop-client test - - - org.codehaus.jackson - jackson-mapper-asl - - diff --git a/launcher/src/main/java/org/apache/spark/launcher/AbstractCommandBuilder.java b/launcher/src/main/java/org/apache/spark/launcher/AbstractCommandBuilder.java index 33fd813f7a86..5e793a5c4877 100644 --- a/launcher/src/main/java/org/apache/spark/launcher/AbstractCommandBuilder.java +++ b/launcher/src/main/java/org/apache/spark/launcher/AbstractCommandBuilder.java @@ -136,7 +136,7 @@ void addPermGenSizeOpt(List cmd) { } } - cmd.add("-XX:MaxPermSize=128m"); + cmd.add("-XX:MaxPermSize=256m"); } void addOptionString(List cmd, String options) { @@ -296,6 +296,9 @@ Properties loadPropertiesFile() throws IOException { try { fd = new FileInputStream(propsFile); props.load(new InputStreamReader(fd, "UTF-8")); + for (Map.Entry e : props.entrySet()) { + e.setValue(e.getValue().toString().trim()); + } } finally { if (fd != null) { try { diff --git a/launcher/src/main/java/org/apache/spark/launcher/CommandBuilderUtils.java b/launcher/src/main/java/org/apache/spark/launcher/CommandBuilderUtils.java index 2665a700fe1f..a16c0d2b5ca0 100644 --- a/launcher/src/main/java/org/apache/spark/launcher/CommandBuilderUtils.java +++ b/launcher/src/main/java/org/apache/spark/launcher/CommandBuilderUtils.java @@ -27,7 +27,7 @@ */ class CommandBuilderUtils { - static final String DEFAULT_MEM = "512m"; + static final String DEFAULT_MEM = "1g"; static final String DEFAULT_PROPERTIES_FILE = "spark-defaults.conf"; static final String ENV_SPARK_HOME = "SPARK_HOME"; static final String ENV_SPARK_ASSEMBLY = "_SPARK_ASSEMBLY"; diff --git a/launcher/src/main/java/org/apache/spark/launcher/Main.java b/launcher/src/main/java/org/apache/spark/launcher/Main.java index 929b29a49ed7..62492f9baf3b 100644 --- a/launcher/src/main/java/org/apache/spark/launcher/Main.java +++ b/launcher/src/main/java/org/apache/spark/launcher/Main.java @@ -53,21 +53,33 @@ public static void main(String[] argsArray) throws Exception { List args = new ArrayList(Arrays.asList(argsArray)); String className = args.remove(0); - boolean printLaunchCommand; - boolean printUsage; + boolean printLaunchCommand = !isEmpty(System.getenv("SPARK_PRINT_LAUNCH_COMMAND")); AbstractCommandBuilder builder; - try { - if (className.equals("org.apache.spark.deploy.SparkSubmit")) { + if (className.equals("org.apache.spark.deploy.SparkSubmit")) { + try { builder = new SparkSubmitCommandBuilder(args); - } else { - builder = new SparkClassCommandBuilder(className, args); + } catch (IllegalArgumentException e) { + printLaunchCommand = false; + System.err.println("Error: " + e.getMessage()); + System.err.println(); + + MainClassOptionParser parser = new MainClassOptionParser(); + try { + parser.parse(args); + } catch (Exception ignored) { + // Ignore parsing exceptions. + } + + List help = new ArrayList(); + if (parser.className != null) { + help.add(parser.CLASS); + help.add(parser.className); + } + help.add(parser.USAGE_ERROR); + builder = new SparkSubmitCommandBuilder(help); } - printLaunchCommand = !isEmpty(System.getenv("SPARK_PRINT_LAUNCH_COMMAND")); - printUsage = false; - } catch (IllegalArgumentException e) { - builder = new UsageCommandBuilder(e.getMessage()); - printLaunchCommand = false; - printUsage = true; + } else { + builder = new SparkClassCommandBuilder(className, args); } Map env = new HashMap(); @@ -78,13 +90,7 @@ public static void main(String[] argsArray) throws Exception { } if (isWindows()) { - // When printing the usage message, we can't use "cmd /v" since that prevents the env - // variable from being seen in the caller script. So do not call prepareWindowsCommand(). - if (printUsage) { - System.out.println(join(" ", cmd)); - } else { - System.out.println(prepareWindowsCommand(cmd, env)); - } + System.out.println(prepareWindowsCommand(cmd, env)); } else { // In bash, use NULL as the arg separator since it cannot be used in an argument. List bashCmd = prepareBashCommand(cmd, env); @@ -135,33 +141,30 @@ private static List prepareBashCommand(List cmd, Map buildCommand(Map env) { - if (isWindows()) { - return Arrays.asList("set", "SPARK_LAUNCHER_USAGE_ERROR=" + message); - } else { - return Arrays.asList("usage", message, "1"); - } + protected boolean handleUnknown(String opt) { + return false; + } + + @Override + protected void handleExtraArgs(List extra) { + } } diff --git a/launcher/src/main/java/org/apache/spark/launcher/SparkClassCommandBuilder.java b/launcher/src/main/java/org/apache/spark/launcher/SparkClassCommandBuilder.java index d80abf2a8676..de85720febf2 100644 --- a/launcher/src/main/java/org/apache/spark/launcher/SparkClassCommandBuilder.java +++ b/launcher/src/main/java/org/apache/spark/launcher/SparkClassCommandBuilder.java @@ -93,6 +93,9 @@ public List buildCommand(Map env) throws IOException { toolsDir.getAbsolutePath(), className); javaOptsKeys.add("SPARK_JAVA_OPTS"); + } else { + javaOptsKeys.add("SPARK_JAVA_OPTS"); + memKey = "SPARK_DRIVER_MEMORY"; } List cmd = buildJavaCommand(extraClassPath); diff --git a/launcher/src/main/java/org/apache/spark/launcher/SparkSubmitCommandBuilder.java b/launcher/src/main/java/org/apache/spark/launcher/SparkSubmitCommandBuilder.java index 7d387d406eda..87c43aa9980e 100644 --- a/launcher/src/main/java/org/apache/spark/launcher/SparkSubmitCommandBuilder.java +++ b/launcher/src/main/java/org/apache/spark/launcher/SparkSubmitCommandBuilder.java @@ -77,6 +77,7 @@ class SparkSubmitCommandBuilder extends AbstractCommandBuilder { } private final List sparkArgs; + private final boolean printHelp; /** * Controls whether mixing spark-submit arguments with app arguments is allowed. This is needed @@ -87,10 +88,11 @@ class SparkSubmitCommandBuilder extends AbstractCommandBuilder { SparkSubmitCommandBuilder() { this.sparkArgs = new ArrayList(); + this.printHelp = false; } SparkSubmitCommandBuilder(List args) { - this(); + this.sparkArgs = new ArrayList(); List submitArgs = args; if (args.size() > 0 && args.get(0).equals(PYSPARK_SHELL)) { this.allowsMixedArguments = true; @@ -104,14 +106,16 @@ class SparkSubmitCommandBuilder extends AbstractCommandBuilder { this.allowsMixedArguments = false; } - new OptionParser().parse(submitArgs); + OptionParser parser = new OptionParser(); + parser.parse(submitArgs); + this.printHelp = parser.helpRequested; } @Override public List buildCommand(Map env) throws IOException { - if (PYSPARK_SHELL_RESOURCE.equals(appResource)) { + if (PYSPARK_SHELL_RESOURCE.equals(appResource) && !printHelp) { return buildPySparkShellCommand(env); - } else if (SPARKR_SHELL_RESOURCE.equals(appResource)) { + } else if (SPARKR_SHELL_RESOURCE.equals(appResource) && !printHelp) { return buildSparkRCommand(env); } else { return buildSparkSubmitCommand(env); @@ -204,7 +208,7 @@ private List buildSparkSubmitCommand(Map env) throws IOE // - properties file. // - SPARK_DRIVER_MEMORY env variable // - SPARK_MEM env variable - // - default value (512m) + // - default value (1g) // Take Thrift Server as daemon String tsMemory = isThriftServer(mainClass) ? System.getenv("SPARK_DAEMON_MEMORY") : null; @@ -311,6 +315,8 @@ private boolean isThriftServer(String mainClass) { private class OptionParser extends SparkSubmitOptionParser { + boolean helpRequested = false; + @Override protected boolean handle(String opt, String value) { if (opt.equals(MASTER)) { @@ -341,6 +347,9 @@ protected boolean handle(String opt, String value) { allowsMixedArguments = true; appResource = specialClasses.get(value); } + } else if (opt.equals(HELP) || opt.equals(USAGE_ERROR)) { + helpRequested = true; + sparkArgs.add(opt); } else { sparkArgs.add(opt); if (value != null) { @@ -360,6 +369,7 @@ protected boolean handleUnknown(String opt) { appArgs.add(opt); return true; } else { + checkArgument(!opt.startsWith("-"), "Unrecognized option: %s", opt); sparkArgs.add(opt); return false; } diff --git a/launcher/src/main/java/org/apache/spark/launcher/SparkSubmitOptionParser.java b/launcher/src/main/java/org/apache/spark/launcher/SparkSubmitOptionParser.java index 229000087688..b88bba883ac6 100644 --- a/launcher/src/main/java/org/apache/spark/launcher/SparkSubmitOptionParser.java +++ b/launcher/src/main/java/org/apache/spark/launcher/SparkSubmitOptionParser.java @@ -61,6 +61,7 @@ class SparkSubmitOptionParser { // Options that do not take arguments. protected final String HELP = "--help"; protected final String SUPERVISE = "--supervise"; + protected final String USAGE_ERROR = "--usage-error"; protected final String VERBOSE = "--verbose"; protected final String VERSION = "--version"; @@ -120,6 +121,7 @@ class SparkSubmitOptionParser { final String[][] switches = { { HELP, "-h" }, { SUPERVISE }, + { USAGE_ERROR }, { VERBOSE, "-v" }, { VERSION }, }; diff --git a/launcher/src/test/java/org/apache/spark/launcher/SparkSubmitCommandBuilderSuite.java b/launcher/src/test/java/org/apache/spark/launcher/SparkSubmitCommandBuilderSuite.java index 97043a76cc61..7329ac9f7fb8 100644 --- a/launcher/src/test/java/org/apache/spark/launcher/SparkSubmitCommandBuilderSuite.java +++ b/launcher/src/test/java/org/apache/spark/launcher/SparkSubmitCommandBuilderSuite.java @@ -194,7 +194,7 @@ private void testCmdBuilder(boolean isDriver) throws Exception { if (isDriver) { assertEquals("-XX:MaxPermSize=256m", arg); } else { - assertEquals("-XX:MaxPermSize=128m", arg); + assertEquals("-XX:MaxPermSize=256m", arg); } } } diff --git a/make-distribution.sh b/make-distribution.sh index 8d6e91d67593..9f063da3a16c 100755 --- a/make-distribution.sh +++ b/make-distribution.sh @@ -141,22 +141,6 @@ SPARK_HIVE=$("$MVN" help:evaluate -Dexpression=project.activeProfiles -pl sql/hi # because we use "set -o pipefail" echo -n) -JAVA_CMD="$JAVA_HOME"/bin/java -JAVA_VERSION=$("$JAVA_CMD" -version 2>&1) -if [[ ! "$JAVA_VERSION" =~ "1.6" && -z "$SKIP_JAVA_TEST" ]]; then - echo "***NOTE***: JAVA_HOME is not set to a JDK 6 installation. The resulting" - echo " distribution may not work well with PySpark and will not run" - echo " with Java 6 (See SPARK-1703 and SPARK-1911)." - echo " This test can be disabled by adding --skip-java-test." - echo "Output from 'java -version' was:" - echo "$JAVA_VERSION" - read -p "Would you like to continue anyways? [y,n]: " -r - if [[ ! "$REPLY" =~ ^[Yy]$ ]]; then - echo "Okay, exiting." - exit 1 - fi -fi - if [ "$NAME" == "none" ]; then NAME=$SPARK_HADOOP_VERSION fi @@ -231,6 +215,11 @@ cp -r "$SPARK_HOME/bin" "$DISTDIR" cp -r "$SPARK_HOME/python" "$DISTDIR" cp -r "$SPARK_HOME/sbin" "$DISTDIR" cp -r "$SPARK_HOME/ec2" "$DISTDIR" +# Copy SparkR if it exists +if [ -d "$SPARK_HOME"/R/lib/SparkR ]; then + mkdir -p "$DISTDIR"/R/lib + cp -r "$SPARK_HOME/R/lib/SparkR" "$DISTDIR"/R/lib +fi # Download and copy in tachyon, if requested if [ "$SPARK_TACHYON" == "true" ]; then diff --git a/mllib/pom.xml b/mllib/pom.xml index 0c07ca1a62fd..a5db14407b4f 100644 --- a/mllib/pom.xml +++ b/mllib/pom.xml @@ -21,7 +21,7 @@ org.apache.spark spark-parent_2.10 - 1.4.0-SNAPSHOT + 1.5.0-SNAPSHOT ../pom.xml @@ -40,6 +40,13 @@ spark-core_${scala.binary.version} ${project.version} + + org.apache.spark + spark-core_${scala.binary.version} + ${project.version} + test-jar + test + org.apache.spark spark-streaming_${scala.binary.version} @@ -99,7 +106,7 @@ org.mockito - mockito-all + mockito-core test diff --git a/mllib/src/main/scala/org/apache/spark/ml/Estimator.scala b/mllib/src/main/scala/org/apache/spark/ml/Estimator.scala index 7f3f3262a644..57e416591de6 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/Estimator.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/Estimator.scala @@ -19,16 +19,16 @@ package org.apache.spark.ml import scala.annotation.varargs -import org.apache.spark.annotation.AlphaComponent -import org.apache.spark.ml.param.{ParamMap, ParamPair, Params} +import org.apache.spark.annotation.DeveloperApi +import org.apache.spark.ml.param.{ParamMap, ParamPair} import org.apache.spark.sql.DataFrame /** - * :: AlphaComponent :: + * :: DeveloperApi :: * Abstract class for estimators that fit models to data. */ -@AlphaComponent -abstract class Estimator[M <: Model[M]] extends PipelineStage with Params { +@DeveloperApi +abstract class Estimator[M <: Model[M]] extends PipelineStage { /** * Fits a single model to the input data with optional parameters. @@ -78,7 +78,5 @@ abstract class Estimator[M <: Model[M]] extends PipelineStage with Params { paramMaps.map(fit(dataset, _)) } - override def copy(extra: ParamMap): Estimator[M] = { - super.copy(extra).asInstanceOf[Estimator[M]] - } + override def copy(extra: ParamMap): Estimator[M] } diff --git a/mllib/src/main/scala/org/apache/spark/ml/Model.scala b/mllib/src/main/scala/org/apache/spark/ml/Model.scala index 7fd515369b19..252acc156583 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/Model.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/Model.scala @@ -17,22 +17,22 @@ package org.apache.spark.ml -import org.apache.spark.annotation.AlphaComponent +import org.apache.spark.annotation.DeveloperApi import org.apache.spark.ml.param.ParamMap /** - * :: AlphaComponent :: + * :: DeveloperApi :: * A fitted model, i.e., a [[Transformer]] produced by an [[Estimator]]. * * @tparam M model type */ -@AlphaComponent +@DeveloperApi abstract class Model[M <: Model[M]] extends Transformer { /** * The parent estimator that produced this model. * Note: For ensembles' component Models, this value can be null. */ - var parent: Estimator[M] = _ + @transient var parent: Estimator[M] = _ /** * Sets the parent of this model (Java API). @@ -42,8 +42,8 @@ abstract class Model[M <: Model[M]] extends Transformer { this.asInstanceOf[M] } - override def copy(extra: ParamMap): M = { - // The default implementation of Params.copy doesn't work for models. - throw new NotImplementedError(s"${this.getClass} doesn't implement copy(extra: ParamMap)") - } + /** Indicates whether this [[Model]] has a corresponding parent. */ + def hasParent: Boolean = parent != null + + override def copy(extra: ParamMap): M } diff --git a/mllib/src/main/scala/org/apache/spark/ml/Pipeline.scala b/mllib/src/main/scala/org/apache/spark/ml/Pipeline.scala index fac54188f9f4..a1f3851d804f 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/Pipeline.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/Pipeline.scala @@ -17,20 +17,23 @@ package org.apache.spark.ml +import java.{util => ju} + +import scala.collection.JavaConverters._ import scala.collection.mutable.ListBuffer import org.apache.spark.Logging -import org.apache.spark.annotation.{AlphaComponent, DeveloperApi} +import org.apache.spark.annotation.{DeveloperApi, Experimental} import org.apache.spark.ml.param.{Param, ParamMap, Params} import org.apache.spark.ml.util.Identifiable import org.apache.spark.sql.DataFrame import org.apache.spark.sql.types.StructType /** - * :: AlphaComponent :: + * :: DeveloperApi :: * A stage in a pipeline, either an [[Estimator]] or a [[Transformer]]. */ -@AlphaComponent +@DeveloperApi abstract class PipelineStage extends Params with Logging { /** @@ -63,13 +66,11 @@ abstract class PipelineStage extends Params with Logging { outputSchema } - override def copy(extra: ParamMap): PipelineStage = { - super.copy(extra).asInstanceOf[PipelineStage] - } + override def copy(extra: ParamMap): PipelineStage } /** - * :: AlphaComponent :: + * :: Experimental :: * A simple pipeline, which acts as an estimator. A Pipeline consists of a sequence of stages, each * of which is either an [[Estimator]] or a [[Transformer]]. When [[Pipeline#fit]] is called, the * stages are executed in order. If a stage is an [[Estimator]], its [[Estimator#fit]] method will @@ -80,7 +81,7 @@ abstract class PipelineStage extends Params with Logging { * transformers, corresponding to the pipeline stages. If there are no stages, the pipeline acts as * an identity transformer. */ -@AlphaComponent +@Experimental class Pipeline(override val uid: String) extends Estimator[PipelineModel] { def this() = this(Identifiable.randomUID("pipeline")) @@ -97,12 +98,9 @@ class Pipeline(override val uid: String) extends Estimator[PipelineModel] { /** @group getParam */ def getStages: Array[PipelineStage] = $(stages).clone() - override def validateParams(paramMap: ParamMap): Unit = { - val map = extractParamMap(paramMap) - getStages.foreach { - case pStage: Params => pStage.validateParams(map) - case _ => - } + override def validateParams(): Unit = { + super.validateParams() + $(stages).foreach(_.validateParams()) } /** @@ -169,15 +167,20 @@ class Pipeline(override val uid: String) extends Estimator[PipelineModel] { } /** - * :: AlphaComponent :: - * Represents a compiled pipeline. + * :: Experimental :: + * Represents a fitted pipeline. */ -@AlphaComponent +@Experimental class PipelineModel private[ml] ( override val uid: String, val stages: Array[Transformer]) extends Model[PipelineModel] with Logging { + /** A Java/Python-friendly auxiliary constructor. */ + private[ml] def this(uid: String, stages: ju.List[Transformer]) = { + this(uid, stages.asScala.toArray) + } + override def validateParams(): Unit = { super.validateParams() stages.foreach(_.validateParams()) @@ -193,6 +196,6 @@ class PipelineModel private[ml] ( } override def copy(extra: ParamMap): PipelineModel = { - new PipelineModel(uid, stages) + new PipelineModel(uid, stages.map(_.copy(extra))) } } diff --git a/mllib/src/main/scala/org/apache/spark/ml/Predictor.scala b/mllib/src/main/scala/org/apache/spark/ml/Predictor.scala index ec0f76aa668b..333b42711ec5 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/Predictor.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/Predictor.scala @@ -58,7 +58,6 @@ private[ml] trait PredictorParams extends Params /** * :: DeveloperApi :: - * * Abstraction for prediction problems (regression and classification). * * @tparam FeaturesType Type of features. @@ -91,9 +90,7 @@ abstract class Predictor[ copyValues(train(dataset).setParent(this)) } - override def copy(extra: ParamMap): Learner = { - super.copy(extra).asInstanceOf[Learner] - } + override def copy(extra: ParamMap): Learner /** * Train a model using the given dataset and parameters. @@ -113,7 +110,6 @@ abstract class Predictor[ * * The default value is VectorUDT, but it may be overridden if FeaturesType is not Vector. */ - @DeveloperApi private[ml] def featuresDataType: DataType = new VectorUDT override def transformSchema(schema: StructType): StructType = { @@ -126,15 +122,12 @@ abstract class Predictor[ */ protected def extractLabeledPoints(dataset: DataFrame): RDD[LabeledPoint] = { dataset.select($(labelCol), $(featuresCol)) - .map { case Row(label: Double, features: Vector) => - LabeledPoint(label, features) - } + .map { case Row(label: Double, features: Vector) => LabeledPoint(label, features) } } } /** * :: DeveloperApi :: - * * Abstraction for a model for prediction tasks (regression and classification). * * @tparam FeaturesType Type of features. @@ -176,7 +169,10 @@ abstract class PredictionModel[FeaturesType, M <: PredictionModel[FeaturesType, override def transform(dataset: DataFrame): DataFrame = { transformSchema(dataset.schema, logging = true) if ($(predictionCol).nonEmpty) { - dataset.withColumn($(predictionCol), callUDF(predict _, DoubleType, col($(featuresCol)))) + val predictUDF = udf { (features: Any) => + predict(features.asInstanceOf[FeaturesType]) + } + dataset.withColumn($(predictionCol), predictUDF(col($(featuresCol)))) } else { this.logWarning(s"$uid: Predictor.transform() was called as NOOP" + " since no output columns were set.") diff --git a/mllib/src/main/scala/org/apache/spark/ml/Transformer.scala b/mllib/src/main/scala/org/apache/spark/ml/Transformer.scala index d96b54e511e9..3c7bcf7590e6 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/Transformer.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/Transformer.scala @@ -20,7 +20,7 @@ package org.apache.spark.ml import scala.annotation.varargs import org.apache.spark.Logging -import org.apache.spark.annotation.AlphaComponent +import org.apache.spark.annotation.DeveloperApi import org.apache.spark.ml.param._ import org.apache.spark.ml.param.shared._ import org.apache.spark.sql.DataFrame @@ -28,11 +28,11 @@ import org.apache.spark.sql.functions._ import org.apache.spark.sql.types._ /** - * :: AlphaComponent :: + * :: DeveloperApi :: * Abstract class for transformers that transform one dataset into another. */ -@AlphaComponent -abstract class Transformer extends PipelineStage with Params { +@DeveloperApi +abstract class Transformer extends PipelineStage { /** * Transforms the dataset with optional parameters @@ -67,16 +67,16 @@ abstract class Transformer extends PipelineStage with Params { */ def transform(dataset: DataFrame): DataFrame - override def copy(extra: ParamMap): Transformer = { - super.copy(extra).asInstanceOf[Transformer] - } + override def copy(extra: ParamMap): Transformer } /** + * :: DeveloperApi :: * Abstract class for transformers that take one input column, apply transformation, and output the * result as a new column. */ -private[ml] abstract class UnaryTransformer[IN, OUT, T <: UnaryTransformer[IN, OUT, T]] +@DeveloperApi +abstract class UnaryTransformer[IN, OUT, T <: UnaryTransformer[IN, OUT, T]] extends Transformer with HasInputCol with HasOutputCol with Logging { /** @group setParam */ @@ -118,4 +118,6 @@ private[ml] abstract class UnaryTransformer[IN, OUT, T <: UnaryTransformer[IN, O dataset.withColumn($(outputCol), callUDF(this.createTransformFunc, outputDataType, dataset($(inputCol)))) } + + override def copy(extra: ParamMap): T = defaultCopy(extra) } diff --git a/mllib/src/main/scala/org/apache/spark/ml/attribute/AttributeGroup.scala b/mllib/src/main/scala/org/apache/spark/ml/attribute/AttributeGroup.scala index f5f37aa77929..457c15830fd3 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/attribute/AttributeGroup.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/attribute/AttributeGroup.scala @@ -19,10 +19,12 @@ package org.apache.spark.ml.attribute import scala.collection.mutable.ArrayBuffer +import org.apache.spark.annotation.DeveloperApi import org.apache.spark.mllib.linalg.VectorUDT import org.apache.spark.sql.types.{Metadata, MetadataBuilder, StructField} /** + * :: DeveloperApi :: * Attributes that describe a vector ML column. * * @param name name of the attribute group (the ML column name) @@ -31,6 +33,7 @@ import org.apache.spark.sql.types.{Metadata, MetadataBuilder, StructField} * @param attrs optional array of attributes. Attribute will be copied with their corresponding * indices in the array. */ +@DeveloperApi class AttributeGroup private ( val name: String, val numAttributes: Option[Int], @@ -182,7 +185,11 @@ class AttributeGroup private ( } } -/** Factory methods to create attribute groups. */ +/** + * :: DeveloperApi :: + * Factory methods to create attribute groups. + */ +@DeveloperApi object AttributeGroup { import AttributeKeys._ diff --git a/mllib/src/main/scala/org/apache/spark/ml/attribute/AttributeType.scala b/mllib/src/main/scala/org/apache/spark/ml/attribute/AttributeType.scala index a83febd7de2c..5c7089b49167 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/attribute/AttributeType.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/attribute/AttributeType.scala @@ -17,12 +17,17 @@ package org.apache.spark.ml.attribute +import org.apache.spark.annotation.DeveloperApi + /** + * :: DeveloperApi :: * An enum-like type for attribute types: [[AttributeType$#Numeric]], [[AttributeType$#Nominal]], * and [[AttributeType$#Binary]]. */ +@DeveloperApi sealed abstract class AttributeType(val name: String) +@DeveloperApi object AttributeType { /** Numeric type. */ diff --git a/mllib/src/main/scala/org/apache/spark/ml/attribute/attributes.scala b/mllib/src/main/scala/org/apache/spark/ml/attribute/attributes.scala index e8f7f152784a..e479f169021d 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/attribute/attributes.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/attribute/attributes.scala @@ -19,11 +19,14 @@ package org.apache.spark.ml.attribute import scala.annotation.varargs -import org.apache.spark.sql.types.{DoubleType, Metadata, MetadataBuilder, StructField} +import org.apache.spark.annotation.DeveloperApi +import org.apache.spark.sql.types.{DoubleType, NumericType, Metadata, MetadataBuilder, StructField} /** + * :: DeveloperApi :: * Abstract class for ML attributes. */ +@DeveloperApi sealed abstract class Attribute extends Serializable { name.foreach { n => @@ -124,7 +127,7 @@ private[attribute] trait AttributeFactory { * Creates an [[Attribute]] from a [[StructField]] instance. */ def fromStructField(field: StructField): Attribute = { - require(field.dataType == DoubleType) + require(field.dataType.isInstanceOf[NumericType]) val metadata = field.metadata val mlAttr = AttributeKeys.ML_ATTR if (metadata.contains(mlAttr)) { @@ -135,6 +138,10 @@ private[attribute] trait AttributeFactory { } } +/** + * :: DeveloperApi :: + */ +@DeveloperApi object Attribute extends AttributeFactory { private[attribute] override def fromMetadata(metadata: Metadata): Attribute = { @@ -163,6 +170,7 @@ object Attribute extends AttributeFactory { /** + * :: DeveloperApi :: * A numeric attribute with optional summary statistics. * @param name optional name * @param index optional index @@ -171,6 +179,7 @@ object Attribute extends AttributeFactory { * @param std optional standard deviation * @param sparsity optional sparsity (ratio of zeros) */ +@DeveloperApi class NumericAttribute private[ml] ( override val name: Option[String] = None, override val index: Option[Int] = None, @@ -278,8 +287,10 @@ class NumericAttribute private[ml] ( } /** + * :: DeveloperApi :: * Factory methods for numeric attributes. */ +@DeveloperApi object NumericAttribute extends AttributeFactory { /** The default numeric attribute. */ @@ -298,6 +309,7 @@ object NumericAttribute extends AttributeFactory { } /** + * :: DeveloperApi :: * A nominal attribute. * @param name optional name * @param index optional index @@ -306,6 +318,7 @@ object NumericAttribute extends AttributeFactory { * defined. * @param values optional values. At most one of `numValues` and `values` can be defined. */ +@DeveloperApi class NominalAttribute private[ml] ( override val name: Option[String] = None, override val index: Option[Int] = None, @@ -430,7 +443,11 @@ class NominalAttribute private[ml] ( } } -/** Factory methods for nominal attributes. */ +/** + * :: DeveloperApi :: + * Factory methods for nominal attributes. + */ +@DeveloperApi object NominalAttribute extends AttributeFactory { /** The default nominal attribute. */ @@ -450,11 +467,13 @@ object NominalAttribute extends AttributeFactory { } /** + * :: DeveloperApi :: * A binary attribute. * @param name optional name * @param index optional index * @param values optionla values. If set, its size must be 2. */ +@DeveloperApi class BinaryAttribute private[ml] ( override val name: Option[String] = None, override val index: Option[Int] = None, @@ -526,7 +545,11 @@ class BinaryAttribute private[ml] ( } } -/** Factory methods for binary attributes. */ +/** + * :: DeveloperApi :: + * Factory methods for binary attributes. + */ +@DeveloperApi object BinaryAttribute extends AttributeFactory { /** The default binary attribute. */ @@ -543,8 +566,10 @@ object BinaryAttribute extends AttributeFactory { } /** + * :: DeveloperApi :: * An unresolved attribute. */ +@DeveloperApi object UnresolvedAttribute extends Attribute { override def attrType: AttributeType = AttributeType.Unresolved diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/Classifier.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/Classifier.scala index 263d580fe2dd..85c097bc64a4 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/classification/Classifier.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/classification/Classifier.scala @@ -18,6 +18,7 @@ package org.apache.spark.ml.classification import org.apache.spark.annotation.DeveloperApi +import org.apache.spark.ml.param.ParamMap import org.apache.spark.ml.{PredictionModel, PredictorParams, Predictor} import org.apache.spark.ml.param.shared.HasRawPredictionCol import org.apache.spark.ml.util.SchemaUtils @@ -101,15 +102,20 @@ abstract class ClassificationModel[FeaturesType, M <: ClassificationModel[Featur var outputData = dataset var numColsOutput = 0 if (getRawPredictionCol != "") { - outputData = outputData.withColumn(getRawPredictionCol, - callUDF(predictRaw _, new VectorUDT, col(getFeaturesCol))) + val predictRawUDF = udf { (features: Any) => + predictRaw(features.asInstanceOf[FeaturesType]) + } + outputData = outputData.withColumn(getRawPredictionCol, predictRawUDF(col(getFeaturesCol))) numColsOutput += 1 } if (getPredictionCol != "") { val predUDF = if (getRawPredictionCol != "") { - callUDF(raw2prediction _, DoubleType, col(getRawPredictionCol)) + udf(raw2prediction _).apply(col(getRawPredictionCol)) } else { - callUDF(predict _, DoubleType, col(getFeaturesCol)) + val predictUDF = udf { (features: Any) => + predict(features.asInstanceOf[FeaturesType]) + } + predictUDF(col(getFeaturesCol)) } outputData = outputData.withColumn(getPredictionCol, predUDF) numColsOutput += 1 diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/DecisionTreeClassifier.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/DecisionTreeClassifier.scala index 7c961332bf5b..2dc1824964a4 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/classification/DecisionTreeClassifier.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/classification/DecisionTreeClassifier.scala @@ -17,10 +17,10 @@ package org.apache.spark.ml.classification -import org.apache.spark.annotation.AlphaComponent +import org.apache.spark.annotation.Experimental import org.apache.spark.ml.{PredictionModel, Predictor} import org.apache.spark.ml.param.ParamMap -import org.apache.spark.ml.tree.{TreeClassifierParams, DecisionTreeParams, DecisionTreeModel, Node} +import org.apache.spark.ml.tree.{DecisionTreeModel, DecisionTreeParams, Node, TreeClassifierParams} import org.apache.spark.ml.util.{Identifiable, MetadataUtils} import org.apache.spark.mllib.linalg.Vector import org.apache.spark.mllib.regression.LabeledPoint @@ -31,14 +31,13 @@ import org.apache.spark.rdd.RDD import org.apache.spark.sql.DataFrame /** - * :: AlphaComponent :: - * + * :: Experimental :: * [[http://en.wikipedia.org/wiki/Decision_tree_learning Decision tree]] learning algorithm * for classification. * It supports both binary and multiclass labels, as well as both continuous and categorical * features. */ -@AlphaComponent +@Experimental final class DecisionTreeClassifier(override val uid: String) extends Predictor[Vector, DecisionTreeClassifier, DecisionTreeClassificationModel] with DecisionTreeParams with TreeClassifierParams { @@ -87,21 +86,23 @@ final class DecisionTreeClassifier(override val uid: String) super.getOldStrategy(categoricalFeatures, numClasses, OldAlgo.Classification, getOldImpurity, subsamplingRate = 1.0) } + + override def copy(extra: ParamMap): DecisionTreeClassifier = defaultCopy(extra) } +@Experimental object DecisionTreeClassifier { /** Accessor for supported impurities: entropy, gini */ final val supportedImpurities: Array[String] = TreeClassifierParams.supportedImpurities } /** - * :: AlphaComponent :: - * + * :: Experimental :: * [[http://en.wikipedia.org/wiki/Decision_tree_learning Decision tree]] model for classification. * It supports both binary and multiclass labels, as well as both continuous and categorical * features. */ -@AlphaComponent +@Experimental final class DecisionTreeClassificationModel private[ml] ( override val uid: String, override val rootNode: Node) diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/GBTClassifier.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/GBTClassifier.scala index d504d84beb91..554e3b8e052b 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/classification/GBTClassifier.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/classification/GBTClassifier.scala @@ -20,11 +20,11 @@ package org.apache.spark.ml.classification import com.github.fommil.netlib.BLAS.{getInstance => blas} import org.apache.spark.Logging -import org.apache.spark.annotation.AlphaComponent +import org.apache.spark.annotation.Experimental import org.apache.spark.ml.{PredictionModel, Predictor} import org.apache.spark.ml.param.{Param, ParamMap} import org.apache.spark.ml.regression.DecisionTreeRegressionModel -import org.apache.spark.ml.tree.{GBTParams, TreeClassifierParams, DecisionTreeModel, TreeEnsembleModel} +import org.apache.spark.ml.tree.{DecisionTreeModel, GBTParams, TreeClassifierParams, TreeEnsembleModel} import org.apache.spark.ml.util.{Identifiable, MetadataUtils} import org.apache.spark.mllib.linalg.Vector import org.apache.spark.mllib.regression.LabeledPoint @@ -36,14 +36,13 @@ import org.apache.spark.rdd.RDD import org.apache.spark.sql.DataFrame /** - * :: AlphaComponent :: - * + * :: Experimental :: * [[http://en.wikipedia.org/wiki/Gradient_boosting Gradient-Boosted Trees (GBTs)]] * learning algorithm for classification. * It supports binary labels, as well as both continuous and categorical features. * Note: Multiclass labels are not currently supported. */ -@AlphaComponent +@Experimental final class GBTClassifier(override val uid: String) extends Predictor[Vector, GBTClassifier, GBTClassificationModel] with GBTParams with TreeClassifierParams with Logging { @@ -142,8 +141,11 @@ final class GBTClassifier(override val uid: String) val oldModel = oldGBT.run(oldDataset) GBTClassificationModel.fromOld(oldModel, this, categoricalFeatures) } + + override def copy(extra: ParamMap): GBTClassifier = defaultCopy(extra) } +@Experimental object GBTClassifier { // The losses below should be lowercase. /** Accessor for supported loss settings: logistic */ @@ -151,8 +153,7 @@ object GBTClassifier { } /** - * :: AlphaComponent :: - * + * :: Experimental :: * [[http://en.wikipedia.org/wiki/Gradient_boosting Gradient-Boosted Trees (GBTs)]] * model for classification. * It supports binary labels, as well as both continuous and categorical features. @@ -160,7 +161,7 @@ object GBTClassifier { * @param _trees Decision trees in the ensemble. * @param _treeWeights Weights for the decision trees in the ensemble. */ -@AlphaComponent +@Experimental final class GBTClassificationModel( override val uid: String, private val _trees: Array[DecisionTreeRegressionModel], @@ -209,7 +210,7 @@ private[ml] object GBTClassificationModel { require(oldModel.algo == OldAlgo.Classification, "Cannot convert GradientBoostedTreesModel" + s" with algo=${oldModel.algo} (old API) to GBTClassificationModel (new API).") val newTrees = oldModel.trees.map { tree => - // parent, fittingParamMap for each tree is null since there are no good ways to set these. + // parent for each tree is null since there is no good way to set this. DecisionTreeRegressionModel.fromOld(tree, null, categoricalFeatures) } val uid = if (parent != null) parent.uid else Identifiable.randomUID("gbtc") diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala index 8694c96e4c5b..2e6eedd45ab0 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala @@ -19,11 +19,11 @@ package org.apache.spark.ml.classification import scala.collection.mutable -import breeze.linalg.{norm => brzNorm, DenseVector => BDV} -import breeze.optimize.{LBFGS => BreezeLBFGS, OWLQN => BreezeOWLQN} -import breeze.optimize.{CachedDiffFunction, DiffFunction} +import breeze.linalg.{DenseVector => BDV, norm => brzNorm} +import breeze.optimize.{CachedDiffFunction, DiffFunction, LBFGS => BreezeLBFGS, OWLQN => BreezeOWLQN} -import org.apache.spark.annotation.AlphaComponent +import org.apache.spark.{Logging, SparkException} +import org.apache.spark.annotation.Experimental import org.apache.spark.ml.param._ import org.apache.spark.ml.param.shared._ import org.apache.spark.ml.util.Identifiable @@ -35,7 +35,6 @@ import org.apache.spark.mllib.util.MLUtils import org.apache.spark.rdd.RDD import org.apache.spark.sql.DataFrame import org.apache.spark.storage.StorageLevel -import org.apache.spark.{SparkException, Logging} /** * Params for logistic regression. @@ -45,12 +44,11 @@ private[classification] trait LogisticRegressionParams extends ProbabilisticClas with HasThreshold /** - * :: AlphaComponent :: - * + * :: Experimental :: * Logistic regression. * Currently, this class only supports binary classification. */ -@AlphaComponent +@Experimental class LogisticRegression(override val uid: String) extends ProbabilisticClassifier[Vector, LogisticRegression, LogisticRegressionModel] with LogisticRegressionParams with Logging { @@ -76,7 +74,7 @@ class LogisticRegression(override val uid: String) setDefault(elasticNetParam -> 0.0) /** - * Set the maximal number of iterations. + * Set the maximum number of iterations. * Default is 100. * @group setParam */ @@ -92,7 +90,11 @@ class LogisticRegression(override val uid: String) def setTol(value: Double): this.type = set(tol, value) setDefault(tol -> 1E-6) - /** @group setParam */ + /** + * Whether to fit an intercept term. + * Default is true. + * @group setParam + * */ def setFitIntercept(value: Boolean): this.type = set(fitIntercept, value) setDefault(fitIntercept -> true) @@ -218,14 +220,15 @@ class LogisticRegression(override val uid: String) new LogisticRegressionModel(uid, weights.compressed, intercept) } + + override def copy(extra: ParamMap): LogisticRegression = defaultCopy(extra) } /** - * :: AlphaComponent :: - * + * :: Experimental :: * Model produced by [[LogisticRegression]]. */ -@AlphaComponent +@Experimental class LogisticRegressionModel private[ml] ( override val uid: String, val weights: Vector, diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/OneVsRest.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/OneVsRest.scala index 1543f051ccd1..ea757c5e40c7 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/classification/OneVsRest.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/classification/OneVsRest.scala @@ -21,10 +21,10 @@ import java.util.UUID import scala.language.existentials -import org.apache.spark.annotation.{AlphaComponent, Experimental} +import org.apache.spark.annotation.Experimental import org.apache.spark.ml._ import org.apache.spark.ml.attribute._ -import org.apache.spark.ml.param.Param +import org.apache.spark.ml.param.{Param, ParamMap} import org.apache.spark.ml.util.{Identifiable, MetadataUtils} import org.apache.spark.mllib.linalg.Vector import org.apache.spark.sql.{DataFrame, Row} @@ -37,11 +37,13 @@ import org.apache.spark.storage.StorageLevel */ private[ml] trait OneVsRestParams extends PredictorParams { + // scalastyle:off structural.type type ClassifierType = Classifier[F, E, M] forSome { type F type M <: ClassificationModel[F, M] type E <: Classifier[F, E, M] } + // scalastyle:on structural.type /** * param for the base binary classifier that we reduce multiclass classification into. @@ -54,8 +56,7 @@ private[ml] trait OneVsRestParams extends PredictorParams { } /** - * :: AlphaComponent :: - * + * :: Experimental :: * Model produced by [[OneVsRest]]. * This stores the models resulting from training k binary classifiers: one for each class. * Each example is scored against all k models, and the model with the highest score @@ -67,11 +68,11 @@ private[ml] trait OneVsRestParams extends PredictorParams { * The i-th model is produced by testing the i-th class (taking label 1) vs the rest * (taking label 0). */ -@AlphaComponent +@Experimental final class OneVsRestModel private[ml] ( override val uid: String, labelMetadata: Metadata, - val models: Array[_ <: ClassificationModel[_,_]]) + val models: Array[_ <: ClassificationModel[_, _]]) extends Model[OneVsRestModel] with OneVsRestParams { override def transformSchema(schema: StructType): StructType = { @@ -87,9 +88,9 @@ final class OneVsRestModel private[ml] ( // add an accumulator column to store predictions of all the models val accColName = "mbc$acc" + UUID.randomUUID().toString - val init: () => Map[Int, Double] = () => {Map()} + val initUDF = udf { () => Map[Int, Double]() } val mapType = MapType(IntegerType, DoubleType, valueContainsNull = false) - val newDataset = dataset.withColumn(accColName, callUDF(init, mapType)) + val newDataset = dataset.withColumn(accColName, initUDF()) // persist if underlying dataset is not persistent. val handlePersistence = dataset.rdd.getStorageLevel == StorageLevel.NONE @@ -105,17 +106,16 @@ final class OneVsRestModel private[ml] ( // add temporary column to store intermediate scores and update val tmpColName = "mbc$tmp" + UUID.randomUUID().toString - val update: (Map[Int, Double], Vector) => Map[Int, Double] = - (predictions: Map[Int, Double], prediction: Vector) => { - predictions + ((index, prediction(1))) - } - val updateUdf = callUDF(update, mapType, col(accColName), col(rawPredictionCol)) - val transformedDataset = model.transform(df).select(columns:_*) - val updatedDataset = transformedDataset.withColumn(tmpColName, updateUdf) + val updateUDF = udf { (predictions: Map[Int, Double], prediction: Vector) => + predictions + ((index, prediction(1))) + } + val transformedDataset = model.transform(df).select(columns : _*) + val updatedDataset = transformedDataset + .withColumn(tmpColName, updateUDF(col(accColName), col(rawPredictionCol))) val newColumns = origCols ++ List(col(tmpColName)) // switch out the intermediate column with the accumulator column - updatedDataset.select(newColumns:_*).withColumnRenamed(tmpColName, accColName) + updatedDataset.select(newColumns : _*).withColumnRenamed(tmpColName, accColName) } if (handlePersistence) { @@ -123,13 +123,20 @@ final class OneVsRestModel private[ml] ( } // output the index of the classifier with highest confidence as prediction - val label: Map[Int, Double] => Double = (predictions: Map[Int, Double]) => { + val labelUDF = udf { (predictions: Map[Int, Double]) => predictions.maxBy(_._2)._1.toDouble } // output label and label metadata as prediction - val labelUdf = callUDF(label, DoubleType, col(accColName)) - aggregatedDataset.withColumn($(predictionCol), labelUdf.as($(predictionCol), labelMetadata)) + aggregatedDataset + .withColumn($(predictionCol), labelUDF(col(accColName)).as($(predictionCol), labelMetadata)) + .drop(accColName) + } + + override def copy(extra: ParamMap): OneVsRestModel = { + val copied = new OneVsRestModel( + uid, labelMetadata, models.map(_.copy(extra).asInstanceOf[ClassificationModel[_, _]])) + copyValues(copied, extra) } } @@ -177,21 +184,19 @@ final class OneVsRest(override val uid: String) // create k columns, one for each binary classifier. val models = Range(0, numClasses).par.map { index => - - val label: Double => Double = (label: Double) => { + val labelUDF = udf { (label: Double) => if (label.toInt == index) 1.0 else 0.0 } // generate new label metadata for the binary problem. // TODO: use when ... otherwise after SPARK-7321 is merged - val labelUDF = callUDF(label, DoubleType, col($(labelCol))) val newLabelMeta = BinaryAttribute.defaultAttr.withName("label").toMetadata() val labelColName = "mc2b$" + index - val labelUDFWithNewMeta = labelUDF.as(labelColName, newLabelMeta) + val labelUDFWithNewMeta = labelUDF(col($(labelCol))).as(labelColName, newLabelMeta) val trainingDataset = multiclassLabeled.withColumn(labelColName, labelUDFWithNewMeta) val classifier = getClassifier classifier.fit(trainingDataset, classifier.labelCol -> labelColName) - }.toArray[ClassificationModel[_,_]] + }.toArray[ClassificationModel[_, _]] if (handlePersistence) { multiclassLabeled.unpersist() @@ -207,4 +212,12 @@ final class OneVsRest(override val uid: String) val model = new OneVsRestModel(uid, labelAttribute.toMetadata(), models).setParent(this) copyValues(model) } + + override def copy(extra: ParamMap): OneVsRest = { + val copied = defaultCopy(extra).asInstanceOf[OneVsRest] + if (isDefined(classifier)) { + copied.setClassifier($(classifier).copy(extra)) + } + copied + } } diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/ProbabilisticClassifier.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/ProbabilisticClassifier.scala index 330ae2938f4e..38e832372698 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/classification/ProbabilisticClassifier.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/classification/ProbabilisticClassifier.scala @@ -98,26 +98,34 @@ private[spark] abstract class ProbabilisticClassificationModel[ var outputData = dataset var numColsOutput = 0 if ($(rawPredictionCol).nonEmpty) { - outputData = outputData.withColumn(getRawPredictionCol, - callUDF(predictRaw _, new VectorUDT, col(getFeaturesCol))) + val predictRawUDF = udf { (features: Any) => + predictRaw(features.asInstanceOf[FeaturesType]) + } + outputData = outputData.withColumn(getRawPredictionCol, predictRawUDF(col(getFeaturesCol))) numColsOutput += 1 } if ($(probabilityCol).nonEmpty) { val probUDF = if ($(rawPredictionCol).nonEmpty) { - callUDF(raw2probability _, new VectorUDT, col($(rawPredictionCol))) + udf(raw2probability _).apply(col($(rawPredictionCol))) } else { - callUDF(predictProbability _, new VectorUDT, col($(featuresCol))) + val probabilityUDF = udf { (features: Any) => + predictProbability(features.asInstanceOf[FeaturesType]) + } + probabilityUDF(col($(featuresCol))) } outputData = outputData.withColumn($(probabilityCol), probUDF) numColsOutput += 1 } if ($(predictionCol).nonEmpty) { val predUDF = if ($(rawPredictionCol).nonEmpty) { - callUDF(raw2prediction _, DoubleType, col($(rawPredictionCol))) + udf(raw2prediction _).apply(col($(rawPredictionCol))) } else if ($(probabilityCol).nonEmpty) { - callUDF(probability2prediction _, DoubleType, col($(probabilityCol))) + udf(probability2prediction _).apply(col($(probabilityCol))) } else { - callUDF(predict _, DoubleType, col($(featuresCol))) + val predictUDF = udf { (features: Any) => + predict(features.asInstanceOf[FeaturesType]) + } + predictUDF(col($(featuresCol))) } outputData = outputData.withColumn($(predictionCol), predUDF) numColsOutput += 1 diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/RandomForestClassifier.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/RandomForestClassifier.scala index a1de7919859e..d3c67494a31e 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/classification/RandomForestClassifier.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/classification/RandomForestClassifier.scala @@ -19,10 +19,10 @@ package org.apache.spark.ml.classification import scala.collection.mutable -import org.apache.spark.annotation.AlphaComponent +import org.apache.spark.annotation.Experimental import org.apache.spark.ml.{PredictionModel, Predictor} import org.apache.spark.ml.param.ParamMap -import org.apache.spark.ml.tree.{RandomForestParams, TreeClassifierParams, DecisionTreeModel, TreeEnsembleModel} +import org.apache.spark.ml.tree.{DecisionTreeModel, RandomForestParams, TreeClassifierParams, TreeEnsembleModel} import org.apache.spark.ml.util.{Identifiable, MetadataUtils} import org.apache.spark.mllib.linalg.Vector import org.apache.spark.mllib.regression.LabeledPoint @@ -33,14 +33,13 @@ import org.apache.spark.rdd.RDD import org.apache.spark.sql.DataFrame /** - * :: AlphaComponent :: - * + * :: Experimental :: * [[http://en.wikipedia.org/wiki/Random_forest Random Forest]] learning algorithm for * classification. * It supports both binary and multiclass labels, as well as both continuous and categorical * features. */ -@AlphaComponent +@Experimental final class RandomForestClassifier(override val uid: String) extends Predictor[Vector, RandomForestClassifier, RandomForestClassificationModel] with RandomForestParams with TreeClassifierParams { @@ -98,8 +97,11 @@ final class RandomForestClassifier(override val uid: String) oldDataset, strategy, getNumTrees, getFeatureSubsetStrategy, getSeed.toInt) RandomForestClassificationModel.fromOld(oldModel, this, categoricalFeatures) } + + override def copy(extra: ParamMap): RandomForestClassifier = defaultCopy(extra) } +@Experimental object RandomForestClassifier { /** Accessor for supported impurity settings: entropy, gini */ final val supportedImpurities: Array[String] = TreeClassifierParams.supportedImpurities @@ -110,15 +112,14 @@ object RandomForestClassifier { } /** - * :: AlphaComponent :: - * + * :: Experimental :: * [[http://en.wikipedia.org/wiki/Random_forest Random Forest]] model for classification. * It supports both binary and multiclass labels, as well as both continuous and categorical * features. * @param _trees Decision trees in the ensemble. * Warning: These have null parents. */ -@AlphaComponent +@Experimental final class RandomForestClassificationModel private[ml] ( override val uid: String, private val _trees: Array[DecisionTreeClassificationModel]) @@ -171,7 +172,7 @@ private[ml] object RandomForestClassificationModel { require(oldModel.algo == OldAlgo.Classification, "Cannot convert RandomForestModel" + s" with algo=${oldModel.algo} (old API) to RandomForestClassificationModel (new API).") val newTrees = oldModel.trees.map { tree => - // parent, fittingParamMap for each tree is null since there are no good ways to set these. + // parent for each tree is null since there is no good way to set this. DecisionTreeClassificationModel.fromOld(tree, null, categoricalFeatures) } val uid = if (parent != null) parent.uid else Identifiable.randomUID("rfc") diff --git a/mllib/src/main/scala/org/apache/spark/ml/evaluation/BinaryClassificationEvaluator.scala b/mllib/src/main/scala/org/apache/spark/ml/evaluation/BinaryClassificationEvaluator.scala index c1af09c9694b..4a82b77f0edc 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/evaluation/BinaryClassificationEvaluator.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/evaluation/BinaryClassificationEvaluator.scala @@ -17,8 +17,7 @@ package org.apache.spark.ml.evaluation -import org.apache.spark.annotation.AlphaComponent -import org.apache.spark.ml.Evaluator +import org.apache.spark.annotation.Experimental import org.apache.spark.ml.param._ import org.apache.spark.ml.param.shared._ import org.apache.spark.ml.util.{Identifiable, SchemaUtils} @@ -28,11 +27,10 @@ import org.apache.spark.sql.{DataFrame, Row} import org.apache.spark.sql.types.DoubleType /** - * :: AlphaComponent :: - * + * :: Experimental :: * Evaluator for binary classification, which expects two input columns: score and label. */ -@AlphaComponent +@Experimental class BinaryClassificationEvaluator(override val uid: String) extends Evaluator with HasRawPredictionCol with HasLabelCol { @@ -81,4 +79,6 @@ class BinaryClassificationEvaluator(override val uid: String) metrics.unpersist() metric } + + override def copy(extra: ParamMap): BinaryClassificationEvaluator = defaultCopy(extra) } diff --git a/mllib/src/main/scala/org/apache/spark/ml/Evaluator.scala b/mllib/src/main/scala/org/apache/spark/ml/evaluation/Evaluator.scala similarity index 86% rename from mllib/src/main/scala/org/apache/spark/ml/Evaluator.scala rename to mllib/src/main/scala/org/apache/spark/ml/evaluation/Evaluator.scala index 5f2f8c94e9ff..e56c946a063e 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/Evaluator.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/evaluation/Evaluator.scala @@ -15,21 +15,21 @@ * limitations under the License. */ -package org.apache.spark.ml +package org.apache.spark.ml.evaluation -import org.apache.spark.annotation.AlphaComponent +import org.apache.spark.annotation.DeveloperApi import org.apache.spark.ml.param.{ParamMap, Params} import org.apache.spark.sql.DataFrame /** - * :: AlphaComponent :: + * :: DeveloperApi :: * Abstract class for evaluators that compute metrics from predictions. */ -@AlphaComponent +@DeveloperApi abstract class Evaluator extends Params { /** - * Evaluates the output. + * Evaluates model output and returns a scalar metric (larger is better). * * @param dataset a dataset that contains labels/observations and predictions. * @param paramMap parameter map that specifies the input columns and output metrics @@ -46,7 +46,5 @@ abstract class Evaluator extends Params { */ def evaluate(dataset: DataFrame): Double - override def copy(extra: ParamMap): Evaluator = { - super.copy(extra).asInstanceOf[Evaluator] - } + override def copy(extra: ParamMap): Evaluator } diff --git a/mllib/src/main/scala/org/apache/spark/ml/evaluation/RegressionEvaluator.scala b/mllib/src/main/scala/org/apache/spark/ml/evaluation/RegressionEvaluator.scala new file mode 100644 index 000000000000..01c000b47514 --- /dev/null +++ b/mllib/src/main/scala/org/apache/spark/ml/evaluation/RegressionEvaluator.scala @@ -0,0 +1,89 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.ml.evaluation + +import org.apache.spark.annotation.Experimental +import org.apache.spark.ml.param.{Param, ParamMap, ParamValidators} +import org.apache.spark.ml.param.shared.{HasLabelCol, HasPredictionCol} +import org.apache.spark.ml.util.{Identifiable, SchemaUtils} +import org.apache.spark.mllib.evaluation.RegressionMetrics +import org.apache.spark.sql.{DataFrame, Row} +import org.apache.spark.sql.types.DoubleType + +/** + * :: Experimental :: + * Evaluator for regression, which expects two input columns: prediction and label. + */ +@Experimental +final class RegressionEvaluator(override val uid: String) + extends Evaluator with HasPredictionCol with HasLabelCol { + + def this() = this(Identifiable.randomUID("regEval")) + + /** + * param for metric name in evaluation (supports `"rmse"` (default), `"mse"`, `"r2"`, and `"mae"`) + * + * Because we will maximize evaluation value (ref: `CrossValidator`), + * when we evaluate a metric that is needed to minimize (e.g., `"rmse"`, `"mse"`, `"mae"`), + * we take and output the negative of this metric. + * @group param + */ + val metricName: Param[String] = { + val allowedParams = ParamValidators.inArray(Array("mse", "rmse", "r2", "mae")) + new Param(this, "metricName", "metric name in evaluation (mse|rmse|r2|mae)", allowedParams) + } + + /** @group getParam */ + def getMetricName: String = $(metricName) + + /** @group setParam */ + def setMetricName(value: String): this.type = set(metricName, value) + + /** @group setParam */ + def setPredictionCol(value: String): this.type = set(predictionCol, value) + + /** @group setParam */ + def setLabelCol(value: String): this.type = set(labelCol, value) + + setDefault(metricName -> "rmse") + + override def evaluate(dataset: DataFrame): Double = { + val schema = dataset.schema + SchemaUtils.checkColumnType(schema, $(predictionCol), DoubleType) + SchemaUtils.checkColumnType(schema, $(labelCol), DoubleType) + + val predictionAndLabels = dataset.select($(predictionCol), $(labelCol)) + .map { case Row(prediction: Double, label: Double) => + (prediction, label) + } + val metrics = new RegressionMetrics(predictionAndLabels) + val metric = $(metricName) match { + case "rmse" => + -metrics.rootMeanSquaredError + case "mse" => + -metrics.meanSquaredError + case "r2" => + metrics.r2 + case "mae" => + -metrics.meanAbsoluteError + } + metric + } + + override def copy(extra: ParamMap): RegressionEvaluator = defaultCopy(extra) +} diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/Binarizer.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/Binarizer.scala index 62f4a6343423..46314854d5e3 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/Binarizer.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/Binarizer.scala @@ -17,7 +17,7 @@ package org.apache.spark.ml.feature -import org.apache.spark.annotation.AlphaComponent +import org.apache.spark.annotation.Experimental import org.apache.spark.ml.Transformer import org.apache.spark.ml.attribute.BinaryAttribute import org.apache.spark.ml.param._ @@ -28,10 +28,10 @@ import org.apache.spark.sql.functions._ import org.apache.spark.sql.types.{DoubleType, StructType} /** - * :: AlphaComponent :: + * :: Experimental :: * Binarize a column of continuous features given a threshold. */ -@AlphaComponent +@Experimental final class Binarizer(override val uid: String) extends Transformer with HasInputCol with HasOutputCol { @@ -83,4 +83,6 @@ final class Binarizer(override val uid: String) val outputFields = inputFields :+ attr.toStructField() StructType(outputFields) } + + override def copy(extra: ParamMap): Binarizer = defaultCopy(extra) } diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/Bucketizer.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/Bucketizer.scala index ac8dfb5632a7..67e4785bc355 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/Bucketizer.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/Bucketizer.scala @@ -20,7 +20,7 @@ package org.apache.spark.ml.feature import java.{util => ju} import org.apache.spark.SparkException -import org.apache.spark.annotation.AlphaComponent +import org.apache.spark.annotation.Experimental import org.apache.spark.ml.Model import org.apache.spark.ml.attribute.NominalAttribute import org.apache.spark.ml.param._ @@ -31,10 +31,10 @@ import org.apache.spark.sql.functions._ import org.apache.spark.sql.types.{DoubleType, StructField, StructType} /** - * :: AlphaComponent :: + * :: Experimental :: * `Bucketizer` maps a column of continuous features to a column of feature buckets. */ -@AlphaComponent +@Experimental final class Bucketizer(override val uid: String) extends Model[Bucketizer] with HasInputCol with HasOutputCol { @@ -89,6 +89,8 @@ final class Bucketizer(override val uid: String) SchemaUtils.checkColumnType(schema, $(inputCol), DoubleType) SchemaUtils.appendColumn(schema, prepOutputField(schema)) } + + override def copy(extra: ParamMap): Bucketizer = defaultCopy(extra) } private[feature] object Bucketizer { diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/DCT.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/DCT.scala new file mode 100644 index 000000000000..228347635c92 --- /dev/null +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/DCT.scala @@ -0,0 +1,72 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.ml.feature + +import edu.emory.mathcs.jtransforms.dct._ + +import org.apache.spark.annotation.Experimental +import org.apache.spark.ml.UnaryTransformer +import org.apache.spark.ml.param.BooleanParam +import org.apache.spark.ml.util.Identifiable +import org.apache.spark.mllib.linalg.{Vector, VectorUDT, Vectors} +import org.apache.spark.sql.types.DataType + +/** + * :: Experimental :: + * A feature transformer that takes the 1D discrete cosine transform of a real vector. No zero + * padding is performed on the input vector. + * It returns a real vector of the same length representing the DCT. The return vector is scaled + * such that the transform matrix is unitary (aka scaled DCT-II). + * + * More information on [[https://en.wikipedia.org/wiki/Discrete_cosine_transform#DCT-II Wikipedia]]. + */ +@Experimental +class DCT(override val uid: String) + extends UnaryTransformer[Vector, Vector, DCT] { + + def this() = this(Identifiable.randomUID("dct")) + + /** + * Indicates whether to perform the inverse DCT (true) or forward DCT (false). + * Default: false + * @group param + */ + def inverse: BooleanParam = new BooleanParam( + this, "inverse", "Set transformer to perform inverse DCT") + + /** @group setParam */ + def setInverse(value: Boolean): this.type = set(inverse, value) + + /** @group getParam */ + def getInverse: Boolean = $(inverse) + + setDefault(inverse -> false) + + override protected def createTransformFunc: Vector => Vector = { vec => + val result = vec.toArray + val jTransformer = new DoubleDCT_1D(result.length) + if ($(inverse)) jTransformer.inverse(result, true) else jTransformer.forward(result, true) + Vectors.dense(result) + } + + override protected def validateInputType(inputType: DataType): Unit = { + require(inputType.isInstanceOf[VectorUDT], s"Input type must be VectorUDT but got $inputType.") + } + + override protected def outputDataType: DataType = new VectorUDT +} diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/ElementwiseProduct.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/ElementwiseProduct.scala index 8b32eee0e490..a359cb8f37ec 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/ElementwiseProduct.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/ElementwiseProduct.scala @@ -17,21 +17,21 @@ package org.apache.spark.ml.feature -import org.apache.spark.annotation.AlphaComponent +import org.apache.spark.annotation.Experimental import org.apache.spark.ml.UnaryTransformer -import org.apache.spark.ml.param.Param +import org.apache.spark.ml.param.{ParamMap, Param} import org.apache.spark.ml.util.Identifiable import org.apache.spark.mllib.feature import org.apache.spark.mllib.linalg.{Vector, VectorUDT} import org.apache.spark.sql.types.DataType /** - * :: AlphaComponent :: + * :: Experimental :: * Outputs the Hadamard product (i.e., the element-wise product) of each input vector with a * provided "weight" vector. In other words, it scales each column of the dataset by a scalar * multiplier. */ -@AlphaComponent +@Experimental class ElementwiseProduct(override val uid: String) extends UnaryTransformer[Vector, Vector, ElementwiseProduct] { @@ -41,7 +41,7 @@ class ElementwiseProduct(override val uid: String) * the vector to multiply with input vectors * @group param */ - val scalingVec: Param[Vector] = new Param(this, "scalingVector", "vector for hadamard product") + val scalingVec: Param[Vector] = new Param(this, "scalingVec", "vector for hadamard product") /** @group setParam */ def setScalingVec(value: Vector): this.type = set(scalingVec, value) diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/HashingTF.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/HashingTF.scala index 30033ced68a0..319d23e46cef 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/HashingTF.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/HashingTF.scala @@ -17,23 +17,32 @@ package org.apache.spark.ml.feature -import org.apache.spark.annotation.AlphaComponent -import org.apache.spark.ml.UnaryTransformer -import org.apache.spark.ml.param.{IntParam, ParamValidators} -import org.apache.spark.ml.util.Identifiable +import org.apache.spark.annotation.Experimental +import org.apache.spark.ml.Transformer +import org.apache.spark.ml.attribute.AttributeGroup +import org.apache.spark.ml.param.{IntParam, ParamMap, ParamValidators} +import org.apache.spark.ml.param.shared.{HasInputCol, HasOutputCol} +import org.apache.spark.ml.util.{Identifiable, SchemaUtils} import org.apache.spark.mllib.feature -import org.apache.spark.mllib.linalg.{Vector, VectorUDT} -import org.apache.spark.sql.types.DataType +import org.apache.spark.sql.DataFrame +import org.apache.spark.sql.functions.{col, udf} +import org.apache.spark.sql.types.{ArrayType, StructType} /** - * :: AlphaComponent :: + * :: Experimental :: * Maps a sequence of terms to their term frequencies using the hashing trick. */ -@AlphaComponent -class HashingTF(override val uid: String) extends UnaryTransformer[Iterable[_], Vector, HashingTF] { +@Experimental +class HashingTF(override val uid: String) extends Transformer with HasInputCol with HasOutputCol { def this() = this(Identifiable.randomUID("hashingTF")) + /** @group setParam */ + def setInputCol(value: String): this.type = set(inputCol, value) + + /** @group setParam */ + def setOutputCol(value: String): this.type = set(outputCol, value) + /** * Number of features. Should be > 0. * (default = 2^18^) @@ -50,10 +59,21 @@ class HashingTF(override val uid: String) extends UnaryTransformer[Iterable[_], /** @group setParam */ def setNumFeatures(value: Int): this.type = set(numFeatures, value) - override protected def createTransformFunc: Iterable[_] => Vector = { + override def transform(dataset: DataFrame): DataFrame = { + val outputSchema = transformSchema(dataset.schema) val hashingTF = new feature.HashingTF($(numFeatures)) - hashingTF.transform + val t = udf { terms: Seq[_] => hashingTF.transform(terms) } + val metadata = outputSchema($(outputCol)).metadata + dataset.select(col("*"), t(col($(inputCol))).as($(outputCol), metadata)) + } + + override def transformSchema(schema: StructType): StructType = { + val inputType = schema($(inputCol)).dataType + require(inputType.isInstanceOf[ArrayType], + s"The input column must be ArrayType, but got $inputType.") + val attrGroup = new AttributeGroup($(outputCol), $(numFeatures)) + SchemaUtils.appendColumn(schema, attrGroup.toStructField()) } - override protected def outputDataType: DataType = new VectorUDT() + override def copy(extra: ParamMap): HashingTF = defaultCopy(extra) } diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/IDF.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/IDF.scala index 788c392050c2..ecde80810580 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/IDF.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/IDF.scala @@ -17,7 +17,7 @@ package org.apache.spark.ml.feature -import org.apache.spark.annotation.AlphaComponent +import org.apache.spark.annotation.Experimental import org.apache.spark.ml._ import org.apache.spark.ml.param._ import org.apache.spark.ml.param.shared._ @@ -45,9 +45,6 @@ private[feature] trait IDFBase extends Params with HasInputCol with HasOutputCol /** @group getParam */ def getMinDocFreq: Int = $(minDocFreq) - /** @group setParam */ - def setMinDocFreq(value: Int): this.type = set(minDocFreq, value) - /** * Validate and transform the input schema. */ @@ -58,10 +55,10 @@ private[feature] trait IDFBase extends Params with HasInputCol with HasOutputCol } /** - * :: AlphaComponent :: + * :: Experimental :: * Compute the Inverse Document Frequency (IDF) given a collection of documents. */ -@AlphaComponent +@Experimental final class IDF(override val uid: String) extends Estimator[IDFModel] with IDFBase { def this() = this(Identifiable.randomUID("idf")) @@ -72,6 +69,9 @@ final class IDF(override val uid: String) extends Estimator[IDFModel] with IDFBa /** @group setParam */ def setOutputCol(value: String): this.type = set(outputCol, value) + /** @group setParam */ + def setMinDocFreq(value: Int): this.type = set(minDocFreq, value) + override def fit(dataset: DataFrame): IDFModel = { transformSchema(dataset.schema, logging = true) val input = dataset.select($(inputCol)).map { case Row(v: Vector) => v } @@ -82,13 +82,15 @@ final class IDF(override val uid: String) extends Estimator[IDFModel] with IDFBa override def transformSchema(schema: StructType): StructType = { validateAndTransformSchema(schema) } + + override def copy(extra: ParamMap): IDF = defaultCopy(extra) } /** - * :: AlphaComponent :: + * :: Experimental :: * Model fitted by [[IDF]]. */ -@AlphaComponent +@Experimental class IDFModel private[ml] ( override val uid: String, idfModel: feature.IDFModel) @@ -109,4 +111,9 @@ class IDFModel private[ml] ( override def transformSchema(schema: StructType): StructType = { validateAndTransformSchema(schema) } + + override def copy(extra: ParamMap): IDFModel = { + val copied = new IDFModel(uid, idfModel) + copyValues(copied, extra) + } } diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/MinMaxScaler.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/MinMaxScaler.scala new file mode 100644 index 000000000000..b30adf3df48d --- /dev/null +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/MinMaxScaler.scala @@ -0,0 +1,170 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.ml.feature + +import org.apache.spark.annotation.Experimental +import org.apache.spark.ml.param.shared.{HasInputCol, HasOutputCol} +import org.apache.spark.ml.param.{ParamMap, DoubleParam, Params} +import org.apache.spark.ml.util.Identifiable +import org.apache.spark.ml.{Estimator, Model} +import org.apache.spark.mllib.linalg.{Vector, VectorUDT, Vectors} +import org.apache.spark.mllib.stat.Statistics +import org.apache.spark.sql._ +import org.apache.spark.sql.functions._ +import org.apache.spark.sql.types.{StructField, StructType} + +/** + * Params for [[MinMaxScaler]] and [[MinMaxScalerModel]]. + */ +private[feature] trait MinMaxScalerParams extends Params with HasInputCol with HasOutputCol { + + /** + * lower bound after transformation, shared by all features + * Default: 0.0 + * @group param + */ + val min: DoubleParam = new DoubleParam(this, "min", + "lower bound of the output feature range") + + /** + * upper bound after transformation, shared by all features + * Default: 1.0 + * @group param + */ + val max: DoubleParam = new DoubleParam(this, "max", + "upper bound of the output feature range") + + /** Validates and transforms the input schema. */ + protected def validateAndTransformSchema(schema: StructType): StructType = { + val inputType = schema($(inputCol)).dataType + require(inputType.isInstanceOf[VectorUDT], + s"Input column ${$(inputCol)} must be a vector column") + require(!schema.fieldNames.contains($(outputCol)), + s"Output column ${$(outputCol)} already exists.") + val outputFields = schema.fields :+ StructField($(outputCol), new VectorUDT, false) + StructType(outputFields) + } + + override def validateParams(): Unit = { + require($(min) < $(max), s"The specified min(${$(min)}) is larger or equal to max(${$(max)})") + } +} + +/** + * :: Experimental :: + * Rescale each feature individually to a common range [min, max] linearly using column summary + * statistics, which is also known as min-max normalization or Rescaling. The rescaled value for + * feature E is calculated as, + * + * Rescaled(e_i) = \frac{e_i - E_{min}}{E_{max} - E_{min}} * (max - min) + min + * + * For the case E_{max} == E_{min}, Rescaled(e_i) = 0.5 * (max + min) + * Note that since zero values will probably be transformed to non-zero values, output of the + * transformer will be DenseVector even for sparse input. + */ +@Experimental +class MinMaxScaler(override val uid: String) + extends Estimator[MinMaxScalerModel] with MinMaxScalerParams { + + def this() = this(Identifiable.randomUID("minMaxScal")) + + setDefault(min -> 0.0, max -> 1.0) + + /** @group setParam */ + def setInputCol(value: String): this.type = set(inputCol, value) + + /** @group setParam */ + def setOutputCol(value: String): this.type = set(outputCol, value) + + /** @group setParam */ + def setMin(value: Double): this.type = set(min, value) + + /** @group setParam */ + def setMax(value: Double): this.type = set(max, value) + + override def fit(dataset: DataFrame): MinMaxScalerModel = { + transformSchema(dataset.schema, logging = true) + val input = dataset.select($(inputCol)).map { case Row(v: Vector) => v } + val summary = Statistics.colStats(input) + copyValues(new MinMaxScalerModel(uid, summary.min, summary.max).setParent(this)) + } + + override def transformSchema(schema: StructType): StructType = { + validateAndTransformSchema(schema) + } + + override def copy(extra: ParamMap): MinMaxScaler = defaultCopy(extra) +} + +/** + * :: Experimental :: + * Model fitted by [[MinMaxScaler]]. + * + * TODO: The transformer does not yet set the metadata in the output column (SPARK-8529). + */ +@Experimental +class MinMaxScalerModel private[ml] ( + override val uid: String, + val originalMin: Vector, + val originalMax: Vector) + extends Model[MinMaxScalerModel] with MinMaxScalerParams { + + /** @group setParam */ + def setInputCol(value: String): this.type = set(inputCol, value) + + /** @group setParam */ + def setOutputCol(value: String): this.type = set(outputCol, value) + + /** @group setParam */ + def setMin(value: Double): this.type = set(min, value) + + /** @group setParam */ + def setMax(value: Double): this.type = set(max, value) + + + override def transform(dataset: DataFrame): DataFrame = { + val originalRange = (originalMax.toBreeze - originalMin.toBreeze).toArray + val minArray = originalMin.toArray + + val reScale = udf { (vector: Vector) => + val scale = $(max) - $(min) + + // 0 in sparse vector will probably be rescaled to non-zero + val values = vector.toArray + val size = values.size + var i = 0 + while (i < size) { + val raw = if (originalRange(i) != 0) (values(i) - minArray(i)) / originalRange(i) else 0.5 + values(i) = raw * scale + $(min) + i += 1 + } + Vectors.dense(values) + } + + dataset.withColumn($(outputCol), reScale(col($(inputCol)))) + } + + override def transformSchema(schema: StructType): StructType = { + validateAndTransformSchema(schema) + } + + override def copy(extra: ParamMap): MinMaxScalerModel = { + val copied = new MinMaxScalerModel(uid, originalMin, originalMax) + copyValues(copied, extra) + } +} diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/NGram.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/NGram.scala new file mode 100644 index 000000000000..8de10eb51f92 --- /dev/null +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/NGram.scala @@ -0,0 +1,69 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.ml.feature + +import org.apache.spark.annotation.Experimental +import org.apache.spark.ml.UnaryTransformer +import org.apache.spark.ml.param._ +import org.apache.spark.ml.util.Identifiable +import org.apache.spark.sql.types.{ArrayType, DataType, StringType} + +/** + * :: Experimental :: + * A feature transformer that converts the input array of strings into an array of n-grams. Null + * values in the input array are ignored. + * It returns an array of n-grams where each n-gram is represented by a space-separated string of + * words. + * + * When the input is empty, an empty array is returned. + * When the input array length is less than n (number of elements per n-gram), no n-grams are + * returned. + */ +@Experimental +class NGram(override val uid: String) + extends UnaryTransformer[Seq[String], Seq[String], NGram] { + + def this() = this(Identifiable.randomUID("ngram")) + + /** + * Minimum n-gram length, >= 1. + * Default: 2, bigram features + * @group param + */ + val n: IntParam = new IntParam(this, "n", "number elements per n-gram (>=1)", + ParamValidators.gtEq(1)) + + /** @group setParam */ + def setN(value: Int): this.type = set(n, value) + + /** @group getParam */ + def getN: Int = $(n) + + setDefault(n -> 2) + + override protected def createTransformFunc: Seq[String] => Seq[String] = { + _.iterator.sliding($(n)).withPartial(false).map(_.mkString(" ")).toSeq + } + + override protected def validateInputType(inputType: DataType): Unit = { + require(inputType.sameType(ArrayType(StringType)), + s"Input type must be ArrayType(StringType) but got $inputType.") + } + + override protected def outputDataType: DataType = new ArrayType(StringType, false) +} diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/Normalizer.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/Normalizer.scala index 3f689d1585cd..8282e5ffa17f 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/Normalizer.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/Normalizer.scala @@ -17,7 +17,7 @@ package org.apache.spark.ml.feature -import org.apache.spark.annotation.AlphaComponent +import org.apache.spark.annotation.Experimental import org.apache.spark.ml.UnaryTransformer import org.apache.spark.ml.param.{DoubleParam, ParamValidators} import org.apache.spark.ml.util.Identifiable @@ -26,10 +26,10 @@ import org.apache.spark.mllib.linalg.{Vector, VectorUDT} import org.apache.spark.sql.types.DataType /** - * :: AlphaComponent :: + * :: Experimental :: * Normalize a vector to have unit norm using the given p-norm. */ -@AlphaComponent +@Experimental class Normalizer(override val uid: String) extends UnaryTransformer[Vector, Vector, Normalizer] { def this() = this(Identifiable.randomUID("normalizer")) diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/OneHotEncoder.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/OneHotEncoder.scala index 1fb9b9ae7509..382594279564 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/OneHotEncoder.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/OneHotEncoder.scala @@ -17,93 +17,154 @@ package org.apache.spark.ml.feature -import org.apache.spark.SparkException -import org.apache.spark.annotation.AlphaComponent -import org.apache.spark.ml.UnaryTransformer -import org.apache.spark.ml.attribute.{Attribute, BinaryAttribute, NominalAttribute} -import org.apache.spark.mllib.linalg.{Vector, Vectors, VectorUDT} +import org.apache.spark.annotation.Experimental +import org.apache.spark.ml.Transformer +import org.apache.spark.ml.attribute._ import org.apache.spark.ml.param._ import org.apache.spark.ml.param.shared.{HasInputCol, HasOutputCol} import org.apache.spark.ml.util.{Identifiable, SchemaUtils} -import org.apache.spark.sql.types.{DataType, DoubleType, StructType} +import org.apache.spark.mllib.linalg.Vectors +import org.apache.spark.sql.DataFrame +import org.apache.spark.sql.functions.{col, udf} +import org.apache.spark.sql.types.{DoubleType, StructType} /** - * A one-hot encoder that maps a column of label indices to a column of binary vectors, with - * at most a single one-value. By default, the binary vector has an element for each category, so - * with 5 categories, an input value of 2.0 would map to an output vector of - * (0.0, 0.0, 1.0, 0.0, 0.0). If includeFirst is set to false, the first category is omitted, so the - * output vector for the previous example would be (0.0, 1.0, 0.0, 0.0) and an input value - * of 0.0 would map to a vector of all zeros. Including the first category makes the vector columns - * linearly dependent because they sum up to one. + * :: Experimental :: + * A one-hot encoder that maps a column of category indices to a column of binary vectors, with + * at most a single one-value per row that indicates the input category index. + * For example with 5 categories, an input value of 2.0 would map to an output vector of + * `[0.0, 0.0, 1.0, 0.0]`. + * The last category is not included by default (configurable via [[OneHotEncoder!.dropLast]] + * because it makes the vector entries sum up to one, and hence linearly dependent. + * So an input value of 4.0 maps to `[0.0, 0.0, 0.0, 0.0]`. + * Note that this is different from scikit-learn's OneHotEncoder, which keeps all categories. + * The output vectors are sparse. + * + * @see [[StringIndexer]] for converting categorical values into category indices */ -@AlphaComponent -class OneHotEncoder(override val uid: String) - extends UnaryTransformer[Double, Vector, OneHotEncoder] with HasInputCol with HasOutputCol { +@Experimental +class OneHotEncoder(override val uid: String) extends Transformer + with HasInputCol with HasOutputCol { def this() = this(Identifiable.randomUID("oneHot")) /** - * Whether to include a component in the encoded vectors for the first category, defaults to true. + * Whether to drop the last category in the encoded vector (default: true) * @group param */ - final val includeFirst: BooleanParam = - new BooleanParam(this, "includeFirst", "include first category") - setDefault(includeFirst -> true) - - private var categories: Array[String] = _ + final val dropLast: BooleanParam = + new BooleanParam(this, "dropLast", "whether to drop the last category") + setDefault(dropLast -> true) /** @group setParam */ - def setIncludeFirst(value: Boolean): this.type = set(includeFirst, value) + def setDropLast(value: Boolean): this.type = set(dropLast, value) /** @group setParam */ - override def setInputCol(value: String): this.type = set(inputCol, value) + def setInputCol(value: String): this.type = set(inputCol, value) /** @group setParam */ - override def setOutputCol(value: String): this.type = set(outputCol, value) + def setOutputCol(value: String): this.type = set(outputCol, value) override def transformSchema(schema: StructType): StructType = { - SchemaUtils.checkColumnType(schema, $(inputCol), DoubleType) - val inputFields = schema.fields + val is = "_is_" + val inputColName = $(inputCol) val outputColName = $(outputCol) - require(inputFields.forall(_.name != $(outputCol)), - s"Output column ${$(outputCol)} already exists.") - val inputColAttr = Attribute.fromStructField(schema($(inputCol))) - categories = inputColAttr match { + SchemaUtils.checkColumnType(schema, inputColName, DoubleType) + val inputFields = schema.fields + require(!inputFields.exists(_.name == outputColName), + s"Output column $outputColName already exists.") + + val inputAttr = Attribute.fromStructField(schema(inputColName)) + val outputAttrNames: Option[Array[String]] = inputAttr match { case nominal: NominalAttribute => - nominal.values.getOrElse((0 until nominal.numValues.get).map(_.toString).toArray) - case binary: BinaryAttribute => binary.values.getOrElse(Array("0", "1")) + if (nominal.values.isDefined) { + nominal.values.map(_.map(v => inputColName + is + v)) + } else if (nominal.numValues.isDefined) { + nominal.numValues.map(n => Array.tabulate(n)(i => inputColName + is + i)) + } else { + None + } + case binary: BinaryAttribute => + if (binary.values.isDefined) { + binary.values.map(_.map(v => inputColName + is + v)) + } else { + Some(Array.tabulate(2)(i => inputColName + is + i)) + } + case _: NumericAttribute => + throw new RuntimeException( + s"The input column $inputColName cannot be numeric.") case _ => - throw new SparkException(s"OneHotEncoder input column ${$(inputCol)} is not nominal") + None // optimistic about unknown attributes + } + + val filteredOutputAttrNames = outputAttrNames.map { names => + if ($(dropLast)) { + require(names.length > 1, + s"The input column $inputColName should have at least two distinct values.") + names.dropRight(1) + } else { + names + } } - val attrValues = (if ($(includeFirst)) categories else categories.drop(1)).toArray - val attr = NominalAttribute.defaultAttr.withName(outputColName).withValues(attrValues) - val outputFields = inputFields :+ attr.toStructField() + val outputAttrGroup = if (filteredOutputAttrNames.isDefined) { + val attrs: Array[Attribute] = filteredOutputAttrNames.get.map { name => + BinaryAttribute.defaultAttr.withName(name) + } + new AttributeGroup($(outputCol), attrs) + } else { + new AttributeGroup($(outputCol)) + } + + val outputFields = inputFields :+ outputAttrGroup.toStructField() StructType(outputFields) } - protected override def createTransformFunc(): (Double) => Vector = { - val first = $(includeFirst) - val vecLen = if (first) categories.length else categories.length - 1 + override def transform(dataset: DataFrame): DataFrame = { + // schema transformation + val is = "_is_" + val inputColName = $(inputCol) + val outputColName = $(outputCol) + val shouldDropLast = $(dropLast) + var outputAttrGroup = AttributeGroup.fromStructField( + transformSchema(dataset.schema)(outputColName)) + if (outputAttrGroup.size < 0) { + // If the number of attributes is unknown, we check the values from the input column. + val numAttrs = dataset.select(col(inputColName).cast(DoubleType)).map(_.getDouble(0)) + .aggregate(0.0)( + (m, x) => { + assert(x >=0.0 && x == x.toInt, + s"Values from column $inputColName must be indices, but got $x.") + math.max(m, x) + }, + (m0, m1) => { + math.max(m0, m1) + } + ).toInt + 1 + val outputAttrNames = Array.tabulate(numAttrs)(i => inputColName + is + i) + val filtered = if (shouldDropLast) outputAttrNames.dropRight(1) else outputAttrNames + val outputAttrs: Array[Attribute] = + filtered.map(name => BinaryAttribute.defaultAttr.withName(name)) + outputAttrGroup = new AttributeGroup(outputColName, outputAttrs) + } + val metadata = outputAttrGroup.toMetadata() + + // data transformation + val size = outputAttrGroup.size val oneValue = Array(1.0) val emptyValues = Array[Double]() val emptyIndices = Array[Int]() - label: Double => { - val values = if (first || label != 0.0) oneValue else emptyValues - val indices = if (first) { - Array(label.toInt) - } else if (label != 0.0) { - Array(label.toInt - 1) + val encode = udf { label: Double => + if (label < size) { + Vectors.sparse(size, Array(label.toInt), oneValue) } else { - emptyIndices + Vectors.sparse(size, emptyIndices, emptyValues) } - Vectors.sparse(vecLen, indices, values) } + + dataset.select(col("*"), encode(col(inputColName).cast(DoubleType)).as(outputColName, metadata)) } - /** - * Returns the data type of the output column. - */ - protected def outputDataType: DataType = new VectorUDT + override def copy(extra: ParamMap): OneHotEncoder = defaultCopy(extra) } diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/PCA.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/PCA.scala new file mode 100644 index 000000000000..2d3bb680cf30 --- /dev/null +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/PCA.scala @@ -0,0 +1,130 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.ml.feature + +import org.apache.spark.annotation.Experimental +import org.apache.spark.ml._ +import org.apache.spark.ml.param._ +import org.apache.spark.ml.param.shared._ +import org.apache.spark.ml.util.Identifiable +import org.apache.spark.mllib.feature +import org.apache.spark.mllib.linalg.{Vector, VectorUDT} +import org.apache.spark.sql._ +import org.apache.spark.sql.functions._ +import org.apache.spark.sql.types.{StructField, StructType} + +/** + * Params for [[PCA]] and [[PCAModel]]. + */ +private[feature] trait PCAParams extends Params with HasInputCol with HasOutputCol { + + /** + * The number of principal components. + * @group param + */ + final val k: IntParam = new IntParam(this, "k", "the number of principal components") + + /** @group getParam */ + def getK: Int = $(k) + +} + +/** + * :: Experimental :: + * PCA trains a model to project vectors to a low-dimensional space using PCA. + */ +@Experimental +class PCA (override val uid: String) extends Estimator[PCAModel] with PCAParams { + + def this() = this(Identifiable.randomUID("pca")) + + /** @group setParam */ + def setInputCol(value: String): this.type = set(inputCol, value) + + /** @group setParam */ + def setOutputCol(value: String): this.type = set(outputCol, value) + + /** @group setParam */ + def setK(value: Int): this.type = set(k, value) + + /** + * Computes a [[PCAModel]] that contains the principal components of the input vectors. + */ + override def fit(dataset: DataFrame): PCAModel = { + transformSchema(dataset.schema, logging = true) + val input = dataset.select($(inputCol)).map { case Row(v: Vector) => v} + val pca = new feature.PCA(k = $(k)) + val pcaModel = pca.fit(input) + copyValues(new PCAModel(uid, pcaModel).setParent(this)) + } + + override def transformSchema(schema: StructType): StructType = { + val inputType = schema($(inputCol)).dataType + require(inputType.isInstanceOf[VectorUDT], + s"Input column ${$(inputCol)} must be a vector column") + require(!schema.fieldNames.contains($(outputCol)), + s"Output column ${$(outputCol)} already exists.") + val outputFields = schema.fields :+ StructField($(outputCol), new VectorUDT, false) + StructType(outputFields) + } + + override def copy(extra: ParamMap): PCA = defaultCopy(extra) +} + +/** + * :: Experimental :: + * Model fitted by [[PCA]]. + */ +@Experimental +class PCAModel private[ml] ( + override val uid: String, + pcaModel: feature.PCAModel) + extends Model[PCAModel] with PCAParams { + + /** @group setParam */ + def setInputCol(value: String): this.type = set(inputCol, value) + + /** @group setParam */ + def setOutputCol(value: String): this.type = set(outputCol, value) + + /** + * Transform a vector by computed Principal Components. + * NOTE: Vectors to be transformed must be the same length + * as the source vectors given to [[PCA.fit()]]. + */ + override def transform(dataset: DataFrame): DataFrame = { + transformSchema(dataset.schema, logging = true) + val pcaOp = udf { pcaModel.transform _ } + dataset.withColumn($(outputCol), pcaOp(col($(inputCol)))) + } + + override def transformSchema(schema: StructType): StructType = { + val inputType = schema($(inputCol)).dataType + require(inputType.isInstanceOf[VectorUDT], + s"Input column ${$(inputCol)} must be a vector column") + require(!schema.fieldNames.contains($(outputCol)), + s"Output column ${$(outputCol)} already exists.") + val outputFields = schema.fields :+ StructField($(outputCol), new VectorUDT, false) + StructType(outputFields) + } + + override def copy(extra: ParamMap): PCAModel = { + val copied = new PCAModel(uid, pcaModel) + copyValues(copied, extra) + } +} diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/PolynomialExpansion.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/PolynomialExpansion.scala index 41564410e496..d85e468562d4 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/PolynomialExpansion.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/PolynomialExpansion.scala @@ -19,22 +19,22 @@ package org.apache.spark.ml.feature import scala.collection.mutable -import org.apache.spark.annotation.AlphaComponent +import org.apache.spark.annotation.Experimental import org.apache.spark.ml.UnaryTransformer -import org.apache.spark.ml.param.{IntParam, ParamValidators} +import org.apache.spark.ml.param.{ParamMap, IntParam, ParamValidators} import org.apache.spark.ml.util.Identifiable import org.apache.spark.mllib.linalg._ import org.apache.spark.sql.types.DataType /** - * :: AlphaComponent :: + * :: Experimental :: * Perform feature expansion in a polynomial space. As said in wikipedia of Polynomial Expansion, * which is available at [[http://en.wikipedia.org/wiki/Polynomial_expansion]], "In mathematics, an * expansion of a product of sums expresses it as a sum of products by using the fact that * multiplication distributes over addition". Take a 2-variable feature vector as an example: * `(x, y)`, if we want to expand it with degree 2, then we get `(x, x * x, y, x * y, y * y)`. */ -@AlphaComponent +@Experimental class PolynomialExpansion(override val uid: String) extends UnaryTransformer[Vector, Vector, PolynomialExpansion] { @@ -61,6 +61,8 @@ class PolynomialExpansion(override val uid: String) } override protected def outputDataType: DataType = new VectorUDT() + + override def copy(extra: ParamMap): PolynomialExpansion = defaultCopy(extra) } /** @@ -75,7 +77,7 @@ class PolynomialExpansion(override val uid: String) * To handle sparsity, if c is zero, we can skip all monomials that contain it. We remember the * current index and increment it properly for sparse input. */ -object PolynomialExpansion { +private[feature] object PolynomialExpansion { private def choose(n: Int, k: Int): Int = { Range(n, n - k, -1).product / Range(k, 1, -1).product diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/StandardScaler.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/StandardScaler.scala index 5ccda15d872e..ca3c1cfb56b7 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/StandardScaler.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/StandardScaler.scala @@ -17,7 +17,7 @@ package org.apache.spark.ml.feature -import org.apache.spark.annotation.AlphaComponent +import org.apache.spark.annotation.Experimental import org.apache.spark.ml._ import org.apache.spark.ml.param._ import org.apache.spark.ml.param.shared._ @@ -35,13 +35,13 @@ private[feature] trait StandardScalerParams extends Params with HasInputCol with /** * Centers the data with mean before scaling. - * It will build a dense output, so this does not work on sparse input + * It will build a dense output, so this does not work on sparse input * and will raise an exception. * Default: false * @group param */ val withMean: BooleanParam = new BooleanParam(this, "withMean", "Center data with mean") - + /** * Scales the data to unit standard deviation. * Default: true @@ -51,11 +51,11 @@ private[feature] trait StandardScalerParams extends Params with HasInputCol with } /** - * :: AlphaComponent :: + * :: Experimental :: * Standardizes features by removing the mean and scaling to unit variance using column summary * statistics on the samples in the training set. */ -@AlphaComponent +@Experimental class StandardScaler(override val uid: String) extends Estimator[StandardScalerModel] with StandardScalerParams { @@ -68,13 +68,13 @@ class StandardScaler(override val uid: String) extends Estimator[StandardScalerM /** @group setParam */ def setOutputCol(value: String): this.type = set(outputCol, value) - + /** @group setParam */ def setWithMean(value: Boolean): this.type = set(withMean, value) - + /** @group setParam */ def setWithStd(value: Boolean): this.type = set(withStd, value) - + override def fit(dataset: DataFrame): StandardScalerModel = { transformSchema(dataset.schema, logging = true) val input = dataset.select($(inputCol)).map { case Row(v: Vector) => v } @@ -92,13 +92,15 @@ class StandardScaler(override val uid: String) extends Estimator[StandardScalerM val outputFields = schema.fields :+ StructField($(outputCol), new VectorUDT, false) StructType(outputFields) } + + override def copy(extra: ParamMap): StandardScaler = defaultCopy(extra) } /** - * :: AlphaComponent :: + * :: Experimental :: * Model fitted by [[StandardScaler]]. */ -@AlphaComponent +@Experimental class StandardScalerModel private[ml] ( override val uid: String, scaler: feature.StandardScalerModel) @@ -125,4 +127,9 @@ class StandardScalerModel private[ml] ( val outputFields = schema.fields :+ StructField($(outputCol), new VectorUDT, false) StructType(outputFields) } + + override def copy(extra: ParamMap): StandardScalerModel = { + val copied = new StandardScalerModel(uid, scaler) + copyValues(copied, extra) + } } diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/StringIndexer.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/StringIndexer.scala index 3f79b67309f0..bf7be363b822 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/StringIndexer.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/StringIndexer.scala @@ -18,7 +18,7 @@ package org.apache.spark.ml.feature import org.apache.spark.SparkException -import org.apache.spark.annotation.AlphaComponent +import org.apache.spark.annotation.Experimental import org.apache.spark.ml.{Estimator, Model} import org.apache.spark.ml.attribute.NominalAttribute import org.apache.spark.ml.param._ @@ -52,13 +52,13 @@ private[feature] trait StringIndexerBase extends Params with HasInputCol with Ha } /** - * :: AlphaComponent :: + * :: Experimental :: * A label indexer that maps a string column of labels to an ML column of label indices. * If the input column is numeric, we cast it to string and index the string values. * The indices are in [0, numLabels), ordered by label frequencies. * So the most frequent label gets index 0. */ -@AlphaComponent +@Experimental class StringIndexer(override val uid: String) extends Estimator[StringIndexerModel] with StringIndexerBase { @@ -83,13 +83,18 @@ class StringIndexer(override val uid: String) extends Estimator[StringIndexerMod override def transformSchema(schema: StructType): StructType = { validateAndTransformSchema(schema) } + + override def copy(extra: ParamMap): StringIndexer = defaultCopy(extra) } /** - * :: AlphaComponent :: + * :: Experimental :: * Model fitted by [[StringIndexer]]. + * NOTE: During transformation, if the input column does not exist, + * [[StringIndexerModel.transform]] would return the input dataset unmodified. + * This is a temporary fix for the case when target labels do not exist during prediction. */ -@AlphaComponent +@Experimental class StringIndexerModel private[ml] ( override val uid: String, labels: Array[String]) extends Model[StringIndexerModel] with StringIndexerBase { @@ -112,6 +117,12 @@ class StringIndexerModel private[ml] ( def setOutputCol(value: String): this.type = set(outputCol, value) override def transform(dataset: DataFrame): DataFrame = { + if (!dataset.schema.fieldNames.contains($(inputCol))) { + logInfo(s"Input column ${$(inputCol)} does not exist during transformation. " + + "Skip StringIndexerModel.") + return dataset + } + val indexer = udf { label: String => if (labelToIndex.contains(label)) { labelToIndex(label) @@ -128,6 +139,16 @@ class StringIndexerModel private[ml] ( } override def transformSchema(schema: StructType): StructType = { - validateAndTransformSchema(schema) + if (schema.fieldNames.contains($(inputCol))) { + validateAndTransformSchema(schema) + } else { + // If the input column does not exist during transformation, we skip StringIndexerModel. + schema + } + } + + override def copy(extra: ParamMap): StringIndexerModel = { + val copied = new StringIndexerModel(uid, labels) + copyValues(copied, extra) } } diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/Tokenizer.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/Tokenizer.scala index 36d9e17eca41..5f9f57a2ebcf 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/Tokenizer.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/Tokenizer.scala @@ -17,17 +17,19 @@ package org.apache.spark.ml.feature -import org.apache.spark.annotation.AlphaComponent +import org.apache.spark.annotation.Experimental import org.apache.spark.ml.UnaryTransformer import org.apache.spark.ml.param._ import org.apache.spark.ml.util.Identifiable import org.apache.spark.sql.types.{ArrayType, DataType, StringType} /** - * :: AlphaComponent :: + * :: Experimental :: * A tokenizer that converts the input string to lowercase and then splits it by white spaces. + * + * @see [[RegexTokenizer]] */ -@AlphaComponent +@Experimental class Tokenizer(override val uid: String) extends UnaryTransformer[String, Seq[String], Tokenizer] { def this() = this(Identifiable.randomUID("tok")) @@ -41,16 +43,18 @@ class Tokenizer(override val uid: String) extends UnaryTransformer[String, Seq[S } override protected def outputDataType: DataType = new ArrayType(StringType, false) + + override def copy(extra: ParamMap): Tokenizer = defaultCopy(extra) } /** - * :: AlphaComponent :: - * A regex based tokenizer that extracts tokens either by repeatedly matching the regex(default) - * or using it to split the text (set matching to false). Optional parameters also allow filtering - * tokens using a minimal length. + * :: Experimental :: + * A regex based tokenizer that extracts tokens either by using the provided regex pattern to split + * the text (default) or repeatedly matching the regex (if `gaps` is true). + * Optional parameters also allow filtering tokens using a minimal length. * It returns an array of strings that can be empty. */ -@AlphaComponent +@Experimental class RegexTokenizer(override val uid: String) extends UnaryTransformer[String, Seq[String], RegexTokenizer] { @@ -61,7 +65,7 @@ class RegexTokenizer(override val uid: String) * Default: 1, to avoid returning empty strings * @group param */ - val minTokenLength: IntParam = new IntParam(this, "minLength", "minimum token length (>= 0)", + val minTokenLength: IntParam = new IntParam(this, "minTokenLength", "minimum token length (>= 0)", ParamValidators.gtEq(0)) /** @group setParam */ @@ -71,8 +75,8 @@ class RegexTokenizer(override val uid: String) def getMinTokenLength: Int = $(minTokenLength) /** - * Indicates whether regex splits on gaps (true) or matching tokens (false). - * Default: false + * Indicates whether regex splits on gaps (true) or matches tokens (false). + * Default: true * @group param */ val gaps: BooleanParam = new BooleanParam(this, "gaps", "Set regex to match gaps or tokens") @@ -84,8 +88,8 @@ class RegexTokenizer(override val uid: String) def getGaps: Boolean = $(gaps) /** - * Regex pattern used by tokenizer. - * Default: `"\\p{L}+|[^\\p{L}\\s]+"` + * Regex pattern used to match delimiters if [[gaps]] is true or tokens if [[gaps]] is false. + * Default: `"\\s+"` * @group param */ val pattern: Param[String] = new Param(this, "pattern", "regex pattern used for tokenizing") @@ -96,7 +100,7 @@ class RegexTokenizer(override val uid: String) /** @group getParam */ def getPattern: String = $(pattern) - setDefault(minTokenLength -> 1, gaps -> false, pattern -> "\\p{L}+|[^\\p{L}\\s]+") + setDefault(minTokenLength -> 1, gaps -> true, pattern -> "\\s+") override protected def createTransformFunc: String => Seq[String] = { str => val re = $(pattern).r @@ -110,4 +114,6 @@ class RegexTokenizer(override val uid: String) } override protected def outputDataType: DataType = new ArrayType(StringType, false) + + override def copy(extra: ParamMap): RegexTokenizer = defaultCopy(extra) } diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/VectorAssembler.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/VectorAssembler.scala index 1c0009476908..9f83c2ee1617 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/VectorAssembler.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/VectorAssembler.scala @@ -20,8 +20,10 @@ package org.apache.spark.ml.feature import scala.collection.mutable.ArrayBuilder import org.apache.spark.SparkException -import org.apache.spark.annotation.AlphaComponent +import org.apache.spark.annotation.Experimental import org.apache.spark.ml.Transformer +import org.apache.spark.ml.attribute.{Attribute, AttributeGroup, NumericAttribute, UnresolvedAttribute} +import org.apache.spark.ml.param.ParamMap import org.apache.spark.ml.param.shared._ import org.apache.spark.ml.util.Identifiable import org.apache.spark.mllib.linalg.{Vector, VectorUDT, Vectors} @@ -30,14 +32,14 @@ import org.apache.spark.sql.functions._ import org.apache.spark.sql.types._ /** - * :: AlphaComponent :: + * :: Experimental :: * A feature transformer that merges multiple columns into a vector column. */ -@AlphaComponent +@Experimental class VectorAssembler(override val uid: String) extends Transformer with HasInputCols with HasOutputCol { - def this() = this(Identifiable.randomUID("va")) + def this() = this(Identifiable.randomUID("vecAssembler")) /** @group setParam */ def setInputCols(value: Array[String]): this.type = set(inputCols, value) @@ -46,19 +48,59 @@ class VectorAssembler(override val uid: String) def setOutputCol(value: String): this.type = set(outputCol, value) override def transform(dataset: DataFrame): DataFrame = { + // Schema transformation. + val schema = dataset.schema + lazy val first = dataset.first() + val attrs = $(inputCols).flatMap { c => + val field = schema(c) + val index = schema.fieldIndex(c) + field.dataType match { + case DoubleType => + val attr = Attribute.fromStructField(field) + // If the input column doesn't have ML attribute, assume numeric. + if (attr == UnresolvedAttribute) { + Some(NumericAttribute.defaultAttr.withName(c)) + } else { + Some(attr.withName(c)) + } + case _: NumericType | BooleanType => + // If the input column type is a compatible scalar type, assume numeric. + Some(NumericAttribute.defaultAttr.withName(c)) + case _: VectorUDT => + val group = AttributeGroup.fromStructField(field) + if (group.attributes.isDefined) { + // If attributes are defined, copy them with updated names. + group.attributes.get.map { attr => + if (attr.name.isDefined) { + // TODO: Define a rigorous naming scheme. + attr.withName(c + "_" + attr.name.get) + } else { + attr + } + } + } else { + // Otherwise, treat all attributes as numeric. If we cannot get the number of attributes + // from metadata, check the first row. + val numAttrs = group.numAttributes.getOrElse(first.getAs[Vector](index).size) + Array.fill(numAttrs)(NumericAttribute.defaultAttr) + } + } + } + val metadata = new AttributeGroup($(outputCol), attrs).toMetadata() + + // Data transformation. val assembleFunc = udf { r: Row => VectorAssembler.assemble(r.toSeq: _*) } - val schema = dataset.schema - val inputColNames = $(inputCols) - val args = inputColNames.map { c => + val args = $(inputCols).map { c => schema(c).dataType match { case DoubleType => dataset(c) case _: VectorUDT => dataset(c) case _: NumericType | BooleanType => dataset(c).cast(DoubleType).as(s"${c}_double_$uid") } } - dataset.select(col("*"), assembleFunc(struct(args : _*)).as($(outputCol))) + + dataset.select(col("*"), assembleFunc(struct(args : _*)).as($(outputCol), metadata)) } override def transformSchema(schema: StructType): StructType = { @@ -76,10 +118,11 @@ class VectorAssembler(override val uid: String) } StructType(schema.fields :+ new StructField(outputColName, new VectorUDT, false)) } + + override def copy(extra: ParamMap): VectorAssembler = defaultCopy(extra) } -@AlphaComponent -object VectorAssembler { +private object VectorAssembler { private[feature] def assemble(vv: Any*): Vector = { val indices = ArrayBuilder.make[Int] diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/VectorIndexer.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/VectorIndexer.scala index 6d1d0524e59e..c73bdccdef5f 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/VectorIndexer.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/VectorIndexer.scala @@ -17,15 +17,20 @@ package org.apache.spark.ml.feature -import org.apache.spark.annotation.AlphaComponent +import java.lang.{Double => JDouble, Integer => JInt} +import java.util.{Map => JMap} + +import scala.collection.JavaConverters._ + +import org.apache.spark.annotation.Experimental import org.apache.spark.ml.{Estimator, Model} import org.apache.spark.ml.attribute._ -import org.apache.spark.ml.param.{IntParam, ParamValidators, Params} +import org.apache.spark.ml.param.{IntParam, ParamMap, ParamValidators, Params} import org.apache.spark.ml.param.shared._ import org.apache.spark.ml.util.{Identifiable, SchemaUtils} import org.apache.spark.mllib.linalg.{DenseVector, SparseVector, Vector, VectorUDT} import org.apache.spark.sql.{DataFrame, Row} -import org.apache.spark.sql.functions.callUDF +import org.apache.spark.sql.functions.udf import org.apache.spark.sql.types.{StructField, StructType} import org.apache.spark.util.collection.OpenHashSet @@ -51,8 +56,7 @@ private[ml] trait VectorIndexerParams extends Params with HasInputCol with HasOu } /** - * :: AlphaComponent :: - * + * :: Experimental :: * Class for indexing categorical feature columns in a dataset of [[Vector]]. * * This has 2 usage modes: @@ -86,7 +90,7 @@ private[ml] trait VectorIndexerParams extends Params with HasInputCol with HasOu * - Add warning if a categorical feature has only 1 category. * - Add option for allowing unknown categories. */ -@AlphaComponent +@Experimental class VectorIndexer(override val uid: String) extends Estimator[VectorIndexerModel] with VectorIndexerParams { @@ -127,6 +131,8 @@ class VectorIndexer(override val uid: String) extends Estimator[VectorIndexerMod SchemaUtils.checkColumnType(schema, $(inputCol), dataType) SchemaUtils.appendColumn(schema, $(outputCol), dataType) } + + override def copy(extra: ParamMap): VectorIndexer = defaultCopy(extra) } private object VectorIndexer { @@ -225,8 +231,7 @@ private object VectorIndexer { } /** - * :: AlphaComponent :: - * + * :: Experimental :: * Transform categorical features to use 0-based indices instead of their original values. * - Categorical features are mapped to indices. * - Continuous features (columns) are left unchanged. @@ -241,13 +246,18 @@ private object VectorIndexer { * Values are maps from original features values to 0-based category indices. * If a feature is not in this map, it is treated as continuous. */ -@AlphaComponent +@Experimental class VectorIndexerModel private[ml] ( override val uid: String, val numFeatures: Int, val categoryMaps: Map[Int, Map[Double, Int]]) extends Model[VectorIndexerModel] with VectorIndexerParams { + /** Java-friendly version of [[categoryMaps]] */ + def javaCategoryMaps: JMap[JInt, JMap[JDouble, JInt]] = { + categoryMaps.mapValues(_.asJava).asJava.asInstanceOf[JMap[JInt, JMap[JDouble, JInt]]] + } + /** * Pre-computed feature attributes, with some missing info. * In transform(), set attribute name and other info, if available. @@ -329,7 +339,8 @@ class VectorIndexerModel private[ml] ( override def transform(dataset: DataFrame): DataFrame = { transformSchema(dataset.schema, logging = true) val newField = prepOutputField(dataset.schema) - val newCol = callUDF(transformFunc, new VectorUDT, dataset($(inputCol))) + val transformUDF = udf { (vector: Vector) => transformFunc(vector) } + val newCol = transformUDF(dataset($(inputCol))) dataset.withColumn($(outputCol), newCol.as($(outputCol), newField.metadata)) } @@ -391,4 +402,9 @@ class VectorIndexerModel private[ml] ( val newAttributeGroup = new AttributeGroup($(outputCol), featureAttributes) newAttributeGroup.toStructField() } + + override def copy(extra: ParamMap): VectorIndexerModel = { + val copied = new VectorIndexerModel(uid, numFeatures, categoryMaps) + copyValues(copied, extra) + } } diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/Word2Vec.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/Word2Vec.scala index 8ace8c53bb66..6ea659095630 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/Word2Vec.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/Word2Vec.scala @@ -17,7 +17,7 @@ package org.apache.spark.ml.feature -import org.apache.spark.annotation.AlphaComponent +import org.apache.spark.annotation.Experimental import org.apache.spark.ml.{Estimator, Model} import org.apache.spark.ml.param._ import org.apache.spark.ml.param.shared._ @@ -37,6 +37,7 @@ private[feature] trait Word2VecBase extends Params /** * The dimension of the code that you want to transform from words. + * @group param */ final val vectorSize = new IntParam( this, "vectorSize", "the dimension of codes after transforming from words") @@ -47,6 +48,7 @@ private[feature] trait Word2VecBase extends Params /** * Number of partitions for sentences of words. + * @group param */ final val numPartitions = new IntParam( this, "numPartitions", "number of partitions for sentences of words") @@ -58,6 +60,7 @@ private[feature] trait Word2VecBase extends Params /** * The minimum number of times a token must appear to be included in the word2vec model's * vocabulary. + * @group param */ final val minCount = new IntParam(this, "minCount", "the minimum number of times a token must " + "appear to be included in the word2vec model's vocabulary") @@ -68,7 +71,6 @@ private[feature] trait Word2VecBase extends Params setDefault(stepSize -> 0.025) setDefault(maxIter -> 1) - setDefault(seed -> 42L) /** * Validate and transform the input schema. @@ -80,11 +82,11 @@ private[feature] trait Word2VecBase extends Params } /** - * :: AlphaComponent :: + * :: Experimental :: * Word2Vec trains a model of `Map(String, Vector)`, i.e. transforms a word into a code for further * natural language processing or machine learning process. */ -@AlphaComponent +@Experimental final class Word2Vec(override val uid: String) extends Estimator[Word2VecModel] with Word2VecBase { def this() = this(Identifiable.randomUID("w2v")) @@ -130,13 +132,15 @@ final class Word2Vec(override val uid: String) extends Estimator[Word2VecModel] override def transformSchema(schema: StructType): StructType = { validateAndTransformSchema(schema) } + + override def copy(extra: ParamMap): Word2Vec = defaultCopy(extra) } /** - * :: AlphaComponent :: + * :: Experimental :: * Model fitted by [[Word2Vec]]. */ -@AlphaComponent +@Experimental class Word2VecModel private[ml] ( override val uid: String, wordVectors: feature.Word2VecModel) @@ -178,4 +182,9 @@ class Word2VecModel private[ml] ( override def transformSchema(schema: StructType): StructType = { validateAndTransformSchema(schema) } + + override def copy(extra: ParamMap): Word2VecModel = { + val copied = new Word2VecModel(uid, wordVectors) + copyValues(copied, extra) + } } diff --git a/mllib/src/main/scala/org/apache/spark/ml/package-info.java b/mllib/src/main/scala/org/apache/spark/ml/package-info.java index 00d9c802e930..87f4223964ad 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/package-info.java +++ b/mllib/src/main/scala/org/apache/spark/ml/package-info.java @@ -16,10 +16,10 @@ */ /** - * Spark ML is an ALPHA component that adds a new set of machine learning APIs to let users quickly + * Spark ML is a BETA component that adds a new set of machine learning APIs to let users quickly * assemble and configure practical machine learning pipelines. */ -@AlphaComponent +@Experimental package org.apache.spark.ml; -import org.apache.spark.annotation.AlphaComponent; +import org.apache.spark.annotation.Experimental; diff --git a/mllib/src/main/scala/org/apache/spark/ml/package.scala b/mllib/src/main/scala/org/apache/spark/ml/package.scala index ac75e9de1a8f..c589d06d9f7e 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/package.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/package.scala @@ -18,7 +18,7 @@ package org.apache.spark /** - * Spark ML is an ALPHA component that adds a new set of machine learning APIs to let users quickly + * Spark ML is a BETA component that adds a new set of machine learning APIs to let users quickly * assemble and configure practical machine learning pipelines. * * @groupname param Parameters diff --git a/mllib/src/main/scala/org/apache/spark/ml/param/params.scala b/mllib/src/main/scala/org/apache/spark/ml/param/params.scala index 247e08be1bb1..50c0d855066f 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/param/params.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/param/params.scala @@ -24,11 +24,11 @@ import scala.annotation.varargs import scala.collection.mutable import scala.collection.JavaConverters._ -import org.apache.spark.annotation.AlphaComponent +import org.apache.spark.annotation.{DeveloperApi, Experimental} import org.apache.spark.ml.util.Identifiable /** - * :: AlphaComponent :: + * :: DeveloperApi :: * A param with self-contained documentation and optionally default value. Primitive-typed param * should use the specialized versions, which are more friendly to Java users. * @@ -39,7 +39,7 @@ import org.apache.spark.ml.util.Identifiable * See [[ParamValidators]] for factory methods for common validation functions. * @tparam T param value type */ -@AlphaComponent +@DeveloperApi class Param[T](val parent: String, val name: String, val doc: String, val isValid: T => Boolean) extends Serializable { @@ -69,14 +69,10 @@ class Param[T](val parent: String, val name: String, val doc: String, val isVali } } - /** - * Creates a param pair with the given value (for Java). - */ + /** Creates a param pair with the given value (for Java). */ def w(value: T): ParamPair[T] = this -> value - /** - * Creates a param pair with the given value (for Scala). - */ + /** Creates a param pair with the given value (for Scala). */ def ->(value: T): ParamPair[T] = ParamPair(this, value) override final def toString: String = s"${parent}__$name" @@ -92,9 +88,11 @@ class Param[T](val parent: String, val name: String, val doc: String, val isVali } /** + * :: DeveloperApi :: * Factory methods for common validation functions for [[Param.isValid]]. * The numerical methods only support Int, Long, Float, and Double. */ +@DeveloperApi object ParamValidators { /** (private[param]) Default validation always return true */ @@ -172,7 +170,11 @@ object ParamValidators { // specialize primitive-typed params because Java doesn't recognize scala.Double, scala.Int, ... -/** Specialized version of [[Param[Double]]] for Java. */ +/** + * :: DeveloperApi :: + * Specialized version of [[Param[Double]]] for Java. + */ +@DeveloperApi class DoubleParam(parent: String, name: String, doc: String, isValid: Double => Boolean) extends Param[Double](parent, name, doc, isValid) { @@ -184,10 +186,15 @@ class DoubleParam(parent: String, name: String, doc: String, isValid: Double => def this(parent: Identifiable, name: String, doc: String) = this(parent.uid, name, doc) + /** Creates a param pair with the given value (for Java). */ override def w(value: Double): ParamPair[Double] = super.w(value) } -/** Specialized version of [[Param[Int]]] for Java. */ +/** + * :: DeveloperApi :: + * Specialized version of [[Param[Int]]] for Java. + */ +@DeveloperApi class IntParam(parent: String, name: String, doc: String, isValid: Int => Boolean) extends Param[Int](parent, name, doc, isValid) { @@ -199,10 +206,15 @@ class IntParam(parent: String, name: String, doc: String, isValid: Int => Boolea def this(parent: Identifiable, name: String, doc: String) = this(parent.uid, name, doc) + /** Creates a param pair with the given value (for Java). */ override def w(value: Int): ParamPair[Int] = super.w(value) } -/** Specialized version of [[Param[Float]]] for Java. */ +/** + * :: DeveloperApi :: + * Specialized version of [[Param[Float]]] for Java. + */ +@DeveloperApi class FloatParam(parent: String, name: String, doc: String, isValid: Float => Boolean) extends Param[Float](parent, name, doc, isValid) { @@ -214,10 +226,15 @@ class FloatParam(parent: String, name: String, doc: String, isValid: Float => Bo def this(parent: Identifiable, name: String, doc: String) = this(parent.uid, name, doc) + /** Creates a param pair with the given value (for Java). */ override def w(value: Float): ParamPair[Float] = super.w(value) } -/** Specialized version of [[Param[Long]]] for Java. */ +/** + * :: DeveloperApi :: + * Specialized version of [[Param[Long]]] for Java. + */ +@DeveloperApi class LongParam(parent: String, name: String, doc: String, isValid: Long => Boolean) extends Param[Long](parent, name, doc, isValid) { @@ -229,47 +246,60 @@ class LongParam(parent: String, name: String, doc: String, isValid: Long => Bool def this(parent: Identifiable, name: String, doc: String) = this(parent.uid, name, doc) + /** Creates a param pair with the given value (for Java). */ override def w(value: Long): ParamPair[Long] = super.w(value) } -/** Specialized version of [[Param[Boolean]]] for Java. */ +/** + * :: DeveloperApi :: + * Specialized version of [[Param[Boolean]]] for Java. + */ +@DeveloperApi class BooleanParam(parent: String, name: String, doc: String) // No need for isValid extends Param[Boolean](parent, name, doc) { def this(parent: Identifiable, name: String, doc: String) = this(parent.uid, name, doc) + /** Creates a param pair with the given value (for Java). */ override def w(value: Boolean): ParamPair[Boolean] = super.w(value) } -/** Specialized version of [[Param[Array[String]]]] for Java. */ +/** + * :: DeveloperApi :: + * Specialized version of [[Param[Array[String]]]] for Java. + */ +@DeveloperApi class StringArrayParam(parent: Params, name: String, doc: String, isValid: Array[String] => Boolean) extends Param[Array[String]](parent, name, doc, isValid) { def this(parent: Params, name: String, doc: String) = this(parent, name, doc, ParamValidators.alwaysTrue) - override def w(value: Array[String]): ParamPair[Array[String]] = super.w(value) - /** Creates a param pair with a [[java.util.List]] of values (for Java and Python). */ def w(value: java.util.List[String]): ParamPair[Array[String]] = w(value.asScala.toArray) } -/** Specialized version of [[Param[Array[Double]]]] for Java. */ +/** + * :: DeveloperApi :: + * Specialized version of [[Param[Array[Double]]]] for Java. + */ +@DeveloperApi class DoubleArrayParam(parent: Params, name: String, doc: String, isValid: Array[Double] => Boolean) extends Param[Array[Double]](parent, name, doc, isValid) { def this(parent: Params, name: String, doc: String) = this(parent, name, doc, ParamValidators.alwaysTrue) - override def w(value: Array[Double]): ParamPair[Array[Double]] = super.w(value) - /** Creates a param pair with a [[java.util.List]] of values (for Java and Python). */ - def w(value: java.util.List[Double]): ParamPair[Array[Double]] = w(value.asScala.toArray) + def w(value: java.util.List[java.lang.Double]): ParamPair[Array[Double]] = + w(value.asScala.map(_.asInstanceOf[Double]).toArray) } /** - * A param amd its value. + * :: Experimental :: + * A param and its value. */ +@Experimental case class ParamPair[T](param: Param[T], value: T) { // This is *the* place Param.validate is called. Whenever a parameter is specified, we should // always construct a ParamPair so that validate is called. @@ -277,11 +307,11 @@ case class ParamPair[T](param: Param[T], value: T) { } /** - * :: AlphaComponent :: + * :: DeveloperApi :: * Trait for components that take parameters. This also provides an internal param map to store * parameter values attached to the instance. */ -@AlphaComponent +@DeveloperApi trait Params extends Identifiable with Serializable { /** @@ -301,19 +331,6 @@ trait Params extends Identifiable with Serializable { .map(m => m.invoke(this).asInstanceOf[Param[_]]) } - /** - * Validates parameter values stored internally plus the input parameter map. - * Raises an exception if any parameter is invalid. - * - * This only needs to check for interactions between parameters. - * Parameter value checks which do not depend on other parameters are handled by - * [[Param.validate()]]. This method does not handle input/output column parameters; - * those are checked during schema validation. - */ - def validateParams(paramMap: ParamMap): Unit = { - copy(paramMap).validateParams() - } - /** * Validates parameter values stored internally. * Raise an exception if any parameter value is invalid. @@ -438,19 +455,18 @@ trait Params extends Identifiable with Serializable { * @param value the default value */ protected final def setDefault[T](param: Param[T], value: T): this.type = { - defaultParamMap.put(param, value) + defaultParamMap.put(param -> value) this } /** * Sets default values for a list of params. * - * Note: Java developers should use the single-parameter [[setDefault()]]. - * Annotating this with varargs causes compilation failures. See SPARK-7498. * @param paramPairs a list of param pairs that specify params and their default values to set * respectively. Make sure that the params are initialized before this method * gets called. */ + @varargs protected final def setDefault(paramPairs: ParamPair[_]*): this.type = { paramPairs.foreach { p => setDefault(p.param.asInstanceOf[Param[Any]], p.value) @@ -476,23 +492,29 @@ trait Params extends Identifiable with Serializable { /** * Creates a copy of this instance with the same UID and some extra params. - * The default implementation tries to create a new instance with the same UID. + * Subclasses should implement this method and set the return type properly. + * + * @see [[defaultCopy()]] + */ + def copy(extra: ParamMap): Params + + /** + * Default implementation of copy with extra params. + * It tries to create a new instance with the same UID. * Then it copies the embedded and extra parameters over and returns the new instance. - * Subclasses should override this method if the default approach is not sufficient. */ - def copy(extra: ParamMap): Params = { + protected final def defaultCopy[T <: Params](extra: ParamMap): T = { val that = this.getClass.getConstructor(classOf[String]).newInstance(uid) - copyValues(that, extra) - that + copyValues(that, extra).asInstanceOf[T] } /** * Extracts the embedded default param values and user-supplied values, and then merges them with * extra values from input into a flat param map, where the latter value is used if there exist - * conflicts, i.e., with ordering: default param values < user-supplied values < extraParamMap. + * conflicts, i.e., with ordering: default param values < user-supplied values < extra. */ - final def extractParamMap(extraParamMap: ParamMap): ParamMap = { - defaultParamMap ++ paramMap ++ extraParamMap + final def extractParamMap(extra: ParamMap): ParamMap = { + defaultParamMap ++ paramMap ++ extra } /** @@ -531,18 +553,20 @@ trait Params extends Identifiable with Serializable { } /** + * :: DeveloperApi :: * Java-friendly wrapper for [[Params]]. * Java developers who need to extend [[Params]] should use this class instead. * If you need to extend a abstract class which already extends [[Params]], then that abstract * class should be Java-friendly as well. */ +@DeveloperApi abstract class JavaParams extends Params /** - * :: AlphaComponent :: + * :: Experimental :: * A param to value map. */ -@AlphaComponent +@Experimental final class ParamMap private[ml] (private val map: mutable.Map[Param[Any], Any]) extends Serializable { @@ -560,7 +584,7 @@ final class ParamMap private[ml] (private val map: mutable.Map[Param[Any], Any]) /** * Puts a (param, value) pair (overwrites if the input param exists). */ - def put[T](param: Param[T], value: T): this.type = put(ParamPair(param, value)) + def put[T](param: Param[T], value: T): this.type = put(param -> value) /** * Puts a list of param pairs (overwrites if the input params exists). @@ -663,6 +687,7 @@ final class ParamMap private[ml] (private val map: mutable.Map[Param[Any], Any]) def size: Int = map.size } +@Experimental object ParamMap { /** diff --git a/mllib/src/main/scala/org/apache/spark/ml/param/shared/SharedParamsCodeGen.scala b/mllib/src/main/scala/org/apache/spark/ml/param/shared/SharedParamsCodeGen.scala index 5085b798daa1..b0a6af171c01 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/param/shared/SharedParamsCodeGen.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/param/shared/SharedParamsCodeGen.scala @@ -33,7 +33,7 @@ private[shared] object SharedParamsCodeGen { val params = Seq( ParamDesc[Double]("regParam", "regularization parameter (>= 0)", isValid = "ParamValidators.gtEq(0)"), - ParamDesc[Int]("maxIter", "max number of iterations (>= 0)", + ParamDesc[Int]("maxIter", "maximum number of iterations (>= 0)", isValid = "ParamValidators.gtEq(0)"), ParamDesc[String]("featuresCol", "features column name", Some("\"features\"")), ParamDesc[String]("labelCol", "label column name", Some("\"label\"")), @@ -49,11 +49,14 @@ private[shared] object SharedParamsCodeGen { isValid = "ParamValidators.inRange(0, 1)"), ParamDesc[String]("inputCol", "input column name"), ParamDesc[Array[String]]("inputCols", "input column names"), - ParamDesc[String]("outputCol", "output column name"), + ParamDesc[String]("outputCol", "output column name", Some("uid + \"__output\"")), ParamDesc[Int]("checkpointInterval", "checkpoint interval (>= 1)", isValid = "ParamValidators.gtEq(1)"), ParamDesc[Boolean]("fitIntercept", "whether to fit an intercept term", Some("true")), - ParamDesc[Long]("seed", "random seed", Some("Utils.random.nextLong()")), + ParamDesc[Boolean]("standardization", "whether to standardize the training features" + + " prior to fitting the model sequence. Note that the coefficients of models are" + + " always returned on the original scale.", Some("true")), + ParamDesc[Long]("seed", "random seed", Some("this.getClass.getName.hashCode.toLong")), ParamDesc[Double]("elasticNetParam", "the ElasticNet mixing parameter, in range [0, 1]." + " For alpha = 0, the penalty is an L2 penalty. For alpha = 1, it is an L1 penalty.", isValid = "ParamValidators.inRange(0, 1)"), diff --git a/mllib/src/main/scala/org/apache/spark/ml/param/shared/sharedParams.scala b/mllib/src/main/scala/org/apache/spark/ml/param/shared/sharedParams.scala index 7525d3700737..bbe08939b6d7 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/param/shared/sharedParams.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/param/shared/sharedParams.scala @@ -45,10 +45,10 @@ private[ml] trait HasRegParam extends Params { private[ml] trait HasMaxIter extends Params { /** - * Param for max number of iterations (>= 0). + * Param for maximum number of iterations (>= 0). * @group param */ - final val maxIter: IntParam = new IntParam(this, "maxIter", "max number of iterations (>= 0)", ParamValidators.gtEq(0)) + final val maxIter: IntParam = new IntParam(this, "maxIter", "maximum number of iterations (>= 0)", ParamValidators.gtEq(0)) /** @group getParam */ final def getMaxIter: Int = $(maxIter) @@ -185,7 +185,7 @@ private[ml] trait HasInputCols extends Params { } /** - * (private[ml]) Trait for shared param outputCol. + * (private[ml]) Trait for shared param outputCol (default: uid + "__output"). */ private[ml] trait HasOutputCol extends Params { @@ -195,6 +195,8 @@ private[ml] trait HasOutputCol extends Params { */ final val outputCol: Param[String] = new Param[String](this, "outputCol", "output column name") + setDefault(outputCol, uid + "__output") + /** @group getParam */ final def getOutputCol: String = $(outputCol) } @@ -232,7 +234,24 @@ private[ml] trait HasFitIntercept extends Params { } /** - * (private[ml]) Trait for shared param seed (default: Utils.random.nextLong()). + * (private[ml]) Trait for shared param standardization (default: true). + */ +private[ml] trait HasStandardization extends Params { + + /** + * Param for whether to standardize the training features prior to fitting the model sequence. Note that the coefficients of models are always returned on the original scale.. + * @group param + */ + final val standardization: BooleanParam = new BooleanParam(this, "standardization", "whether to standardize the training features prior to fitting the model sequence. Note that the coefficients of models are always returned on the original scale.") + + setDefault(standardization, true) + + /** @group getParam */ + final def getStandardization: Boolean = $(standardization) +} + +/** + * (private[ml]) Trait for shared param seed (default: this.getClass.getName.hashCode.toLong). */ private[ml] trait HasSeed extends Params { @@ -242,7 +261,7 @@ private[ml] trait HasSeed extends Params { */ final val seed: LongParam = new LongParam(this, "seed", "random seed") - setDefault(seed, Utils.random.nextLong()) + setDefault(seed, this.getClass.getName.hashCode.toLong) /** @group getParam */ final def getSeed: Long = $(seed) diff --git a/mllib/src/main/scala/org/apache/spark/ml/recommendation/ALS.scala b/mllib/src/main/scala/org/apache/spark/ml/recommendation/ALS.scala index 45c57b50da70..2e44cd4cc6a2 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/recommendation/ALS.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/recommendation/ALS.scala @@ -31,25 +31,50 @@ import org.apache.hadoop.fs.{FileSystem, Path} import org.netlib.util.intW import org.apache.spark.{Logging, Partitioner} -import org.apache.spark.annotation.DeveloperApi +import org.apache.spark.annotation.{DeveloperApi, Experimental} import org.apache.spark.ml.{Estimator, Model} import org.apache.spark.ml.param._ import org.apache.spark.ml.param.shared._ -import org.apache.spark.ml.util.Identifiable +import org.apache.spark.ml.util.{Identifiable, SchemaUtils} import org.apache.spark.mllib.optimization.NNLS import org.apache.spark.rdd.RDD import org.apache.spark.sql.DataFrame import org.apache.spark.sql.functions._ -import org.apache.spark.sql.types.{DoubleType, FloatType, IntegerType, StructField, StructType} +import org.apache.spark.sql.types.{DoubleType, FloatType, IntegerType, StructType} import org.apache.spark.storage.StorageLevel import org.apache.spark.util.Utils import org.apache.spark.util.collection.{OpenHashMap, OpenHashSet, SortDataFormat, Sorter} import org.apache.spark.util.random.XORShiftRandom +/** + * Common params for ALS and ALSModel. + */ +private[recommendation] trait ALSModelParams extends Params with HasPredictionCol { + /** + * Param for the column name for user ids. + * Default: "user" + * @group param + */ + val userCol = new Param[String](this, "userCol", "column name for user ids") + + /** @group getParam */ + def getUserCol: String = $(userCol) + + /** + * Param for the column name for item ids. + * Default: "item" + * @group param + */ + val itemCol = new Param[String](this, "itemCol", "column name for item ids") + + /** @group getParam */ + def getItemCol: String = $(itemCol) +} + /** * Common params for ALS. */ -private[recommendation] trait ALSParams extends Params with HasMaxIter with HasRegParam +private[recommendation] trait ALSParams extends ALSModelParams with HasMaxIter with HasRegParam with HasPredictionCol with HasCheckpointInterval with HasSeed { /** @@ -105,26 +130,6 @@ private[recommendation] trait ALSParams extends Params with HasMaxIter with HasR /** @group getParam */ def getAlpha: Double = $(alpha) - /** - * Param for the column name for user ids. - * Default: "user" - * @group param - */ - val userCol = new Param[String](this, "userCol", "column name for user ids") - - /** @group getParam */ - def getUserCol: String = $(userCol) - - /** - * Param for the column name for item ids. - * Default: "item" - * @group param - */ - val itemCol = new Param[String](this, "itemCol", "column name for item ids") - - /** @group getParam */ - def getItemCol: String = $(itemCol) - /** * Param for the column name for ratings. * Default: "rating" @@ -148,7 +153,7 @@ private[recommendation] trait ALSParams extends Params with HasMaxIter with HasR setDefault(rank -> 10, maxIter -> 10, regParam -> 0.1, numUserBlocks -> 10, numItemBlocks -> 10, implicitPrefs -> false, alpha -> 1.0, userCol -> "user", itemCol -> "item", - ratingCol -> "rating", nonnegative -> false, checkpointInterval -> 10, seed -> 0L) + ratingCol -> "rating", nonnegative -> false, checkpointInterval -> 10) /** * Validates and transforms the input schema. @@ -156,58 +161,71 @@ private[recommendation] trait ALSParams extends Params with HasMaxIter with HasR * @return output schema */ protected def validateAndTransformSchema(schema: StructType): StructType = { - require(schema($(userCol)).dataType == IntegerType) - require(schema($(itemCol)).dataType== IntegerType) + SchemaUtils.checkColumnType(schema, $(userCol), IntegerType) + SchemaUtils.checkColumnType(schema, $(itemCol), IntegerType) val ratingType = schema($(ratingCol)).dataType require(ratingType == FloatType || ratingType == DoubleType) - val predictionColName = $(predictionCol) - require(!schema.fieldNames.contains(predictionColName), - s"Prediction column $predictionColName already exists.") - val newFields = schema.fields :+ StructField($(predictionCol), FloatType, nullable = false) - StructType(newFields) + SchemaUtils.appendColumn(schema, $(predictionCol), FloatType) } } /** + * :: Experimental :: * Model fitted by ALS. + * + * @param rank rank of the matrix factorization model + * @param userFactors a DataFrame that stores user factors in two columns: `id` and `features` + * @param itemFactors a DataFrame that stores item factors in two columns: `id` and `features` */ +@Experimental class ALSModel private[ml] ( override val uid: String, - k: Int, - userFactors: RDD[(Int, Array[Float])], - itemFactors: RDD[(Int, Array[Float])]) - extends Model[ALSModel] with ALSParams { + val rank: Int, + @transient val userFactors: DataFrame, + @transient val itemFactors: DataFrame) + extends Model[ALSModel] with ALSModelParams { + + /** @group setParam */ + def setUserCol(value: String): this.type = set(userCol, value) + + /** @group setParam */ + def setItemCol(value: String): this.type = set(itemCol, value) /** @group setParam */ def setPredictionCol(value: String): this.type = set(predictionCol, value) override def transform(dataset: DataFrame): DataFrame = { - import dataset.sqlContext.implicits._ - val users = userFactors.toDF("id", "features") - val items = itemFactors.toDF("id", "features") - // Register a UDF for DataFrame, and then // create a new column named map(predictionCol) by running the predict UDF. val predict = udf { (userFeatures: Seq[Float], itemFeatures: Seq[Float]) => if (userFeatures != null && itemFeatures != null) { - blas.sdot(k, userFeatures.toArray, 1, itemFeatures.toArray, 1) + blas.sdot(rank, userFeatures.toArray, 1, itemFeatures.toArray, 1) } else { Float.NaN } } dataset - .join(users, dataset($(userCol)) === users("id"), "left") - .join(items, dataset($(itemCol)) === items("id"), "left") - .select(dataset("*"), predict(users("features"), items("features")).as($(predictionCol))) + .join(userFactors, dataset($(userCol)) === userFactors("id"), "left") + .join(itemFactors, dataset($(itemCol)) === itemFactors("id"), "left") + .select(dataset("*"), + predict(userFactors("features"), itemFactors("features")).as($(predictionCol))) } override def transformSchema(schema: StructType): StructType = { - validateAndTransformSchema(schema) + SchemaUtils.checkColumnType(schema, $(userCol), IntegerType) + SchemaUtils.checkColumnType(schema, $(itemCol), IntegerType) + SchemaUtils.appendColumn(schema, $(predictionCol), FloatType) + } + + override def copy(extra: ParamMap): ALSModel = { + val copied = new ALSModel(uid, rank, userFactors, itemFactors) + copyValues(copied, extra) } } /** + * :: Experimental :: * Alternating Least Squares (ALS) matrix factorization. * * ALS attempts to estimate the ratings matrix `R` as the product of two lower-rank matrices, @@ -236,6 +254,7 @@ class ALSModel private[ml] ( * indicated user * preferences rather than explicit ratings given to items. */ +@Experimental class ALS(override val uid: String) extends Estimator[ALSModel] with ALSParams { import org.apache.spark.ml.recommendation.ALS.Rating @@ -295,6 +314,7 @@ class ALS(override val uid: String) extends Estimator[ALSModel] with ALSParams { } override def fit(dataset: DataFrame): ALSModel = { + import dataset.sqlContext.implicits._ val ratings = dataset .select(col($(userCol)).cast(IntegerType), col($(itemCol)).cast(IntegerType), col($(ratingCol)).cast(FloatType)) @@ -306,13 +326,17 @@ class ALS(override val uid: String) extends Estimator[ALSModel] with ALSParams { maxIter = $(maxIter), regParam = $(regParam), implicitPrefs = $(implicitPrefs), alpha = $(alpha), nonnegative = $(nonnegative), checkpointInterval = $(checkpointInterval), seed = $(seed)) - val model = new ALSModel(uid, $(rank), userFactors, itemFactors).setParent(this) + val userDF = userFactors.toDF("id", "features") + val itemDF = itemFactors.toDF("id", "features") + val model = new ALSModel(uid, $(rank), userDF, itemDF).setParent(this) copyValues(model) } override def transformSchema(schema: StructType): StructType = { validateAndTransformSchema(schema) } + + override def copy(extra: ParamMap): ALS = defaultCopy(extra) } /** @@ -326,7 +350,11 @@ class ALS(override val uid: String) extends Estimator[ALSModel] with ALSParams { @DeveloperApi object ALS extends Logging { - /** Rating class for better code readability. */ + /** + * :: DeveloperApi :: + * Rating class for better code readability. + */ + @DeveloperApi case class Rating[@specialized(Int, Long) ID](user: ID, item: ID, rating: Float) /** Trait for least squares solvers applied to the normal equation. */ @@ -487,8 +515,10 @@ object ALS extends Logging { } /** + * :: DeveloperApi :: * Implementation of the ALS algorithm. */ + @DeveloperApi def train[ID: ClassTag]( // scalastyle:ignore ratings: RDD[Rating[ID]], rank: Int = 10, diff --git a/mllib/src/main/scala/org/apache/spark/ml/regression/DecisionTreeRegressor.scala b/mllib/src/main/scala/org/apache/spark/ml/regression/DecisionTreeRegressor.scala index e67df21b2e4a..be1f8063d41d 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/regression/DecisionTreeRegressor.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/regression/DecisionTreeRegressor.scala @@ -17,10 +17,10 @@ package org.apache.spark.ml.regression -import org.apache.spark.annotation.AlphaComponent +import org.apache.spark.annotation.Experimental import org.apache.spark.ml.{PredictionModel, Predictor} import org.apache.spark.ml.param.ParamMap -import org.apache.spark.ml.tree.{TreeRegressorParams, DecisionTreeParams, DecisionTreeModel, Node} +import org.apache.spark.ml.tree.{DecisionTreeModel, DecisionTreeParams, Node, TreeRegressorParams} import org.apache.spark.ml.util.{Identifiable, MetadataUtils} import org.apache.spark.mllib.linalg.Vector import org.apache.spark.mllib.regression.LabeledPoint @@ -31,13 +31,12 @@ import org.apache.spark.rdd.RDD import org.apache.spark.sql.DataFrame /** - * :: AlphaComponent :: - * + * :: Experimental :: * [[http://en.wikipedia.org/wiki/Decision_tree_learning Decision tree]] learning algorithm * for regression. * It supports both continuous and categorical features. */ -@AlphaComponent +@Experimental final class DecisionTreeRegressor(override val uid: String) extends Predictor[Vector, DecisionTreeRegressor, DecisionTreeRegressionModel] with DecisionTreeParams with TreeRegressorParams { @@ -77,21 +76,23 @@ final class DecisionTreeRegressor(override val uid: String) super.getOldStrategy(categoricalFeatures, numClasses = 0, OldAlgo.Regression, getOldImpurity, subsamplingRate = 1.0) } + + override def copy(extra: ParamMap): DecisionTreeRegressor = defaultCopy(extra) } +@Experimental object DecisionTreeRegressor { /** Accessor for supported impurities: variance */ final val supportedImpurities: Array[String] = TreeRegressorParams.supportedImpurities } /** - * :: AlphaComponent :: - * + * :: Experimental :: * [[http://en.wikipedia.org/wiki/Decision_tree_learning Decision tree]] model for regression. * It supports both continuous and categorical features. * @param rootNode Root of the decision tree */ -@AlphaComponent +@Experimental final class DecisionTreeRegressionModel private[ml] ( override val uid: String, override val rootNode: Node) diff --git a/mllib/src/main/scala/org/apache/spark/ml/regression/GBTRegressor.scala b/mllib/src/main/scala/org/apache/spark/ml/regression/GBTRegressor.scala index 4249ff5c1ebc..47c110d027d6 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/regression/GBTRegressor.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/regression/GBTRegressor.scala @@ -20,10 +20,10 @@ package org.apache.spark.ml.regression import com.github.fommil.netlib.BLAS.{getInstance => blas} import org.apache.spark.Logging -import org.apache.spark.annotation.AlphaComponent +import org.apache.spark.annotation.Experimental import org.apache.spark.ml.{PredictionModel, Predictor} import org.apache.spark.ml.param.{Param, ParamMap} -import org.apache.spark.ml.tree.{GBTParams, TreeRegressorParams, DecisionTreeModel, TreeEnsembleModel} +import org.apache.spark.ml.tree.{DecisionTreeModel, GBTParams, TreeEnsembleModel, TreeRegressorParams} import org.apache.spark.ml.util.{Identifiable, MetadataUtils} import org.apache.spark.mllib.linalg.Vector import org.apache.spark.mllib.regression.LabeledPoint @@ -35,13 +35,12 @@ import org.apache.spark.rdd.RDD import org.apache.spark.sql.DataFrame /** - * :: AlphaComponent :: - * + * :: Experimental :: * [[http://en.wikipedia.org/wiki/Gradient_boosting Gradient-Boosted Trees (GBTs)]] * learning algorithm for regression. * It supports both continuous and categorical features. */ -@AlphaComponent +@Experimental final class GBTRegressor(override val uid: String) extends Predictor[Vector, GBTRegressor, GBTRegressionModel] with GBTParams with TreeRegressorParams with Logging { @@ -132,8 +131,11 @@ final class GBTRegressor(override val uid: String) val oldModel = oldGBT.run(oldDataset) GBTRegressionModel.fromOld(oldModel, this, categoricalFeatures) } + + override def copy(extra: ParamMap): GBTRegressor = defaultCopy(extra) } +@Experimental object GBTRegressor { // The losses below should be lowercase. /** Accessor for supported loss settings: squared (L2), absolute (L1) */ @@ -141,7 +143,7 @@ object GBTRegressor { } /** - * :: AlphaComponent :: + * :: Experimental :: * * [[http://en.wikipedia.org/wiki/Gradient_boosting Gradient-Boosted Trees (GBTs)]] * model for regression. @@ -149,7 +151,7 @@ object GBTRegressor { * @param _trees Decision trees in the ensemble. * @param _treeWeights Weights for the decision trees in the ensemble. */ -@AlphaComponent +@Experimental final class GBTRegressionModel( override val uid: String, private val _trees: Array[DecisionTreeRegressionModel], @@ -170,8 +172,7 @@ final class GBTRegressionModel( // TODO: When we add a generic Boosting class, handle transform there? SPARK-7129 // Classifies by thresholding sum of weighted tree predictions val treePredictions = _trees.map(_.rootNode.predict(features)) - val prediction = blas.ddot(numTrees, treePredictions, 1, _treeWeights, 1) - if (prediction > 0.0) 1.0 else 0.0 + blas.ddot(numTrees, treePredictions, 1, _treeWeights, 1) } override def copy(extra: ParamMap): GBTRegressionModel = { @@ -198,7 +199,7 @@ private[ml] object GBTRegressionModel { require(oldModel.algo == OldAlgo.Regression, "Cannot convert GradientBoostedTreesModel" + s" with algo=${oldModel.algo} (old API) to GBTRegressionModel (new API).") val newTrees = oldModel.trees.map { tree => - // parent, fittingParamMap for each tree is null since there are no good ways to set these. + // parent for each tree is null since there is no good way to set this. DecisionTreeRegressionModel.fromOld(tree, null, categoricalFeatures) } val uid = if (parent != null) parent.uid else Identifiable.randomUID("gbtr") diff --git a/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala b/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala index 3ebb78f79201..1b1d7299fb49 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala @@ -23,10 +23,10 @@ import breeze.linalg.{DenseVector => BDV, norm => brzNorm} import breeze.optimize.{CachedDiffFunction, DiffFunction, LBFGS => BreezeLBFGS, OWLQN => BreezeOWLQN} import org.apache.spark.Logging -import org.apache.spark.annotation.AlphaComponent +import org.apache.spark.annotation.Experimental import org.apache.spark.ml.PredictorParams import org.apache.spark.ml.param.ParamMap -import org.apache.spark.ml.param.shared.{HasElasticNetParam, HasMaxIter, HasRegParam, HasTol} +import org.apache.spark.ml.param.shared._ import org.apache.spark.ml.util.Identifiable import org.apache.spark.mllib.linalg.{Vector, Vectors} import org.apache.spark.mllib.linalg.BLAS._ @@ -41,11 +41,11 @@ import org.apache.spark.util.StatCounter * Params for linear regression. */ private[regression] trait LinearRegressionParams extends PredictorParams - with HasRegParam with HasElasticNetParam with HasMaxIter with HasTol + with HasRegParam with HasElasticNetParam with HasMaxIter with HasTol + with HasFitIntercept /** - * :: AlphaComponent :: - * + * :: Experimental :: * Linear regression. * * The learning objective is to minimize the squared error, with regularization. @@ -58,7 +58,7 @@ private[regression] trait LinearRegressionParams extends PredictorParams * - L1 (Lasso) * - L2 + L1 (elastic net) */ -@AlphaComponent +@Experimental class LinearRegression(override val uid: String) extends Regressor[Vector, LinearRegression, LinearRegressionModel] with LinearRegressionParams with Logging { @@ -73,6 +73,14 @@ class LinearRegression(override val uid: String) def setRegParam(value: Double): this.type = set(regParam, value) setDefault(regParam -> 0.0) + /** + * Set if we should fit the intercept + * Default is true. + * @group setParam + */ + def setFitIntercept(value: Boolean): this.type = set(fitIntercept, value) + setDefault(fitIntercept -> true) + /** * Set the ElasticNet mixing parameter. * For alpha = 0, the penalty is an L2 penalty. For alpha = 1, it is an L1 penalty. @@ -84,7 +92,7 @@ class LinearRegression(override val uid: String) setDefault(elasticNetParam -> 0.0) /** - * Set the maximal number of iterations. + * Set the maximum number of iterations. * Default is 100. * @group setParam */ @@ -124,6 +132,7 @@ class LinearRegression(override val uid: String) val numFeatures = summarizer.mean.size val yMean = statCounter.mean val yStd = math.sqrt(statCounter.variance) + // look at glmnet5.m L761 maaaybe that has info // If the yStd is zero, then the intercept is yMean with zero weights; // as a result, training is not needed. @@ -143,7 +152,7 @@ class LinearRegression(override val uid: String) val effectiveL1RegParam = $(elasticNetParam) * effectiveRegParam val effectiveL2RegParam = (1.0 - $(elasticNetParam)) * effectiveRegParam - val costFun = new LeastSquaresCostFun(instances, yStd, yMean, + val costFun = new LeastSquaresCostFun(instances, yStd, yMean, $(fitIntercept), featuresStd, featuresMean, effectiveL2RegParam) val optimizer = if ($(elasticNetParam) == 0.0 || effectiveRegParam == 0.0) { @@ -181,20 +190,21 @@ class LinearRegression(override val uid: String) // The intercept in R's GLMNET is computed using closed form after the coefficients are // converged. See the following discussion for detail. // http://stats.stackexchange.com/questions/13617/how-is-the-intercept-computed-in-glmnet - val intercept = yMean - dot(weights, Vectors.dense(featuresMean)) + val intercept = if ($(fitIntercept)) yMean - dot(weights, Vectors.dense(featuresMean)) else 0.0 if (handlePersistence) instances.unpersist() // TODO: Converts to sparse format based on the storage, but may base on the scoring speed. copyValues(new LinearRegressionModel(uid, weights.compressed, intercept)) } + + override def copy(extra: ParamMap): LinearRegression = defaultCopy(extra) } /** - * :: AlphaComponent :: - * + * :: Experimental :: * Model produced by [[LinearRegression]]. */ -@AlphaComponent +@Experimental class LinearRegressionModel private[ml] ( override val uid: String, val weights: Vector, @@ -234,6 +244,7 @@ class LinearRegressionModel private[ml] ( * See this discussion for detail. * http://stats.stackexchange.com/questions/13617/how-is-the-intercept-computed-in-glmnet * + * When training with intercept enabled, * The objective function in the scaled space is given by * {{{ * L = 1/2n ||\sum_i w_i(x_i - \bar{x_i}) / \hat{x_i} - (y - \bar{y}) / \hat{y}||^2, @@ -241,6 +252,10 @@ class LinearRegressionModel private[ml] ( * where \bar{x_i} is the mean of x_i, \hat{x_i} is the standard deviation of x_i, * \bar{y} is the mean of label, and \hat{y} is the standard deviation of label. * + * If we fitting the intercept disabled (that is forced through 0.0), + * we can use the same equation except we set \bar{y} and \bar{x_i} to 0 instead + * of the respective means. + * * This can be rewritten as * {{{ * L = 1/2n ||\sum_i (w_i/\hat{x_i})x_i - \sum_i (w_i/\hat{x_i})\bar{x_i} - y / \hat{y} @@ -255,6 +270,7 @@ class LinearRegressionModel private[ml] ( * \sum_i w_i^\prime x_i - y / \hat{y} + offset * }}} * + * * Note that the effective weights and offset don't depend on training dataset, * so they can be precomputed. * @@ -301,6 +317,7 @@ private class LeastSquaresAggregator( weights: Vector, labelStd: Double, labelMean: Double, + fitIntercept: Boolean, featuresStd: Array[Double], featuresMean: Array[Double]) extends Serializable { @@ -321,9 +338,9 @@ private class LeastSquaresAggregator( } i += 1 } - (weightsArray, -sum + labelMean / labelStd, weightsArray.length) + (weightsArray, if (fitIntercept) labelMean / labelStd - sum else 0.0, weightsArray.length) } - + private val effectiveWeightsVector = Vectors.dense(effectiveWeightsArray) private val gradientSumArray = Array.ofDim[Double](dim) @@ -404,6 +421,7 @@ private class LeastSquaresCostFun( data: RDD[(Double, Vector)], labelStd: Double, labelMean: Double, + fitIntercept: Boolean, featuresStd: Array[Double], featuresMean: Array[Double], effectiveL2regParam: Double) extends DiffFunction[BDV[Double]] { @@ -412,7 +430,7 @@ private class LeastSquaresCostFun( val w = Vectors.fromBreeze(weights) val leastSquaresAggregator = data.treeAggregate(new LeastSquaresAggregator(w, labelStd, - labelMean, featuresStd, featuresMean))( + labelMean, fitIntercept, featuresStd, featuresMean))( seqOp = (c, v) => (c, v) match { case (aggregator, (label, features)) => aggregator.add(label, features) }, diff --git a/mllib/src/main/scala/org/apache/spark/ml/regression/RandomForestRegressor.scala b/mllib/src/main/scala/org/apache/spark/ml/regression/RandomForestRegressor.scala index 82437aa8de29..21c59061a02f 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/regression/RandomForestRegressor.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/regression/RandomForestRegressor.scala @@ -17,10 +17,10 @@ package org.apache.spark.ml.regression -import org.apache.spark.annotation.AlphaComponent +import org.apache.spark.annotation.Experimental import org.apache.spark.ml.{PredictionModel, Predictor} import org.apache.spark.ml.param.ParamMap -import org.apache.spark.ml.tree.{RandomForestParams, TreeRegressorParams, DecisionTreeModel, TreeEnsembleModel} +import org.apache.spark.ml.tree.{DecisionTreeModel, RandomForestParams, TreeEnsembleModel, TreeRegressorParams} import org.apache.spark.ml.util.{Identifiable, MetadataUtils} import org.apache.spark.mllib.linalg.Vector import org.apache.spark.mllib.regression.LabeledPoint @@ -31,12 +31,11 @@ import org.apache.spark.rdd.RDD import org.apache.spark.sql.DataFrame /** - * :: AlphaComponent :: - * + * :: Experimental :: * [[http://en.wikipedia.org/wiki/Random_forest Random Forest]] learning algorithm for regression. * It supports both continuous and categorical features. */ -@AlphaComponent +@Experimental final class RandomForestRegressor(override val uid: String) extends Predictor[Vector, RandomForestRegressor, RandomForestRegressionModel] with RandomForestParams with TreeRegressorParams { @@ -87,8 +86,11 @@ final class RandomForestRegressor(override val uid: String) oldDataset, strategy, getNumTrees, getFeatureSubsetStrategy, getSeed.toInt) RandomForestRegressionModel.fromOld(oldModel, this, categoricalFeatures) } + + override def copy(extra: ParamMap): RandomForestRegressor = defaultCopy(extra) } +@Experimental object RandomForestRegressor { /** Accessor for supported impurity settings: variance */ final val supportedImpurities: Array[String] = TreeRegressorParams.supportedImpurities @@ -99,13 +101,12 @@ object RandomForestRegressor { } /** - * :: AlphaComponent :: - * + * :: Experimental :: * [[http://en.wikipedia.org/wiki/Random_forest Random Forest]] model for regression. * It supports both continuous and categorical features. * @param _trees Decision trees in the ensemble. */ -@AlphaComponent +@Experimental final class RandomForestRegressionModel private[ml] ( override val uid: String, private val _trees: Array[DecisionTreeRegressionModel]) @@ -153,7 +154,7 @@ private[ml] object RandomForestRegressionModel { require(oldModel.algo == OldAlgo.Regression, "Cannot convert RandomForestModel" + s" with algo=${oldModel.algo} (old API) to RandomForestRegressionModel (new API).") val newTrees = oldModel.trees.map { tree => - // parent, fittingParamMap for each tree is null since there are no good ways to set these. + // parent for each tree is null since there is no good way to set this. DecisionTreeRegressionModel.fromOld(tree, null, categoricalFeatures) } new RandomForestRegressionModel(parent.uid, newTrees) diff --git a/mllib/src/main/scala/org/apache/spark/ml/tree/Node.scala b/mllib/src/main/scala/org/apache/spark/ml/tree/Node.scala index d2dec0c76cb1..4242154be14c 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/tree/Node.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/tree/Node.scala @@ -17,14 +17,16 @@ package org.apache.spark.ml.tree +import org.apache.spark.annotation.DeveloperApi import org.apache.spark.mllib.linalg.Vector import org.apache.spark.mllib.tree.model.{InformationGainStats => OldInformationGainStats, Node => OldNode, Predict => OldPredict} - /** + * :: DeveloperApi :: * Decision tree node interface. */ +@DeveloperApi sealed abstract class Node extends Serializable { // TODO: Add aggregate stats (once available). This will happen after we move the DecisionTree @@ -89,10 +91,12 @@ private[ml] object Node { } /** + * :: DeveloperApi :: * Decision tree leaf node. * @param prediction Prediction this node makes * @param impurity Impurity measure at this node (for training data) */ +@DeveloperApi final class LeafNode private[ml] ( override val prediction: Double, override val impurity: Double) extends Node { @@ -118,6 +122,7 @@ final class LeafNode private[ml] ( } /** + * :: DeveloperApi :: * Internal Decision Tree node. * @param prediction Prediction this node would make if it were a leaf node * @param impurity Impurity measure at this node (for training data) @@ -127,6 +132,7 @@ final class LeafNode private[ml] ( * @param rightChild Right-hand child node * @param split Information about the test used to split to the left or right child. */ +@DeveloperApi final class InternalNode private[ml] ( override val prediction: Double, override val impurity: Double, @@ -153,9 +159,9 @@ final class InternalNode private[ml] ( override private[tree] def subtreeToString(indentFactor: Int = 0): String = { val prefix: String = " " * indentFactor - prefix + s"If (${InternalNode.splitToString(split, left=true)})\n" + + prefix + s"If (${InternalNode.splitToString(split, left = true)})\n" + leftChild.subtreeToString(indentFactor + 1) + - prefix + s"Else (${InternalNode.splitToString(split, left=false)})\n" + + prefix + s"Else (${InternalNode.splitToString(split, left = false)})\n" + rightChild.subtreeToString(indentFactor + 1) } diff --git a/mllib/src/main/scala/org/apache/spark/ml/tree/Split.scala b/mllib/src/main/scala/org/apache/spark/ml/tree/Split.scala index 90f1d052764d..7acdeeee72d2 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/tree/Split.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/tree/Split.scala @@ -17,15 +17,18 @@ package org.apache.spark.ml.tree +import org.apache.spark.annotation.DeveloperApi import org.apache.spark.mllib.linalg.Vector import org.apache.spark.mllib.tree.configuration.{FeatureType => OldFeatureType} import org.apache.spark.mllib.tree.model.{Split => OldSplit} /** + * :: DeveloperApi :: * Interface for a "Split," which specifies a test made at a decision tree node * to choose the left or right path. */ +@DeveloperApi sealed trait Split extends Serializable { /** Index of feature which this split tests */ @@ -52,12 +55,14 @@ private[tree] object Split { } /** + * :: DeveloperApi :: * Split which tests a categorical feature. * @param featureIndex Index of the feature to test * @param _leftCategories If the feature value is in this set of categories, then the split goes * left. Otherwise, it goes right. * @param numCategories Number of categories for this feature. */ +@DeveloperApi final class CategoricalSplit private[ml] ( override val featureIndex: Int, _leftCategories: Array[Double], @@ -125,11 +130,13 @@ final class CategoricalSplit private[ml] ( } /** + * :: DeveloperApi :: * Split which tests a continuous feature. * @param featureIndex Index of the feature to test * @param threshold If the feature value is <= this threshold, then the split goes left. * Otherwise, it goes right. */ +@DeveloperApi final class ContinuousSplit private[ml] (override val featureIndex: Int, val threshold: Double) extends Split { diff --git a/mllib/src/main/scala/org/apache/spark/ml/tree/treeParams.scala b/mllib/src/main/scala/org/apache/spark/ml/tree/treeParams.scala index 816fcedf2efb..a0c5238d966b 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/tree/treeParams.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/tree/treeParams.scala @@ -17,7 +17,6 @@ package org.apache.spark.ml.tree -import org.apache.spark.annotation.DeveloperApi import org.apache.spark.ml.PredictorParams import org.apache.spark.ml.param._ import org.apache.spark.ml.param.shared.{HasMaxIter, HasSeed} @@ -26,12 +25,10 @@ import org.apache.spark.mllib.tree.impurity.{Entropy => OldEntropy, Gini => OldG import org.apache.spark.mllib.tree.loss.{Loss => OldLoss} /** - * :: DeveloperApi :: * Parameters for Decision Tree-based algorithms. * * Note: Marked as private and DeveloperApi since this may be made public in the future. */ -@DeveloperApi private[ml] trait DecisionTreeParams extends PredictorParams { /** @@ -265,12 +262,10 @@ private[ml] object TreeRegressorParams { } /** - * :: DeveloperApi :: * Parameters for Decision Tree-based ensemble algorithms. * * Note: Marked as private and DeveloperApi since this may be made public in the future. */ -@DeveloperApi private[ml] trait TreeEnsembleParams extends DecisionTreeParams with HasSeed { /** @@ -307,12 +302,10 @@ private[ml] trait TreeEnsembleParams extends DecisionTreeParams with HasSeed { } /** - * :: DeveloperApi :: * Parameters for Random Forest algorithms. * * Note: Marked as private and DeveloperApi since this may be made public in the future. */ -@DeveloperApi private[ml] trait RandomForestParams extends TreeEnsembleParams { /** @@ -377,12 +370,10 @@ private[ml] object RandomForestParams { } /** - * :: DeveloperApi :: * Parameters for Gradient-Boosted Tree algorithms. * * Note: Marked as private and DeveloperApi since this may be made public in the future. */ -@DeveloperApi private[ml] trait GBTParams extends TreeEnsembleParams with HasMaxIter { /** diff --git a/mllib/src/main/scala/org/apache/spark/ml/tuning/CrossValidator.scala b/mllib/src/main/scala/org/apache/spark/ml/tuning/CrossValidator.scala index 5c6ff2dda360..e2444ab65b43 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/tuning/CrossValidator.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/tuning/CrossValidator.scala @@ -20,8 +20,9 @@ package org.apache.spark.ml.tuning import com.github.fommil.netlib.F2jBLAS import org.apache.spark.Logging -import org.apache.spark.annotation.AlphaComponent +import org.apache.spark.annotation.Experimental import org.apache.spark.ml._ +import org.apache.spark.ml.evaluation.Evaluator import org.apache.spark.ml.param._ import org.apache.spark.ml.util.Identifiable import org.apache.spark.mllib.util.MLUtils @@ -78,10 +79,10 @@ private[ml] trait CrossValidatorParams extends Params { } /** - * :: AlphaComponent :: + * :: Experimental :: * K-fold cross validation. */ -@AlphaComponent +@Experimental class CrossValidator(override val uid: String) extends Estimator[CrossValidatorModel] with CrossValidatorParams with Logging { @@ -101,12 +102,6 @@ class CrossValidator(override val uid: String) extends Estimator[CrossValidatorM /** @group setParam */ def setNumFolds(value: Int): this.type = set(numFolds, value) - override def validateParams(paramMap: ParamMap): Unit = { - getEstimatorParamMaps.foreach { eMap => - getEstimator.validateParams(eMap ++ paramMap) - } - } - override def fit(dataset: DataFrame): CrossValidatorModel = { val schema = dataset.schema transformSchema(schema, logging = true) @@ -140,26 +135,46 @@ class CrossValidator(override val uid: String) extends Estimator[CrossValidatorM logInfo(s"Best set of parameters:\n${epm(bestIndex)}") logInfo(s"Best cross-validation metric: $bestMetric.") val bestModel = est.fit(dataset, epm(bestIndex)).asInstanceOf[Model[_]] - copyValues(new CrossValidatorModel(uid, bestModel).setParent(this)) + copyValues(new CrossValidatorModel(uid, bestModel, metrics).setParent(this)) } override def transformSchema(schema: StructType): StructType = { $(estimator).transformSchema(schema) } + + override def validateParams(): Unit = { + super.validateParams() + val est = $(estimator) + for (paramMap <- $(estimatorParamMaps)) { + est.copy(paramMap).validateParams() + } + } + + override def copy(extra: ParamMap): CrossValidator = { + val copied = defaultCopy(extra).asInstanceOf[CrossValidator] + if (copied.isDefined(estimator)) { + copied.setEstimator(copied.getEstimator.copy(extra)) + } + if (copied.isDefined(evaluator)) { + copied.setEvaluator(copied.getEvaluator.copy(extra)) + } + copied + } } /** - * :: AlphaComponent :: + * :: Experimental :: * Model from k-fold cross validation. */ -@AlphaComponent +@Experimental class CrossValidatorModel private[ml] ( override val uid: String, - val bestModel: Model[_]) + val bestModel: Model[_], + val avgMetrics: Array[Double]) extends Model[CrossValidatorModel] with CrossValidatorParams { - override def validateParams(paramMap: ParamMap): Unit = { - bestModel.validateParams(paramMap) + override def validateParams(): Unit = { + bestModel.validateParams() } override def transform(dataset: DataFrame): DataFrame = { @@ -170,4 +185,12 @@ class CrossValidatorModel private[ml] ( override def transformSchema(schema: StructType): StructType = { bestModel.transformSchema(schema) } + + override def copy(extra: ParamMap): CrossValidatorModel = { + val copied = new CrossValidatorModel( + uid, + bestModel.copy(extra).asInstanceOf[Model[_]], + avgMetrics.clone()) + copyValues(copied, extra) + } } diff --git a/mllib/src/main/scala/org/apache/spark/ml/tuning/ParamGridBuilder.scala b/mllib/src/main/scala/org/apache/spark/ml/tuning/ParamGridBuilder.scala index dafe73d82c00..98a8f0330ca4 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/tuning/ParamGridBuilder.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/tuning/ParamGridBuilder.scala @@ -20,14 +20,14 @@ package org.apache.spark.ml.tuning import scala.annotation.varargs import scala.collection.mutable -import org.apache.spark.annotation.AlphaComponent +import org.apache.spark.annotation.Experimental import org.apache.spark.ml.param._ /** - * :: AlphaComponent :: + * :: Experimental :: * Builder for a param grid used in grid search-based model selection. */ -@AlphaComponent +@Experimental class ParamGridBuilder { private val paramGrid = mutable.Map.empty[Param[_], Iterable[_]] diff --git a/mllib/src/main/scala/org/apache/spark/ml/util/Identifiable.scala b/mllib/src/main/scala/org/apache/spark/ml/util/Identifiable.scala index 146697680092..ddd34a54503a 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/util/Identifiable.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/util/Identifiable.scala @@ -23,15 +23,17 @@ import java.util.UUID /** * Trait for an object with an immutable unique ID that identifies itself and its derivatives. */ -trait Identifiable { +private[spark] trait Identifiable { /** * An immutable unique ID for the object and its derivatives. */ val uid: String + + override def toString: String = uid } -object Identifiable { +private[spark] object Identifiable { /** * Returns a random UID that concatenates the given prefix, "_", and 12 random hex chars. diff --git a/mllib/src/main/scala/org/apache/spark/ml/util/MetadataUtils.scala b/mllib/src/main/scala/org/apache/spark/ml/util/MetadataUtils.scala index 56075c9a6b39..2a1db90f2ca2 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/util/MetadataUtils.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/util/MetadataUtils.scala @@ -19,18 +19,14 @@ package org.apache.spark.ml.util import scala.collection.immutable.HashMap -import org.apache.spark.annotation.Experimental import org.apache.spark.ml.attribute._ import org.apache.spark.sql.types.StructField /** - * :: Experimental :: - * * Helper utilities for tree-based algorithms */ -@Experimental -object MetadataUtils { +private[spark] object MetadataUtils { /** * Examine a schema to identify the number of classes in a label column. diff --git a/mllib/src/main/scala/org/apache/spark/ml/util/SchemaUtils.scala b/mllib/src/main/scala/org/apache/spark/ml/util/SchemaUtils.scala index 11592b77eb35..76f651488aef 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/util/SchemaUtils.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/util/SchemaUtils.scala @@ -17,15 +17,13 @@ package org.apache.spark.ml.util -import org.apache.spark.annotation.DeveloperApi import org.apache.spark.sql.types.{DataType, StructField, StructType} + /** - * :: DeveloperApi :: * Utils for handling schemas. */ -@DeveloperApi -object SchemaUtils { +private[spark] object SchemaUtils { // TODO: Move the utility methods to SQL. @@ -34,10 +32,15 @@ object SchemaUtils { * @param colName column name * @param dataType required column data type */ - def checkColumnType(schema: StructType, colName: String, dataType: DataType): Unit = { + def checkColumnType( + schema: StructType, + colName: String, + dataType: DataType, + msg: String = ""): Unit = { val actualDataType = schema(colName).dataType + val message = if (msg != null && msg.trim.length > 0) " " + msg else "" require(actualDataType.equals(dataType), - s"Column $colName must be of type $dataType but was actually $actualDataType.") + s"Column $colName must be of type $dataType but was actually $actualDataType.$message") } /** diff --git a/mllib/src/main/scala/org/apache/spark/mllib/api/python/PowerIterationClusteringModelWrapper.scala b/mllib/src/main/scala/org/apache/spark/mllib/api/python/PowerIterationClusteringModelWrapper.scala new file mode 100644 index 000000000000..bc6041b22173 --- /dev/null +++ b/mllib/src/main/scala/org/apache/spark/mllib/api/python/PowerIterationClusteringModelWrapper.scala @@ -0,0 +1,32 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.mllib.api.python + +import org.apache.spark.rdd.RDD +import org.apache.spark.mllib.clustering.PowerIterationClusteringModel + +/** + * A Wrapper of PowerIterationClusteringModel to provide helper method for Python + */ +private[python] class PowerIterationClusteringModelWrapper(model: PowerIterationClusteringModel) + extends PowerIterationClusteringModel(model.k, model.assignments) { + + def getAssignments: RDD[Array[Any]] = { + model.assignments.map(x => Array(x.id, x.cluster)) + } +} diff --git a/mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala b/mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala index 2fa54df6fc2b..e628059c4af8 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala @@ -28,6 +28,7 @@ import scala.reflect.ClassTag import net.razorvine.pickle._ +import org.apache.spark.SparkContext import org.apache.spark.api.java.{JavaRDD, JavaSparkContext} import org.apache.spark.api.python.SerDeUtil import org.apache.spark.mllib.classification._ @@ -43,13 +44,15 @@ import org.apache.spark.mllib.regression._ import org.apache.spark.mllib.stat.correlation.CorrelationNames import org.apache.spark.mllib.stat.distribution.MultivariateGaussian import org.apache.spark.mllib.stat.test.ChiSqTestResult -import org.apache.spark.mllib.stat.{MultivariateStatisticalSummary, Statistics} +import org.apache.spark.mllib.stat.{ + KernelDensity, MultivariateStatisticalSummary, Statistics} import org.apache.spark.mllib.tree.configuration.{Algo, BoostingStrategy, Strategy} import org.apache.spark.mllib.tree.impurity._ import org.apache.spark.mllib.tree.loss.Losses import org.apache.spark.mllib.tree.model.{DecisionTreeModel, GradientBoostedTreesModel, RandomForestModel} import org.apache.spark.mllib.tree.{DecisionTree, GradientBoostedTrees, RandomForest} import org.apache.spark.mllib.util.MLUtils +import org.apache.spark.mllib.util.LinearDataGenerator import org.apache.spark.rdd.RDD import org.apache.spark.sql.DataFrame import org.apache.spark.storage.StorageLevel @@ -73,6 +76,15 @@ private[python] class PythonMLLibAPI extends Serializable { minPartitions: Int): JavaRDD[LabeledPoint] = MLUtils.loadLabeledPoints(jsc.sc, path, minPartitions) + /** + * Loads and serializes vectors saved with `RDD#saveAsTextFile`. + * @param jsc Java SparkContext + * @param path file or directory path in any Hadoop-supported file system URI + * @return serialized vectors in a RDD + */ + def loadVectors(jsc: JavaSparkContext, path: String): RDD[Vector] = + MLUtils.loadVectors(jsc.sc, path) + private def trainRegressionModel( learner: GeneralizedLinearAlgorithm[_ <: GeneralizedLinearModel], data: JavaRDD[LabeledPoint], @@ -276,7 +288,7 @@ private[python] class PythonMLLibAPI extends Serializable { /** * Java stub for NaiveBayes.train() */ - def trainNaiveBayes( + def trainNaiveBayesModel( data: JavaRDD[LabeledPoint], lambda: Double): JList[Object] = { val model = NaiveBayes.train(data.rdd, lambda) @@ -344,7 +356,7 @@ private[python] class PythonMLLibAPI extends Serializable { * Java stub for Python mllib GaussianMixture.run() * Returns a list containing weights, mean and covariance of each mixture component. */ - def trainGaussianMixture( + def trainGaussianMixtureModel( data: JavaRDD[Vector], k: Int, convergenceTol: Double, @@ -392,18 +404,45 @@ private[python] class PythonMLLibAPI extends Serializable { data: JavaRDD[Vector], wt: Vector, mu: Array[Object], - si: Array[Object]): RDD[Vector] = { + si: Array[Object]): RDD[Vector] = { val weight = wt.toArray val mean = mu.map(_.asInstanceOf[DenseVector]) val sigma = si.map(_.asInstanceOf[DenseMatrix]) val gaussians = Array.tabulate(weight.length){ i => new MultivariateGaussian(mean(i), sigma(i)) - } + } val model = new GaussianMixtureModel(weight, gaussians) model.predictSoft(data).map(Vectors.dense) } + /** + * Java stub for Python mllib PowerIterationClustering.run(). This stub returns a + * handle to the Java object instead of the content of the Java object. Extra care + * needs to be taken in the Python code to ensure it gets freed on exit; see the + * Py4J documentation. + * @param data an RDD of (i, j, s,,ij,,) tuples representing the affinity matrix. + * @param k number of clusters. + * @param maxIterations maximum number of iterations of the power iteration loop. + * @param initMode the initialization mode. This can be either "random" to use + * a random vector as vertex properties, or "degree" to use + * normalized sum similarities. Default: random. + */ + def trainPowerIterationClusteringModel( + data: JavaRDD[Vector], + k: Int, + maxIterations: Int, + initMode: String): PowerIterationClusteringModel = { + + val pic = new PowerIterationClustering() + .setK(k) + .setMaxIterations(maxIterations) + .setInitializationMode(initMode) + + val model = pic.run(data.rdd.map(v => (v(0).toLong, v(1).toLong, v(2)))) + new PowerIterationClusteringModelWrapper(model) + } + /** * Java stub for Python mllib ALS.train(). This stub returns a handle * to the Java object instead of the content of the Java object. Extra care @@ -428,7 +467,7 @@ private[python] class PythonMLLibAPI extends Serializable { if (seed != null) als.setSeed(seed) - val model = als.run(ratingsJRDD.rdd) + val model = als.run(ratingsJRDD.rdd) new MatrixFactorizationModelWrapper(model) } @@ -459,7 +498,7 @@ private[python] class PythonMLLibAPI extends Serializable { if (seed != null) als.setSeed(seed) - val model = als.run(ratingsJRDD.rdd) + val model = als.run(ratingsJRDD.rdd) new MatrixFactorizationModelWrapper(model) } @@ -494,7 +533,7 @@ private[python] class PythonMLLibAPI extends Serializable { def normalizeVector(p: Double, rdd: JavaRDD[Vector]): JavaRDD[Vector] = { new Normalizer(p).transform(rdd) } - + /** * Java stub for StandardScaler.fit(). This stub returns a * handle to the Java object instead of the content of the Java object. @@ -518,6 +557,16 @@ private[python] class PythonMLLibAPI extends Serializable { new ChiSqSelector(numTopFeatures).fit(data.rdd) } + /** + * Java stub for PCA.fit(). This stub returns a + * handle to the Java object instead of the content of the Java object. + * Extra care needs to be taken in the Python code to ensure it gets freed on + * exit; see the Py4J documentation. + */ + def fitPCA(k: Int, data: JavaRDD[Vector]): PCAModel = { + new PCA(k).fit(data.rdd) + } + /** * Java stub for IDF.fit(). This stub returns a * handle to the Java object instead of the content of the Java object. @@ -541,7 +590,7 @@ private[python] class PythonMLLibAPI extends Serializable { * @param seed initial seed for random generator * @return A handle to java Word2VecModelWrapper instance at python side */ - def trainWord2Vec( + def trainWord2VecModel( dataJRDD: JavaRDD[java.util.ArrayList[String]], vectorSize: Int, learningRate: Double, @@ -593,6 +642,8 @@ private[python] class PythonMLLibAPI extends Serializable { def getVectors: JMap[String, JList[Float]] = { model.getVectors.map({case (k, v) => (k, v.toList.asJava)}).asJava } + + def save(sc: SparkContext, path: String): Unit = model.save(sc, path) } /** @@ -685,12 +736,14 @@ private[python] class PythonMLLibAPI extends Serializable { lossStr: String, numIterations: Int, learningRate: Double, - maxDepth: Int): GradientBoostedTreesModel = { + maxDepth: Int, + maxBins: Int): GradientBoostedTreesModel = { val boostingStrategy = BoostingStrategy.defaultParams(algoStr) boostingStrategy.setLoss(Losses.fromString(lossStr)) boostingStrategy.setNumIterations(numIterations) boostingStrategy.setLearningRate(learningRate) boostingStrategy.treeStrategy.setMaxDepth(maxDepth) + boostingStrategy.treeStrategy.setMaxBins(maxBins) boostingStrategy.treeStrategy.categoricalFeaturesInfo = categoricalFeaturesInfo.asScala.toMap val cached = data.rdd.persist(StorageLevel.MEMORY_AND_DISK) @@ -701,6 +754,14 @@ private[python] class PythonMLLibAPI extends Serializable { } } + def elementwiseProductVector(scalingVector: Vector, vector: Vector): Vector = { + new ElementwiseProduct(scalingVector).transform(vector) + } + + def elementwiseProductVector(scalingVector: Vector, vector: JavaRDD[Vector]): JavaRDD[Vector] = { + new ElementwiseProduct(scalingVector).transform(vector) + } + /** * Java stub for mllib Statistics.colStats(X: RDD[Vector]). * TODO figure out return type. @@ -945,7 +1006,60 @@ private[python] class PythonMLLibAPI extends Serializable { r => (r.getSeq(0).toArray[Any], r.getSeq(1).toArray[Any]))) } + /** + * Java stub for the estimate method of KernelDensity + */ + def estimateKernelDensity( + sample: JavaRDD[Double], + bandwidth: Double, points: java.util.ArrayList[Double]): Array[Double] = { + new KernelDensity().setSample(sample).setBandwidth(bandwidth).estimate( + points.asScala.toArray) + } + /** + * Java stub for the update method of StreamingKMeansModel. + */ + def updateStreamingKMeansModel( + clusterCenters: JList[Vector], + clusterWeights: JList[Double], + data: JavaRDD[Vector], + decayFactor: Double, + timeUnit: String): JList[Object] = { + val model = new StreamingKMeansModel( + clusterCenters.asScala.toArray, clusterWeights.asScala.toArray) + .update(data, decayFactor, timeUnit) + List[AnyRef](model.clusterCenters, Vectors.dense(model.clusterWeights)).asJava + } + + /** + * Wrapper around the generateLinearInput method of LinearDataGenerator. + */ + def generateLinearInputWrapper( + intercept: Double, + weights: JList[Double], + xMean: JList[Double], + xVariance: JList[Double], + nPoints: Int, + seed: Int, + eps: Double): Array[LabeledPoint] = { + LinearDataGenerator.generateLinearInput( + intercept, weights.asScala.toArray, xMean.asScala.toArray, + xVariance.asScala.toArray, nPoints, seed, eps).toArray + } + + /** + * Wrapper around the generateLinearRDD method of LinearDataGenerator. + */ + def generateLinearRDDWrapper( + sc: JavaSparkContext, + nexamples: Int, + nfeatures: Int, + eps: Double, + nparts: Int, + intercept: Double): JavaRDD[LabeledPoint] = { + LinearDataGenerator.generateLinearRDD( + sc, nexamples, nfeatures, eps, nparts, intercept) + } } /** @@ -1242,7 +1356,7 @@ private[spark] object SerDe extends Serializable { } /* convert RDD[Tuple2[,]] to RDD[Array[Any]] */ - def fromTuple2RDD(rdd: RDD[(Any, Any)]): RDD[Array[Any]] = { + def fromTuple2RDD(rdd: RDD[(Any, Any)]): RDD[Array[Any]] = { rdd.map(x => Array(x._1, x._2)) } diff --git a/mllib/src/main/scala/org/apache/spark/mllib/classification/LogisticRegression.scala b/mllib/src/main/scala/org/apache/spark/mllib/classification/LogisticRegression.scala index bd2e9079ce1a..2df4d21e8cd5 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/classification/LogisticRegression.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/classification/LogisticRegression.scala @@ -163,7 +163,7 @@ class LogisticRegressionModel ( override protected def formatVersion: String = "1.0" override def toString: String = { - s"${super.toString}, numClasses = ${numClasses}, threshold = ${threshold.get}" + s"${super.toString}, numClasses = ${numClasses}, threshold = ${threshold.getOrElse("None")}" } } diff --git a/mllib/src/main/scala/org/apache/spark/mllib/classification/NaiveBayes.scala b/mllib/src/main/scala/org/apache/spark/mllib/classification/NaiveBayes.scala index af24ab616663..f51ee36d0dfc 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/classification/NaiveBayes.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/classification/NaiveBayes.scala @@ -21,19 +21,16 @@ import java.lang.{Iterable => JIterable} import scala.collection.JavaConverters._ -import breeze.linalg.{Axis, DenseMatrix => BDM, DenseVector => BDV, argmax => brzArgmax, sum => brzSum} -import breeze.numerics.{exp => brzExp, log => brzLog} import org.json4s.JsonDSL._ import org.json4s.jackson.JsonMethods._ import org.apache.spark.{Logging, SparkContext, SparkException} -import org.apache.spark.mllib.linalg.{BLAS, DenseVector, SparseVector, Vector} +import org.apache.spark.mllib.linalg.{BLAS, DenseMatrix, DenseVector, SparseVector, Vector} import org.apache.spark.mllib.regression.LabeledPoint import org.apache.spark.mllib.util.{Loader, Saveable} import org.apache.spark.rdd.RDD import org.apache.spark.sql.{DataFrame, SQLContext} - /** * Model for Naive Bayes Classifiers. * @@ -41,7 +38,7 @@ import org.apache.spark.sql.{DataFrame, SQLContext} * @param pi log of class priors, whose dimension is C, number of labels * @param theta log of class conditional probabilities, whose dimension is C-by-D, * where D is number of features - * @param modelType The type of NB model to fit can be "Multinomial" or "Bernoulli" + * @param modelType The type of NB model to fit can be "multinomial" or "bernoulli" */ class NaiveBayesModel private[mllib] ( val labels: Array[Double], @@ -50,8 +47,13 @@ class NaiveBayesModel private[mllib] ( val modelType: String) extends ClassificationModel with Serializable with Saveable { + import NaiveBayes.{Bernoulli, Multinomial, supportedModelTypes} + + private val piVector = new DenseVector(pi) + private val thetaMatrix = new DenseMatrix(labels.length, theta(0).length, theta.flatten, true) + private[mllib] def this(labels: Array[Double], pi: Array[Double], theta: Array[Array[Double]]) = - this(labels, pi, theta, "Multinomial") + this(labels, pi, theta, NaiveBayes.Multinomial) /** A Java-friendly constructor that takes three Iterable parameters. */ private[mllib] def this( @@ -60,20 +62,24 @@ class NaiveBayesModel private[mllib] ( theta: JIterable[JIterable[Double]]) = this(labels.asScala.toArray, pi.asScala.toArray, theta.asScala.toArray.map(_.asScala.toArray)) - private val brzPi = new BDV[Double](pi) - private val brzTheta = new BDM(theta(0).length, theta.length, theta.flatten).t + require(supportedModelTypes.contains(modelType), + s"Invalid modelType $modelType. Supported modelTypes are $supportedModelTypes.") // Bernoulli scoring requires log(condprob) if 1, log(1-condprob) if 0. - // This precomputes log(1.0 - exp(theta)) and its sum which are used for the linear algebra + // This precomputes log(1.0 - exp(theta)) and its sum which are used for the linear algebra // application of this condition (in predict function). - private val (brzNegTheta, brzNegThetaSum) = modelType match { - case "Multinomial" => (None, None) - case "Bernoulli" => - val negTheta = brzLog((brzExp(brzTheta.copy) :*= (-1.0)) :+= 1.0) // log(1.0 - exp(x)) - (Option(negTheta), Option(brzSum(negTheta, Axis._1))) + private val (thetaMinusNegTheta, negThetaSum) = modelType match { + case Multinomial => (None, None) + case Bernoulli => + val negTheta = thetaMatrix.map(value => math.log(1.0 - math.exp(value))) + val ones = new DenseVector(Array.fill(thetaMatrix.numCols){1.0}) + val thetaMinusNegTheta = thetaMatrix.map { value => + value - math.log(1.0 - math.exp(value)) + } + (Option(thetaMinusNegTheta), Option(negTheta.multiply(ones))) case _ => // This should never happen. - throw new UnknownError(s"NaiveBayesModel was created with an unknown ModelType: $modelType") + throw new UnknownError(s"Invalid modelType: $modelType.") } override def predict(testData: RDD[Vector]): RDD[Double] = { @@ -85,20 +91,25 @@ class NaiveBayesModel private[mllib] ( } override def predict(testData: Vector): Double = { - val brzData = testData.toBreeze modelType match { - case "Multinomial" => - labels(brzArgmax(brzPi + brzTheta * brzData)) - case "Bernoulli" => - if (!brzData.forall(v => v == 0.0 || v == 1.0)) { - throw new SparkException( - s"Bernoulli Naive Bayes requires 0 or 1 feature values but found $testData.") + case Multinomial => + val prob = thetaMatrix.multiply(testData) + BLAS.axpy(1.0, piVector, prob) + labels(prob.argmax) + case Bernoulli => + testData.foreachActive { (index, value) => + if (value != 0.0 && value != 1.0) { + throw new SparkException( + s"Bernoulli naive Bayes requires 0 or 1 feature values but found $testData.") + } } - labels(brzArgmax(brzPi + - (brzTheta - brzNegTheta.get) * brzData + brzNegThetaSum.get)) + val prob = thetaMinusNegTheta.get.multiply(testData) + BLAS.axpy(1.0, piVector, prob) + BLAS.axpy(1.0, negThetaSum.get, prob) + labels(prob.argmax) case _ => // This should never happen. - throw new UnknownError(s"NaiveBayesModel was created with an unknown ModelType: $modelType") + throw new UnknownError(s"Invalid modelType: $modelType.") } } @@ -140,13 +151,13 @@ object NaiveBayesModel extends Loader[NaiveBayesModel] { // Create Parquet data. val dataRDD: DataFrame = sc.parallelize(Seq(data), 1).toDF() - dataRDD.saveAsParquetFile(dataPath(path)) + dataRDD.write.parquet(dataPath(path)) } def load(sc: SparkContext, path: String): NaiveBayesModel = { val sqlContext = new SQLContext(sc) // Load Parquet data. - val dataRDD = sqlContext.parquetFile(dataPath(path)) + val dataRDD = sqlContext.read.parquet(dataPath(path)) // Check schema explicitly since erasure makes it hard to use match-case for checking. checkSchema[Data](dataRDD.schema) val dataArray = dataRDD.select("labels", "pi", "theta", "modelType").take(1) @@ -186,13 +197,13 @@ object NaiveBayesModel extends Loader[NaiveBayesModel] { // Create Parquet data. val dataRDD: DataFrame = sc.parallelize(Seq(data), 1).toDF() - dataRDD.saveAsParquetFile(dataPath(path)) + dataRDD.write.parquet(dataPath(path)) } def load(sc: SparkContext, path: String): NaiveBayesModel = { val sqlContext = new SQLContext(sc) // Load Parquet data. - val dataRDD = sqlContext.parquetFile(dataPath(path)) + val dataRDD = sqlContext.read.parquet(dataPath(path)) // Check schema explicitly since erasure makes it hard to use match-case for checking. checkSchema[Data](dataRDD.schema) val dataArray = dataRDD.select("labels", "pi", "theta").take(1) @@ -223,16 +234,16 @@ object NaiveBayesModel extends Loader[NaiveBayesModel] { s"($loadedClassName, $version). Supported:\n" + s" ($classNameV1_0, 1.0)") } - assert(model.pi.size == numClasses, + assert(model.pi.length == numClasses, s"NaiveBayesModel.load expected $numClasses classes," + - s" but class priors vector pi had ${model.pi.size} elements") - assert(model.theta.size == numClasses, + s" but class priors vector pi had ${model.pi.length} elements") + assert(model.theta.length == numClasses, s"NaiveBayesModel.load expected $numClasses classes," + - s" but class conditionals array theta had ${model.theta.size} elements") - assert(model.theta.forall(_.size == numFeatures), + s" but class conditionals array theta had ${model.theta.length} elements") + assert(model.theta.forall(_.length == numFeatures), s"NaiveBayesModel.load expected $numFeatures features," + s" but class conditionals array theta had elements of size:" + - s" ${model.theta.map(_.size).mkString(",")}") + s" ${model.theta.map(_.length).mkString(",")}") model } } @@ -250,9 +261,11 @@ class NaiveBayes private ( private var lambda: Double, private var modelType: String) extends Serializable with Logging { - def this(lambda: Double) = this(lambda, "Multinomial") + import NaiveBayes.{Bernoulli, Multinomial} - def this() = this(1.0, "Multinomial") + def this(lambda: Double) = this(lambda, NaiveBayes.Multinomial) + + def this() = this(1.0, NaiveBayes.Multinomial) /** Set the smoothing parameter. Default: 1.0. */ def setLambda(lambda: Double): NaiveBayes = { @@ -265,12 +278,11 @@ class NaiveBayes private ( /** * Set the model type using a string (case-sensitive). - * Supported options: "Multinomial" and "Bernoulli". - * (default: Multinomial) + * Supported options: "multinomial" (default) and "bernoulli". */ - def setModelType(modelType:String): NaiveBayes = { + def setModelType(modelType: String): NaiveBayes = { require(NaiveBayes.supportedModelTypes.contains(modelType), - s"NaiveBayes was created with an unknown ModelType: $modelType") + s"NaiveBayes was created with an unknown modelType: $modelType.") this.modelType = modelType this } @@ -301,7 +313,7 @@ class NaiveBayes private ( } if (!values.forall(v => v == 0.0 || v == 1.0)) { throw new SparkException( - s"Bernoulli Naive Bayes requires 0 or 1 feature values but found $v.") + s"Bernoulli naive Bayes requires 0 or 1 feature values but found $v.") } } @@ -310,7 +322,7 @@ class NaiveBayes private ( // TODO: similar to reduceByKeyLocally to save one stage. val aggregated = data.map(p => (p.label, p.features)).combineByKey[(Long, DenseVector)]( createCombiner = (v: Vector) => { - if (modelType == "Bernoulli") { + if (modelType == Bernoulli) { requireZeroOneBernoulliValues(v) } else { requireNonnegativeValues(v) @@ -345,11 +357,11 @@ class NaiveBayes private ( labels(i) = label pi(i) = math.log(n + lambda) - piLogDenom val thetaLogDenom = modelType match { - case "Multinomial" => math.log(sumTermFreqs.values.sum + numFeatures * lambda) - case "Bernoulli" => math.log(n + 2.0 * lambda) + case Multinomial => math.log(sumTermFreqs.values.sum + numFeatures * lambda) + case Bernoulli => math.log(n + 2.0 * lambda) case _ => // This should never happen. - throw new UnknownError(s"NaiveBayes was created with an unknown ModelType: $modelType") + throw new UnknownError(s"Invalid modelType: $modelType.") } var j = 0 while (j < numFeatures) { @@ -368,8 +380,14 @@ class NaiveBayes private ( */ object NaiveBayes { + /** String name for multinomial model type. */ + private[classification] val Multinomial: String = "multinomial" + + /** String name for Bernoulli model type. */ + private[classification] val Bernoulli: String = "bernoulli" + /* Set of modelTypes that NaiveBayes supports */ - private[mllib] val supportedModelTypes = Set("Multinomial", "Bernoulli") + private[classification] val supportedModelTypes = Set(Multinomial, Bernoulli) /** * Trains a Naive Bayes model given an RDD of `(label, features)` pairs. @@ -399,7 +417,7 @@ object NaiveBayes { * @param lambda The smoothing parameter */ def train(input: RDD[LabeledPoint], lambda: Double): NaiveBayesModel = { - new NaiveBayes(lambda, "Multinomial").run(input) + new NaiveBayes(lambda, Multinomial).run(input) } /** @@ -422,7 +440,7 @@ object NaiveBayes { */ def train(input: RDD[LabeledPoint], lambda: Double, modelType: String): NaiveBayesModel = { require(supportedModelTypes.contains(modelType), - s"NaiveBayes was created with an unknown ModelType: $modelType") + s"NaiveBayes was created with an unknown modelType: $modelType.") new NaiveBayes(lambda, modelType).run(input) } diff --git a/mllib/src/main/scala/org/apache/spark/mllib/classification/SVM.scala b/mllib/src/main/scala/org/apache/spark/mllib/classification/SVM.scala index 33104cf06c6e..348485560713 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/classification/SVM.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/classification/SVM.scala @@ -89,7 +89,7 @@ class SVMModel ( override protected def formatVersion: String = "1.0" override def toString: String = { - s"${super.toString}, numClasses = 2, threshold = ${threshold.get}" + s"${super.toString}, numClasses = 2, threshold = ${threshold.getOrElse("None")}" } } diff --git a/mllib/src/main/scala/org/apache/spark/mllib/classification/impl/GLMClassificationModel.scala b/mllib/src/main/scala/org/apache/spark/mllib/classification/impl/GLMClassificationModel.scala index 3b6790cce47c..fe09f6b75d28 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/classification/impl/GLMClassificationModel.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/classification/impl/GLMClassificationModel.scala @@ -62,7 +62,7 @@ private[classification] object GLMClassificationModel { // Create Parquet data. val data = Data(weights, intercept, threshold) - sc.parallelize(Seq(data), 1).toDF().saveAsParquetFile(Loader.dataPath(path)) + sc.parallelize(Seq(data), 1).toDF().write.parquet(Loader.dataPath(path)) } /** @@ -75,7 +75,7 @@ private[classification] object GLMClassificationModel { def loadData(sc: SparkContext, path: String, modelClass: String): Data = { val datapath = Loader.dataPath(path) val sqlContext = new SQLContext(sc) - val dataRDD = sqlContext.parquetFile(datapath) + val dataRDD = sqlContext.read.parquet(datapath) val dataArray = dataRDD.select("weights", "intercept", "threshold").take(1) assert(dataArray.size == 1, s"Unable to load $modelClass data from: $datapath") val data = dataArray(0) diff --git a/mllib/src/main/scala/org/apache/spark/mllib/clustering/GaussianMixture.scala b/mllib/src/main/scala/org/apache/spark/mllib/clustering/GaussianMixture.scala index c88410ac0ff4..fc509d2ba147 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/clustering/GaussianMixture.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/clustering/GaussianMixture.scala @@ -22,6 +22,7 @@ import scala.collection.mutable.IndexedSeq import breeze.linalg.{diag, DenseMatrix => BreezeMatrix, DenseVector => BDV, Vector => BV} import org.apache.spark.annotation.Experimental +import org.apache.spark.api.java.JavaRDD import org.apache.spark.mllib.linalg.{BLAS, DenseMatrix, Matrices, Vector, Vectors} import org.apache.spark.mllib.stat.distribution.MultivariateGaussian import org.apache.spark.mllib.util.MLUtils @@ -36,11 +37,11 @@ import org.apache.spark.util.Utils * independent Gaussian distributions with associated "mixing" weights * specifying each's contribution to the composite. * - * Given a set of sample points, this class will maximize the log-likelihood - * for a mixture of k Gaussians, iterating until the log-likelihood changes by + * Given a set of sample points, this class will maximize the log-likelihood + * for a mixture of k Gaussians, iterating until the log-likelihood changes by * less than convergenceTol, or until it has reached the max number of iterations. * While this process is generally guaranteed to converge, it is not guaranteed - * to find a global optimum. + * to find a global optimum. * * Note: For high-dimensional data (with many features), this algorithm may perform poorly. * This is due to high-dimensional data (a) making it difficult to cluster at all (based @@ -53,24 +54,24 @@ import org.apache.spark.util.Utils */ @Experimental class GaussianMixture private ( - private var k: Int, - private var convergenceTol: Double, + private var k: Int, + private var convergenceTol: Double, private var maxIterations: Int, private var seed: Long) extends Serializable { - + /** * Constructs a default instance. The default parameters are {k: 2, convergenceTol: 0.01, * maxIterations: 100, seed: random}. */ def this() = this(2, 0.01, 100, Utils.random.nextLong()) - + // number of samples per cluster to use when initializing Gaussians private val nSamples = 5 - - // an initializing GMM can be provided rather than using the + + // an initializing GMM can be provided rather than using the // default random starting point private var initialModel: Option[GaussianMixtureModel] = None - + /** Set the initial GMM starting point, bypassing the random initialization. * You must call setK() prior to calling this method, and the condition * (model.k == this.k) must be met; failure will result in an IllegalArgumentException @@ -83,37 +84,37 @@ class GaussianMixture private ( } this } - + /** Return the user supplied initial GMM, if supplied */ def getInitialModel: Option[GaussianMixtureModel] = initialModel - + /** Set the number of Gaussians in the mixture model. Default: 2 */ def setK(k: Int): this.type = { this.k = k this } - + /** Return the number of Gaussians in the mixture model */ def getK: Int = k - + /** Set the maximum number of iterations to run. Default: 100 */ def setMaxIterations(maxIterations: Int): this.type = { this.maxIterations = maxIterations this } - + /** Return the maximum number of iterations to run */ def getMaxIterations: Int = maxIterations - + /** - * Set the largest change in log-likelihood at which convergence is + * Set the largest change in log-likelihood at which convergence is * considered to have occurred. */ def setConvergenceTol(convergenceTol: Double): this.type = { this.convergenceTol = convergenceTol this } - + /** * Return the largest change in log-likelihood at which convergence is * considered to have occurred. @@ -132,41 +133,41 @@ class GaussianMixture private ( /** Perform expectation maximization */ def run(data: RDD[Vector]): GaussianMixtureModel = { val sc = data.sparkContext - + // we will operate on the data as breeze data val breezeData = data.map(_.toBreeze).cache() - + // Get length of the input vectors val d = breezeData.first().length - + // Determine initial weights and corresponding Gaussians. // If the user supplied an initial GMM, we use those values, otherwise // we start with uniform weights, a random mean from the data, and // diagonal covariance matrices using component variances - // derived from the samples + // derived from the samples val (weights, gaussians) = initialModel match { case Some(gmm) => (gmm.weights, gmm.gaussians) - + case None => { val samples = breezeData.takeSample(withReplacement = true, k * nSamples, seed) - (Array.fill(k)(1.0 / k), Array.tabulate(k) { i => + (Array.fill(k)(1.0 / k), Array.tabulate(k) { i => val slice = samples.view(i * nSamples, (i + 1) * nSamples) - new MultivariateGaussian(vectorMean(slice), initCovariance(slice)) + new MultivariateGaussian(vectorMean(slice), initCovariance(slice)) }) } } - - var llh = Double.MinValue // current log-likelihood + + var llh = Double.MinValue // current log-likelihood var llhp = 0.0 // previous log-likelihood - + var iter = 0 while (iter < maxIterations && math.abs(llh-llhp) > convergenceTol) { // create and broadcast curried cluster contribution function val compute = sc.broadcast(ExpectationSum.add(weights, gaussians)_) - + // aggregate the cluster contribution for all sample points val sums = breezeData.aggregate(ExpectationSum.zero(k, d))(compute.value, _ += _) - + // Create new distributions based on the partial assignments // (often referred to as the "M" step in literature) val sumWeights = sums.weights.sum @@ -179,22 +180,25 @@ class GaussianMixture private ( gaussians(i) = new MultivariateGaussian(mu, sums.sigmas(i) / sums.weights(i)) i = i + 1 } - + llhp = llh // current becomes previous llh = sums.logLikelihood // this is the freshly computed log-likelihood iter += 1 - } - + } + new GaussianMixtureModel(weights, gaussians) } - + + /** Java-friendly version of [[run()]] */ + def run(data: JavaRDD[Vector]): GaussianMixtureModel = run(data.rdd) + /** Average of dense breeze vectors */ private def vectorMean(x: IndexedSeq[BV[Double]]): BDV[Double] = { val v = BDV.zeros[Double](x(0).length) x.foreach(xi => v += xi) - v / x.length.toDouble + v / x.length.toDouble } - + /** * Construct matrix where diagonal entries are element-wise * variance of input vectors (computes biased variance) @@ -210,14 +214,14 @@ class GaussianMixture private ( // companion class to provide zero constructor for ExpectationSum private object ExpectationSum { def zero(k: Int, d: Int): ExpectationSum = { - new ExpectationSum(0.0, Array.fill(k)(0.0), - Array.fill(k)(BDV.zeros(d)), Array.fill(k)(BreezeMatrix.zeros(d,d))) + new ExpectationSum(0.0, Array.fill(k)(0.0), + Array.fill(k)(BDV.zeros(d)), Array.fill(k)(BreezeMatrix.zeros(d, d))) } - + // compute cluster contributions for each input point // (U, T) => U for aggregation def add( - weights: Array[Double], + weights: Array[Double], dists: Array[MultivariateGaussian]) (sums: ExpectationSum, x: BV[Double]): ExpectationSum = { val p = weights.zip(dists).map { @@ -235,7 +239,7 @@ private object ExpectationSum { i = i + 1 } sums - } + } } // Aggregation class for partial expectation results @@ -244,9 +248,9 @@ private class ExpectationSum( val weights: Array[Double], val means: Array[BDV[Double]], val sigmas: Array[BreezeMatrix[Double]]) extends Serializable { - + val k = weights.length - + def +=(x: ExpectationSum): ExpectationSum = { var i = 0 while (i < k) { @@ -257,5 +261,5 @@ private class ExpectationSum( } logLikelihood += x.logLikelihood this - } + } } diff --git a/mllib/src/main/scala/org/apache/spark/mllib/clustering/GaussianMixtureModel.scala b/mllib/src/main/scala/org/apache/spark/mllib/clustering/GaussianMixtureModel.scala index c22862c130e7..cb807c803810 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/clustering/GaussianMixtureModel.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/clustering/GaussianMixtureModel.scala @@ -25,6 +25,7 @@ import org.json4s.jackson.JsonMethods._ import org.apache.spark.SparkContext import org.apache.spark.annotation.Experimental +import org.apache.spark.api.java.JavaRDD import org.apache.spark.mllib.linalg.{Vector, Matrices, Matrix} import org.apache.spark.mllib.stat.distribution.MultivariateGaussian import org.apache.spark.mllib.util.{MLUtils, Loader, Saveable} @@ -34,10 +35,10 @@ import org.apache.spark.sql.{SQLContext, Row} /** * :: Experimental :: * - * Multivariate Gaussian Mixture Model (GMM) consisting of k Gaussians, where points - * are drawn from each Gaussian i=1..k with probability w(i); mu(i) and sigma(i) are - * the respective mean and covariance for each Gaussian distribution i=1..k. - * + * Multivariate Gaussian Mixture Model (GMM) consisting of k Gaussians, where points + * are drawn from each Gaussian i=1..k with probability w(i); mu(i) and sigma(i) are + * the respective mean and covariance for each Gaussian distribution i=1..k. + * * @param weights Weights for each Gaussian distribution in the mixture, where weights(i) is * the weight for Gaussian i, and weights.sum == 1 * @param gaussians Array of MultivariateGaussian where gaussians(i) represents @@ -45,9 +46,9 @@ import org.apache.spark.sql.{SQLContext, Row} */ @Experimental class GaussianMixtureModel( - val weights: Array[Double], - val gaussians: Array[MultivariateGaussian]) extends Serializable with Saveable{ - + val weights: Array[Double], + val gaussians: Array[MultivariateGaussian]) extends Serializable with Saveable { + require(weights.length == gaussians.length, "Length of weight and Gaussian arrays must match") override protected def formatVersion = "1.0" @@ -64,20 +65,24 @@ class GaussianMixtureModel( val responsibilityMatrix = predictSoft(points) responsibilityMatrix.map(r => r.indexOf(r.max)) } - + + /** Java-friendly version of [[predict()]] */ + def predict(points: JavaRDD[Vector]): JavaRDD[java.lang.Integer] = + predict(points.rdd).toJavaRDD().asInstanceOf[JavaRDD[java.lang.Integer]] + /** * Given the input vectors, return the membership value of each vector - * to all mixture components. + * to all mixture components. */ def predictSoft(points: RDD[Vector]): RDD[Array[Double]] = { val sc = points.sparkContext val bcDists = sc.broadcast(gaussians) val bcWeights = sc.broadcast(weights) - points.map { x => + points.map { x => computeSoftAssignments(x.toBreeze.toDenseVector, bcDists.value, bcWeights.value, k) } } - + /** * Compute the partial assignments for each vector */ @@ -89,7 +94,7 @@ class GaussianMixtureModel( val p = weights.zip(dists).map { case (weight, dist) => MLUtils.EPSILON + weight * dist.pdf(pt) } - val pSum = p.sum + val pSum = p.sum for (i <- 0 until k) { p(i) /= pSum } @@ -126,13 +131,13 @@ object GaussianMixtureModel extends Loader[GaussianMixtureModel] { val dataArray = Array.tabulate(weights.length) { i => Data(weights(i), gaussians(i).mu, gaussians(i).sigma) } - sc.parallelize(dataArray, 1).toDF().saveAsParquetFile(Loader.dataPath(path)) + sc.parallelize(dataArray, 1).toDF().write.parquet(Loader.dataPath(path)) } def load(sc: SparkContext, path: String): GaussianMixtureModel = { val dataPath = Loader.dataPath(path) val sqlContext = new SQLContext(sc) - val dataFrame = sqlContext.parquetFile(dataPath) + val dataFrame = sqlContext.read.parquet(dataPath) val dataArray = dataFrame.select("weight", "mu", "sigma").collect() // Check schema explicitly since erasure makes it hard to use match-case for checking. diff --git a/mllib/src/main/scala/org/apache/spark/mllib/clustering/KMeansModel.scala b/mllib/src/main/scala/org/apache/spark/mllib/clustering/KMeansModel.scala index ba228b11fcec..8ecb3df11d95 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/clustering/KMeansModel.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/clustering/KMeansModel.scala @@ -110,7 +110,7 @@ object KMeansModel extends Loader[KMeansModel] { val dataRDD = sc.parallelize(model.clusterCenters.zipWithIndex).map { case (point, id) => Cluster(id, point) }.toDF() - dataRDD.saveAsParquetFile(Loader.dataPath(path)) + dataRDD.write.parquet(Loader.dataPath(path)) } def load(sc: SparkContext, path: String): KMeansModel = { @@ -120,7 +120,7 @@ object KMeansModel extends Loader[KMeansModel] { assert(className == thisClassName) assert(formatVersion == thisFormatVersion) val k = (metadata \ "k").extract[Int] - val centriods = sqlContext.parquetFile(Loader.dataPath(path)) + val centriods = sqlContext.read.parquet(Loader.dataPath(path)) Loader.checkSchema[Cluster](centriods.schema) val localCentriods = centriods.map(Cluster.apply).collect() assert(k == localCentriods.size) diff --git a/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDAModel.scala b/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDAModel.scala index 6cf26445f20a..974b26924dfb 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDAModel.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDAModel.scala @@ -20,6 +20,7 @@ package org.apache.spark.mllib.clustering import breeze.linalg.{DenseMatrix => BDM, normalize, sum => brzSum} import org.apache.spark.annotation.Experimental +import org.apache.spark.api.java.JavaPairRDD import org.apache.spark.graphx.{VertexId, EdgeContext, Graph} import org.apache.spark.mllib.linalg.{Vectors, Vector, Matrices, Matrix} import org.apache.spark.rdd.RDD @@ -345,6 +346,11 @@ class DistributedLDAModel private ( } } + /** Java-friendly version of [[topicDistributions]] */ + def javaTopicDistributions: JavaPairRDD[java.lang.Long, Vector] = { + JavaPairRDD.fromRDD(topicDistributions.asInstanceOf[RDD[(java.lang.Long, Vector)]]) + } + // TODO: // override def topicDistributions(documents: RDD[(Long, Vector)]): RDD[(Long, Vector)] = ??? diff --git a/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDAOptimizer.scala b/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDAOptimizer.scala index 6fa2fe053c6a..8e5154b902d1 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDAOptimizer.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDAOptimizer.scala @@ -273,7 +273,7 @@ final class OnlineLDAOptimizer extends LDAOptimizer { * Default: 1024, following the original Online LDA paper. */ def setTau0(tau0: Double): this.type = { - require(tau0 > 0, s"LDA tau0 must be positive, but was set to $tau0") + require(tau0 > 0, s"LDA tau0 must be positive, but was set to $tau0") this.tau0 = tau0 this } @@ -339,7 +339,7 @@ final class OnlineLDAOptimizer extends LDAOptimizer { override private[clustering] def initialize( docs: RDD[(Long, Vector)], - lda: LDA): OnlineLDAOptimizer = { + lda: LDA): OnlineLDAOptimizer = { this.k = lda.getK this.corpusSize = docs.count() this.vocabSize = docs.first()._2.size @@ -458,7 +458,7 @@ final class OnlineLDAOptimizer extends LDAOptimizer { * uses digamma which is accurate but expensive. */ private def dirichletExpectation(alpha: BDM[Double]): BDM[Double] = { - val rowSum = sum(alpha(breeze.linalg.*, ::)) + val rowSum = sum(alpha(breeze.linalg.*, ::)) val digAlpha = digamma(alpha) val digRowSum = digamma(rowSum) val result = digAlpha(::, breeze.linalg.*) - digRowSum diff --git a/mllib/src/main/scala/org/apache/spark/mllib/clustering/PowerIterationClustering.scala b/mllib/src/main/scala/org/apache/spark/mllib/clustering/PowerIterationClustering.scala index aa53e88d5985..e7a243f854e3 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/clustering/PowerIterationClustering.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/clustering/PowerIterationClustering.scala @@ -74,7 +74,7 @@ object PowerIterationClusteringModel extends Loader[PowerIterationClusteringMode sc.parallelize(Seq(metadata), 1).saveAsTextFile(Loader.metadataPath(path)) val dataRDD = model.assignments.toDF() - dataRDD.saveAsParquetFile(Loader.dataPath(path)) + dataRDD.write.parquet(Loader.dataPath(path)) } def load(sc: SparkContext, path: String): PowerIterationClusteringModel = { @@ -86,7 +86,7 @@ object PowerIterationClusteringModel extends Loader[PowerIterationClusteringMode assert(formatVersion == thisFormatVersion) val k = (metadata \ "k").extract[Int] - val assignments = sqlContext.parquetFile(Loader.dataPath(path)) + val assignments = sqlContext.read.parquet(Loader.dataPath(path)) Loader.checkSchema[PowerIterationClustering.Assignment](assignments.schema) val assignmentsRDD = assignments.map { @@ -121,7 +121,7 @@ class PowerIterationClustering private[clustering] ( import org.apache.spark.mllib.clustering.PowerIterationClustering._ /** Constructs a PIC instance with default parameters: {k: 2, maxIterations: 100, - * initMode: "random"}. + * initMode: "random"}. */ def this() = this(k = 2, maxIterations = 100, initMode = "random") @@ -243,7 +243,7 @@ object PowerIterationClustering extends Logging { /** * Generates random vertex properties (v0) to start power iteration. - * + * * @param g a graph representing the normalized affinity matrix (W) * @return a graph with edges representing W and vertices representing a random vector * with unit 1-norm @@ -266,7 +266,7 @@ object PowerIterationClustering extends Logging { * Generates the degree vector as the vertex properties (v0) to start power iteration. * It is not exactly the node degrees but just the normalized sum similarities. Call it * as degree vector because it is used in the PIC paper. - * + * * @param g a graph representing the normalized affinity matrix (W) * @return a graph with edges representing W and vertices representing the degree vector */ @@ -276,7 +276,7 @@ object PowerIterationClustering extends Logging { val v0 = g.vertices.mapValues(_ / sum) GraphImpl.fromExistingRDDs(VertexRDD(v0), g.edges) } - + /** * Runs power iteration. * @param g input graph with edges representing the normalized affinity matrix (W) and vertices diff --git a/mllib/src/main/scala/org/apache/spark/mllib/clustering/StreamingKMeans.scala b/mllib/src/main/scala/org/apache/spark/mllib/clustering/StreamingKMeans.scala index 812014a04171..d9b34cec6489 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/clustering/StreamingKMeans.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/clustering/StreamingKMeans.scala @@ -21,8 +21,10 @@ import scala.reflect.ClassTag import org.apache.spark.Logging import org.apache.spark.annotation.Experimental +import org.apache.spark.api.java.JavaSparkContext._ import org.apache.spark.mllib.linalg.{BLAS, Vector, Vectors} import org.apache.spark.rdd.RDD +import org.apache.spark.streaming.api.java.{JavaPairDStream, JavaDStream} import org.apache.spark.streaming.dstream.DStream import org.apache.spark.util.Utils import org.apache.spark.util.random.XORShiftRandom @@ -178,7 +180,7 @@ class StreamingKMeans( /** Set the decay factor directly (for forgetful algorithms). */ def setDecayFactor(a: Double): this.type = { - this.decayFactor = decayFactor + this.decayFactor = a this } @@ -234,6 +236,9 @@ class StreamingKMeans( } } + /** Java-friendly version of `trainOn`. */ + def trainOn(data: JavaDStream[Vector]): Unit = trainOn(data.dstream) + /** * Use the clustering model to make predictions on batches of data from a DStream. * @@ -245,6 +250,11 @@ class StreamingKMeans( data.map(model.predict) } + /** Java-friendly version of `predictOn`. */ + def predictOn(data: JavaDStream[Vector]): JavaDStream[java.lang.Integer] = { + JavaDStream.fromDStream(predictOn(data.dstream).asInstanceOf[DStream[java.lang.Integer]]) + } + /** * Use the model to make predictions on the values of a DStream and carry over its keys. * @@ -257,6 +267,14 @@ class StreamingKMeans( data.mapValues(model.predict) } + /** Java-friendly version of `predictOnValues`. */ + def predictOnValues[K]( + data: JavaPairDStream[K, Vector]): JavaPairDStream[K, java.lang.Integer] = { + implicit val tag = fakeClassTag[K] + JavaPairDStream.fromPairDStream( + predictOnValues(data.dstream).asInstanceOf[DStream[(K, java.lang.Integer)]]) + } + /** Check whether cluster centers have been initialized. */ private[this] def assertInitialized(): Unit = { if (model.clusterCenters == null) { diff --git a/mllib/src/main/scala/org/apache/spark/mllib/evaluation/MultilabelMetrics.scala b/mllib/src/main/scala/org/apache/spark/mllib/evaluation/MultilabelMetrics.scala index a8378a76d20a..bf6eb1d5bd2a 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/evaluation/MultilabelMetrics.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/evaluation/MultilabelMetrics.scala @@ -19,6 +19,7 @@ package org.apache.spark.mllib.evaluation import org.apache.spark.rdd.RDD import org.apache.spark.SparkContext._ +import org.apache.spark.sql.DataFrame /** * Evaluator for multilabel classification. @@ -27,6 +28,13 @@ import org.apache.spark.SparkContext._ */ class MultilabelMetrics(predictionAndLabels: RDD[(Array[Double], Array[Double])]) { + /** + * An auxiliary constructor taking a DataFrame. + * @param predictionAndLabels a DataFrame with two double array columns: prediction and label + */ + private[mllib] def this(predictionAndLabels: DataFrame) = + this(predictionAndLabels.map(r => (r.getSeq[Double](0).toArray, r.getSeq[Double](1).toArray))) + private lazy val numDocs: Long = predictionAndLabels.count() private lazy val numLabels: Long = predictionAndLabels.flatMap { case (_, labels) => diff --git a/mllib/src/main/scala/org/apache/spark/mllib/evaluation/RankingMetrics.scala b/mllib/src/main/scala/org/apache/spark/mllib/evaluation/RankingMetrics.scala index b9b54b93c27f..5b5a2a1450f7 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/evaluation/RankingMetrics.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/evaluation/RankingMetrics.scala @@ -31,6 +31,8 @@ import org.apache.spark.rdd.RDD * ::Experimental:: * Evaluator for ranking algorithms. * + * Java users should use [[RankingMetrics$.of]] to create a [[RankingMetrics]] instance. + * * @param predictionAndLabels an RDD of (predicted ranking, ground truth set) pairs. */ @Experimental diff --git a/mllib/src/main/scala/org/apache/spark/mllib/feature/ChiSqSelector.scala b/mllib/src/main/scala/org/apache/spark/mllib/feature/ChiSqSelector.scala index 9cc2d0ffcab7..5f8c1dea237b 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/feature/ChiSqSelector.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/feature/ChiSqSelector.scala @@ -108,7 +108,7 @@ class ChiSqSelectorModel (val selectedFeatures: Array[Int]) extends VectorTransf * (ordered by statistic value descending) */ @Experimental -class ChiSqSelector (val numTopFeatures: Int) { +class ChiSqSelector (val numTopFeatures: Int) extends Serializable { /** * Returns a ChiSquared feature selector. diff --git a/mllib/src/main/scala/org/apache/spark/mllib/feature/ElementwiseProduct.scala b/mllib/src/main/scala/org/apache/spark/mllib/feature/ElementwiseProduct.scala index b0985baf9b27..d67fe6c3ee4f 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/feature/ElementwiseProduct.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/feature/ElementwiseProduct.scala @@ -25,10 +25,10 @@ import org.apache.spark.mllib.linalg._ * Outputs the Hadamard product (i.e., the element-wise product) of each input vector with a * provided "weight" vector. In other words, it scales each column of the dataset by a scalar * multiplier. - * @param scalingVector The values used to scale the reference vector's individual components. + * @param scalingVec The values used to scale the reference vector's individual components. */ @Experimental -class ElementwiseProduct(val scalingVector: Vector) extends VectorTransformer { +class ElementwiseProduct(val scalingVec: Vector) extends VectorTransformer { /** * Does the hadamard product transformation. @@ -37,15 +37,15 @@ class ElementwiseProduct(val scalingVector: Vector) extends VectorTransformer { * @return transformed vector. */ override def transform(vector: Vector): Vector = { - require(vector.size == scalingVector.size, - s"vector sizes do not match: Expected ${scalingVector.size} but found ${vector.size}") + require(vector.size == scalingVec.size, + s"vector sizes do not match: Expected ${scalingVec.size} but found ${vector.size}") vector match { case dv: DenseVector => val values: Array[Double] = dv.values.clone() - val dim = scalingVector.size + val dim = scalingVec.size var i = 0 while (i < dim) { - values(i) *= scalingVector(i) + values(i) *= scalingVec(i) i += 1 } Vectors.dense(values) @@ -54,7 +54,7 @@ class ElementwiseProduct(val scalingVector: Vector) extends VectorTransformer { val dim = values.length var i = 0 while (i < dim) { - values(i) *= scalingVector(indices(i)) + values(i) *= scalingVec(indices(i)) i += 1 } Vectors.sparse(size, indices, values) diff --git a/mllib/src/main/scala/org/apache/spark/mllib/feature/IDF.scala b/mllib/src/main/scala/org/apache/spark/mllib/feature/IDF.scala index a89eea0e21be..3fab7ea79bef 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/feature/IDF.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/feature/IDF.scala @@ -144,7 +144,7 @@ private object IDF { * Since arrays are initialized to 0 by default, * we just omit changing those entries. */ - if(df(j) >= minDocFreq) { + if (df(j) >= minDocFreq) { inv(j) = math.log((m + 1.0) / (df(j) + 1.0)) } j += 1 @@ -159,7 +159,7 @@ private object IDF { * Represents an IDF model that can transform term frequency vectors. */ @Experimental -class IDFModel private[mllib] (val idf: Vector) extends Serializable { +class IDFModel private[spark] (val idf: Vector) extends Serializable { /** * Transforms term frequency (TF) vectors to TF-IDF vectors. diff --git a/mllib/src/main/scala/org/apache/spark/mllib/feature/PCA.scala b/mllib/src/main/scala/org/apache/spark/mllib/feature/PCA.scala index 4e01e402b428..2a66263d8b7d 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/feature/PCA.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/feature/PCA.scala @@ -68,7 +68,7 @@ class PCA(val k: Int) { * @param k number of principal components. * @param pc a principal components Matrix. Each column is one principal component. */ -class PCAModel private[mllib] (val k: Int, val pc: DenseMatrix) extends VectorTransformer { +class PCAModel private[spark] (val k: Int, val pc: DenseMatrix) extends VectorTransformer { /** * Transform a vector by computed Principal Components. * diff --git a/mllib/src/main/scala/org/apache/spark/mllib/feature/StandardScaler.scala b/mllib/src/main/scala/org/apache/spark/mllib/feature/StandardScaler.scala index 6ae6917eae59..c73b8f258060 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/feature/StandardScaler.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/feature/StandardScaler.scala @@ -90,7 +90,7 @@ class StandardScalerModel ( @DeveloperApi def setWithMean(withMean: Boolean): this.type = { - require(!(withMean && this.mean == null),"cannot set withMean to true while mean is null") + require(!(withMean && this.mean == null), "cannot set withMean to true while mean is null") this.withMean = withMean this } diff --git a/mllib/src/main/scala/org/apache/spark/mllib/feature/Word2Vec.scala b/mllib/src/main/scala/org/apache/spark/mllib/feature/Word2Vec.scala index 98e83112f52a..f087d06d2a46 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/feature/Word2Vec.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/feature/Word2Vec.scala @@ -42,32 +42,32 @@ import org.apache.spark.util.random.XORShiftRandom import org.apache.spark.sql.{SQLContext, Row} /** - * Entry in vocabulary + * Entry in vocabulary */ private case class VocabWord( var word: String, var cn: Int, var point: Array[Int], var code: Array[Int], - var codeLen:Int + var codeLen: Int ) /** * :: Experimental :: * Word2Vec creates vector representation of words in a text corpus. * The algorithm first constructs a vocabulary from the corpus - * and then learns vector representation of words in the vocabulary. - * The vector representation can be used as features in + * and then learns vector representation of words in the vocabulary. + * The vector representation can be used as features in * natural language processing and machine learning algorithms. - * - * We used skip-gram model in our implementation and hierarchical softmax + * + * We used skip-gram model in our implementation and hierarchical softmax * method to train the model. The variable names in the implementation * matches the original C implementation. * - * For original C implementation, see https://code.google.com/p/word2vec/ - * For research papers, see + * For original C implementation, see https://code.google.com/p/word2vec/ + * For research papers, see * Efficient Estimation of Word Representations in Vector Space - * and + * and * Distributed Representations of Words and Phrases and their Compositionality. */ @Experimental @@ -79,7 +79,7 @@ class Word2Vec extends Serializable with Logging { private var numIterations = 1 private var seed = Utils.random.nextLong() private var minCount = 5 - + /** * Sets vector size (default: 100). */ @@ -122,15 +122,15 @@ class Word2Vec extends Serializable with Logging { this } - /** - * Sets minCount, the minimum number of times a token must appear to be included in the word2vec + /** + * Sets minCount, the minimum number of times a token must appear to be included in the word2vec * model's vocabulary (default: 5). */ def setMinCount(minCount: Int): this.type = { this.minCount = minCount this } - + private val EXP_TABLE_SIZE = 1000 private val MAX_EXP = 6 private val MAX_CODE_LENGTH = 40 @@ -150,14 +150,17 @@ class Word2Vec extends Serializable with Logging { .map(x => VocabWord( x._1, x._2, - new Array[Int](MAX_CODE_LENGTH), - new Array[Int](MAX_CODE_LENGTH), + new Array[Int](MAX_CODE_LENGTH), + new Array[Int](MAX_CODE_LENGTH), 0)) .filter(_.cn >= minCount) .collect() .sortWith((a, b) => a.cn > b.cn) - + vocabSize = vocab.length + require(vocabSize > 0, "The vocabulary size should be > 0. You may need to check " + + "the setting of minCount, which could be large enough to remove all your words in sentences.") + var a = 0 while (a < vocabSize) { vocabHash += vocab(a).word -> a @@ -195,8 +198,8 @@ class Word2Vec extends Serializable with Logging { } var pos1 = vocabSize - 1 var pos2 = vocabSize - - var min1i = 0 + + var min1i = 0 var min2i = 0 a = 0 @@ -265,15 +268,15 @@ class Word2Vec extends Serializable with Logging { val words = dataset.flatMap(x => x) learnVocab(words) - + createBinaryTree() - + val sc = dataset.context val expTable = sc.broadcast(createExpTable()) val bcVocab = sc.broadcast(vocab) val bcVocabHash = sc.broadcast(vocabHash) - + val sentences: RDD[Array[Int]] = words.mapPartitions { iter => new Iterator[Array[Int]] { def hasNext: Boolean = iter.hasNext @@ -294,7 +297,7 @@ class Word2Vec extends Serializable with Logging { } } } - + val newSentences = sentences.repartition(numPartitions).cache() val initRandom = new XORShiftRandom(seed) @@ -399,7 +402,7 @@ class Word2Vec extends Serializable with Logging { } } newSentences.unpersist() - + val word2VecMap = mutable.HashMap.empty[String, Array[Float]] var i = 0 while (i < vocabSize) { @@ -428,7 +431,7 @@ class Word2Vec extends Serializable with Logging { * Word2Vec model */ @Experimental -class Word2VecModel private[mllib] ( +class Word2VecModel private[spark] ( model: Map[String, Array[Float]]) extends Serializable with Saveable { // wordList: Ordered list of words obtained from model. @@ -466,7 +469,7 @@ class Word2VecModel private[mllib] ( val norm1 = blas.snrm2(n, v1, 1) val norm2 = blas.snrm2(n, v2, 1) if (norm1 == 0 || norm2 == 0) return 0.0 - blas.sdot(n, v1, 1, v2,1) / norm1 / norm2 + blas.sdot(n, v1, 1, v2, 1) / norm1 / norm2 } override protected def formatVersion = "1.0" @@ -477,7 +480,7 @@ class Word2VecModel private[mllib] ( /** * Transforms a word to its vector representation - * @param word a word + * @param word a word * @return vector representation of word */ def transform(word: String): Vector = { @@ -492,18 +495,18 @@ class Word2VecModel private[mllib] ( /** * Find synonyms of a word * @param word a word - * @param num number of synonyms to find + * @param num number of synonyms to find * @return array of (word, cosineSimilarity) */ def findSynonyms(word: String, num: Int): Array[(String, Double)] = { val vector = transform(word) - findSynonyms(vector,num) + findSynonyms(vector, num) } /** * Find synonyms of the vector representation of a word * @param vector vector representation of a word - * @param num number of synonyms to find + * @param num number of synonyms to find * @return array of (word, cosineSimilarity) */ def findSynonyms(vector: Vector, num: Int): Array[(String, Double)] = { @@ -556,7 +559,7 @@ object Word2VecModel extends Loader[Word2VecModel] { def load(sc: SparkContext, path: String): Word2VecModel = { val dataPath = Loader.dataPath(path) val sqlContext = new SQLContext(sc) - val dataFrame = sqlContext.parquetFile(dataPath) + val dataFrame = sqlContext.read.parquet(dataPath) val dataArray = dataFrame.select("word", "vector").collect() @@ -580,7 +583,7 @@ object Word2VecModel extends Loader[Word2VecModel] { sc.parallelize(Seq(metadata), 1).saveAsTextFile(Loader.metadataPath(path)) val dataArray = model.toSeq.map { case (w, v) => Data(w, v) } - sc.parallelize(dataArray.toSeq, 1).toDF().saveAsParquetFile(Loader.dataPath(path)) + sc.parallelize(dataArray.toSeq, 1).toDF().write.parquet(Loader.dataPath(path)) } } diff --git a/mllib/src/main/scala/org/apache/spark/mllib/linalg/BLAS.scala b/mllib/src/main/scala/org/apache/spark/mllib/linalg/BLAS.scala index 87052e1ba853..3523f1804325 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/linalg/BLAS.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/linalg/BLAS.scala @@ -213,9 +213,9 @@ private[spark] object BLAS extends Serializable with Logging { def scal(a: Double, x: Vector): Unit = { x match { case sx: SparseVector => - f2jBLAS.dscal(sx.values.size, a, sx.values, 1) + f2jBLAS.dscal(sx.values.length, a, sx.values, 1) case dx: DenseVector => - f2jBLAS.dscal(dx.values.size, a, dx.values, 1) + f2jBLAS.dscal(dx.values.length, a, dx.values, 1) case _ => throw new IllegalArgumentException(s"scal doesn't support vector type ${x.getClass}.") } @@ -228,7 +228,7 @@ private[spark] object BLAS extends Serializable with Logging { } _nativeBLAS } - + /** * A := alpha * x * x^T^ + A * @param alpha a real scalar that will be multiplied to x * x^T^. @@ -264,7 +264,7 @@ private[spark] object BLAS extends Serializable with Logging { j += 1 } i += 1 - } + } } private def syr(alpha: Double, x: SparseVector, A: DenseMatrix) { @@ -463,7 +463,7 @@ private[spark] object BLAS extends Serializable with Logging { def gemv( alpha: Double, A: Matrix, - x: DenseVector, + x: Vector, beta: Double, y: DenseVector): Unit = { require(A.numCols == x.size, @@ -473,27 +473,32 @@ private[spark] object BLAS extends Serializable with Logging { if (alpha == 0.0) { logDebug("gemv: alpha is equal to 0. Returning y.") } else { - A match { - case sparse: SparseMatrix => - gemv(alpha, sparse, x, beta, y) - case dense: DenseMatrix => - gemv(alpha, dense, x, beta, y) + (A, x) match { + case (smA: SparseMatrix, dvx: DenseVector) => + gemv(alpha, smA, dvx, beta, y) + case (smA: SparseMatrix, svx: SparseVector) => + gemv(alpha, smA, svx, beta, y) + case (dmA: DenseMatrix, dvx: DenseVector) => + gemv(alpha, dmA, dvx, beta, y) + case (dmA: DenseMatrix, svx: SparseVector) => + gemv(alpha, dmA, svx, beta, y) case _ => - throw new IllegalArgumentException(s"gemv doesn't support matrix type ${A.getClass}.") + throw new IllegalArgumentException(s"gemv doesn't support running on matrix type " + + s"${A.getClass} and vector type ${x.getClass}.") } } } /** * y := alpha * A * x + beta * y - * For `DenseMatrix` A. + * For `DenseMatrix` A and `DenseVector` x. */ private def gemv( alpha: Double, A: DenseMatrix, x: DenseVector, beta: Double, - y: DenseVector): Unit = { + y: DenseVector): Unit = { val tStrA = if (A.isTransposed) "T" else "N" val mA = if (!A.isTransposed) A.numRows else A.numCols val nA = if (!A.isTransposed) A.numCols else A.numRows @@ -503,14 +508,134 @@ private[spark] object BLAS extends Serializable with Logging { /** * y := alpha * A * x + beta * y - * For `SparseMatrix` A. + * For `DenseMatrix` A and `SparseVector` x. + */ + private def gemv( + alpha: Double, + A: DenseMatrix, + x: SparseVector, + beta: Double, + y: DenseVector): Unit = { + val mA: Int = A.numRows + val nA: Int = A.numCols + + val Avals = A.values + + val xIndices = x.indices + val xNnz = xIndices.length + val xValues = x.values + val yValues = y.values + + if (alpha == 0.0) { + scal(beta, y) + return + } + + if (A.isTransposed) { + var rowCounterForA = 0 + while (rowCounterForA < mA) { + var sum = 0.0 + var k = 0 + while (k < xNnz) { + sum += xValues(k) * Avals(xIndices(k) + rowCounterForA * nA) + k += 1 + } + yValues(rowCounterForA) = sum * alpha + beta * yValues(rowCounterForA) + rowCounterForA += 1 + } + } else { + var rowCounterForA = 0 + while (rowCounterForA < mA) { + var sum = 0.0 + var k = 0 + while (k < xNnz) { + sum += xValues(k) * Avals(xIndices(k) * mA + rowCounterForA) + k += 1 + } + yValues(rowCounterForA) = sum * alpha + beta * yValues(rowCounterForA) + rowCounterForA += 1 + } + } + } + + /** + * y := alpha * A * x + beta * y + * For `SparseMatrix` A and `SparseVector` x. + */ + private def gemv( + alpha: Double, + A: SparseMatrix, + x: SparseVector, + beta: Double, + y: DenseVector): Unit = { + val xValues = x.values + val xIndices = x.indices + val xNnz = xIndices.length + + val yValues = y.values + + val mA: Int = A.numRows + val nA: Int = A.numCols + + val Avals = A.values + val Arows = if (!A.isTransposed) A.rowIndices else A.colPtrs + val Acols = if (!A.isTransposed) A.colPtrs else A.rowIndices + + if (alpha == 0.0) { + scal(beta, y) + return + } + + if (A.isTransposed) { + var rowCounter = 0 + while (rowCounter < mA) { + var i = Arows(rowCounter) + val indEnd = Arows(rowCounter + 1) + var sum = 0.0 + var k = 0 + while (k < xNnz && i < indEnd) { + if (xIndices(k) == Acols(i)) { + sum += Avals(i) * xValues(k) + i += 1 + } + k += 1 + } + yValues(rowCounter) = sum * alpha + beta * yValues(rowCounter) + rowCounter += 1 + } + } else { + scal(beta, y) + + var colCounterForA = 0 + var k = 0 + while (colCounterForA < nA && k < xNnz) { + if (xIndices(k) == colCounterForA) { + var i = Acols(colCounterForA) + val indEnd = Acols(colCounterForA + 1) + + val xTemp = xValues(k) * alpha + while (i < indEnd) { + val rowIndex = Arows(i) + yValues(Arows(i)) += Avals(i) * xTemp + i += 1 + } + k += 1 + } + colCounterForA += 1 + } + } + } + + /** + * y := alpha * A * x + beta * y + * For `SparseMatrix` A and `DenseVector` x. */ private def gemv( alpha: Double, A: SparseMatrix, x: DenseVector, beta: Double, - y: DenseVector): Unit = { + y: DenseVector): Unit = { val xValues = x.values val yValues = y.values val mA: Int = A.numRows @@ -534,10 +659,7 @@ private[spark] object BLAS extends Serializable with Logging { rowCounter += 1 } } else { - // Scale vector first if `beta` is not equal to 0.0 - if (beta != 0.0) { - scal(beta, y) - } + scal(beta, y) // Perform matrix-vector multiplication and add to y var colCounterForA = 0 while (colCounterForA < nA) { diff --git a/mllib/src/main/scala/org/apache/spark/mllib/linalg/EigenValueDecomposition.scala b/mllib/src/main/scala/org/apache/spark/mllib/linalg/EigenValueDecomposition.scala index 866936aa4f11..ae3ba3099c87 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/linalg/EigenValueDecomposition.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/linalg/EigenValueDecomposition.scala @@ -81,7 +81,7 @@ private[mllib] object EigenValueDecomposition { require(n * ncv.toLong <= Integer.MAX_VALUE && ncv * (ncv.toLong + 8) <= Integer.MAX_VALUE, s"k = $k and/or n = $n are too large to compute an eigendecomposition") - + var ido = new intW(0) var info = new intW(0) var resid = new Array[Double](n) diff --git a/mllib/src/main/scala/org/apache/spark/mllib/linalg/Matrices.scala b/mllib/src/main/scala/org/apache/spark/mllib/linalg/Matrices.scala index a609674df6b8..75e7004464af 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/linalg/Matrices.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/linalg/Matrices.scala @@ -77,8 +77,13 @@ sealed trait Matrix extends Serializable { C } - /** Convenience method for `Matrix`-`DenseVector` multiplication. */ + /** Convenience method for `Matrix`-`DenseVector` multiplication. For binary compatibility. */ def multiply(y: DenseVector): DenseVector = { + multiply(y.asInstanceOf[Vector]) + } + + /** Convenience method for `Matrix`-`Vector` multiplication. */ + def multiply(y: Vector): DenseVector = { val output = new DenseVector(new Array[Double](numRows)) BLAS.gemv(1.0, this, y, 0.0, output) output @@ -109,6 +114,16 @@ sealed trait Matrix extends Serializable { * corresponding value in the matrix with type `Double`. */ private[spark] def foreachActive(f: (Int, Int, Double) => Unit) + + /** + * Find the number of non-zero active values. + */ + def numNonzeros: Int + + /** + * Find the number of values stored explicitly. These values can be zero as well. + */ + def numActives: Int } @DeveloperApi @@ -188,10 +203,13 @@ private[spark] class MatrixUDT extends UserDefinedType[Matrix] { } } - override def hashCode(): Int = 1994 + // see [SPARK-8647], this achieves the needed constant hash code without constant no. + override def hashCode(): Int = classOf[MatrixUDT].getName.hashCode() override def typeName: String = "matrix" + override def pyUDT: String = "pyspark.mllib.linalg.MatrixUDT" + private[spark] override def asNullable: MatrixUDT = this } @@ -316,6 +334,10 @@ class DenseMatrix( } } + override def numNonzeros: Int = values.count(_ != 0) + + override def numActives: Int = values.length + /** * Generate a `SparseMatrix` from the given `DenseMatrix`. The new matrix will have isTransposed * set to false. @@ -585,6 +607,11 @@ class SparseMatrix( def toDense: DenseMatrix = { new DenseMatrix(numRows, numCols, toArray) } + + override def numNonzeros: Int = values.count(_ != 0) + + override def numActives: Int = values.length + } /** diff --git a/mllib/src/main/scala/org/apache/spark/mllib/linalg/Vectors.scala b/mllib/src/main/scala/org/apache/spark/mllib/linalg/Vectors.scala index f6bcdf83cd33..c9c27425d287 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/linalg/Vectors.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/linalg/Vectors.scala @@ -176,27 +176,31 @@ private[spark] class VectorUDT extends UserDefinedType[Vector] { } override def serialize(obj: Any): Row = { - val row = new GenericMutableRow(4) obj match { case SparseVector(size, indices, values) => + val row = new GenericMutableRow(4) row.setByte(0, 0) row.setInt(1, size) row.update(2, indices.toSeq) row.update(3, values.toSeq) + row case DenseVector(values) => + val row = new GenericMutableRow(4) row.setByte(0, 1) row.setNullAt(1) row.setNullAt(2) row.update(3, values.toSeq) + row + // TODO: There are bugs in UDT serialization because we don't have a clear separation between + // TODO: internal SQL types and language specific types (including UDT). UDT serialize and + // TODO: deserialize may get called twice. See SPARK-7186. + case row: Row => + row } - row } override def deserialize(datum: Any): Vector = { datum match { - // TODO: something wrong with UDT serialization - case v: Vector => - v case row: Row => require(row.length == 4, s"VectorUDT.deserialize given row with length ${row.length} but requires length == 4") @@ -211,6 +215,11 @@ private[spark] class VectorUDT extends UserDefinedType[Vector] { val values = row.getAs[Iterable[Double]](3).toArray new DenseVector(values) } + // TODO: There are bugs in UDT serialization because we don't have a clear separation between + // TODO: internal SQL types and language specific types (including UDT). UDT serialize and + // TODO: deserialize may get called twice. See SPARK-7186. + case v: Vector => + v } } @@ -225,7 +234,8 @@ private[spark] class VectorUDT extends UserDefinedType[Vector] { } } - override def hashCode: Int = 7919 + // see [SPARK-8647], this achieves the needed constant hash code without constant no. + override def hashCode(): Int = classOf[VectorUDT].getName.hashCode() override def typeName: String = "vector" diff --git a/mllib/src/main/scala/org/apache/spark/mllib/linalg/distributed/IndexedRowMatrix.scala b/mllib/src/main/scala/org/apache/spark/mllib/linalg/distributed/IndexedRowMatrix.scala index 3be530fa0753..1c33b43ea7a8 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/linalg/distributed/IndexedRowMatrix.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/linalg/distributed/IndexedRowMatrix.scala @@ -146,7 +146,7 @@ class IndexedRowMatrix( val indexedRows = indices.zip(svd.U.rows).map { case (i, v) => IndexedRow(i, v) } - new IndexedRowMatrix(indexedRows, nRows, nCols) + new IndexedRowMatrix(indexedRows, nRows, svd.U.numCols().toInt) } else { null } diff --git a/mllib/src/main/scala/org/apache/spark/mllib/linalg/distributed/RowMatrix.scala b/mllib/src/main/scala/org/apache/spark/mllib/linalg/distributed/RowMatrix.scala index 9a89a6f3a515..1626da9c3d2e 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/linalg/distributed/RowMatrix.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/linalg/distributed/RowMatrix.scala @@ -219,7 +219,7 @@ class RowMatrix( val computeMode = mode match { case "auto" => - if(k > 5000) { + if (k > 5000) { logWarning(s"computing svd with k=$k and n=$n, please check necessity") } diff --git a/mllib/src/main/scala/org/apache/spark/mllib/optimization/GradientDescent.scala b/mllib/src/main/scala/org/apache/spark/mllib/optimization/GradientDescent.scala index 4b7d0589c973..ab7611fd077e 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/optimization/GradientDescent.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/optimization/GradientDescent.scala @@ -19,13 +19,14 @@ package org.apache.spark.mllib.optimization import scala.collection.mutable.ArrayBuffer -import breeze.linalg.{DenseVector => BDV} +import breeze.linalg.{DenseVector => BDV, norm} import org.apache.spark.annotation.{Experimental, DeveloperApi} import org.apache.spark.Logging import org.apache.spark.rdd.RDD import org.apache.spark.mllib.linalg.{Vectors, Vector} + /** * Class used to solve an optimization problem using Gradient Descent. * @param gradient Gradient function to be used. @@ -38,6 +39,7 @@ class GradientDescent private[mllib] (private var gradient: Gradient, private va private var numIterations: Int = 100 private var regParam: Double = 0.0 private var miniBatchFraction: Double = 1.0 + private var convergenceTol: Double = 0.001 /** * Set the initial step size of SGD for the first step. Default 1.0. @@ -75,6 +77,23 @@ class GradientDescent private[mllib] (private var gradient: Gradient, private va this } + /** + * Set the convergence tolerance. Default 0.001 + * convergenceTol is a condition which decides iteration termination. + * The end of iteration is decided based on below logic. + * - If the norm of the new solution vector is >1, the diff of solution vectors + * is compared to relative tolerance which means normalizing by the norm of + * the new solution vector. + * - If the norm of the new solution vector is <=1, the diff of solution vectors + * is compared to absolute tolerance which is not normalizing. + * Must be between 0.0 and 1.0 inclusively. + */ + def setConvergenceTol(tolerance: Double): this.type = { + require(0.0 <= tolerance && tolerance <= 1.0) + this.convergenceTol = tolerance + this + } + /** * Set the gradient function (of the loss function of one single data example) * to be used for SGD. @@ -112,7 +131,8 @@ class GradientDescent private[mllib] (private var gradient: Gradient, private va numIterations, regParam, miniBatchFraction, - initialWeights) + initialWeights, + convergenceTol) weights } @@ -131,17 +151,20 @@ object GradientDescent extends Logging { * Sampling, and averaging the subgradients over this subset is performed using one standard * spark map-reduce in each iteration. * - * @param data - Input data for SGD. RDD of the set of data examples, each of - * the form (label, [feature values]). - * @param gradient - Gradient object (used to compute the gradient of the loss function of - * one single data example) - * @param updater - Updater function to actually perform a gradient step in a given direction. - * @param stepSize - initial step size for the first step - * @param numIterations - number of iterations that SGD should be run. - * @param regParam - regularization parameter - * @param miniBatchFraction - fraction of the input data set that should be used for - * one iteration of SGD. Default value 1.0. - * + * @param data Input data for SGD. RDD of the set of data examples, each of + * the form (label, [feature values]). + * @param gradient Gradient object (used to compute the gradient of the loss function of + * one single data example) + * @param updater Updater function to actually perform a gradient step in a given direction. + * @param stepSize initial step size for the first step + * @param numIterations number of iterations that SGD should be run. + * @param regParam regularization parameter + * @param miniBatchFraction fraction of the input data set that should be used for + * one iteration of SGD. Default value 1.0. + * @param convergenceTol Minibatch iteration will end before numIterations if the relative + * difference between the current weight and the previous weight is less + * than this value. In measuring convergence, L2 norm is calculated. + * Default value 0.001. Must be between 0.0 and 1.0 inclusively. * @return A tuple containing two elements. The first element is a column matrix containing * weights for every feature, and the second element is an array containing the * stochastic loss computed for every iteration. @@ -154,9 +177,20 @@ object GradientDescent extends Logging { numIterations: Int, regParam: Double, miniBatchFraction: Double, - initialWeights: Vector): (Vector, Array[Double]) = { + initialWeights: Vector, + convergenceTol: Double): (Vector, Array[Double]) = { + + // convergenceTol should be set with non minibatch settings + if (miniBatchFraction < 1.0 && convergenceTol > 0.0) { + logWarning("Testing against a convergenceTol when using miniBatchFraction " + + "< 1.0 can be unstable because of the stochasticity in sampling.") + } val stochasticLossHistory = new ArrayBuffer[Double](numIterations) + // Record previous weight and current one to calculate solution vector difference + + var previousWeights: Option[Vector] = None + var currentWeights: Option[Vector] = None val numExamples = data.count() @@ -179,9 +213,11 @@ object GradientDescent extends Logging { * if it's L2 updater; for L1 updater, the same logic is followed. */ var regVal = updater.compute( - weights, Vectors.dense(new Array[Double](weights.size)), 0, 1, regParam)._2 + weights, Vectors.zeros(weights.size), 0, 1, regParam)._2 - for (i <- 1 to numIterations) { + var converged = false // indicates whether converged based on convergenceTol + var i = 1 + while (!converged && i <= numIterations) { val bcWeights = data.context.broadcast(weights) // Sample a subset (fraction miniBatchFraction) of the total data // compute and sum up the subgradients on this subset (this is one map-reduce) @@ -204,12 +240,21 @@ object GradientDescent extends Logging { */ stochasticLossHistory.append(lossSum / miniBatchSize + regVal) val update = updater.compute( - weights, Vectors.fromBreeze(gradientSum / miniBatchSize.toDouble), stepSize, i, regParam) + weights, Vectors.fromBreeze(gradientSum / miniBatchSize.toDouble), + stepSize, i, regParam) weights = update._1 regVal = update._2 + + previousWeights = currentWeights + currentWeights = Some(weights) + if (previousWeights != None && currentWeights != None) { + converged = isConverged(previousWeights.get, + currentWeights.get, convergenceTol) + } } else { logWarning(s"Iteration ($i/$numIterations). The size of sampled batch is zero") } + i += 1 } logInfo("GradientDescent.runMiniBatchSGD finished. Last 10 stochastic losses %s".format( @@ -218,4 +263,32 @@ object GradientDescent extends Logging { (weights, stochasticLossHistory.toArray) } + + def runMiniBatchSGD( + data: RDD[(Double, Vector)], + gradient: Gradient, + updater: Updater, + stepSize: Double, + numIterations: Int, + regParam: Double, + miniBatchFraction: Double, + initialWeights: Vector): (Vector, Array[Double]) = + GradientDescent.runMiniBatchSGD(data, gradient, updater, stepSize, numIterations, + regParam, miniBatchFraction, initialWeights, 0.001) + + + private def isConverged( + previousWeights: Vector, + currentWeights: Vector, + convergenceTol: Double): Boolean = { + // To compare with convergence tolerance. + val previousBDV = previousWeights.toBreeze.toDenseVector + val currentBDV = currentWeights.toBreeze.toDenseVector + + // This represents the difference of updated weights in the iteration. + val solutionVecDiff: Double = norm(previousBDV - currentBDV) + + solutionVecDiff < convergenceTol * Math.max(norm(currentBDV), 1.0) + } + } diff --git a/mllib/src/main/scala/org/apache/spark/mllib/pmml/PMMLExportable.scala b/mllib/src/main/scala/org/apache/spark/mllib/pmml/PMMLExportable.scala index 354e90f3eeaa..5e882d4ebb10 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/pmml/PMMLExportable.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/pmml/PMMLExportable.scala @@ -23,13 +23,16 @@ import javax.xml.transform.stream.StreamResult import org.jpmml.model.JAXBUtil import org.apache.spark.SparkContext +import org.apache.spark.annotation.{DeveloperApi, Experimental} import org.apache.spark.mllib.pmml.export.PMMLModelExportFactory /** + * :: DeveloperApi :: * Export model to the PMML format * Predictive Model Markup Language (PMML) is an XML-based file format * developed by the Data Mining Group (www.dmg.org). */ +@DeveloperApi trait PMMLExportable { /** @@ -41,30 +44,38 @@ trait PMMLExportable { } /** + * :: Experimental :: * Export the model to a local file in PMML format */ + @Experimental def toPMML(localPath: String): Unit = { toPMML(new StreamResult(new File(localPath))) } /** + * :: Experimental :: * Export the model to a directory on a distributed file system in PMML format */ + @Experimental def toPMML(sc: SparkContext, path: String): Unit = { val pmml = toPMML() sc.parallelize(Array(pmml), 1).saveAsTextFile(path) } /** + * :: Experimental :: * Export the model to the OutputStream in PMML format */ + @Experimental def toPMML(outputStream: OutputStream): Unit = { toPMML(new StreamResult(outputStream)) } /** + * :: Experimental :: * Export the model to a String in PMML format */ + @Experimental def toPMML(): String = { val writer = new StringWriter toPMML(new StreamResult(writer)) diff --git a/mllib/src/main/scala/org/apache/spark/mllib/pmml/export/BinaryClassificationPMMLModelExport.scala b/mllib/src/main/scala/org/apache/spark/mllib/pmml/export/BinaryClassificationPMMLModelExport.scala index 34b447584e52..622b53a252ac 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/pmml/export/BinaryClassificationPMMLModelExport.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/pmml/export/BinaryClassificationPMMLModelExport.scala @@ -27,10 +27,10 @@ import org.apache.spark.mllib.regression.GeneralizedLinearModel * PMML Model Export for GeneralizedLinearModel class with binary ClassificationModel */ private[mllib] class BinaryClassificationPMMLModelExport( - model : GeneralizedLinearModel, + model : GeneralizedLinearModel, description : String, normalizationMethod : RegressionNormalizationMethodType, - threshold: Double) + threshold: Double) extends PMMLModelExport { populateBinaryClassificationPMML() @@ -72,7 +72,7 @@ private[mllib] class BinaryClassificationPMMLModelExport( .withUsageType(FieldUsageType.ACTIVE)) regressionTableYES.withNumericPredictors(new NumericPredictor(fields(i), model.weights(i))) } - + // add target field val targetField = FieldName.create("target") dataDictionary @@ -80,9 +80,9 @@ private[mllib] class BinaryClassificationPMMLModelExport( miningSchema .withMiningFields(new MiningField(targetField) .withUsageType(FieldUsageType.TARGET)) - + dataDictionary.withNumberOfFields(dataDictionary.getDataFields.size) - + pmml.setDataDictionary(dataDictionary) pmml.withModels(regressionModel) } diff --git a/mllib/src/main/scala/org/apache/spark/mllib/pmml/export/PMMLModelExport.scala b/mllib/src/main/scala/org/apache/spark/mllib/pmml/export/PMMLModelExport.scala index ebdeae50bb32..c5fdecd3ca17 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/pmml/export/PMMLModelExport.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/pmml/export/PMMLModelExport.scala @@ -25,7 +25,7 @@ import scala.beans.BeanProperty import org.dmg.pmml.{Application, Header, PMML, Timestamp} private[mllib] trait PMMLModelExport { - + /** * Holder of the exported model in PMML format */ @@ -33,7 +33,7 @@ private[mllib] trait PMMLModelExport { val pmml: PMML = new PMML setHeader(pmml) - + private def setHeader(pmml: PMML): Unit = { val version = getClass.getPackage.getImplementationVersion val app = new Application().withName("Apache Spark MLlib").withVersion(version) diff --git a/mllib/src/main/scala/org/apache/spark/mllib/pmml/export/PMMLModelExportFactory.scala b/mllib/src/main/scala/org/apache/spark/mllib/pmml/export/PMMLModelExportFactory.scala index c16e83d6a067..29bd689e1185 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/pmml/export/PMMLModelExportFactory.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/pmml/export/PMMLModelExportFactory.scala @@ -27,9 +27,9 @@ import org.apache.spark.mllib.regression.LinearRegressionModel import org.apache.spark.mllib.regression.RidgeRegressionModel private[mllib] object PMMLModelExportFactory { - + /** - * Factory object to help creating the necessary PMMLModelExport implementation + * Factory object to help creating the necessary PMMLModelExport implementation * taking as input the machine learning model (for example KMeansModel). */ def createPMMLModelExport(model: Any): PMMLModelExport = { @@ -44,7 +44,7 @@ private[mllib] object PMMLModelExportFactory { new GeneralizedLinearPMMLModelExport(lasso, "lasso regression") case svm: SVMModel => new BinaryClassificationPMMLModelExport( - svm, "linear SVM", RegressionNormalizationMethodType.NONE, + svm, "linear SVM", RegressionNormalizationMethodType.NONE, svm.getThreshold.getOrElse(0.0)) case logistic: LogisticRegressionModel => if (logistic.numClasses == 2) { @@ -60,5 +60,5 @@ private[mllib] object PMMLModelExportFactory { "PMML Export not supported for model: " + model.getClass.getName) } } - + } diff --git a/mllib/src/main/scala/org/apache/spark/mllib/random/RandomRDDs.scala b/mllib/src/main/scala/org/apache/spark/mllib/random/RandomRDDs.scala index 8341bb86afd7..174d5e0f6c9f 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/random/RandomRDDs.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/random/RandomRDDs.scala @@ -52,7 +52,7 @@ object RandomRDDs { numPartitions: Int = 0, seed: Long = Utils.random.nextLong()): RDD[Double] = { val uniform = new UniformGenerator() - randomRDD(sc, uniform, size, numPartitionsOrDefault(sc, numPartitions), seed) + randomRDD(sc, uniform, size, numPartitionsOrDefault(sc, numPartitions), seed) } /** @@ -234,7 +234,7 @@ object RandomRDDs { * * @param sc SparkContext used to create the RDD. * @param shape shape parameter (> 0) for the gamma distribution - * @param scale scale parameter (> 0) for the gamma distribution + * @param scale scale parameter (> 0) for the gamma distribution * @param size Size of the RDD. * @param numPartitions Number of partitions in the RDD (default: `sc.defaultParallelism`). * @param seed Random seed (default: a random long integer). @@ -293,7 +293,7 @@ object RandomRDDs { * * @param sc SparkContext used to create the RDD. * @param mean mean for the log normal distribution - * @param std standard deviation for the log normal distribution + * @param std standard deviation for the log normal distribution * @param size Size of the RDD. * @param numPartitions Number of partitions in the RDD (default: `sc.defaultParallelism`). * @param seed Random seed (default: a random long integer). @@ -671,7 +671,7 @@ object RandomRDDs { * * @param sc SparkContext used to create the RDD. * @param shape shape parameter (> 0) for the gamma distribution. - * @param scale scale parameter (> 0) for the gamma distribution. + * @param scale scale parameter (> 0) for the gamma distribution. * @param numRows Number of Vectors in the RDD. * @param numCols Number of elements in each Vector. * @param numPartitions Number of partitions in the RDD (default: `sc.defaultParallelism`) diff --git a/mllib/src/main/scala/org/apache/spark/mllib/recommendation/ALS.scala b/mllib/src/main/scala/org/apache/spark/mllib/recommendation/ALS.scala index dddefe1944e9..93290e650852 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/recommendation/ALS.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/recommendation/ALS.scala @@ -175,7 +175,7 @@ class ALS private ( /** * :: DeveloperApi :: * Sets storage level for final RDDs (user/product used in MatrixFactorizationModel). The default - * value is `MEMORY_AND_DISK`. Users can change it to a serialized storage, e.g. + * value is `MEMORY_AND_DISK`. Users can change it to a serialized storage, e.g. * `MEMORY_AND_DISK_SER` and set `spark.rdd.compress` to `true` to reduce the space requirement, * at the cost of speed. */ diff --git a/mllib/src/main/scala/org/apache/spark/mllib/recommendation/MatrixFactorizationModel.scala b/mllib/src/main/scala/org/apache/spark/mllib/recommendation/MatrixFactorizationModel.scala index 88c214840331..43d219a49cf4 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/recommendation/MatrixFactorizationModel.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/recommendation/MatrixFactorizationModel.scala @@ -22,6 +22,7 @@ import java.lang.{Integer => JavaInteger} import scala.collection.mutable +import com.clearspring.analytics.stream.cardinality.HyperLogLogPlus import com.github.fommil.netlib.BLAS.{getInstance => blas} import org.apache.hadoop.fs.Path import org.json4s._ @@ -79,6 +80,30 @@ class MatrixFactorizationModel( blas.ddot(rank, userVector, 1, productVector, 1) } + /** + * Return approximate numbers of users and products in the given usersProducts tuples. + * This method is based on `countApproxDistinct` in class `RDD`. + * + * @param usersProducts RDD of (user, product) pairs. + * @return approximate numbers of users and products. + */ + private[this] def countApproxDistinctUserProduct(usersProducts: RDD[(Int, Int)]): (Long, Long) = { + val zeroCounterUser = new HyperLogLogPlus(4, 0) + val zeroCounterProduct = new HyperLogLogPlus(4, 0) + val aggregated = usersProducts.aggregate((zeroCounterUser, zeroCounterProduct))( + (hllTuple: (HyperLogLogPlus, HyperLogLogPlus), v: (Int, Int)) => { + hllTuple._1.offer(v._1) + hllTuple._2.offer(v._2) + hllTuple + }, + (h1: (HyperLogLogPlus, HyperLogLogPlus), h2: (HyperLogLogPlus, HyperLogLogPlus)) => { + h1._1.addAll(h2._1) + h1._2.addAll(h2._2) + h1 + }) + (aggregated._1.cardinality(), aggregated._2.cardinality()) + } + /** * Predict the rating of many users for many products. * The output RDD has an element per each element in the input RDD (including all duplicates) @@ -88,12 +113,30 @@ class MatrixFactorizationModel( * @return RDD of Ratings. */ def predict(usersProducts: RDD[(Int, Int)]): RDD[Rating] = { - val users = userFeatures.join(usersProducts).map { - case (user, (uFeatures, product)) => (product, (user, uFeatures)) - } - users.join(productFeatures).map { - case (product, ((user, uFeatures), pFeatures)) => - Rating(user, product, blas.ddot(uFeatures.length, uFeatures, 1, pFeatures, 1)) + // Previously the partitions of ratings are only based on the given products. + // So if the usersProducts given for prediction contains only few products or + // even one product, the generated ratings will be pushed into few or single partition + // and can't use high parallelism. + // Here we calculate approximate numbers of users and products. Then we decide the + // partitions should be based on users or products. + val (usersCount, productsCount) = countApproxDistinctUserProduct(usersProducts) + + if (usersCount < productsCount) { + val users = userFeatures.join(usersProducts).map { + case (user, (uFeatures, product)) => (product, (user, uFeatures)) + } + users.join(productFeatures).map { + case (product, ((user, uFeatures), pFeatures)) => + Rating(user, product, blas.ddot(uFeatures.length, uFeatures, 1, pFeatures, 1)) + } + } else { + val products = productFeatures.join(usersProducts.map(_.swap)).map { + case (product, (pFeatures, user)) => (user, (product, pFeatures)) + } + products.join(userFeatures).map { + case (user, ((product, pFeatures), uFeatures)) => + Rating(user, product, blas.ddot(uFeatures.length, uFeatures, 1, pFeatures, 1)) + } } } @@ -281,8 +324,8 @@ object MatrixFactorizationModel extends Loader[MatrixFactorizationModel] { val metadata = compact(render( ("class" -> thisClassName) ~ ("version" -> thisFormatVersion) ~ ("rank" -> model.rank))) sc.parallelize(Seq(metadata), 1).saveAsTextFile(metadataPath(path)) - model.userFeatures.toDF("id", "features").saveAsParquetFile(userPath(path)) - model.productFeatures.toDF("id", "features").saveAsParquetFile(productPath(path)) + model.userFeatures.toDF("id", "features").write.parquet(userPath(path)) + model.productFeatures.toDF("id", "features").write.parquet(productPath(path)) } def load(sc: SparkContext, path: String): MatrixFactorizationModel = { @@ -292,11 +335,11 @@ object MatrixFactorizationModel extends Loader[MatrixFactorizationModel] { assert(className == thisClassName) assert(formatVersion == thisFormatVersion) val rank = (metadata \ "rank").extract[Int] - val userFeatures = sqlContext.parquetFile(userPath(path)) + val userFeatures = sqlContext.read.parquet(userPath(path)) .map { case Row(id: Int, features: Seq[_]) => (id, features.asInstanceOf[Seq[Double]].toArray) } - val productFeatures = sqlContext.parquetFile(productPath(path)) + val productFeatures = sqlContext.read.parquet(productPath(path)) .map { case Row(id: Int, features: Seq[_]) => (id, features.asInstanceOf[Seq[Double]].toArray) } diff --git a/mllib/src/main/scala/org/apache/spark/mllib/regression/GeneralizedLinearAlgorithm.scala b/mllib/src/main/scala/org/apache/spark/mllib/regression/GeneralizedLinearAlgorithm.scala index 26be30ff9d6f..6709bd79bc82 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/regression/GeneralizedLinearAlgorithm.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/regression/GeneralizedLinearAlgorithm.scala @@ -195,11 +195,11 @@ abstract class GeneralizedLinearAlgorithm[M <: GeneralizedLinearModel] */ val initialWeights = { if (numOfLinearPredictor == 1) { - Vectors.dense(new Array[Double](numFeatures)) + Vectors.zeros(numFeatures) } else if (addIntercept) { - Vectors.dense(new Array[Double]((numFeatures + 1) * numOfLinearPredictor)) + Vectors.zeros((numFeatures + 1) * numOfLinearPredictor) } else { - Vectors.dense(new Array[Double](numFeatures * numOfLinearPredictor)) + Vectors.zeros(numFeatures * numOfLinearPredictor) } } run(input, initialWeights) diff --git a/mllib/src/main/scala/org/apache/spark/mllib/regression/IsotonicRegression.scala b/mllib/src/main/scala/org/apache/spark/mllib/regression/IsotonicRegression.scala index 4ce541ae5bed..f3b46c75c05f 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/regression/IsotonicRegression.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/regression/IsotonicRegression.scala @@ -170,26 +170,26 @@ object IsotonicRegressionModel extends Loader[IsotonicRegressionModel] { case class Data(boundary: Double, prediction: Double) def save( - sc: SparkContext, - path: String, - boundaries: Array[Double], - predictions: Array[Double], + sc: SparkContext, + path: String, + boundaries: Array[Double], + predictions: Array[Double], isotonic: Boolean): Unit = { val sqlContext = new SQLContext(sc) val metadata = compact(render( - ("class" -> thisClassName) ~ ("version" -> thisFormatVersion) ~ + ("class" -> thisClassName) ~ ("version" -> thisFormatVersion) ~ ("isotonic" -> isotonic))) sc.parallelize(Seq(metadata), 1).saveAsTextFile(metadataPath(path)) sqlContext.createDataFrame( boundaries.toSeq.zip(predictions).map { case (b, p) => Data(b, p) } - ).saveAsParquetFile(dataPath(path)) + ).write.parquet(dataPath(path)) } def load(sc: SparkContext, path: String): (Array[Double], Array[Double]) = { val sqlContext = new SQLContext(sc) - val dataRDD = sqlContext.parquetFile(dataPath(path)) + val dataRDD = sqlContext.read.parquet(dataPath(path)) checkSchema[Data](dataRDD.schema) val dataArray = dataRDD.select("boundary", "prediction").collect() @@ -203,7 +203,7 @@ object IsotonicRegressionModel extends Loader[IsotonicRegressionModel] { override def load(sc: SparkContext, path: String): IsotonicRegressionModel = { implicit val formats = DefaultFormats val (loadedClassName, version, metadata) = loadMetadata(sc, path) - val isotonic = (metadata \ "isotonic").extract[Boolean] + val isotonic = (metadata \ "isotonic").extract[Boolean] val classNameV1_0 = SaveLoadV1_0.thisClassName (loadedClassName, version) match { case (className, "1.0") if className == classNameV1_0 => diff --git a/mllib/src/main/scala/org/apache/spark/mllib/regression/RidgeRegression.scala b/mllib/src/main/scala/org/apache/spark/mllib/regression/RidgeRegression.scala index e0c03d8180c7..7d28ffad45c9 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/regression/RidgeRegression.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/regression/RidgeRegression.scala @@ -73,7 +73,7 @@ object RidgeRegressionModel extends Loader[RidgeRegressionModel] { /** * Train a regression model with L2-regularization using Stochastic Gradient Descent. - * This solves the l1-regularized least squares regression formulation + * This solves the l2-regularized least squares regression formulation * f(weights) = 1/2n ||A weights-y||^2^ + regParam/2 ||weights||^2^ * Here the data matrix has n rows, and the input RDD holds the set of rows of A, each with * its corresponding right hand side label y. diff --git a/mllib/src/main/scala/org/apache/spark/mllib/regression/StreamingLinearAlgorithm.scala b/mllib/src/main/scala/org/apache/spark/mllib/regression/StreamingLinearAlgorithm.scala index cea8f3f47307..141052ba813e 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/regression/StreamingLinearAlgorithm.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/regression/StreamingLinearAlgorithm.scala @@ -83,21 +83,15 @@ abstract class StreamingLinearAlgorithm[ throw new IllegalArgumentException("Model must be initialized before starting training.") } data.foreachRDD { (rdd, time) => - val initialWeights = - model match { - case Some(m) => - m.weights - case None => - val numFeatures = rdd.first().features.size - Vectors.dense(numFeatures) + if (!rdd.isEmpty) { + model = Some(algorithm.run(rdd, model.get.weights)) + logInfo(s"Model updated at time ${time.toString}") + val display = model.get.weights.size match { + case x if x > 100 => model.get.weights.toArray.take(100).mkString("[", ",", "...") + case _ => model.get.weights.toArray.mkString("[", ",", "]") } - model = Some(algorithm.run(rdd, initialWeights)) - logInfo("Model updated at time %s".format(time.toString)) - val display = model.get.weights.size match { - case x if x > 100 => model.get.weights.toArray.take(100).mkString("[", ",", "...") - case _ => model.get.weights.toArray.mkString("[", ",", "]") + logInfo(s"Current model: weights, ${display}") } - logInfo("Current model: weights, %s".format (display)) } } diff --git a/mllib/src/main/scala/org/apache/spark/mllib/regression/StreamingLinearRegressionWithSGD.scala b/mllib/src/main/scala/org/apache/spark/mllib/regression/StreamingLinearRegressionWithSGD.scala index a49153bf73c0..c6d04464a12b 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/regression/StreamingLinearRegressionWithSGD.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/regression/StreamingLinearRegressionWithSGD.scala @@ -79,10 +79,16 @@ class StreamingLinearRegressionWithSGD private[mllib] ( this } - /** Set the initial weights. Default: [0.0, 0.0]. */ + /** Set the initial weights. */ def setInitialWeights(initialWeights: Vector): this.type = { this.model = Some(algorithm.createModel(initialWeights, 0.0)) this } + /** Set the convergence tolerance. */ + def setConvergenceTol(tolerance: Double): this.type = { + this.algorithm.optimizer.setConvergenceTol(tolerance) + this + } + } diff --git a/mllib/src/main/scala/org/apache/spark/mllib/regression/impl/GLMRegressionModel.scala b/mllib/src/main/scala/org/apache/spark/mllib/regression/impl/GLMRegressionModel.scala index b55944f74f62..317d3a570263 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/regression/impl/GLMRegressionModel.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/regression/impl/GLMRegressionModel.scala @@ -60,7 +60,7 @@ private[regression] object GLMRegressionModel { val data = Data(weights, intercept) val dataRDD: DataFrame = sc.parallelize(Seq(data), 1).toDF() // TODO: repartition with 1 partition after SPARK-5532 gets fixed - dataRDD.saveAsParquetFile(Loader.dataPath(path)) + dataRDD.write.parquet(Loader.dataPath(path)) } /** @@ -72,7 +72,7 @@ private[regression] object GLMRegressionModel { def loadData(sc: SparkContext, path: String, modelClass: String, numFeatures: Int): Data = { val datapath = Loader.dataPath(path) val sqlContext = new SQLContext(sc) - val dataRDD = sqlContext.parquetFile(datapath) + val dataRDD = sqlContext.read.parquet(datapath) val dataArray = dataRDD.select("weights", "intercept").take(1) assert(dataArray.size == 1, s"Unable to load $modelClass data from: $datapath") val data = dataArray(0) diff --git a/mllib/src/main/scala/org/apache/spark/mllib/stat/KernelDensity.scala b/mllib/src/main/scala/org/apache/spark/mllib/stat/KernelDensity.scala index 79747cc5d7d7..58a50f9c19f1 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/stat/KernelDensity.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/stat/KernelDensity.scala @@ -17,52 +17,101 @@ package org.apache.spark.mllib.stat +import com.github.fommil.netlib.BLAS.{getInstance => blas} + +import org.apache.spark.annotation.Experimental +import org.apache.spark.api.java.JavaRDD import org.apache.spark.rdd.RDD -private[stat] object KernelDensity { +/** + * :: Experimental :: + * Kernel density estimation. Given a sample from a population, estimate its probability density + * function at each of the given evaluation points using kernels. Only Gaussian kernel is supported. + * + * Scala example: + * + * {{{ + * val sample = sc.parallelize(Seq(0.0, 1.0, 4.0, 4.0)) + * val kd = new KernelDensity() + * .setSample(sample) + * .setBandwidth(3.0) + * val densities = kd.estimate(Array(-1.0, 2.0, 5.0)) + * }}} + */ +@Experimental +class KernelDensity extends Serializable { + + import KernelDensity._ + + /** Bandwidth of the kernel function. */ + private var bandwidth: Double = 1.0 + + /** A sample from a population. */ + private var sample: RDD[Double] = _ + /** - * Given a set of samples from a distribution, estimates its density at the set of given points. - * Uses a Gaussian kernel with the given standard deviation. + * Sets the bandwidth (standard deviation) of the Gaussian kernel (default: `1.0`). */ - def estimate(samples: RDD[Double], standardDeviation: Double, - evaluationPoints: Array[Double]): Array[Double] = { - if (standardDeviation <= 0.0) { - throw new IllegalArgumentException("Standard deviation must be positive") - } + def setBandwidth(bandwidth: Double): this.type = { + require(bandwidth > 0, s"Bandwidth must be positive, but got $bandwidth.") + this.bandwidth = bandwidth + this + } - // This gets used in each Gaussian PDF computation, so compute it up front - val logStandardDeviationPlusHalfLog2Pi = - math.log(standardDeviation) + 0.5 * math.log(2 * math.Pi) + /** + * Sets the sample to use for density estimation. + */ + def setSample(sample: RDD[Double]): this.type = { + this.sample = sample + this + } + + /** + * Sets the sample to use for density estimation (for Java users). + */ + def setSample(sample: JavaRDD[java.lang.Double]): this.type = { + this.sample = sample.rdd.asInstanceOf[RDD[Double]] + this + } + + /** + * Estimates probability density function at the given array of points. + */ + def estimate(points: Array[Double]): Array[Double] = { + val sample = this.sample + val bandwidth = this.bandwidth + + require(sample != null, "Must set sample before calling estimate.") - val (points, count) = samples.aggregate((new Array[Double](evaluationPoints.length), 0))( + val n = points.length + // This gets used in each Gaussian PDF computation, so compute it up front + val logStandardDeviationPlusHalfLog2Pi = math.log(bandwidth) + 0.5 * math.log(2 * math.Pi) + val (densities, count) = sample.aggregate((new Array[Double](n), 0L))( (x, y) => { var i = 0 - while (i < evaluationPoints.length) { - x._1(i) += normPdf(y, standardDeviation, logStandardDeviationPlusHalfLog2Pi, - evaluationPoints(i)) + while (i < n) { + x._1(i) += normPdf(y, bandwidth, logStandardDeviationPlusHalfLog2Pi, points(i)) i += 1 } - (x._1, i) + (x._1, x._2 + 1) }, (x, y) => { - var i = 0 - while (i < evaluationPoints.length) { - x._1(i) += y._1(i) - i += 1 - } + blas.daxpy(n, 1.0, y._1, 1, x._1, 1) (x._1, x._2 + y._2) }) - - var i = 0 - while (i < points.length) { - points(i) /= count - i += 1 - } - points + blas.dscal(n, 1.0 / count, densities, 1) + densities } +} + +private object KernelDensity { - private def normPdf(mean: Double, standardDeviation: Double, - logStandardDeviationPlusHalfLog2Pi: Double, x: Double): Double = { + /** Evaluates the PDF of a normal distribution. */ + def normPdf( + mean: Double, + standardDeviation: Double, + logStandardDeviationPlusHalfLog2Pi: Double, + x: Double): Double = { val x0 = x - mean val x1 = x0 / standardDeviation val logDensity = -0.5 * x1 * x1 - logStandardDeviationPlusHalfLog2Pi diff --git a/mllib/src/main/scala/org/apache/spark/mllib/stat/MultivariateOnlineSummarizer.scala b/mllib/src/main/scala/org/apache/spark/mllib/stat/MultivariateOnlineSummarizer.scala index 0b1755613aac..d321cc554c1c 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/stat/MultivariateOnlineSummarizer.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/stat/MultivariateOnlineSummarizer.scala @@ -70,7 +70,7 @@ class MultivariateOnlineSummarizer extends MultivariateStatisticalSummary with S require(n == sample.size, s"Dimensions mismatch when adding new sample." + s" Expecting $n but got ${sample.size}.") - val localCurrMean= currMean + val localCurrMean = currMean val localCurrM2n = currM2n val localCurrM2 = currM2 val localCurrL1 = currL1 diff --git a/mllib/src/main/scala/org/apache/spark/mllib/stat/Statistics.scala b/mllib/src/main/scala/org/apache/spark/mllib/stat/Statistics.scala index 32561620ac91..900007ec6bc7 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/stat/Statistics.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/stat/Statistics.scala @@ -18,6 +18,7 @@ package org.apache.spark.mllib.stat import org.apache.spark.annotation.Experimental +import org.apache.spark.api.java.JavaRDD import org.apache.spark.mllib.linalg.distributed.RowMatrix import org.apache.spark.mllib.linalg.{Matrix, Vector} import org.apache.spark.mllib.regression.LabeledPoint @@ -80,6 +81,10 @@ object Statistics { */ def corr(x: RDD[Double], y: RDD[Double]): Double = Correlations.corr(x, y) + /** Java-friendly version of [[corr()]] */ + def corr(x: JavaRDD[java.lang.Double], y: JavaRDD[java.lang.Double]): Double = + corr(x.rdd.asInstanceOf[RDD[Double]], y.rdd.asInstanceOf[RDD[Double]]) + /** * Compute the correlation for the input RDDs using the specified method. * Methods currently supported: `pearson` (default), `spearman`. @@ -96,6 +101,10 @@ object Statistics { */ def corr(x: RDD[Double], y: RDD[Double], method: String): Double = Correlations.corr(x, y, method) + /** Java-friendly version of [[corr()]] */ + def corr(x: JavaRDD[java.lang.Double], y: JavaRDD[java.lang.Double], method: String): Double = + corr(x.rdd.asInstanceOf[RDD[Double]], y.rdd.asInstanceOf[RDD[Double]], method) + /** * Conduct Pearson's chi-squared goodness of fit test of the observed data against the * expected distribution. @@ -149,18 +158,4 @@ object Statistics { def chiSqTest(data: RDD[LabeledPoint]): Array[ChiSqTestResult] = { ChiSqTest.chiSquaredFeatures(data) } - - /** - * Given an empirical distribution defined by the input RDD of samples, estimate its density at - * each of the given evaluation points using a Gaussian kernel. - * - * @param samples The samples RDD used to define the empirical distribution. - * @param standardDeviation The standard deviation of the kernel Gaussians. - * @param evaluationPoints The points at which to estimate densities. - * @return An array the same size as evaluationPoints with the density at each point. - */ - def kernelDensity(samples: RDD[Double], standardDeviation: Double, - evaluationPoints: Iterable[Double]): Array[Double] = { - KernelDensity.estimate(samples, standardDeviation, evaluationPoints.toArray) - } } diff --git a/mllib/src/main/scala/org/apache/spark/mllib/stat/distribution/MultivariateGaussian.scala b/mllib/src/main/scala/org/apache/spark/mllib/stat/distribution/MultivariateGaussian.scala index cd6add9d60b0..cf51b24ff777 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/stat/distribution/MultivariateGaussian.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/stat/distribution/MultivariateGaussian.scala @@ -29,102 +29,102 @@ import org.apache.spark.mllib.util.MLUtils * the event that the covariance matrix is singular, the density will be computed in a * reduced dimensional subspace under which the distribution is supported. * (see [[http://en.wikipedia.org/wiki/Multivariate_normal_distribution#Degenerate_case]]) - * + * * @param mu The mean vector of the distribution * @param sigma The covariance matrix of the distribution */ @DeveloperApi class MultivariateGaussian ( - val mu: Vector, + val mu: Vector, val sigma: Matrix) extends Serializable { require(sigma.numCols == sigma.numRows, "Covariance matrix must be square") require(mu.size == sigma.numCols, "Mean vector length must match covariance matrix size") - + private val breezeMu = mu.toBreeze.toDenseVector - + /** * private[mllib] constructor - * + * * @param mu The mean vector of the distribution * @param sigma The covariance matrix of the distribution */ private[mllib] def this(mu: DBV[Double], sigma: DBM[Double]) = { this(Vectors.fromBreeze(mu), Matrices.fromBreeze(sigma)) } - + /** * Compute distribution dependent constants: * rootSigmaInv = D^(-1/2)^ * U, where sigma = U * D * U.t - * u = log((2*pi)^(-k/2)^ * det(sigma)^(-1/2)^) + * u = log((2*pi)^(-k/2)^ * det(sigma)^(-1/2)^) */ private val (rootSigmaInv: DBM[Double], u: Double) = calculateCovarianceConstants - + /** Returns density of this multivariate Gaussian at given point, x */ def pdf(x: Vector): Double = { pdf(x.toBreeze) } - + /** Returns the log-density of this multivariate Gaussian at given point, x */ def logpdf(x: Vector): Double = { logpdf(x.toBreeze) } - + /** Returns density of this multivariate Gaussian at given point, x */ private[mllib] def pdf(x: BV[Double]): Double = { math.exp(logpdf(x)) } - + /** Returns the log-density of this multivariate Gaussian at given point, x */ private[mllib] def logpdf(x: BV[Double]): Double = { val delta = x - breezeMu val v = rootSigmaInv * delta u + v.t * v * -0.5 } - + /** * Calculate distribution dependent components used for the density function: * pdf(x) = (2*pi)^(-k/2)^ * det(sigma)^(-1/2)^ * exp((-1/2) * (x-mu).t * inv(sigma) * (x-mu)) * where k is length of the mean vector. - * - * We here compute distribution-fixed parts + * + * We here compute distribution-fixed parts * log((2*pi)^(-k/2)^ * det(sigma)^(-1/2)^) * and * D^(-1/2)^ * U, where sigma = U * D * U.t - * + * * Both the determinant and the inverse can be computed from the singular value decomposition * of sigma. Noting that covariance matrices are always symmetric and positive semi-definite, * we can use the eigendecomposition. We also do not compute the inverse directly; noting - * that - * + * that + * * sigma = U * D * U.t - * inv(Sigma) = U * inv(D) * U.t + * inv(Sigma) = U * inv(D) * U.t * = (D^{-1/2}^ * U).t * (D^{-1/2}^ * U) - * + * * and thus - * + * * -0.5 * (x-mu).t * inv(Sigma) * (x-mu) = -0.5 * norm(D^{-1/2}^ * U * (x-mu))^2^ - * - * To guard against singular covariance matrices, this method computes both the + * + * To guard against singular covariance matrices, this method computes both the * pseudo-determinant and the pseudo-inverse (Moore-Penrose). Singular values are considered * to be non-zero only if they exceed a tolerance based on machine precision, matrix size, and * relation to the maximum singular value (same tolerance used by, e.g., Octave). */ private def calculateCovarianceConstants: (DBM[Double], Double) = { val eigSym.EigSym(d, u) = eigSym(sigma.toBreeze.toDenseMatrix) // sigma = u * diag(d) * u.t - + // For numerical stability, values are considered to be non-zero only if they exceed tol. // This prevents any inverted value from exceeding (eps * n * max(d))^-1 val tol = MLUtils.EPSILON * max(d) * d.length - + try { // log(pseudo-determinant) is sum of the logs of all non-zero singular values val logPseudoDetSigma = d.activeValuesIterator.filter(_ > tol).map(math.log).sum - - // calculate the root-pseudo-inverse of the diagonal matrix of singular values + + // calculate the root-pseudo-inverse of the diagonal matrix of singular values // by inverting the square root of all non-zero values val pinvS = diag(new DBV(d.map(v => if (v > tol) math.sqrt(1.0 / v) else 0.0).toArray)) - + (pinvS * u, -0.5 * (mu.size * math.log(2.0 * math.Pi) + logPseudoDetSigma)) } catch { case uex: UnsupportedOperationException => diff --git a/mllib/src/main/scala/org/apache/spark/mllib/stat/test/ChiSqTest.scala b/mllib/src/main/scala/org/apache/spark/mllib/stat/test/ChiSqTest.scala index e597fce2babd..23c8d7c7c807 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/stat/test/ChiSqTest.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/stat/test/ChiSqTest.scala @@ -196,7 +196,7 @@ private[stat] object ChiSqTest extends Logging { * Pearson's independence test on the input contingency matrix. * TODO: optimize for SparseMatrix when it becomes supported. */ - def chiSquaredMatrix(counts: Matrix, methodName:String = PEARSON.name): ChiSqTestResult = { + def chiSquaredMatrix(counts: Matrix, methodName: String = PEARSON.name): ChiSqTestResult = { val method = methodFromString(methodName) val numRows = counts.numRows val numCols = counts.numCols diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala index dfe3a0b6913e..cecd1fed896d 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala @@ -169,7 +169,7 @@ object DecisionTree extends Serializable with Logging { numClasses: Int, maxBins: Int, quantileCalculationStrategy: QuantileStrategy, - categoricalFeaturesInfo: Map[Int,Int]): DecisionTreeModel = { + categoricalFeaturesInfo: Map[Int, Int]): DecisionTreeModel = { val strategy = new Strategy(algo, impurity, maxDepth, numClasses, maxBins, quantileCalculationStrategy, categoricalFeaturesInfo) new DecisionTree(strategy).run(input) @@ -768,7 +768,7 @@ object DecisionTree extends Serializable with Logging { */ private def calculatePredictImpurity( leftImpurityCalculator: ImpurityCalculator, - rightImpurityCalculator: ImpurityCalculator): (Predict, Double) = { + rightImpurityCalculator: ImpurityCalculator): (Predict, Double) = { val parentNodeAgg = leftImpurityCalculator.copy parentNodeAgg.add(rightImpurityCalculator) val predict = calculatePredict(parentNodeAgg) diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/GradientBoostedTrees.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/GradientBoostedTrees.scala index 1f779584dcff..a835f96d5d0e 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/GradientBoostedTrees.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/GradientBoostedTrees.scala @@ -60,12 +60,12 @@ class GradientBoostedTrees(private val boostingStrategy: BoostingStrategy) def run(input: RDD[LabeledPoint]): GradientBoostedTreesModel = { val algo = boostingStrategy.treeStrategy.algo algo match { - case Regression => GradientBoostedTrees.boost(input, input, boostingStrategy, validate=false) + case Regression => + GradientBoostedTrees.boost(input, input, boostingStrategy, validate = false) case Classification => // Map labels to -1, +1 so binary classification can be treated as regression. val remappedInput = input.map(x => new LabeledPoint((x.label * 2) - 1, x.features)) - GradientBoostedTrees.boost(remappedInput, - remappedInput, boostingStrategy, validate=false) + GradientBoostedTrees.boost(remappedInput, remappedInput, boostingStrategy, validate = false) case _ => throw new IllegalArgumentException(s"$algo is not supported by the gradient boosting.") } @@ -93,8 +93,8 @@ class GradientBoostedTrees(private val boostingStrategy: BoostingStrategy) validationInput: RDD[LabeledPoint]): GradientBoostedTreesModel = { val algo = boostingStrategy.treeStrategy.algo algo match { - case Regression => GradientBoostedTrees.boost( - input, validationInput, boostingStrategy, validate=true) + case Regression => + GradientBoostedTrees.boost(input, validationInput, boostingStrategy, validate = true) case Classification => // Map labels to -1, +1 so binary classification can be treated as regression. val remappedInput = input.map( @@ -102,7 +102,7 @@ class GradientBoostedTrees(private val boostingStrategy: BoostingStrategy) val remappedValidationInput = validationInput.map( x => new LabeledPoint((x.label * 2) - 1, x.features)) GradientBoostedTrees.boost(remappedInput, remappedValidationInput, boostingStrategy, - validate=true) + validate = true) case _ => throw new IllegalArgumentException(s"$algo is not supported by the gradient boosting.") } @@ -270,7 +270,7 @@ object GradientBoostedTrees extends Logging { logInfo(s"$timer") if (persistedInput) input.unpersist() - + if (validate) { new GradientBoostedTreesModel( boostingStrategy.treeStrategy.algo, diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/RandomForest.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/RandomForest.scala index 055e60c7d9c9..069959976a18 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/RandomForest.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/RandomForest.scala @@ -36,6 +36,7 @@ import org.apache.spark.mllib.tree.model._ import org.apache.spark.rdd.RDD import org.apache.spark.storage.StorageLevel import org.apache.spark.util.Utils +import org.apache.spark.util.random.SamplingUtils /** * :: Experimental :: @@ -248,7 +249,7 @@ private class RandomForest ( try { nodeIdCache.get.deleteAllCheckpoints() } catch { - case e:IOException => + case e: IOException => logWarning(s"delete all checkpoints failed. Error reason: ${e.getMessage}") } } @@ -473,9 +474,8 @@ object RandomForest extends Serializable with Logging { val (treeIndex, node) = nodeQueue.head // Choose subset of features for node (if subsampling). val featureSubset: Option[Array[Int]] = if (metadata.subsamplingFeatures) { - // TODO: Use more efficient subsampling? (use selection-and-rejection or reservoir) - Some(rng.shuffle(Range(0, metadata.numFeatures).toList) - .take(metadata.numFeaturesPerNode).toArray) + Some(SamplingUtils.reservoirSampleAndCount(Range(0, + metadata.numFeatures).iterator, metadata.numFeaturesPerNode, rng.nextLong)._1) } else { None } diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/model/DecisionTreeModel.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/model/DecisionTreeModel.scala index 331af428533d..f2c78bbabff0 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/model/DecisionTreeModel.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/model/DecisionTreeModel.scala @@ -198,7 +198,7 @@ object DecisionTreeModel extends Loader[DecisionTreeModel] with Logging { val driverMemory = sc.getConf.getOption("spark.driver.memory") .orElse(Option(System.getenv("SPARK_DRIVER_MEMORY"))) .map(Utils.memoryStringToMb) - .getOrElse(512) + .getOrElse(Utils.DEFAULT_DRIVER_MEM_MB) if (driverMemory <= memThreshold) { logWarning(s"$thisClassName.save() was called, but it may fail because of too little" + s" driver memory (${driverMemory}m)." + @@ -223,14 +223,14 @@ object DecisionTreeModel extends Loader[DecisionTreeModel] with Logging { val dataRDD: DataFrame = sc.parallelize(nodes) .map(NodeData.apply(0, _)) .toDF() - dataRDD.saveAsParquetFile(Loader.dataPath(path)) + dataRDD.write.parquet(Loader.dataPath(path)) } def load(sc: SparkContext, path: String, algo: String, numNodes: Int): DecisionTreeModel = { val datapath = Loader.dataPath(path) val sqlContext = new SQLContext(sc) // Load Parquet data. - val dataRDD = sqlContext.parquetFile(datapath) + val dataRDD = sqlContext.read.parquet(datapath) // Check schema explicitly since erasure makes it hard to use match-case for checking. Loader.checkSchema[NodeData](dataRDD.schema) val nodes = dataRDD.map(NodeData.apply) diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/model/Node.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/model/Node.scala index 431a839817ea..a6d1398fc267 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/model/Node.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/model/Node.scala @@ -83,7 +83,7 @@ class Node ( def predict(features: Vector) : Double = { if (isLeaf) { predict.predict - } else{ + } else { if (split.get.featureType == Continuous) { if (features(split.get.feature) <= split.get.threshold) { leftNode.get.predict(features) @@ -151,9 +151,9 @@ class Node ( s"(feature ${split.feature} > ${split.threshold})" } case Categorical => if (left) { - s"(feature ${split.feature} in ${split.categories.mkString("{",",","}")})" + s"(feature ${split.feature} in ${split.categories.mkString("{", ",", "}")})" } else { - s"(feature ${split.feature} not in ${split.categories.mkString("{",",","}")})" + s"(feature ${split.feature} not in ${split.categories.mkString("{", ",", "}")})" } } } @@ -161,9 +161,9 @@ class Node ( if (isLeaf) { prefix + s"Predict: ${predict.predict}\n" } else { - prefix + s"If ${splitToString(split.get, left=true)}\n" + + prefix + s"If ${splitToString(split.get, left = true)}\n" + leftNode.get.subtreeToString(indentFactor + 1) + - prefix + s"Else ${splitToString(split.get, left=false)}\n" + + prefix + s"Else ${splitToString(split.get, left = false)}\n" + rightNode.get.subtreeToString(indentFactor + 1) } } diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/model/treeEnsembleModels.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/model/treeEnsembleModels.scala index 8341219bfa71..905c5fb42bd4 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/model/treeEnsembleModels.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/model/treeEnsembleModels.scala @@ -387,7 +387,7 @@ private[tree] object TreeEnsembleModel extends Logging { val driverMemory = sc.getConf.getOption("spark.driver.memory") .orElse(Option(System.getenv("SPARK_DRIVER_MEMORY"))) .map(Utils.memoryStringToMb) - .getOrElse(512) + .getOrElse(Utils.DEFAULT_DRIVER_MEM_MB) if (driverMemory <= memThreshold) { logWarning(s"$className.save() was called, but it may fail because of too little" + s" driver memory (${driverMemory}m)." + @@ -414,7 +414,7 @@ private[tree] object TreeEnsembleModel extends Logging { val dataRDD = sc.parallelize(model.trees.zipWithIndex).flatMap { case (tree, treeId) => tree.topNode.subtreeIterator.toSeq.map(node => NodeData(treeId, node)) }.toDF() - dataRDD.saveAsParquetFile(Loader.dataPath(path)) + dataRDD.write.parquet(Loader.dataPath(path)) } /** @@ -437,7 +437,7 @@ private[tree] object TreeEnsembleModel extends Logging { treeAlgo: String): Array[DecisionTreeModel] = { val datapath = Loader.dataPath(path) val sqlContext = new SQLContext(sc) - val nodes = sqlContext.parquetFile(datapath).map(NodeData.apply) + val nodes = sqlContext.read.parquet(datapath).map(NodeData.apply) val trees = constructTrees(nodes) trees.map(new DecisionTreeModel(_, Algo.fromString(treeAlgo))) } diff --git a/mllib/src/main/scala/org/apache/spark/mllib/util/MFDataGenerator.scala b/mllib/src/main/scala/org/apache/spark/mllib/util/MFDataGenerator.scala index 0c5b4f9d04a7..bd73a866c8a8 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/util/MFDataGenerator.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/util/MFDataGenerator.scala @@ -82,8 +82,7 @@ object MFDataGenerator { BLAS.gemm(z, A, B, 1.0, fullData) val df = rank * (m + n - rank) - val sampSize = scala.math.min(scala.math.round(trainSampFact * df), - scala.math.round(.99 * m * n)).toInt + val sampSize = math.min(math.round(trainSampFact * df), math.round(.99 * m * n)).toInt val rand = new Random() val mn = m * n val shuffled = rand.shuffle((0 until mn).toList) @@ -102,8 +101,8 @@ object MFDataGenerator { // optionally generate testing data if (test) { - val testSampSize = scala.math - .min(scala.math.round(sampSize * testSampFact),scala.math.round(mn - sampSize)).toInt + val testSampSize = math.min( + math.round(sampSize * testSampFact), math.round(mn - sampSize)).toInt val testOmega = shuffled.slice(sampSize, sampSize + testSampSize) val testOrdered = testOmega.sortWith(_ < _).toArray val testData: RDD[(Int, Int, Double)] = sc.parallelize(testOrdered) diff --git a/mllib/src/main/scala/org/apache/spark/mllib/util/MLUtils.scala b/mllib/src/main/scala/org/apache/spark/mllib/util/MLUtils.scala index 681f4c618d30..7c5cfa7bd84c 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/util/MLUtils.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/util/MLUtils.scala @@ -82,6 +82,18 @@ object MLUtils { val value = indexAndValue(1).toDouble (index, value) }.unzip + + // check if indices are one-based and in ascending order + var previous = -1 + var i = 0 + val indicesLength = indices.length + while (i < indicesLength) { + val current = indices(i) + require(current > previous, "indices should be one-based and in ascending order" ) + previous = current + i += 1 + } + (label, indices.toArray, values.toArray) } @@ -258,14 +270,30 @@ object MLUtils { * Returns a new vector with `1.0` (bias) appended to the input vector. */ def appendBias(vector: Vector): Vector = { - val vector1 = vector.toBreeze match { - case dv: BDV[Double] => BDV.vertcat(dv, new BDV[Double](Array(1.0))) - case sv: BSV[Double] => BSV.vertcat(sv, new BSV[Double](Array(0), Array(1.0), 1)) - case v: Any => throw new IllegalArgumentException("Do not support vector type " + v.getClass) + vector match { + case dv: DenseVector => + val inputValues = dv.values + val inputLength = inputValues.length + val outputValues = Array.ofDim[Double](inputLength + 1) + System.arraycopy(inputValues, 0, outputValues, 0, inputLength) + outputValues(inputLength) = 1.0 + Vectors.dense(outputValues) + case sv: SparseVector => + val inputValues = sv.values + val inputIndices = sv.indices + val inputValuesLength = inputValues.length + val dim = sv.size + val outputValues = Array.ofDim[Double](inputValuesLength + 1) + val outputIndices = Array.ofDim[Int](inputValuesLength + 1) + System.arraycopy(inputValues, 0, outputValues, 0, inputValuesLength) + System.arraycopy(inputIndices, 0, outputIndices, 0, inputValuesLength) + outputValues(inputValuesLength) = 1.0 + outputIndices(inputValuesLength) = dim + Vectors.sparse(dim + 1, outputIndices, outputValues) + case _ => throw new IllegalArgumentException(s"Do not support vector type ${vector.getClass}") } - Vectors.fromBreeze(vector1) } - + /** * Returns the squared Euclidean distance between two vectors. The following formula will be used * if it does not introduce too much numerical error: diff --git a/mllib/src/main/scala/org/apache/spark/mllib/util/NumericParser.scala b/mllib/src/main/scala/org/apache/spark/mllib/util/NumericParser.scala index 308f7f3578e2..a841c5caf014 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/util/NumericParser.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/util/NumericParser.scala @@ -98,6 +98,8 @@ private[mllib] object NumericParser { } } else if (token == ")") { parsing = false + } else if (token.trim.isEmpty){ + // ignore whitespaces between delim chars, e.g. ", [" } else { // expecting a number items.append(parseDouble(token)) diff --git a/mllib/src/test/java/org/apache/spark/ml/feature/JavaBucketizerSuite.java b/mllib/src/test/java/org/apache/spark/ml/feature/JavaBucketizerSuite.java new file mode 100644 index 000000000000..d5bd230a957a --- /dev/null +++ b/mllib/src/test/java/org/apache/spark/ml/feature/JavaBucketizerSuite.java @@ -0,0 +1,80 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.ml.feature; + +import com.google.common.collect.Lists; +import org.junit.After; +import org.junit.Assert; +import org.junit.Before; +import org.junit.Test; + +import org.apache.spark.api.java.JavaRDD; +import org.apache.spark.api.java.JavaSparkContext; +import org.apache.spark.sql.DataFrame; +import org.apache.spark.sql.Row; +import org.apache.spark.sql.RowFactory; +import org.apache.spark.sql.SQLContext; +import org.apache.spark.sql.types.DataTypes; +import org.apache.spark.sql.types.Metadata; +import org.apache.spark.sql.types.StructField; +import org.apache.spark.sql.types.StructType; + +public class JavaBucketizerSuite { + private transient JavaSparkContext jsc; + private transient SQLContext jsql; + + @Before + public void setUp() { + jsc = new JavaSparkContext("local", "JavaBucketizerSuite"); + jsql = new SQLContext(jsc); + } + + @After + public void tearDown() { + jsc.stop(); + jsc = null; + } + + @Test + public void bucketizerTest() { + double[] splits = {-0.5, 0.0, 0.5}; + + JavaRDD data = jsc.parallelize(Lists.newArrayList( + RowFactory.create(-0.5), + RowFactory.create(-0.3), + RowFactory.create(0.0), + RowFactory.create(0.2) + )); + StructType schema = new StructType(new StructField[] { + new StructField("feature", DataTypes.DoubleType, false, Metadata.empty()) + }); + DataFrame dataset = jsql.createDataFrame(data, schema); + + Bucketizer bucketizer = new Bucketizer() + .setInputCol("feature") + .setOutputCol("result") + .setSplits(splits); + + Row[] result = bucketizer.transform(dataset).select("result").collect(); + + for (Row r : result) { + double index = r.getDouble(0); + Assert.assertTrue((index >= 0) && (index <= 1)); + } + } +} diff --git a/mllib/src/test/java/org/apache/spark/ml/feature/JavaDCTSuite.java b/mllib/src/test/java/org/apache/spark/ml/feature/JavaDCTSuite.java new file mode 100644 index 000000000000..845eed61c45c --- /dev/null +++ b/mllib/src/test/java/org/apache/spark/ml/feature/JavaDCTSuite.java @@ -0,0 +1,78 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.ml.feature; + +import com.google.common.collect.Lists; +import edu.emory.mathcs.jtransforms.dct.DoubleDCT_1D; +import org.junit.After; +import org.junit.Assert; +import org.junit.Before; +import org.junit.Test; + +import org.apache.spark.api.java.JavaRDD; +import org.apache.spark.api.java.JavaSparkContext; +import org.apache.spark.mllib.linalg.Vector; +import org.apache.spark.mllib.linalg.VectorUDT; +import org.apache.spark.mllib.linalg.Vectors; +import org.apache.spark.sql.DataFrame; +import org.apache.spark.sql.Row; +import org.apache.spark.sql.RowFactory; +import org.apache.spark.sql.SQLContext; +import org.apache.spark.sql.types.Metadata; +import org.apache.spark.sql.types.StructField; +import org.apache.spark.sql.types.StructType; + +public class JavaDCTSuite { + private transient JavaSparkContext jsc; + private transient SQLContext jsql; + + @Before + public void setUp() { + jsc = new JavaSparkContext("local", "JavaDCTSuite"); + jsql = new SQLContext(jsc); + } + + @After + public void tearDown() { + jsc.stop(); + jsc = null; + } + + @Test + public void javaCompatibilityTest() { + double[] input = new double[] {1D, 2D, 3D, 4D}; + JavaRDD data = jsc.parallelize(Lists.newArrayList( + RowFactory.create(Vectors.dense(input)) + )); + DataFrame dataset = jsql.createDataFrame(data, new StructType(new StructField[]{ + new StructField("vec", (new VectorUDT()), false, Metadata.empty()) + })); + + double[] expectedResult = input.clone(); + (new DoubleDCT_1D(input.length)).forward(expectedResult, true); + + DCT dct = new DCT() + .setInputCol("vec") + .setOutputCol("resultVec"); + + Row[] result = dct.transform(dataset).select("resultVec").collect(); + Vector resultVec = result[0].getAs("resultVec"); + + Assert.assertArrayEquals(expectedResult, resultVec.toArray(), 1e-6); + } +} diff --git a/mllib/src/test/java/org/apache/spark/ml/feature/JavaHashingTFSuite.java b/mllib/src/test/java/org/apache/spark/ml/feature/JavaHashingTFSuite.java index 23463ab5fe84..599e9cfd23ad 100644 --- a/mllib/src/test/java/org/apache/spark/ml/feature/JavaHashingTFSuite.java +++ b/mllib/src/test/java/org/apache/spark/ml/feature/JavaHashingTFSuite.java @@ -55,25 +55,30 @@ public void tearDown() { @Test public void hashingTF() { JavaRDD jrdd = jsc.parallelize(Lists.newArrayList( - RowFactory.create(0, "Hi I heard about Spark"), - RowFactory.create(0, "I wish Java could use case classes"), - RowFactory.create(1, "Logistic regression models are neat") + RowFactory.create(0.0, "Hi I heard about Spark"), + RowFactory.create(0.0, "I wish Java could use case classes"), + RowFactory.create(1.0, "Logistic regression models are neat") )); StructType schema = new StructType(new StructField[]{ new StructField("label", DataTypes.DoubleType, false, Metadata.empty()), new StructField("sentence", DataTypes.StringType, false, Metadata.empty()) }); - DataFrame sentenceDataFrame = jsql.createDataFrame(jrdd, schema); - Tokenizer tokenizer = new Tokenizer().setInputCol("sentence").setOutputCol("words"); - DataFrame wordsDataFrame = tokenizer.transform(sentenceDataFrame); + DataFrame sentenceData = jsql.createDataFrame(jrdd, schema); + Tokenizer tokenizer = new Tokenizer() + .setInputCol("sentence") + .setOutputCol("words"); + DataFrame wordsData = tokenizer.transform(sentenceData); int numFeatures = 20; HashingTF hashingTF = new HashingTF() .setInputCol("words") - .setOutputCol("features") + .setOutputCol("rawFeatures") .setNumFeatures(numFeatures); - DataFrame featurized = hashingTF.transform(wordsDataFrame); - for (Row r : featurized.select("features", "words", "label").take(3)) { + DataFrame featurizedData = hashingTF.transform(wordsData); + IDF idf = new IDF().setInputCol("rawFeatures").setOutputCol("features"); + IDFModel idfModel = idf.fit(featurizedData); + DataFrame rescaledData = idfModel.transform(featurizedData); + for (Row r : rescaledData.select("features", "label").take(3)) { Vector features = r.getAs(0); Assert.assertEquals(features.size(), numFeatures); } diff --git a/mllib/src/test/java/org/apache/spark/ml/feature/JavaNormalizerSuite.java b/mllib/src/test/java/org/apache/spark/ml/feature/JavaNormalizerSuite.java new file mode 100644 index 000000000000..d82f3b7e8c07 --- /dev/null +++ b/mllib/src/test/java/org/apache/spark/ml/feature/JavaNormalizerSuite.java @@ -0,0 +1,71 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.ml.feature; + +import java.util.List; + +import com.google.common.collect.Lists; +import org.junit.After; +import org.junit.Before; +import org.junit.Test; + +import org.apache.spark.api.java.JavaSparkContext; +import org.apache.spark.mllib.linalg.Vectors; +import org.apache.spark.sql.DataFrame; +import org.apache.spark.sql.SQLContext; + +public class JavaNormalizerSuite { + private transient JavaSparkContext jsc; + private transient SQLContext jsql; + + @Before + public void setUp() { + jsc = new JavaSparkContext("local", "JavaNormalizerSuite"); + jsql = new SQLContext(jsc); + } + + @After + public void tearDown() { + jsc.stop(); + jsc = null; + } + + @Test + public void normalizer() { + // The tests are to check Java compatibility. + List points = Lists.newArrayList( + new VectorIndexerSuite.FeatureData(Vectors.dense(0.0, -2.0)), + new VectorIndexerSuite.FeatureData(Vectors.dense(1.0, 3.0)), + new VectorIndexerSuite.FeatureData(Vectors.dense(1.0, 4.0)) + ); + DataFrame dataFrame = jsql.createDataFrame(jsc.parallelize(points, 2), + VectorIndexerSuite.FeatureData.class); + Normalizer normalizer = new Normalizer() + .setInputCol("features") + .setOutputCol("normFeatures"); + + // Normalize each Vector using $L^2$ norm. + DataFrame l2NormData = normalizer.transform(dataFrame, normalizer.p().w(2)); + l2NormData.count(); + + // Normalize each Vector using $L^\infty$ norm. + DataFrame lInfNormData = + normalizer.transform(dataFrame, normalizer.p().w(Double.POSITIVE_INFINITY)); + lInfNormData.count(); + } +} diff --git a/mllib/src/test/java/org/apache/spark/ml/feature/JavaPolynomialExpansionSuite.java b/mllib/src/test/java/org/apache/spark/ml/feature/JavaPolynomialExpansionSuite.java new file mode 100644 index 000000000000..5e8211c2c511 --- /dev/null +++ b/mllib/src/test/java/org/apache/spark/ml/feature/JavaPolynomialExpansionSuite.java @@ -0,0 +1,91 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.ml.feature; + +import com.google.common.collect.Lists; +import org.junit.After; +import org.junit.Assert; +import org.junit.Before; +import org.junit.Test; + +import org.apache.spark.api.java.JavaRDD; +import org.apache.spark.api.java.JavaSparkContext; +import org.apache.spark.mllib.linalg.Vector; +import org.apache.spark.mllib.linalg.VectorUDT; +import org.apache.spark.mllib.linalg.Vectors; +import org.apache.spark.sql.DataFrame; +import org.apache.spark.sql.Row; +import org.apache.spark.sql.RowFactory; +import org.apache.spark.sql.SQLContext; +import org.apache.spark.sql.types.Metadata; +import org.apache.spark.sql.types.StructField; +import org.apache.spark.sql.types.StructType; + +public class JavaPolynomialExpansionSuite { + private transient JavaSparkContext jsc; + private transient SQLContext jsql; + + @Before + public void setUp() { + jsc = new JavaSparkContext("local", "JavaPolynomialExpansionSuite"); + jsql = new SQLContext(jsc); + } + + @After + public void tearDown() { + jsc.stop(); + jsc = null; + } + + @Test + public void polynomialExpansionTest() { + PolynomialExpansion polyExpansion = new PolynomialExpansion() + .setInputCol("features") + .setOutputCol("polyFeatures") + .setDegree(3); + + JavaRDD data = jsc.parallelize(Lists.newArrayList( + RowFactory.create( + Vectors.dense(-2.0, 2.3), + Vectors.dense(-2.0, 4.0, -8.0, 2.3, -4.6, 9.2, 5.29, -10.58, 12.17) + ), + RowFactory.create(Vectors.dense(0.0, 0.0), Vectors.dense(new double[9])), + RowFactory.create( + Vectors.dense(0.6, -1.1), + Vectors.dense(0.6, 0.36, 0.216, -1.1, -0.66, -0.396, 1.21, 0.726, -1.331) + ) + )); + + StructType schema = new StructType(new StructField[] { + new StructField("features", new VectorUDT(), false, Metadata.empty()), + new StructField("expected", new VectorUDT(), false, Metadata.empty()) + }); + + DataFrame dataset = jsql.createDataFrame(data, schema); + + Row[] pairs = polyExpansion.transform(dataset) + .select("polyFeatures", "expected") + .collect(); + + for (Row r : pairs) { + double[] polyFeatures = ((Vector)r.get(0)).toArray(); + double[] expected = ((Vector)r.get(1)).toArray(); + Assert.assertArrayEquals(polyFeatures, expected, 1e-1); + } + } +} diff --git a/mllib/src/test/java/org/apache/spark/ml/feature/JavaStandardScalerSuite.java b/mllib/src/test/java/org/apache/spark/ml/feature/JavaStandardScalerSuite.java new file mode 100644 index 000000000000..74eb2733f06e --- /dev/null +++ b/mllib/src/test/java/org/apache/spark/ml/feature/JavaStandardScalerSuite.java @@ -0,0 +1,71 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.ml.feature; + +import java.util.List; + +import com.google.common.collect.Lists; +import org.junit.After; +import org.junit.Before; +import org.junit.Test; + +import org.apache.spark.api.java.JavaSparkContext; +import org.apache.spark.mllib.linalg.Vectors; +import org.apache.spark.sql.DataFrame; +import org.apache.spark.sql.SQLContext; + +public class JavaStandardScalerSuite { + private transient JavaSparkContext jsc; + private transient SQLContext jsql; + + @Before + public void setUp() { + jsc = new JavaSparkContext("local", "JavaStandardScalerSuite"); + jsql = new SQLContext(jsc); + } + + @After + public void tearDown() { + jsc.stop(); + jsc = null; + } + + @Test + public void standardScaler() { + // The tests are to check Java compatibility. + List points = Lists.newArrayList( + new VectorIndexerSuite.FeatureData(Vectors.dense(0.0, -2.0)), + new VectorIndexerSuite.FeatureData(Vectors.dense(1.0, 3.0)), + new VectorIndexerSuite.FeatureData(Vectors.dense(1.0, 4.0)) + ); + DataFrame dataFrame = jsql.createDataFrame(jsc.parallelize(points, 2), + VectorIndexerSuite.FeatureData.class); + StandardScaler scaler = new StandardScaler() + .setInputCol("features") + .setOutputCol("scaledFeatures") + .setWithStd(true) + .setWithMean(false); + + // Compute summary statistics by fitting the StandardScaler + StandardScalerModel scalerModel = scaler.fit(dataFrame); + + // Normalize each feature to have unit standard deviation. + DataFrame scaledData = scalerModel.transform(dataFrame); + scaledData.count(); + } +} diff --git a/mllib/src/test/java/org/apache/spark/ml/feature/JavaStringIndexerSuite.java b/mllib/src/test/java/org/apache/spark/ml/feature/JavaStringIndexerSuite.java new file mode 100644 index 000000000000..35b18c5308f6 --- /dev/null +++ b/mllib/src/test/java/org/apache/spark/ml/feature/JavaStringIndexerSuite.java @@ -0,0 +1,77 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.ml.feature; + +import java.util.Arrays; + +import org.junit.After; +import org.junit.Assert; +import org.junit.Before; +import org.junit.Test; + +import org.apache.spark.api.java.JavaRDD; +import org.apache.spark.api.java.JavaSparkContext; +import org.apache.spark.sql.DataFrame; +import org.apache.spark.sql.Row; +import org.apache.spark.sql.RowFactory; +import org.apache.spark.sql.SQLContext; +import org.apache.spark.sql.types.StructField; +import org.apache.spark.sql.types.StructType; +import static org.apache.spark.sql.types.DataTypes.*; + +public class JavaStringIndexerSuite { + private transient JavaSparkContext jsc; + private transient SQLContext sqlContext; + + @Before + public void setUp() { + jsc = new JavaSparkContext("local", "JavaStringIndexerSuite"); + sqlContext = new SQLContext(jsc); + } + + @After + public void tearDown() { + jsc.stop(); + sqlContext = null; + } + + @Test + public void testStringIndexer() { + StructType schema = createStructType(new StructField[] { + createStructField("id", IntegerType, false), + createStructField("label", StringType, false) + }); + JavaRDD rdd = jsc.parallelize( + Arrays.asList(c(0, "a"), c(1, "b"), c(2, "c"), c(3, "a"), c(4, "a"), c(5, "c"))); + DataFrame dataset = sqlContext.createDataFrame(rdd, schema); + + StringIndexer indexer = new StringIndexer() + .setInputCol("label") + .setOutputCol("labelIndex"); + DataFrame output = indexer.fit(dataset).transform(dataset); + + Assert.assertArrayEquals( + new Row[] { c(0, 0.0), c(1, 2.0), c(2, 1.0), c(3, 0.0), c(4, 0.0), c(5, 1.0) }, + output.orderBy("id").select("id", "labelIndex").collect()); + } + + /** An alias for RowFactory.create. */ + private Row c(Object... values) { + return RowFactory.create(values); + } +} diff --git a/mllib/src/test/java/org/apache/spark/ml/feature/JavaVectorAssemblerSuite.java b/mllib/src/test/java/org/apache/spark/ml/feature/JavaVectorAssemblerSuite.java new file mode 100644 index 000000000000..b7c564caad3b --- /dev/null +++ b/mllib/src/test/java/org/apache/spark/ml/feature/JavaVectorAssemblerSuite.java @@ -0,0 +1,78 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.ml.feature; + +import java.util.Arrays; + +import org.junit.After; +import org.junit.Assert; +import org.junit.Before; +import org.junit.Test; + +import org.apache.spark.api.java.JavaRDD; +import org.apache.spark.api.java.JavaSparkContext; +import org.apache.spark.mllib.linalg.Vector; +import org.apache.spark.mllib.linalg.VectorUDT; +import org.apache.spark.mllib.linalg.Vectors; +import org.apache.spark.sql.DataFrame; +import org.apache.spark.sql.Row; +import org.apache.spark.sql.RowFactory; +import org.apache.spark.sql.SQLContext; +import org.apache.spark.sql.types.*; +import static org.apache.spark.sql.types.DataTypes.*; + +public class JavaVectorAssemblerSuite { + private transient JavaSparkContext jsc; + private transient SQLContext sqlContext; + + @Before + public void setUp() { + jsc = new JavaSparkContext("local", "JavaVectorAssemblerSuite"); + sqlContext = new SQLContext(jsc); + } + + @After + public void tearDown() { + jsc.stop(); + jsc = null; + } + + @Test + public void testVectorAssembler() { + StructType schema = createStructType(new StructField[] { + createStructField("id", IntegerType, false), + createStructField("x", DoubleType, false), + createStructField("y", new VectorUDT(), false), + createStructField("name", StringType, false), + createStructField("z", new VectorUDT(), false), + createStructField("n", LongType, false) + }); + Row row = RowFactory.create( + 0, 0.0, Vectors.dense(1.0, 2.0), "a", + Vectors.sparse(2, new int[] {1}, new double[] {3.0}), 10L); + JavaRDD rdd = jsc.parallelize(Arrays.asList(row)); + DataFrame dataset = sqlContext.createDataFrame(rdd, schema); + VectorAssembler assembler = new VectorAssembler() + .setInputCols(new String[] {"x", "y", "z", "n"}) + .setOutputCol("features"); + DataFrame output = assembler.transform(dataset); + Assert.assertEquals( + Vectors.sparse(6, new int[] {1, 2, 4, 5}, new double[] {1.0, 2.0, 3.0, 10.0}), + output.select("features").first().getAs(0)); + } +} diff --git a/mllib/src/test/java/org/apache/spark/ml/feature/JavaVectorIndexerSuite.java b/mllib/src/test/java/org/apache/spark/ml/feature/JavaVectorIndexerSuite.java index 161100134c92..c7ae5468b942 100644 --- a/mllib/src/test/java/org/apache/spark/ml/feature/JavaVectorIndexerSuite.java +++ b/mllib/src/test/java/org/apache/spark/ml/feature/JavaVectorIndexerSuite.java @@ -19,6 +19,7 @@ import java.io.Serializable; import java.util.List; +import java.util.Map; import org.junit.After; import org.junit.Assert; @@ -64,7 +65,8 @@ public void vectorIndexerAPI() { .setMaxCategories(2); VectorIndexerModel model = indexer.fit(data); Assert.assertEquals(model.numFeatures(), 2); - Assert.assertEquals(model.categoryMaps().size(), 1); + Map> categoryMaps = model.javaCategoryMaps(); + Assert.assertEquals(categoryMaps.size(), 1); DataFrame indexedData = model.transform(data); } } diff --git a/mllib/src/test/java/org/apache/spark/ml/feature/JavaWord2VecSuite.java b/mllib/src/test/java/org/apache/spark/ml/feature/JavaWord2VecSuite.java new file mode 100644 index 000000000000..39c70157f83c --- /dev/null +++ b/mllib/src/test/java/org/apache/spark/ml/feature/JavaWord2VecSuite.java @@ -0,0 +1,76 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.ml.feature; + +import com.google.common.collect.Lists; +import org.junit.After; +import org.junit.Assert; +import org.junit.Before; +import org.junit.Test; + +import org.apache.spark.api.java.JavaRDD; +import org.apache.spark.api.java.JavaSparkContext; +import org.apache.spark.mllib.linalg.Vector; +import org.apache.spark.sql.DataFrame; +import org.apache.spark.sql.Row; +import org.apache.spark.sql.RowFactory; +import org.apache.spark.sql.SQLContext; +import org.apache.spark.sql.types.*; + +public class JavaWord2VecSuite { + private transient JavaSparkContext jsc; + private transient SQLContext sqlContext; + + @Before + public void setUp() { + jsc = new JavaSparkContext("local", "JavaWord2VecSuite"); + sqlContext = new SQLContext(jsc); + } + + @After + public void tearDown() { + jsc.stop(); + jsc = null; + } + + @Test + public void testJavaWord2Vec() { + JavaRDD jrdd = jsc.parallelize(Lists.newArrayList( + RowFactory.create(Lists.newArrayList("Hi I heard about Spark".split(" "))), + RowFactory.create(Lists.newArrayList("I wish Java could use case classes".split(" "))), + RowFactory.create(Lists.newArrayList("Logistic regression models are neat".split(" "))) + )); + StructType schema = new StructType(new StructField[]{ + new StructField("text", new ArrayType(DataTypes.StringType, true), false, Metadata.empty()) + }); + DataFrame documentDF = sqlContext.createDataFrame(jrdd, schema); + + Word2Vec word2Vec = new Word2Vec() + .setInputCol("text") + .setOutputCol("result") + .setVectorSize(3) + .setMinCount(0); + Word2VecModel model = word2Vec.fit(documentDF); + DataFrame result = model.transform(documentDF); + + for (Row r: result.select("result").collect()) { + double[] polyFeatures = ((Vector)r.get(0)).toArray(); + Assert.assertEquals(polyFeatures.length, 3); + } + } +} diff --git a/mllib/src/test/java/org/apache/spark/ml/param/JavaParamsSuite.java b/mllib/src/test/java/org/apache/spark/ml/param/JavaParamsSuite.java index e7df10dfa63a..9890155e9f86 100644 --- a/mllib/src/test/java/org/apache/spark/ml/param/JavaParamsSuite.java +++ b/mllib/src/test/java/org/apache/spark/ml/param/JavaParamsSuite.java @@ -50,6 +50,7 @@ public void testParams() { testParams.setMyIntParam(2).setMyDoubleParam(0.4).setMyStringParam("a"); Assert.assertEquals(testParams.getMyDoubleParam(), 0.4, 0.0); Assert.assertEquals(testParams.getMyStringParam(), "a"); + Assert.assertArrayEquals(testParams.getMyDoubleArrayParam(), new double[] {1.0, 2.0}, 0.0); } @Test diff --git a/mllib/src/test/java/org/apache/spark/ml/param/JavaTestParams.java b/mllib/src/test/java/org/apache/spark/ml/param/JavaTestParams.java index 3a41890b92d6..3ae09d39ef50 100644 --- a/mllib/src/test/java/org/apache/spark/ml/param/JavaTestParams.java +++ b/mllib/src/test/java/org/apache/spark/ml/param/JavaTestParams.java @@ -51,7 +51,8 @@ public String uid() { public int getMyIntParam() { return (Integer)getOrDefault(myIntParam_); } public JavaTestParams setMyIntParam(int value) { - set(myIntParam_, value); return this; + set(myIntParam_, value); + return this; } private DoubleParam myDoubleParam_; @@ -60,7 +61,8 @@ public JavaTestParams setMyIntParam(int value) { public double getMyDoubleParam() { return (Double)getOrDefault(myDoubleParam_); } public JavaTestParams setMyDoubleParam(double value) { - set(myDoubleParam_, value); return this; + set(myDoubleParam_, value); + return this; } private Param myStringParam_; @@ -69,7 +71,18 @@ public JavaTestParams setMyDoubleParam(double value) { public String getMyStringParam() { return getOrDefault(myStringParam_); } public JavaTestParams setMyStringParam(String value) { - set(myStringParam_, value); return this; + set(myStringParam_, value); + return this; + } + + private DoubleArrayParam myDoubleArrayParam_; + public DoubleArrayParam myDoubleArrayParam() { return myDoubleArrayParam_; } + + public double[] getMyDoubleArrayParam() { return getOrDefault(myDoubleArrayParam_); } + + public JavaTestParams setMyDoubleArrayParam(double[] value) { + set(myDoubleArrayParam_, value); + return this; } private void init() { @@ -79,7 +92,19 @@ private void init() { List validStrings = Lists.newArrayList("a", "b"); myStringParam_ = new Param(this, "myStringParam", "this is a string param", ParamValidators.inArray(validStrings)); - setDefault(myIntParam_, 1); - setDefault(myDoubleParam_, 0.5); + myDoubleArrayParam_ = + new DoubleArrayParam(this, "myDoubleArrayParam", "this is a double param"); + + setDefault(myIntParam(), 1); + setDefault(myIntParam().w(1)); + setDefault(myDoubleParam(), 0.5); + setDefault(myIntParam().w(1), myDoubleParam().w(0.5)); + setDefault(myDoubleArrayParam(), new double[] {1.0, 2.0}); + setDefault(myDoubleArrayParam().w(new double[] {1.0, 2.0})); + } + + @Override + public JavaTestParams copy(ParamMap extra) { + return defaultCopy(extra); } } diff --git a/mllib/src/test/java/org/apache/spark/ml/util/IdentifiableSuite.scala b/mllib/src/test/java/org/apache/spark/ml/util/IdentifiableSuite.scala index 67c262d0f9d8..928301523fba 100644 --- a/mllib/src/test/java/org/apache/spark/ml/util/IdentifiableSuite.scala +++ b/mllib/src/test/java/org/apache/spark/ml/util/IdentifiableSuite.scala @@ -17,9 +17,9 @@ package org.apache.spark.ml.util -import org.scalatest.FunSuite +import org.apache.spark.SparkFunSuite -class IdentifiableSuite extends FunSuite { +class IdentifiableSuite extends SparkFunSuite { import IdentifiableSuite.Test diff --git a/mllib/src/test/java/org/apache/spark/mllib/classification/JavaNaiveBayesSuite.java b/mllib/src/test/java/org/apache/spark/mllib/classification/JavaNaiveBayesSuite.java index 71fb7f13c39c..3771c0ea7ad8 100644 --- a/mllib/src/test/java/org/apache/spark/mllib/classification/JavaNaiveBayesSuite.java +++ b/mllib/src/test/java/org/apache/spark/mllib/classification/JavaNaiveBayesSuite.java @@ -108,7 +108,7 @@ public Vector call(LabeledPoint v) throws Exception { @Test public void testModelTypeSetters() { NaiveBayes nb = new NaiveBayes() - .setModelType("Bernoulli") - .setModelType("Multinomial"); + .setModelType("bernoulli") + .setModelType("multinomial"); } } diff --git a/mllib/src/test/java/org/apache/spark/ml/classification/JavaStreamingLogisticRegressionSuite.java b/mllib/src/test/java/org/apache/spark/mllib/classification/JavaStreamingLogisticRegressionSuite.java similarity index 95% rename from mllib/src/test/java/org/apache/spark/ml/classification/JavaStreamingLogisticRegressionSuite.java rename to mllib/src/test/java/org/apache/spark/mllib/classification/JavaStreamingLogisticRegressionSuite.java index 640d2ec55e4e..55787f8606d4 100644 --- a/mllib/src/test/java/org/apache/spark/ml/classification/JavaStreamingLogisticRegressionSuite.java +++ b/mllib/src/test/java/org/apache/spark/mllib/classification/JavaStreamingLogisticRegressionSuite.java @@ -15,7 +15,7 @@ * limitations under the License. */ -package org.apache.spark.ml.classification; +package org.apache.spark.mllib.classification; import java.io.Serializable; import java.util.List; @@ -28,7 +28,6 @@ import org.junit.Test; import org.apache.spark.SparkConf; -import org.apache.spark.mllib.classification.StreamingLogisticRegressionWithSGD; import org.apache.spark.mllib.linalg.Vector; import org.apache.spark.mllib.linalg.Vectors; import org.apache.spark.mllib.regression.LabeledPoint; diff --git a/mllib/src/test/java/org/apache/spark/mllib/clustering/JavaGaussianMixtureSuite.java b/mllib/src/test/java/org/apache/spark/mllib/clustering/JavaGaussianMixtureSuite.java new file mode 100644 index 000000000000..467a7a69e8f3 --- /dev/null +++ b/mllib/src/test/java/org/apache/spark/mllib/clustering/JavaGaussianMixtureSuite.java @@ -0,0 +1,64 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.mllib.clustering; + +import java.io.Serializable; +import java.util.List; + +import com.google.common.collect.Lists; +import org.junit.After; +import org.junit.Before; +import org.junit.Test; + +import static org.junit.Assert.assertEquals; + +import org.apache.spark.api.java.JavaRDD; +import org.apache.spark.api.java.JavaSparkContext; +import org.apache.spark.mllib.linalg.Vector; +import org.apache.spark.mllib.linalg.Vectors; + +public class JavaGaussianMixtureSuite implements Serializable { + private transient JavaSparkContext sc; + + @Before + public void setUp() { + sc = new JavaSparkContext("local", "JavaGaussianMixture"); + } + + @After + public void tearDown() { + sc.stop(); + sc = null; + } + + @Test + public void runGaussianMixture() { + List points = Lists.newArrayList( + Vectors.dense(1.0, 2.0, 6.0), + Vectors.dense(1.0, 3.0, 0.0), + Vectors.dense(1.0, 4.0, 6.0) + ); + + JavaRDD data = sc.parallelize(points, 2); + GaussianMixtureModel model = new GaussianMixture().setK(2).setMaxIterations(1).setSeed(1234) + .run(data); + assertEquals(model.gaussians().length, 2); + JavaRDD predictions = model.predict(data); + predictions.first(); + } +} diff --git a/mllib/src/test/java/org/apache/spark/mllib/clustering/JavaLDASuite.java b/mllib/src/test/java/org/apache/spark/mllib/clustering/JavaLDASuite.java index 96c2da169961..581c033f08eb 100644 --- a/mllib/src/test/java/org/apache/spark/mllib/clustering/JavaLDASuite.java +++ b/mllib/src/test/java/org/apache/spark/mllib/clustering/JavaLDASuite.java @@ -107,6 +107,10 @@ public void distributedLDAModel() { // Check: log probabilities assert(model.logLikelihood() < 0.0); assert(model.logPrior() < 0.0); + + // Check: topic distributions + JavaPairRDD topicDistributions = model.javaTopicDistributions(); + assertEquals(topicDistributions.count(), corpus.count()); } @Test diff --git a/mllib/src/test/java/org/apache/spark/mllib/clustering/JavaStreamingKMeansSuite.java b/mllib/src/test/java/org/apache/spark/mllib/clustering/JavaStreamingKMeansSuite.java new file mode 100644 index 000000000000..3b0e879eec77 --- /dev/null +++ b/mllib/src/test/java/org/apache/spark/mllib/clustering/JavaStreamingKMeansSuite.java @@ -0,0 +1,82 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.mllib.clustering; + +import java.io.Serializable; +import java.util.List; + +import scala.Tuple2; + +import com.google.common.collect.Lists; +import org.junit.After; +import org.junit.Before; +import org.junit.Test; + +import static org.apache.spark.streaming.JavaTestUtils.*; + +import org.apache.spark.SparkConf; +import org.apache.spark.mllib.linalg.Vector; +import org.apache.spark.mllib.linalg.Vectors; +import org.apache.spark.streaming.Duration; +import org.apache.spark.streaming.api.java.JavaDStream; +import org.apache.spark.streaming.api.java.JavaPairDStream; +import org.apache.spark.streaming.api.java.JavaStreamingContext; + +public class JavaStreamingKMeansSuite implements Serializable { + + protected transient JavaStreamingContext ssc; + + @Before + public void setUp() { + SparkConf conf = new SparkConf() + .setMaster("local[2]") + .setAppName("test") + .set("spark.streaming.clock", "org.apache.spark.util.ManualClock"); + ssc = new JavaStreamingContext(conf, new Duration(1000)); + ssc.checkpoint("checkpoint"); + } + + @After + public void tearDown() { + ssc.stop(); + ssc = null; + } + + @Test + @SuppressWarnings("unchecked") + public void javaAPI() { + List trainingBatch = Lists.newArrayList( + Vectors.dense(1.0), + Vectors.dense(0.0)); + JavaDStream training = + attachTestInputStream(ssc, Lists.newArrayList(trainingBatch, trainingBatch), 2); + List> testBatch = Lists.newArrayList( + new Tuple2(10, Vectors.dense(1.0)), + new Tuple2(11, Vectors.dense(0.0))); + JavaPairDStream test = JavaPairDStream.fromJavaDStream( + attachTestInputStream(ssc, Lists.newArrayList(testBatch, testBatch), 2)); + StreamingKMeans skmeans = new StreamingKMeans() + .setK(1) + .setDecayFactor(1.0) + .setInitialCenters(new Vector[]{Vectors.dense(1.0)}, new double[]{0.0}); + skmeans.trainOn(training); + JavaPairDStream prediction = skmeans.predictOnValues(test); + attachTestOutputStream(prediction.count()); + runStreams(ssc, 2, 2); + } +} diff --git a/mllib/src/test/java/org/apache/spark/mllib/stat/JavaStatisticsSuite.java b/mllib/src/test/java/org/apache/spark/mllib/stat/JavaStatisticsSuite.java new file mode 100644 index 000000000000..62f7f26b7c98 --- /dev/null +++ b/mllib/src/test/java/org/apache/spark/mllib/stat/JavaStatisticsSuite.java @@ -0,0 +1,56 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.mllib.stat; + +import java.io.Serializable; + +import com.google.common.collect.Lists; +import org.junit.After; +import org.junit.Before; +import org.junit.Test; + +import static org.junit.Assert.assertEquals; + +import org.apache.spark.api.java.JavaRDD; +import org.apache.spark.api.java.JavaSparkContext; + +public class JavaStatisticsSuite implements Serializable { + private transient JavaSparkContext sc; + + @Before + public void setUp() { + sc = new JavaSparkContext("local", "JavaStatistics"); + } + + @After + public void tearDown() { + sc.stop(); + sc = null; + } + + @Test + public void testCorr() { + JavaRDD x = sc.parallelize(Lists.newArrayList(1.0, 2.0, 3.0, 4.0)); + JavaRDD y = sc.parallelize(Lists.newArrayList(1.1, 2.2, 3.1, 4.3)); + + Double corr1 = Statistics.corr(x, y); + Double corr2 = Statistics.corr(x, y, "pearson"); + // Check default method + assertEquals(corr1, corr2); + } +} diff --git a/mllib/src/test/scala/org/apache/spark/ml/PipelineSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/PipelineSuite.scala index 2b04a3034782..63d2fa31c749 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/PipelineSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/PipelineSuite.scala @@ -17,15 +17,18 @@ package org.apache.spark.ml +import scala.collection.JavaConverters._ + import org.mockito.Matchers.{any, eq => meq} import org.mockito.Mockito.when -import org.scalatest.FunSuite import org.scalatest.mock.MockitoSugar.mock +import org.apache.spark.SparkFunSuite +import org.apache.spark.ml.feature.HashingTF import org.apache.spark.ml.param.ParamMap import org.apache.spark.sql.DataFrame -class PipelineSuite extends FunSuite { +class PipelineSuite extends SparkFunSuite { abstract class MyModel extends Model[MyModel] @@ -81,4 +84,28 @@ class PipelineSuite extends FunSuite { pipeline.fit(dataset) } } + + test("PipelineModel.copy") { + val hashingTF = new HashingTF() + .setNumFeatures(100) + val model = new PipelineModel("pipeline", Array[Transformer](hashingTF)) + val copied = model.copy(ParamMap(hashingTF.numFeatures -> 10)) + require(copied.stages(0).asInstanceOf[HashingTF].getNumFeatures === 10, + "copy should handle extra stage params") + } + + test("pipeline model constructors") { + val transform0 = mock[Transformer] + val model1 = mock[MyModel] + + val stages = Array(transform0, model1) + val pipelineModel0 = new PipelineModel("pipeline0", stages) + assert(pipelineModel0.uid === "pipeline0") + assert(pipelineModel0.stages === stages) + + val stagesAsList = stages.toList.asJava + val pipelineModel1 = new PipelineModel("pipeline1", stagesAsList) + assert(pipelineModel1.uid === "pipeline1") + assert(pipelineModel1.stages === stages) + } } diff --git a/mllib/src/test/scala/org/apache/spark/ml/attribute/AttributeGroupSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/attribute/AttributeGroupSuite.scala index 17ddd335deb6..512cffb1acb6 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/attribute/AttributeGroupSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/attribute/AttributeGroupSuite.scala @@ -17,9 +17,9 @@ package org.apache.spark.ml.attribute -import org.scalatest.FunSuite +import org.apache.spark.SparkFunSuite -class AttributeGroupSuite extends FunSuite { +class AttributeGroupSuite extends SparkFunSuite { test("attribute group") { val attrs = Array( diff --git a/mllib/src/test/scala/org/apache/spark/ml/attribute/AttributeSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/attribute/AttributeSuite.scala index ec9b717e41ce..c5fd2f9d5a22 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/attribute/AttributeSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/attribute/AttributeSuite.scala @@ -17,11 +17,10 @@ package org.apache.spark.ml.attribute -import org.scalatest.FunSuite - +import org.apache.spark.SparkFunSuite import org.apache.spark.sql.types._ -class AttributeSuite extends FunSuite { +class AttributeSuite extends SparkFunSuite { test("default numeric attribute") { val attr: NumericAttribute = NumericAttribute.defaultAttr @@ -216,5 +215,10 @@ class AttributeSuite extends FunSuite { assert(Attribute.fromStructField(fldWithoutMeta) == UnresolvedAttribute) val fldWithMeta = new StructField("x", DoubleType, false, metadata) assert(Attribute.fromStructField(fldWithMeta).isNumeric) + // Attribute.fromStructField should accept any NumericType, not just DoubleType + val longFldWithMeta = new StructField("x", LongType, false, metadata) + assert(Attribute.fromStructField(longFldWithMeta).isNumeric) + val decimalFldWithMeta = new StructField("x", DecimalType(None), false, metadata) + assert(Attribute.fromStructField(decimalFldWithMeta).isNumeric) } } diff --git a/mllib/src/test/scala/org/apache/spark/ml/classification/DecisionTreeClassifierSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/classification/DecisionTreeClassifierSuite.scala index 3fdc66be8a31..73b4805c4c59 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/classification/DecisionTreeClassifierSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/classification/DecisionTreeClassifierSuite.scala @@ -17,19 +17,18 @@ package org.apache.spark.ml.classification -import org.scalatest.FunSuite - +import org.apache.spark.SparkFunSuite import org.apache.spark.ml.impl.TreeTests +import org.apache.spark.ml.param.ParamsSuite +import org.apache.spark.ml.tree.LeafNode import org.apache.spark.mllib.linalg.Vectors import org.apache.spark.mllib.regression.LabeledPoint -import org.apache.spark.mllib.tree.{DecisionTree => OldDecisionTree, - DecisionTreeSuite => OldDecisionTreeSuite} +import org.apache.spark.mllib.tree.{DecisionTree => OldDecisionTree, DecisionTreeSuite => OldDecisionTreeSuite} import org.apache.spark.mllib.util.MLlibTestSparkContext import org.apache.spark.rdd.RDD import org.apache.spark.sql.DataFrame - -class DecisionTreeClassifierSuite extends FunSuite with MLlibTestSparkContext { +class DecisionTreeClassifierSuite extends SparkFunSuite with MLlibTestSparkContext { import DecisionTreeClassifierSuite.compareAPIs @@ -56,6 +55,12 @@ class DecisionTreeClassifierSuite extends FunSuite with MLlibTestSparkContext { OldDecisionTreeSuite.generateCategoricalDataPointsForMulticlassForOrderedFeatures()) } + test("params") { + ParamsSuite.checkParams(new DecisionTreeClassifier) + val model = new DecisionTreeClassificationModel("dtc", new LeafNode(0.0, 0.0)) + ParamsSuite.checkParams(model) + } + ///////////////////////////////////////////////////////////////////////////// // Tests calling train() ///////////////////////////////////////////////////////////////////////////// @@ -251,7 +256,7 @@ class DecisionTreeClassifierSuite extends FunSuite with MLlibTestSparkContext { */ } -private[ml] object DecisionTreeClassifierSuite extends FunSuite { +private[ml] object DecisionTreeClassifierSuite extends SparkFunSuite { /** * Train 2 decision trees on the given dataset, one using the old API and one using the new API. @@ -266,7 +271,7 @@ private[ml] object DecisionTreeClassifierSuite extends FunSuite { val oldTree = OldDecisionTree.train(data, oldStrategy) val newData: DataFrame = TreeTests.setMetadata(data, categoricalFeatures, numClasses) val newTree = dt.fit(newData) - // Use parent, fittingParamMap from newTree since these are not checked anyways. + // Use parent from newTree since this is not checked anyways. val oldTreeAsNew = DecisionTreeClassificationModel.fromOld( oldTree, newTree.parent.asInstanceOf[DecisionTreeClassifier], categoricalFeatures) TreeTests.checkEqual(oldTreeAsNew, newTree) diff --git a/mllib/src/test/scala/org/apache/spark/ml/classification/GBTClassifierSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/classification/GBTClassifierSuite.scala index ea86867f1161..82c345491bb3 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/classification/GBTClassifierSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/classification/GBTClassifierSuite.scala @@ -17,9 +17,11 @@ package org.apache.spark.ml.classification -import org.scalatest.FunSuite - +import org.apache.spark.SparkFunSuite import org.apache.spark.ml.impl.TreeTests +import org.apache.spark.ml.param.ParamsSuite +import org.apache.spark.ml.regression.DecisionTreeRegressionModel +import org.apache.spark.ml.tree.LeafNode import org.apache.spark.mllib.regression.LabeledPoint import org.apache.spark.mllib.tree.{EnsembleTestHelper, GradientBoostedTrees => OldGBT} import org.apache.spark.mllib.tree.configuration.{Algo => OldAlgo} @@ -31,7 +33,7 @@ import org.apache.spark.sql.DataFrame /** * Test suite for [[GBTClassifier]]. */ -class GBTClassifierSuite extends FunSuite with MLlibTestSparkContext { +class GBTClassifierSuite extends SparkFunSuite with MLlibTestSparkContext { import GBTClassifierSuite.compareAPIs @@ -52,6 +54,14 @@ class GBTClassifierSuite extends FunSuite with MLlibTestSparkContext { sc.parallelize(EnsembleTestHelper.generateOrderedLabeledPoints(numFeatures = 20, 80), 2) } + test("params") { + ParamsSuite.checkParams(new GBTClassifier) + val model = new GBTClassificationModel("gbtc", + Array(new DecisionTreeRegressionModel("dtr", new LeafNode(0.0, 0.0))), + Array(1.0)) + ParamsSuite.checkParams(model) + } + test("Binary classification with continuous features: Log Loss") { val categoricalFeatures = Map.empty[Int, Int] testCombinations.foreach { @@ -128,7 +138,7 @@ private object GBTClassifierSuite { val oldModel = oldGBT.run(data) val newData: DataFrame = TreeTests.setMetadata(data, categoricalFeatures, numClasses = 2) val newModel = gbt.fit(newData) - // Use parent, fittingParamMap from newTree since these are not checked anyways. + // Use parent from newTree since this is not checked anyways. val oldModelAsNew = GBTClassificationModel.fromOld( oldModel, newModel.parent.asInstanceOf[GBTClassifier], categoricalFeatures) TreeTests.checkEqual(oldModelAsNew, newModel) diff --git a/mllib/src/test/scala/org/apache/spark/ml/classification/LogisticRegressionSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/classification/LogisticRegressionSuite.scala index 43765241a20b..ba8fbee84197 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/classification/LogisticRegressionSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/classification/LogisticRegressionSuite.scala @@ -17,40 +17,38 @@ package org.apache.spark.ml.classification -import org.scalatest.FunSuite - +import org.apache.spark.SparkFunSuite +import org.apache.spark.ml.param.ParamsSuite import org.apache.spark.mllib.classification.LogisticRegressionSuite._ -import org.apache.spark.mllib.linalg.Vector +import org.apache.spark.mllib.linalg.{Vectors, Vector} import org.apache.spark.mllib.util.MLlibTestSparkContext import org.apache.spark.mllib.util.TestingUtils._ -import org.apache.spark.sql.{DataFrame, Row, SQLContext} +import org.apache.spark.sql.{DataFrame, Row} -class LogisticRegressionSuite extends FunSuite with MLlibTestSparkContext { +class LogisticRegressionSuite extends SparkFunSuite with MLlibTestSparkContext { - @transient var sqlContext: SQLContext = _ @transient var dataset: DataFrame = _ @transient var binaryDataset: DataFrame = _ private val eps: Double = 1e-5 override def beforeAll(): Unit = { super.beforeAll() - sqlContext = new SQLContext(sc) dataset = sqlContext.createDataFrame(generateLogisticInput(1.0, 1.0, nPoints = 100, seed = 42)) - /** - * Here is the instruction describing how to export the test data into CSV format - * so we can validate the training accuracy compared with R's glmnet package. - * - * import org.apache.spark.mllib.classification.LogisticRegressionSuite - * val nPoints = 10000 - * val weights = Array(-0.57997, 0.912083, -0.371077, -0.819866, 2.688191) - * val xMean = Array(5.843, 3.057, 3.758, 1.199) - * val xVariance = Array(0.6856, 0.1899, 3.116, 0.581) - * val data = sc.parallelize(LogisticRegressionSuite.generateMultinomialLogisticInput( - * weights, xMean, xVariance, true, nPoints, 42), 1) - * data.map(x=> x.label + ", " + x.features(0) + ", " + x.features(1) + ", " - * + x.features(2) + ", " + x.features(3)).saveAsTextFile("path") + /* + Here is the instruction describing how to export the test data into CSV format + so we can validate the training accuracy compared with R's glmnet package. + + import org.apache.spark.mllib.classification.LogisticRegressionSuite + val nPoints = 10000 + val weights = Array(-0.57997, 0.912083, -0.371077, -0.819866, 2.688191) + val xMean = Array(5.843, 3.057, 3.758, 1.199) + val xVariance = Array(0.6856, 0.1899, 3.116, 0.581) + val data = sc.parallelize(LogisticRegressionSuite.generateMultinomialLogisticInput( + weights, xMean, xVariance, true, nPoints, 42), 1) + data.map(x=> x.label + ", " + x.features(0) + ", " + x.features(1) + ", " + + x.features(2) + ", " + x.features(3)).saveAsTextFile("path") */ binaryDataset = { val nPoints = 10000 @@ -65,6 +63,12 @@ class LogisticRegressionSuite extends FunSuite with MLlibTestSparkContext { } } + test("params") { + ParamsSuite.checkParams(new LogisticRegression) + val model = new LogisticRegressionModel("logReg", Vectors.dense(0.0), 0.0) + ParamsSuite.checkParams(model) + } + test("logistic regression: default params") { val lr = new LogisticRegression assert(lr.getLabelCol === "label") @@ -83,6 +87,7 @@ class LogisticRegressionSuite extends FunSuite with MLlibTestSparkContext { assert(model.getRawPredictionCol === "rawPrediction") assert(model.getProbabilityCol === "probability") assert(model.intercept !== 0.0) + assert(model.hasParent) } test("logistic regression doesn't fit intercept when fitIntercept is off") { @@ -206,22 +211,23 @@ class LogisticRegressionSuite extends FunSuite with MLlibTestSparkContext { val trainer = (new LogisticRegression).setFitIntercept(true) val model = trainer.fit(binaryDataset) - /** - * Using the following R code to load the data and train the model using glmnet package. - * - * > library("glmnet") - * > data <- read.csv("path", header=FALSE) - * > label = factor(data$V1) - * > features = as.matrix(data.frame(data$V2, data$V3, data$V4, data$V5)) - * > weights = coef(glmnet(features,label, family="binomial", alpha = 0, lambda = 0)) - * > weights - * 5 x 1 sparse Matrix of class "dgCMatrix" - * s0 - * (Intercept) 2.8366423 - * data.V2 -0.5895848 - * data.V3 0.8931147 - * data.V4 -0.3925051 - * data.V5 -0.7996864 + /* + Using the following R code to load the data and train the model using glmnet package. + + library("glmnet") + data <- read.csv("path", header=FALSE) + label = factor(data$V1) + features = as.matrix(data.frame(data$V2, data$V3, data$V4, data$V5)) + weights = coef(glmnet(features,label, family="binomial", alpha = 0, lambda = 0)) + weights + + 5 x 1 sparse Matrix of class "dgCMatrix" + s0 + (Intercept) 2.8366423 + data.V2 -0.5895848 + data.V3 0.8931147 + data.V4 -0.3925051 + data.V5 -0.7996864 */ val interceptR = 2.8366423 val weightsR = Array(-0.5895848, 0.8931147, -0.3925051, -0.7996864) @@ -237,23 +243,24 @@ class LogisticRegressionSuite extends FunSuite with MLlibTestSparkContext { val trainer = (new LogisticRegression).setFitIntercept(false) val model = trainer.fit(binaryDataset) - /** - * Using the following R code to load the data and train the model using glmnet package. - * - * > library("glmnet") - * > data <- read.csv("path", header=FALSE) - * > label = factor(data$V1) - * > features = as.matrix(data.frame(data$V2, data$V3, data$V4, data$V5)) - * > weights = - * coef(glmnet(features,label, family="binomial", alpha = 0, lambda = 0, intercept=FALSE)) - * > weights - * 5 x 1 sparse Matrix of class "dgCMatrix" - * s0 - * (Intercept) . - * data.V2 -0.3534996 - * data.V3 1.2964482 - * data.V4 -0.3571741 - * data.V5 -0.7407946 + /* + Using the following R code to load the data and train the model using glmnet package. + + library("glmnet") + data <- read.csv("path", header=FALSE) + label = factor(data$V1) + features = as.matrix(data.frame(data$V2, data$V3, data$V4, data$V5)) + weights = + coef(glmnet(features,label, family="binomial", alpha = 0, lambda = 0, intercept=FALSE)) + weights + + 5 x 1 sparse Matrix of class "dgCMatrix" + s0 + (Intercept) . + data.V2 -0.3534996 + data.V3 1.2964482 + data.V4 -0.3571741 + data.V5 -0.7407946 */ val interceptR = 0.0 val weightsR = Array(-0.3534996, 1.2964482, -0.3571741, -0.7407946) @@ -270,22 +277,23 @@ class LogisticRegressionSuite extends FunSuite with MLlibTestSparkContext { .setElasticNetParam(1.0).setRegParam(0.12) val model = trainer.fit(binaryDataset) - /** - * Using the following R code to load the data and train the model using glmnet package. - * - * > library("glmnet") - * > data <- read.csv("path", header=FALSE) - * > label = factor(data$V1) - * > features = as.matrix(data.frame(data$V2, data$V3, data$V4, data$V5)) - * > weights = coef(glmnet(features,label, family="binomial", alpha = 1, lambda = 0.12)) - * > weights - * 5 x 1 sparse Matrix of class "dgCMatrix" - * s0 - * (Intercept) -0.05627428 - * data.V2 . - * data.V3 . - * data.V4 -0.04325749 - * data.V5 -0.02481551 + /* + Using the following R code to load the data and train the model using glmnet package. + + library("glmnet") + data <- read.csv("path", header=FALSE) + label = factor(data$V1) + features = as.matrix(data.frame(data$V2, data$V3, data$V4, data$V5)) + weights = coef(glmnet(features,label, family="binomial", alpha = 1, lambda = 0.12)) + weights + + 5 x 1 sparse Matrix of class "dgCMatrix" + s0 + (Intercept) -0.05627428 + data.V2 . + data.V3 . + data.V4 -0.04325749 + data.V5 -0.02481551 */ val interceptR = -0.05627428 val weightsR = Array(0.0, 0.0, -0.04325749, -0.02481551) @@ -302,23 +310,24 @@ class LogisticRegressionSuite extends FunSuite with MLlibTestSparkContext { .setElasticNetParam(1.0).setRegParam(0.12) val model = trainer.fit(binaryDataset) - /** - * Using the following R code to load the data and train the model using glmnet package. - * - * > library("glmnet") - * > data <- read.csv("path", header=FALSE) - * > label = factor(data$V1) - * > features = as.matrix(data.frame(data$V2, data$V3, data$V4, data$V5)) - * > weights = coef(glmnet(features,label, family="binomial", alpha = 1, lambda = 0.12, - * intercept=FALSE)) - * > weights - * 5 x 1 sparse Matrix of class "dgCMatrix" - * s0 - * (Intercept) . - * data.V2 . - * data.V3 . - * data.V4 -0.05189203 - * data.V5 -0.03891782 + /* + Using the following R code to load the data and train the model using glmnet package. + + library("glmnet") + data <- read.csv("path", header=FALSE) + label = factor(data$V1) + features = as.matrix(data.frame(data$V2, data$V3, data$V4, data$V5)) + weights = coef(glmnet(features,label, family="binomial", alpha = 1, lambda = 0.12, + intercept=FALSE)) + weights + + 5 x 1 sparse Matrix of class "dgCMatrix" + s0 + (Intercept) . + data.V2 . + data.V3 . + data.V4 -0.05189203 + data.V5 -0.03891782 */ val interceptR = 0.0 val weightsR = Array(0.0, 0.0, -0.05189203, -0.03891782) @@ -335,22 +344,23 @@ class LogisticRegressionSuite extends FunSuite with MLlibTestSparkContext { .setElasticNetParam(0.0).setRegParam(1.37) val model = trainer.fit(binaryDataset) - /** - * Using the following R code to load the data and train the model using glmnet package. - * - * > library("glmnet") - * > data <- read.csv("path", header=FALSE) - * > label = factor(data$V1) - * > features = as.matrix(data.frame(data$V2, data$V3, data$V4, data$V5)) - * > weights = coef(glmnet(features,label, family="binomial", alpha = 0, lambda = 1.37)) - * > weights - * 5 x 1 sparse Matrix of class "dgCMatrix" - * s0 - * (Intercept) 0.15021751 - * data.V2 -0.07251837 - * data.V3 0.10724191 - * data.V4 -0.04865309 - * data.V5 -0.10062872 + /* + Using the following R code to load the data and train the model using glmnet package. + + library("glmnet") + data <- read.csv("path", header=FALSE) + label = factor(data$V1) + features = as.matrix(data.frame(data$V2, data$V3, data$V4, data$V5)) + weights = coef(glmnet(features,label, family="binomial", alpha = 0, lambda = 1.37)) + weights + + 5 x 1 sparse Matrix of class "dgCMatrix" + s0 + (Intercept) 0.15021751 + data.V2 -0.07251837 + data.V3 0.10724191 + data.V4 -0.04865309 + data.V5 -0.10062872 */ val interceptR = 0.15021751 val weightsR = Array(-0.07251837, 0.10724191, -0.04865309, -0.10062872) @@ -367,23 +377,24 @@ class LogisticRegressionSuite extends FunSuite with MLlibTestSparkContext { .setElasticNetParam(0.0).setRegParam(1.37) val model = trainer.fit(binaryDataset) - /** - * Using the following R code to load the data and train the model using glmnet package. - * - * > library("glmnet") - * > data <- read.csv("path", header=FALSE) - * > label = factor(data$V1) - * > features = as.matrix(data.frame(data$V2, data$V3, data$V4, data$V5)) - * > weights = coef(glmnet(features,label, family="binomial", alpha = 0, lambda = 1.37, - * intercept=FALSE)) - * > weights - * 5 x 1 sparse Matrix of class "dgCMatrix" - * s0 - * (Intercept) . - * data.V2 -0.06099165 - * data.V3 0.12857058 - * data.V4 -0.04708770 - * data.V5 -0.09799775 + /* + Using the following R code to load the data and train the model using glmnet package. + + library("glmnet") + data <- read.csv("path", header=FALSE) + label = factor(data$V1) + features = as.matrix(data.frame(data$V2, data$V3, data$V4, data$V5)) + weights = coef(glmnet(features,label, family="binomial", alpha = 0, lambda = 1.37, + intercept=FALSE)) + weights + + 5 x 1 sparse Matrix of class "dgCMatrix" + s0 + (Intercept) . + data.V2 -0.06099165 + data.V3 0.12857058 + data.V4 -0.04708770 + data.V5 -0.09799775 */ val interceptR = 0.0 val weightsR = Array(-0.06099165, 0.12857058, -0.04708770, -0.09799775) @@ -400,22 +411,23 @@ class LogisticRegressionSuite extends FunSuite with MLlibTestSparkContext { .setElasticNetParam(0.38).setRegParam(0.21) val model = trainer.fit(binaryDataset) - /** - * Using the following R code to load the data and train the model using glmnet package. - * - * > library("glmnet") - * > data <- read.csv("path", header=FALSE) - * > label = factor(data$V1) - * > features = as.matrix(data.frame(data$V2, data$V3, data$V4, data$V5)) - * > weights = coef(glmnet(features,label, family="binomial", alpha = 0.38, lambda = 0.21)) - * > weights - * 5 x 1 sparse Matrix of class "dgCMatrix" - * s0 - * (Intercept) 0.57734851 - * data.V2 -0.05310287 - * data.V3 . - * data.V4 -0.08849250 - * data.V5 -0.15458796 + /* + Using the following R code to load the data and train the model using glmnet package. + + library("glmnet") + data <- read.csv("path", header=FALSE) + label = factor(data$V1) + features = as.matrix(data.frame(data$V2, data$V3, data$V4, data$V5)) + weights = coef(glmnet(features,label, family="binomial", alpha = 0.38, lambda = 0.21)) + weights + + 5 x 1 sparse Matrix of class "dgCMatrix" + s0 + (Intercept) 0.57734851 + data.V2 -0.05310287 + data.V3 . + data.V4 -0.08849250 + data.V5 -0.15458796 */ val interceptR = 0.57734851 val weightsR = Array(-0.05310287, 0.0, -0.08849250, -0.15458796) @@ -432,23 +444,24 @@ class LogisticRegressionSuite extends FunSuite with MLlibTestSparkContext { .setElasticNetParam(0.38).setRegParam(0.21) val model = trainer.fit(binaryDataset) - /** - * Using the following R code to load the data and train the model using glmnet package. - * - * > library("glmnet") - * > data <- read.csv("path", header=FALSE) - * > label = factor(data$V1) - * > features = as.matrix(data.frame(data$V2, data$V3, data$V4, data$V5)) - * > weights = coef(glmnet(features,label, family="binomial", alpha = 0.38, lambda = 0.21, - * intercept=FALSE)) - * > weights - * 5 x 1 sparse Matrix of class "dgCMatrix" - * s0 - * (Intercept) . - * data.V2 -0.001005743 - * data.V3 0.072577857 - * data.V4 -0.081203769 - * data.V5 -0.142534158 + /* + Using the following R code to load the data and train the model using glmnet package. + + library("glmnet") + data <- read.csv("path", header=FALSE) + label = factor(data$V1) + features = as.matrix(data.frame(data$V2, data$V3, data$V4, data$V5)) + weights = coef(glmnet(features,label, family="binomial", alpha = 0.38, lambda = 0.21, + intercept=FALSE)) + weights + + 5 x 1 sparse Matrix of class "dgCMatrix" + s0 + (Intercept) . + data.V2 -0.001005743 + data.V3 0.072577857 + data.V4 -0.081203769 + data.V5 -0.142534158 */ val interceptR = 0.0 val weightsR = Array(-0.001005743, 0.072577857, -0.081203769, -0.142534158) @@ -475,16 +488,16 @@ class LogisticRegressionSuite extends FunSuite with MLlibTestSparkContext { classSummarizer1.merge(classSummarizer2) }).histogram - /** - * For binary logistic regression with strong L1 regularization, all the weights will be zeros. - * As a result, - * {{{ - * P(0) = 1 / (1 + \exp(b)), and - * P(1) = \exp(b) / (1 + \exp(b)) - * }}}, hence - * {{{ - * b = \log{P(1) / P(0)} = \log{count_1 / count_0} - * }}} + /* + For binary logistic regression with strong L1 regularization, all the weights will be zeros. + As a result, + {{{ + P(0) = 1 / (1 + \exp(b)), and + P(1) = \exp(b) / (1 + \exp(b)) + }}}, hence + {{{ + b = \log{P(1) / P(0)} = \log{count_1 / count_0} + }}} */ val interceptTheory = math.log(histogram(1).toDouble / histogram(0).toDouble) val weightsTheory = Array(0.0, 0.0, 0.0, 0.0) @@ -495,22 +508,23 @@ class LogisticRegressionSuite extends FunSuite with MLlibTestSparkContext { assert(model.weights(2) ~== weightsTheory(2) absTol 1E-6) assert(model.weights(3) ~== weightsTheory(3) absTol 1E-6) - /** - * Using the following R code to load the data and train the model using glmnet package. - * - * > library("glmnet") - * > data <- read.csv("path", header=FALSE) - * > label = factor(data$V1) - * > features = as.matrix(data.frame(data$V2, data$V3, data$V4, data$V5)) - * > weights = coef(glmnet(features,label, family="binomial", alpha = 1.0, lambda = 6.0)) - * > weights - * 5 x 1 sparse Matrix of class "dgCMatrix" - * s0 - * (Intercept) -0.2480643 - * data.V2 0.0000000 - * data.V3 . - * data.V4 . - * data.V5 . + /* + Using the following R code to load the data and train the model using glmnet package. + + library("glmnet") + data <- read.csv("path", header=FALSE) + label = factor(data$V1) + features = as.matrix(data.frame(data$V2, data$V3, data$V4, data$V5)) + weights = coef(glmnet(features,label, family="binomial", alpha = 1.0, lambda = 6.0)) + weights + + 5 x 1 sparse Matrix of class "dgCMatrix" + s0 + (Intercept) -0.2480643 + data.V2 0.0000000 + data.V3 . + data.V4 . + data.V5 . */ val interceptR = -0.248065 val weightsR = Array(0.0, 0.0, 0.0, 0.0) diff --git a/mllib/src/test/scala/org/apache/spark/ml/classification/OneVsRestSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/classification/OneVsRestSuite.scala index 990cfb08af83..75cf5bd4ead4 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/classification/OneVsRestSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/classification/OneVsRestSuite.scala @@ -17,28 +17,29 @@ package org.apache.spark.ml.classification -import org.scalatest.FunSuite - +import org.apache.spark.SparkFunSuite import org.apache.spark.ml.attribute.NominalAttribute +import org.apache.spark.ml.param.{ParamMap, ParamsSuite} import org.apache.spark.ml.util.MetadataUtils -import org.apache.spark.mllib.classification.LogisticRegressionSuite._ import org.apache.spark.mllib.classification.LogisticRegressionWithLBFGS +import org.apache.spark.mllib.classification.LogisticRegressionSuite._ import org.apache.spark.mllib.evaluation.MulticlassMetrics +import org.apache.spark.mllib.linalg.Vectors import org.apache.spark.mllib.regression.LabeledPoint import org.apache.spark.mllib.util.MLlibTestSparkContext import org.apache.spark.mllib.util.TestingUtils._ import org.apache.spark.rdd.RDD -import org.apache.spark.sql.{DataFrame, SQLContext} +import org.apache.spark.sql.DataFrame +import org.apache.spark.sql.types.Metadata -class OneVsRestSuite extends FunSuite with MLlibTestSparkContext { +class OneVsRestSuite extends SparkFunSuite with MLlibTestSparkContext { - @transient var sqlContext: SQLContext = _ @transient var dataset: DataFrame = _ @transient var rdd: RDD[LabeledPoint] = _ override def beforeAll(): Unit = { super.beforeAll() - sqlContext = new SQLContext(sc) + val nPoints = 1000 // The following weights and xMean/xVariance are computed from iris dataset with lambda=0.2. @@ -54,6 +55,13 @@ class OneVsRestSuite extends FunSuite with MLlibTestSparkContext { dataset = sqlContext.createDataFrame(rdd) } + test("params") { + ParamsSuite.checkParams(new OneVsRest) + val lrModel = new LogisticRegressionModel("lr", Vectors.dense(0.0), 0.0) + val model = new OneVsRestModel("ovr", Metadata.empty, Array(lrModel)) + ParamsSuite.checkParams(model) + } + test("one-vs-rest: default params") { val numClasses = 3 val ova = new OneVsRest() @@ -95,6 +103,35 @@ class OneVsRestSuite extends FunSuite with MLlibTestSparkContext { val datasetWithLabelMetadata = dataset.select(labelWithMetadata, features) ova.fit(datasetWithLabelMetadata) } + + test("SPARK-8049: OneVsRest shouldn't output temp columns") { + val logReg = new LogisticRegression() + .setMaxIter(1) + val ovr = new OneVsRest() + .setClassifier(logReg) + val output = ovr.fit(dataset).transform(dataset) + assert(output.schema.fieldNames.toSet === Set("label", "features", "prediction")) + } + + test("OneVsRest.copy and OneVsRestModel.copy") { + val lr = new LogisticRegression() + .setMaxIter(1) + + val ovr = new OneVsRest() + withClue("copy with classifier unset should work") { + ovr.copy(ParamMap(lr.maxIter -> 10)) + } + ovr.setClassifier(lr) + val ovr1 = ovr.copy(ParamMap(lr.maxIter -> 10)) + require(ovr.getClassifier.getOrDefault(lr.maxIter) === 1, "copy should have no side-effects") + require(ovr1.getClassifier.getOrDefault(lr.maxIter) === 10, + "copy should handle extra classifier params") + + val ovrModel = ovr1.fit(dataset).copy(ParamMap(lr.threshold -> 0.1)) + ovrModel.models.foreach { case m: LogisticRegressionModel => + require(m.getThreshold === 0.1, "copy should handle extra model params") + } + } } private class MockLogisticRegression(uid: String) extends LogisticRegression(uid) { diff --git a/mllib/src/test/scala/org/apache/spark/ml/classification/RandomForestClassifierSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/classification/RandomForestClassifierSuite.scala index 08f86fa45bc1..1b6b69c7dc71 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/classification/RandomForestClassifierSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/classification/RandomForestClassifierSuite.scala @@ -17,9 +17,10 @@ package org.apache.spark.ml.classification -import org.scalatest.FunSuite - +import org.apache.spark.SparkFunSuite import org.apache.spark.ml.impl.TreeTests +import org.apache.spark.ml.param.ParamsSuite +import org.apache.spark.ml.tree.LeafNode import org.apache.spark.mllib.linalg.Vectors import org.apache.spark.mllib.regression.LabeledPoint import org.apache.spark.mllib.tree.{EnsembleTestHelper, RandomForest => OldRandomForest} @@ -28,11 +29,10 @@ import org.apache.spark.mllib.util.MLlibTestSparkContext import org.apache.spark.rdd.RDD import org.apache.spark.sql.DataFrame - /** * Test suite for [[RandomForestClassifier]]. */ -class RandomForestClassifierSuite extends FunSuite with MLlibTestSparkContext { +class RandomForestClassifierSuite extends SparkFunSuite with MLlibTestSparkContext { import RandomForestClassifierSuite.compareAPIs @@ -63,6 +63,13 @@ class RandomForestClassifierSuite extends FunSuite with MLlibTestSparkContext { compareAPIs(orderedLabeledPoints50_1000, newRF, categoricalFeatures, numClasses) } + test("params") { + ParamsSuite.checkParams(new RandomForestClassifier) + val model = new RandomForestClassificationModel("rfc", + Array(new DecisionTreeClassificationModel("dtc", new LeafNode(0.0, 0.0)))) + ParamsSuite.checkParams(model) + } + test("Binary classification with continuous features:" + " comparing DecisionTree vs. RandomForest(numTrees = 1)") { val rf = new RandomForestClassifier() @@ -158,9 +165,11 @@ private object RandomForestClassifierSuite { data, oldStrategy, rf.getNumTrees, rf.getFeatureSubsetStrategy, rf.getSeed.toInt) val newData: DataFrame = TreeTests.setMetadata(data, categoricalFeatures, numClasses) val newModel = rf.fit(newData) - // Use parent, fittingParamMap from newTree since these are not checked anyways. + // Use parent from newTree since this is not checked anyways. val oldModelAsNew = RandomForestClassificationModel.fromOld( oldModel, newModel.parent.asInstanceOf[RandomForestClassifier], categoricalFeatures) TreeTests.checkEqual(oldModelAsNew, newModel) + assert(newModel.hasParent) + assert(!newModel.trees.head.asInstanceOf[DecisionTreeClassificationModel].hasParent) } } diff --git a/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MemoryUtils.scala b/mllib/src/test/scala/org/apache/spark/ml/evaluation/BinaryClassificationEvaluatorSuite.scala similarity index 65% rename from core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MemoryUtils.scala rename to mllib/src/test/scala/org/apache/spark/ml/evaluation/BinaryClassificationEvaluatorSuite.scala index 8df4f3b554c4..def869fe6677 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MemoryUtils.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/evaluation/BinaryClassificationEvaluatorSuite.scala @@ -15,17 +15,14 @@ * limitations under the License. */ -package org.apache.spark.scheduler.cluster.mesos +package org.apache.spark.ml.evaluation -import org.apache.spark.SparkContext +import org.apache.spark.SparkFunSuite +import org.apache.spark.ml.param.ParamsSuite -private[spark] object MemoryUtils { - // These defaults copied from YARN - val OVERHEAD_FRACTION = 0.10 - val OVERHEAD_MINIMUM = 384 +class BinaryClassificationEvaluatorSuite extends SparkFunSuite { - def calculateTotalMemory(sc: SparkContext): Int = { - sc.conf.getInt("spark.mesos.executor.memoryOverhead", - math.max(OVERHEAD_FRACTION * sc.executorMemory, OVERHEAD_MINIMUM).toInt) + sc.executorMemory + test("params") { + ParamsSuite.checkParams(new BinaryClassificationEvaluator) } } diff --git a/mllib/src/test/scala/org/apache/spark/ml/evaluation/RegressionEvaluatorSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/evaluation/RegressionEvaluatorSuite.scala new file mode 100644 index 000000000000..5b203784559e --- /dev/null +++ b/mllib/src/test/scala/org/apache/spark/ml/evaluation/RegressionEvaluatorSuite.scala @@ -0,0 +1,76 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.ml.evaluation + +import org.apache.spark.SparkFunSuite +import org.apache.spark.ml.param.ParamsSuite +import org.apache.spark.ml.regression.LinearRegression +import org.apache.spark.mllib.util.{LinearDataGenerator, MLlibTestSparkContext} +import org.apache.spark.mllib.util.TestingUtils._ + +class RegressionEvaluatorSuite extends SparkFunSuite with MLlibTestSparkContext { + + test("params") { + ParamsSuite.checkParams(new RegressionEvaluator) + } + + test("Regression Evaluator: default params") { + /** + * Here is the instruction describing how to export the test data into CSV format + * so we can validate the metrics compared with R's mmetric package. + * + * import org.apache.spark.mllib.util.LinearDataGenerator + * val data = sc.parallelize(LinearDataGenerator.generateLinearInput(6.3, + * Array(4.7, 7.2), Array(0.9, -1.3), Array(0.7, 1.2), 100, 42, 0.1)) + * data.map(x=> x.label + ", " + x.features(0) + ", " + x.features(1)) + * .saveAsTextFile("path") + */ + val dataset = sqlContext.createDataFrame( + sc.parallelize(LinearDataGenerator.generateLinearInput( + 6.3, Array(4.7, 7.2), Array(0.9, -1.3), Array(0.7, 1.2), 100, 42, 0.1), 2)) + + /** + * Using the following R code to load the data, train the model and evaluate metrics. + * + * > library("glmnet") + * > library("rminer") + * > data <- read.csv("path", header=FALSE, stringsAsFactors=FALSE) + * > features <- as.matrix(data.frame(as.numeric(data$V2), as.numeric(data$V3))) + * > label <- as.numeric(data$V1) + * > model <- glmnet(features, label, family="gaussian", alpha = 0, lambda = 0) + * > rmse <- mmetric(label, predict(model, features), metric='RMSE') + * > mae <- mmetric(label, predict(model, features), metric='MAE') + * > r2 <- mmetric(label, predict(model, features), metric='R2') + */ + val trainer = new LinearRegression + val model = trainer.fit(dataset) + val predictions = model.transform(dataset) + + // default = rmse + val evaluator = new RegressionEvaluator() + assert(evaluator.evaluate(predictions) ~== -0.1019382 absTol 0.001) + + // r2 score + evaluator.setMetricName("r2") + assert(evaluator.evaluate(predictions) ~== 0.9998196 absTol 0.001) + + // mae + evaluator.setMetricName("mae") + assert(evaluator.evaluate(predictions) ~== -0.08036075 absTol 0.001) + } +} diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/BinarizerSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/BinarizerSuite.scala index caf1b759593f..208604398366 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/feature/BinarizerSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/feature/BinarizerSuite.scala @@ -17,24 +17,24 @@ package org.apache.spark.ml.feature -import org.scalatest.FunSuite - +import org.apache.spark.SparkFunSuite +import org.apache.spark.ml.param.ParamsSuite import org.apache.spark.mllib.util.MLlibTestSparkContext -import org.apache.spark.mllib.util.TestingUtils._ -import org.apache.spark.sql.{DataFrame, Row, SQLContext} - +import org.apache.spark.sql.{DataFrame, Row} -class BinarizerSuite extends FunSuite with MLlibTestSparkContext { +class BinarizerSuite extends SparkFunSuite with MLlibTestSparkContext { @transient var data: Array[Double] = _ - @transient var sqlContext: SQLContext = _ override def beforeAll(): Unit = { super.beforeAll() - sqlContext = new SQLContext(sc) data = Array(0.1, -0.5, 0.2, -0.3, 0.8, 0.7, -0.1, -0.4) } + test("params") { + ParamsSuite.checkParams(new Binarizer) + } + test("Binarize continuous features with default parameter") { val defaultBinarized: Array[Double] = data.map(x => if (x > 0.0) 1.0 else 0.0) val dataFrame: DataFrame = sqlContext.createDataFrame( @@ -52,7 +52,7 @@ class BinarizerSuite extends FunSuite with MLlibTestSparkContext { test("Binarize continuous features with setter") { val threshold: Double = 0.2 - val thresholdBinarized: Array[Double] = data.map(x => if (x > threshold) 1.0 else 0.0) + val thresholdBinarized: Array[Double] = data.map(x => if (x > threshold) 1.0 else 0.0) val dataFrame: DataFrame = sqlContext.createDataFrame( data.zip(thresholdBinarized)).toDF("feature", "expected") diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/BucketizerSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/BucketizerSuite.scala index 20d2f3ac6696..ec85e0d151e0 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/feature/BucketizerSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/feature/BucketizerSuite.scala @@ -19,21 +19,17 @@ package org.apache.spark.ml.feature import scala.util.Random -import org.scalatest.FunSuite - -import org.apache.spark.SparkException +import org.apache.spark.{SparkException, SparkFunSuite} +import org.apache.spark.ml.param.ParamsSuite import org.apache.spark.mllib.linalg.Vectors import org.apache.spark.mllib.util.MLlibTestSparkContext import org.apache.spark.mllib.util.TestingUtils._ -import org.apache.spark.sql.{DataFrame, Row, SQLContext} - -class BucketizerSuite extends FunSuite with MLlibTestSparkContext { +import org.apache.spark.sql.{DataFrame, Row} - @transient private var sqlContext: SQLContext = _ +class BucketizerSuite extends SparkFunSuite with MLlibTestSparkContext { - override def beforeAll(): Unit = { - super.beforeAll() - sqlContext = new SQLContext(sc) + test("params") { + ParamsSuite.checkParams(new Bucketizer) } test("Bucket continuous features, without -inf,inf") { @@ -117,7 +113,7 @@ class BucketizerSuite extends FunSuite with MLlibTestSparkContext { } } -private object BucketizerSuite extends FunSuite { +private object BucketizerSuite extends SparkFunSuite { /** Brute force search for buckets. Bucket i is defined by the range [split(i), split(i+1)). */ def linearSearchForBuckets(splits: Array[Double], feature: Double): Double = { require(feature >= splits.head) diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/DCTSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/DCTSuite.scala new file mode 100644 index 000000000000..37ed2367c33f --- /dev/null +++ b/mllib/src/test/scala/org/apache/spark/ml/feature/DCTSuite.scala @@ -0,0 +1,73 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.ml.feature + +import scala.beans.BeanInfo + +import edu.emory.mathcs.jtransforms.dct.DoubleDCT_1D + +import org.apache.spark.SparkFunSuite +import org.apache.spark.mllib.linalg.{Vector, Vectors} +import org.apache.spark.mllib.util.MLlibTestSparkContext +import org.apache.spark.sql.{DataFrame, Row} + +@BeanInfo +case class DCTTestData(vec: Vector, wantedVec: Vector) + +class DCTSuite extends SparkFunSuite with MLlibTestSparkContext { + + test("forward transform of discrete cosine matches jTransforms result") { + val data = Vectors.dense((0 until 128).map(_ => 2D * math.random - 1D).toArray) + val inverse = false + + testDCT(data, inverse) + } + + test("inverse transform of discrete cosine matches jTransforms result") { + val data = Vectors.dense((0 until 128).map(_ => 2D * math.random - 1D).toArray) + val inverse = true + + testDCT(data, inverse) + } + + private def testDCT(data: Vector, inverse: Boolean): Unit = { + val expectedResultBuffer = data.toArray.clone() + if (inverse) { + (new DoubleDCT_1D(data.size)).inverse(expectedResultBuffer, true) + } else { + (new DoubleDCT_1D(data.size)).forward(expectedResultBuffer, true) + } + val expectedResult = Vectors.dense(expectedResultBuffer) + + val dataset = sqlContext.createDataFrame(Seq( + DCTTestData(data, expectedResult) + )) + + val transformer = new DCT() + .setInputCol("vec") + .setOutputCol("resultVec") + .setInverse(inverse) + + transformer.transform(dataset) + .select("resultVec", "wantedVec") + .collect() + .foreach { case Row(resultVec: Vector, wantedVec: Vector) => + assert(Vectors.sqdist(resultVec, wantedVec) < 1e-6) + } + } +} diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/HashingTFSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/HashingTFSuite.scala new file mode 100644 index 000000000000..4157b84b29d0 --- /dev/null +++ b/mllib/src/test/scala/org/apache/spark/ml/feature/HashingTFSuite.scala @@ -0,0 +1,53 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.ml.feature + +import org.apache.spark.SparkFunSuite +import org.apache.spark.ml.attribute.AttributeGroup +import org.apache.spark.ml.param.ParamsSuite +import org.apache.spark.mllib.linalg.{Vector, Vectors} +import org.apache.spark.mllib.util.MLlibTestSparkContext +import org.apache.spark.mllib.util.TestingUtils._ +import org.apache.spark.util.Utils + +class HashingTFSuite extends SparkFunSuite with MLlibTestSparkContext { + + test("params") { + ParamsSuite.checkParams(new HashingTF) + } + + test("hashingTF") { + val df = sqlContext.createDataFrame(Seq( + (0, "a a b b c d".split(" ").toSeq) + )).toDF("id", "words") + val n = 100 + val hashingTF = new HashingTF() + .setInputCol("words") + .setOutputCol("features") + .setNumFeatures(n) + val output = hashingTF.transform(df) + val attrGroup = AttributeGroup.fromStructField(output.schema("features")) + require(attrGroup.numAttributes === Some(n)) + val features = output.select("features").first().getAs[Vector](0) + // Assume perfect hash on "a", "b", "c", and "d". + def idx(any: Any): Int = Utils.nonNegativeMod(any.##, n) + val expected = Vectors.sparse(n, + Seq((idx("a"), 2.0), (idx("b"), 2.0), (idx("c"), 1.0), (idx("d"), 1.0))) + assert(features ~== expected absTol 1e-14) + } +} diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/IDFSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/IDFSuite.scala index eaee3443c1f2..08f80af03429 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/feature/IDFSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/feature/IDFSuite.scala @@ -17,21 +17,15 @@ package org.apache.spark.ml.feature -import org.scalatest.FunSuite - +import org.apache.spark.SparkFunSuite +import org.apache.spark.ml.param.ParamsSuite +import org.apache.spark.mllib.feature.{IDFModel => OldIDFModel} import org.apache.spark.mllib.linalg.{DenseVector, SparseVector, Vector, Vectors} import org.apache.spark.mllib.util.MLlibTestSparkContext import org.apache.spark.mllib.util.TestingUtils._ -import org.apache.spark.sql.{Row, SQLContext} - -class IDFSuite extends FunSuite with MLlibTestSparkContext { - - @transient var sqlContext: SQLContext = _ +import org.apache.spark.sql.Row - override def beforeAll(): Unit = { - super.beforeAll() - sqlContext = new SQLContext(sc) - } +class IDFSuite extends SparkFunSuite with MLlibTestSparkContext { def scaleDataWithIDF(dataSet: Array[Vector], model: Vector): Array[Vector] = { dataSet.map { @@ -46,6 +40,12 @@ class IDFSuite extends FunSuite with MLlibTestSparkContext { } } + test("params") { + ParamsSuite.checkParams(new IDF) + val model = new IDFModel("idf", new OldIDFModel(Vectors.dense(1.0))) + ParamsSuite.checkParams(model) + } + test("compute IDF with default parameter") { val numOfFeatures = 4 val data = Array( diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/MinMaxScalerSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/MinMaxScalerSuite.scala new file mode 100644 index 000000000000..c452054bec92 --- /dev/null +++ b/mllib/src/test/scala/org/apache/spark/ml/feature/MinMaxScalerSuite.scala @@ -0,0 +1,68 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.ml.feature + +import org.apache.spark.SparkFunSuite +import org.apache.spark.mllib.linalg.{Vector, Vectors} +import org.apache.spark.mllib.util.MLlibTestSparkContext +import org.apache.spark.sql.{Row, SQLContext} + +class MinMaxScalerSuite extends SparkFunSuite with MLlibTestSparkContext { + + test("MinMaxScaler fit basic case") { + val sqlContext = new SQLContext(sc) + + val data = Array( + Vectors.dense(1, 0, Long.MinValue), + Vectors.dense(2, 0, 0), + Vectors.sparse(3, Array(0, 2), Array(3, Long.MaxValue)), + Vectors.sparse(3, Array(0), Array(1.5))) + + val expected: Array[Vector] = Array( + Vectors.dense(-5, 0, -5), + Vectors.dense(0, 0, 0), + Vectors.sparse(3, Array(0, 2), Array(5, 5)), + Vectors.sparse(3, Array(0), Array(-2.5))) + + val df = sqlContext.createDataFrame(data.zip(expected)).toDF("features", "expected") + val scaler = new MinMaxScaler() + .setInputCol("features") + .setOutputCol("scaled") + .setMin(-5) + .setMax(5) + + val model = scaler.fit(df) + model.transform(df).select("expected", "scaled").collect() + .foreach { case Row(vector1: Vector, vector2: Vector) => + assert(vector1.equals(vector2), "Transformed vector is different with expected.") + } + } + + test("MinMaxScaler arguments max must be larger than min") { + withClue("arguments max must be larger than min") { + intercept[IllegalArgumentException] { + val scaler = new MinMaxScaler().setMin(10).setMax(0) + scaler.validateParams() + } + intercept[IllegalArgumentException] { + val scaler = new MinMaxScaler().setMin(0).setMax(0) + scaler.validateParams() + } + } + } +} diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/NGramSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/NGramSuite.scala new file mode 100644 index 000000000000..ab97e3dbc6ee --- /dev/null +++ b/mllib/src/test/scala/org/apache/spark/ml/feature/NGramSuite.scala @@ -0,0 +1,94 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.ml.feature + +import scala.beans.BeanInfo + +import org.apache.spark.SparkFunSuite +import org.apache.spark.mllib.util.MLlibTestSparkContext +import org.apache.spark.sql.{DataFrame, Row} + +@BeanInfo +case class NGramTestData(inputTokens: Array[String], wantedNGrams: Array[String]) + +class NGramSuite extends SparkFunSuite with MLlibTestSparkContext { + import org.apache.spark.ml.feature.NGramSuite._ + + test("default behavior yields bigram features") { + val nGram = new NGram() + .setInputCol("inputTokens") + .setOutputCol("nGrams") + val dataset = sqlContext.createDataFrame(Seq( + NGramTestData( + Array("Test", "for", "ngram", "."), + Array("Test for", "for ngram", "ngram .") + ))) + testNGram(nGram, dataset) + } + + test("NGramLength=4 yields length 4 n-grams") { + val nGram = new NGram() + .setInputCol("inputTokens") + .setOutputCol("nGrams") + .setN(4) + val dataset = sqlContext.createDataFrame(Seq( + NGramTestData( + Array("a", "b", "c", "d", "e"), + Array("a b c d", "b c d e") + ))) + testNGram(nGram, dataset) + } + + test("empty input yields empty output") { + val nGram = new NGram() + .setInputCol("inputTokens") + .setOutputCol("nGrams") + .setN(4) + val dataset = sqlContext.createDataFrame(Seq( + NGramTestData( + Array(), + Array() + ))) + testNGram(nGram, dataset) + } + + test("input array < n yields empty output") { + val nGram = new NGram() + .setInputCol("inputTokens") + .setOutputCol("nGrams") + .setN(6) + val dataset = sqlContext.createDataFrame(Seq( + NGramTestData( + Array("a", "b", "c", "d", "e"), + Array() + ))) + testNGram(nGram, dataset) + } +} + +object NGramSuite extends SparkFunSuite { + + def testNGram(t: NGram, dataset: DataFrame): Unit = { + t.transform(dataset) + .select("nGrams", "wantedNGrams") + .collect() + .foreach { case Row(actualNGrams, wantedNGrams) => + assert(actualNGrams === wantedNGrams) + } + } +} diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/NormalizerSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/NormalizerSuite.scala index 9d09f24709e2..9f03470b7f32 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/feature/NormalizerSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/feature/NormalizerSuite.scala @@ -17,15 +17,14 @@ package org.apache.spark.ml.feature -import org.scalatest.FunSuite - +import org.apache.spark.SparkFunSuite import org.apache.spark.mllib.linalg.{DenseVector, SparseVector, Vector, Vectors} import org.apache.spark.mllib.util.MLlibTestSparkContext import org.apache.spark.mllib.util.TestingUtils._ import org.apache.spark.sql.{DataFrame, Row, SQLContext} -class NormalizerSuite extends FunSuite with MLlibTestSparkContext { +class NormalizerSuite extends SparkFunSuite with MLlibTestSparkContext { @transient var data: Array[Vector] = _ @transient var dataFrame: DataFrame = _ diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/OneHotEncoderSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/OneHotEncoderSuite.scala index 92ec407b98d6..65846a846b7b 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/feature/OneHotEncoderSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/feature/OneHotEncoderSuite.scala @@ -17,20 +17,15 @@ package org.apache.spark.ml.feature -import org.scalatest.FunSuite - +import org.apache.spark.SparkFunSuite +import org.apache.spark.ml.attribute.{AttributeGroup, BinaryAttribute, NominalAttribute} +import org.apache.spark.ml.param.ParamsSuite import org.apache.spark.mllib.linalg.Vector import org.apache.spark.mllib.util.MLlibTestSparkContext -import org.apache.spark.sql.{DataFrame, SQLContext} - - -class OneHotEncoderSuite extends FunSuite with MLlibTestSparkContext { - private var sqlContext: SQLContext = _ +import org.apache.spark.sql.DataFrame +import org.apache.spark.sql.functions.col - override def beforeAll(): Unit = { - super.beforeAll() - sqlContext = new SQLContext(sc) - } +class OneHotEncoderSuite extends SparkFunSuite with MLlibTestSparkContext { def stringIndexed(): DataFrame = { val data = sc.parallelize(Seq((0, "a"), (1, "b"), (2, "c"), (3, "a"), (4, "a"), (5, "c")), 2) @@ -42,15 +37,20 @@ class OneHotEncoderSuite extends FunSuite with MLlibTestSparkContext { indexer.transform(df) } - test("OneHotEncoder includeFirst = true") { + test("params") { + ParamsSuite.checkParams(new OneHotEncoder) + } + + test("OneHotEncoder dropLast = false") { val transformed = stringIndexed() val encoder = new OneHotEncoder() .setInputCol("labelIndex") .setOutputCol("labelVec") + .setDropLast(false) val encoded = encoder.transform(transformed) val output = encoded.select("id", "labelVec").map { r => - val vec = r.get(1).asInstanceOf[Vector] + val vec = r.getAs[Vector](1) (r.getInt(0), vec(0), vec(1), vec(2)) }.collect().toSet // a -> 0, b -> 2, c -> 1 @@ -59,22 +59,46 @@ class OneHotEncoderSuite extends FunSuite with MLlibTestSparkContext { assert(output === expected) } - test("OneHotEncoder includeFirst = false") { + test("OneHotEncoder dropLast = true") { val transformed = stringIndexed() val encoder = new OneHotEncoder() - .setIncludeFirst(false) .setInputCol("labelIndex") .setOutputCol("labelVec") val encoded = encoder.transform(transformed) val output = encoded.select("id", "labelVec").map { r => - val vec = r.get(1).asInstanceOf[Vector] + val vec = r.getAs[Vector](1) (r.getInt(0), vec(0), vec(1)) }.collect().toSet // a -> 0, b -> 2, c -> 1 - val expected = Set((0, 0.0, 0.0), (1, 0.0, 1.0), (2, 1.0, 0.0), - (3, 0.0, 0.0), (4, 0.0, 0.0), (5, 1.0, 0.0)) + val expected = Set((0, 1.0, 0.0), (1, 0.0, 0.0), (2, 0.0, 1.0), + (3, 1.0, 0.0), (4, 1.0, 0.0), (5, 0.0, 1.0)) assert(output === expected) } + test("input column with ML attribute") { + val attr = NominalAttribute.defaultAttr.withValues("small", "medium", "large") + val df = sqlContext.createDataFrame(Seq(0.0, 1.0, 2.0, 1.0).map(Tuple1.apply)).toDF("size") + .select(col("size").as("size", attr.toMetadata())) + val encoder = new OneHotEncoder() + .setInputCol("size") + .setOutputCol("encoded") + val output = encoder.transform(df) + val group = AttributeGroup.fromStructField(output.schema("encoded")) + assert(group.size === 2) + assert(group.getAttr(0) === BinaryAttribute.defaultAttr.withName("size_is_small").withIndex(0)) + assert(group.getAttr(1) === BinaryAttribute.defaultAttr.withName("size_is_medium").withIndex(1)) + } + + test("input column without ML attribute") { + val df = sqlContext.createDataFrame(Seq(0.0, 1.0, 2.0, 1.0).map(Tuple1.apply)).toDF("index") + val encoder = new OneHotEncoder() + .setInputCol("index") + .setOutputCol("encoded") + val output = encoder.transform(df) + val group = AttributeGroup.fromStructField(output.schema("encoded")) + assert(group.size === 2) + assert(group.getAttr(0) === BinaryAttribute.defaultAttr.withName("index_is_0").withIndex(0)) + assert(group.getAttr(1) === BinaryAttribute.defaultAttr.withName("index_is_1").withIndex(1)) + } } diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/PCASuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/PCASuite.scala new file mode 100644 index 000000000000..d0ae36b28c7a --- /dev/null +++ b/mllib/src/test/scala/org/apache/spark/ml/feature/PCASuite.scala @@ -0,0 +1,64 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.ml.feature + +import org.apache.spark.SparkFunSuite +import org.apache.spark.ml.param.ParamsSuite +import org.apache.spark.mllib.linalg.distributed.RowMatrix +import org.apache.spark.mllib.linalg.{Vector, Vectors, DenseMatrix, Matrices} +import org.apache.spark.mllib.util.MLlibTestSparkContext +import org.apache.spark.mllib.util.TestingUtils._ +import org.apache.spark.mllib.feature.{PCAModel => OldPCAModel} +import org.apache.spark.sql.Row + +class PCASuite extends SparkFunSuite with MLlibTestSparkContext { + + test("params") { + ParamsSuite.checkParams(new PCA) + val mat = Matrices.dense(2, 2, Array(0.0, 1.0, 2.0, 3.0)).asInstanceOf[DenseMatrix] + val model = new PCAModel("pca", new OldPCAModel(2, mat)) + ParamsSuite.checkParams(model) + } + + test("pca") { + val data = Array( + Vectors.sparse(5, Seq((1, 1.0), (3, 7.0))), + Vectors.dense(2.0, 0.0, 3.0, 4.0, 5.0), + Vectors.dense(4.0, 0.0, 0.0, 6.0, 7.0) + ) + + val dataRDD = sc.parallelize(data, 2) + + val mat = new RowMatrix(dataRDD) + val pc = mat.computePrincipalComponents(3) + val expected = mat.multiply(pc).rows + + val df = sqlContext.createDataFrame(dataRDD.zip(expected)).toDF("features", "expected") + + val pca = new PCA() + .setInputCol("features") + .setOutputCol("pca_features") + .setK(3) + .fit(df) + + pca.transform(df).select("pca_features", "expected").collect().foreach { + case Row(x: Vector, y: Vector) => + assert(x ~== y absTol 1e-5, "Transformed vector is different with expected vector.") + } + } +} diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/PolynomialExpansionSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/PolynomialExpansionSuite.scala index c1d64fba0aa8..29eebd8960eb 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/feature/PolynomialExpansionSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/feature/PolynomialExpansionSuite.scala @@ -17,21 +17,19 @@ package org.apache.spark.ml.feature -import org.scalatest.FunSuite +import org.apache.spark.ml.param.ParamsSuite +import org.scalatest.exceptions.TestFailedException +import org.apache.spark.SparkFunSuite import org.apache.spark.mllib.linalg.{DenseVector, SparseVector, Vector, Vectors} import org.apache.spark.mllib.util.MLlibTestSparkContext import org.apache.spark.mllib.util.TestingUtils._ -import org.apache.spark.sql.{Row, SQLContext} -import org.scalatest.exceptions.TestFailedException - -class PolynomialExpansionSuite extends FunSuite with MLlibTestSparkContext { +import org.apache.spark.sql.Row - @transient var sqlContext: SQLContext = _ +class PolynomialExpansionSuite extends SparkFunSuite with MLlibTestSparkContext { - override def beforeAll(): Unit = { - super.beforeAll() - sqlContext = new SQLContext(sc) + test("params") { + ParamsSuite.checkParams(new PolynomialExpansion) } test("Polynomial expansion with default parameter") { diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/StringIndexerSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/StringIndexerSuite.scala index b6939e587041..99f82bea4268 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/feature/StringIndexerSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/feature/StringIndexerSuite.scala @@ -17,18 +17,17 @@ package org.apache.spark.ml.feature -import org.scalatest.FunSuite - +import org.apache.spark.SparkFunSuite import org.apache.spark.ml.attribute.{Attribute, NominalAttribute} +import org.apache.spark.ml.param.ParamsSuite import org.apache.spark.mllib.util.MLlibTestSparkContext -import org.apache.spark.sql.SQLContext -class StringIndexerSuite extends FunSuite with MLlibTestSparkContext { - private var sqlContext: SQLContext = _ +class StringIndexerSuite extends SparkFunSuite with MLlibTestSparkContext { - override def beforeAll(): Unit = { - super.beforeAll() - sqlContext = new SQLContext(sc) + test("params") { + ParamsSuite.checkParams(new StringIndexer) + val model = new StringIndexerModel("indexer", Array("a", "b")) + ParamsSuite.checkParams(model) } test("StringIndexer") { @@ -68,4 +67,12 @@ class StringIndexerSuite extends FunSuite with MLlibTestSparkContext { val expected = Set((0, 0.0), (1, 2.0), (2, 1.0), (3, 0.0), (4, 0.0), (5, 1.0)) assert(output === expected) } + + test("StringIndexerModel should keep silent if the input column does not exist.") { + val indexerModel = new StringIndexerModel("indexer", Array("a", "b", "c")) + .setInputCol("label") + .setOutputCol("labelIndex") + val df = sqlContext.range(0L, 10L) + assert(indexerModel.transform(df).eq(df)) + } } diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/TokenizerSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/TokenizerSuite.scala index d186ead8f542..e5fd21c3f6fc 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/feature/TokenizerSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/feature/TokenizerSuite.scala @@ -19,64 +19,66 @@ package org.apache.spark.ml.feature import scala.beans.BeanInfo -import org.scalatest.FunSuite - +import org.apache.spark.SparkFunSuite +import org.apache.spark.ml.param.ParamsSuite import org.apache.spark.mllib.util.MLlibTestSparkContext -import org.apache.spark.sql.{DataFrame, Row, SQLContext} +import org.apache.spark.sql.{DataFrame, Row} @BeanInfo case class TokenizerTestData(rawText: String, wantedTokens: Array[String]) -class RegexTokenizerSuite extends FunSuite with MLlibTestSparkContext { +class TokenizerSuite extends SparkFunSuite { + + test("params") { + ParamsSuite.checkParams(new Tokenizer) + } +} + +class RegexTokenizerSuite extends SparkFunSuite with MLlibTestSparkContext { import org.apache.spark.ml.feature.RegexTokenizerSuite._ - - @transient var sqlContext: SQLContext = _ - override def beforeAll(): Unit = { - super.beforeAll() - sqlContext = new SQLContext(sc) + test("params") { + ParamsSuite.checkParams(new RegexTokenizer) } test("RegexTokenizer") { - val tokenizer = new RegexTokenizer() + val tokenizer0 = new RegexTokenizer() + .setGaps(false) + .setPattern("\\w+|\\p{Punct}") .setInputCol("rawText") .setOutputCol("tokens") - val dataset0 = sqlContext.createDataFrame(Seq( TokenizerTestData("Test for tokenization.", Array("Test", "for", "tokenization", ".")), TokenizerTestData("Te,st. punct", Array("Te", ",", "st", ".", "punct")) )) - testRegexTokenizer(tokenizer, dataset0) + testRegexTokenizer(tokenizer0, dataset0) val dataset1 = sqlContext.createDataFrame(Seq( TokenizerTestData("Test for tokenization.", Array("Test", "for", "tokenization")), TokenizerTestData("Te,st. punct", Array("punct")) )) + tokenizer0.setMinTokenLength(3) + testRegexTokenizer(tokenizer0, dataset1) - tokenizer.setMinTokenLength(3) - testRegexTokenizer(tokenizer, dataset1) - - tokenizer - .setPattern("\\s") - .setGaps(true) - .setMinTokenLength(0) + val tokenizer2 = new RegexTokenizer() + .setInputCol("rawText") + .setOutputCol("tokens") val dataset2 = sqlContext.createDataFrame(Seq( TokenizerTestData("Test for tokenization.", Array("Test", "for", "tokenization.")), - TokenizerTestData("Te,st. punct", Array("Te,st.", "", "punct")) + TokenizerTestData("Te,st. punct", Array("Te,st.", "punct")) )) - testRegexTokenizer(tokenizer, dataset2) + testRegexTokenizer(tokenizer2, dataset2) } } -object RegexTokenizerSuite extends FunSuite { +object RegexTokenizerSuite extends SparkFunSuite { def testRegexTokenizer(t: RegexTokenizer, dataset: DataFrame): Unit = { t.transform(dataset) .select("tokens", "wantedTokens") .collect() - .foreach { - case Row(tokens, wantedTokens) => - assert(tokens === wantedTokens) - } + .foreach { case Row(tokens, wantedTokens) => + assert(tokens === wantedTokens) + } } } diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/VectorAssemblerSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/VectorAssemblerSuite.scala index 0db27607bc27..bb4d5b983e0d 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/feature/VectorAssemblerSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/feature/VectorAssemblerSuite.scala @@ -17,20 +17,18 @@ package org.apache.spark.ml.feature -import org.scalatest.FunSuite - -import org.apache.spark.SparkException +import org.apache.spark.{SparkException, SparkFunSuite} +import org.apache.spark.ml.attribute.{AttributeGroup, NominalAttribute, NumericAttribute} +import org.apache.spark.ml.param.ParamsSuite import org.apache.spark.mllib.linalg.{DenseVector, SparseVector, Vector, Vectors} import org.apache.spark.mllib.util.MLlibTestSparkContext -import org.apache.spark.sql.{Row, SQLContext} - -class VectorAssemblerSuite extends FunSuite with MLlibTestSparkContext { +import org.apache.spark.sql.Row +import org.apache.spark.sql.functions.col - @transient var sqlContext: SQLContext = _ +class VectorAssemblerSuite extends SparkFunSuite with MLlibTestSparkContext { - override def beforeAll(): Unit = { - super.beforeAll() - sqlContext = new SQLContext(sc) + test("params") { + ParamsSuite.checkParams(new VectorAssembler) } test("assemble") { @@ -68,4 +66,39 @@ class VectorAssemblerSuite extends FunSuite with MLlibTestSparkContext { assert(v === Vectors.sparse(6, Array(1, 2, 4, 5), Array(1.0, 2.0, 3.0, 10.0))) } } + + test("ML attributes") { + val browser = NominalAttribute.defaultAttr.withValues("chrome", "firefox", "safari") + val hour = NumericAttribute.defaultAttr.withMin(0.0).withMax(24.0) + val user = new AttributeGroup("user", Array( + NominalAttribute.defaultAttr.withName("gender").withValues("male", "female"), + NumericAttribute.defaultAttr.withName("salary"))) + val row = (1.0, 0.5, 1, Vectors.dense(1.0, 1000.0), Vectors.sparse(2, Array(1), Array(2.0))) + val df = sqlContext.createDataFrame(Seq(row)).toDF("browser", "hour", "count", "user", "ad") + .select( + col("browser").as("browser", browser.toMetadata()), + col("hour").as("hour", hour.toMetadata()), + col("count"), // "count" is an integer column without ML attribute + col("user").as("user", user.toMetadata()), + col("ad")) // "ad" is a vector column without ML attribute + val assembler = new VectorAssembler() + .setInputCols(Array("browser", "hour", "count", "user", "ad")) + .setOutputCol("features") + val output = assembler.transform(df) + val schema = output.schema + val features = AttributeGroup.fromStructField(schema("features")) + assert(features.size === 7) + val browserOut = features.getAttr(0) + assert(browserOut === browser.withIndex(0).withName("browser")) + val hourOut = features.getAttr(1) + assert(hourOut === hour.withIndex(1).withName("hour")) + val countOut = features.getAttr(2) + assert(countOut === NumericAttribute.defaultAttr.withName("count").withIndex(2)) + val userGenderOut = features.getAttr(3) + assert(userGenderOut === user.getAttr("gender").withName("user_gender").withIndex(3)) + val userSalaryOut = features.getAttr(4) + assert(userSalaryOut === user.getAttr("salary").withName("user_salary").withIndex(4)) + assert(features.getAttr(5) === NumericAttribute.defaultAttr.withIndex(5)) + assert(features.getAttr(6) === NumericAttribute.defaultAttr.withIndex(6)) + } } diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/VectorIndexerSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/VectorIndexerSuite.scala index 38dc83b1241c..8c85c96d5c6d 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/feature/VectorIndexerSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/feature/VectorIndexerSuite.scala @@ -19,22 +19,18 @@ package org.apache.spark.ml.feature import scala.beans.{BeanInfo, BeanProperty} -import org.scalatest.FunSuite - -import org.apache.spark.SparkException +import org.apache.spark.{SparkException, SparkFunSuite} import org.apache.spark.ml.attribute._ +import org.apache.spark.ml.param.ParamsSuite import org.apache.spark.mllib.linalg.{SparseVector, Vector, Vectors} import org.apache.spark.mllib.util.MLlibTestSparkContext import org.apache.spark.rdd.RDD -import org.apache.spark.sql.{DataFrame, SQLContext} - +import org.apache.spark.sql.DataFrame -class VectorIndexerSuite extends FunSuite with MLlibTestSparkContext { +class VectorIndexerSuite extends SparkFunSuite with MLlibTestSparkContext { import VectorIndexerSuite.FeatureData - @transient var sqlContext: SQLContext = _ - // identical, of length 3 @transient var densePoints1: DataFrame = _ @transient var sparsePoints1: DataFrame = _ @@ -86,7 +82,6 @@ class VectorIndexerSuite extends FunSuite with MLlibTestSparkContext { checkPair(densePoints1Seq, sparsePoints1Seq) checkPair(densePoints2Seq, sparsePoints2Seq) - sqlContext = new SQLContext(sc) densePoints1 = sqlContext.createDataFrame(sc.parallelize(densePoints1Seq, 2).map(FeatureData)) sparsePoints1 = sqlContext.createDataFrame(sc.parallelize(sparsePoints1Seq, 2).map(FeatureData)) densePoints2 = sqlContext.createDataFrame(sc.parallelize(densePoints2Seq, 2).map(FeatureData)) @@ -97,6 +92,12 @@ class VectorIndexerSuite extends FunSuite with MLlibTestSparkContext { private def getIndexer: VectorIndexer = new VectorIndexer().setInputCol("features").setOutputCol("indexed") + test("params") { + ParamsSuite.checkParams(new VectorIndexer) + val model = new VectorIndexerModel("indexer", 1, Map.empty) + ParamsSuite.checkParams(model) + } + test("Cannot fit an empty DataFrame") { val rdd = sqlContext.createDataFrame(sc.parallelize(Array.empty[Vector], 2).map(FeatureData)) val vectorIndexer = getIndexer diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/Word2VecSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/Word2VecSuite.scala index 03ba86670d45..aa6ce533fd88 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/feature/Word2VecSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/feature/Word2VecSuite.scala @@ -17,14 +17,21 @@ package org.apache.spark.ml.feature -import org.scalatest.FunSuite - +import org.apache.spark.SparkFunSuite +import org.apache.spark.ml.param.ParamsSuite import org.apache.spark.mllib.linalg.{Vector, Vectors} import org.apache.spark.mllib.util.MLlibTestSparkContext import org.apache.spark.mllib.util.TestingUtils._ import org.apache.spark.sql.{Row, SQLContext} +import org.apache.spark.mllib.feature.{Word2VecModel => OldWord2VecModel} + +class Word2VecSuite extends SparkFunSuite with MLlibTestSparkContext { -class Word2VecSuite extends FunSuite with MLlibTestSparkContext { + test("params") { + ParamsSuite.checkParams(new Word2Vec) + val model = new Word2VecModel("w2v", new OldWord2VecModel(Map("a" -> Array(0.0f)))) + ParamsSuite.checkParams(model) + } test("Word2Vec") { val sqlContext = new SQLContext(sc) @@ -35,9 +42,9 @@ class Word2VecSuite extends FunSuite with MLlibTestSparkContext { val doc = sc.parallelize(Seq(sentence, sentence)).map(line => line.split(" ")) val codes = Map( - "a" -> Array(-0.2811822295188904,-0.6356269121170044,-0.3020961284637451), - "b" -> Array(1.0309048891067505,-1.29472815990448,0.22276712954044342), - "c" -> Array(-0.08456747233867645,0.5137411952018738,0.11731560528278351) + "a" -> Array(-0.2811822295188904, -0.6356269121170044, -0.3020961284637451), + "b" -> Array(1.0309048891067505, -1.29472815990448, 0.22276712954044342), + "c" -> Array(-0.08456747233867645, 0.5137411952018738, 0.11731560528278351) ) val expected = doc.map { sentence => @@ -52,6 +59,7 @@ class Word2VecSuite extends FunSuite with MLlibTestSparkContext { .setVectorSize(3) .setInputCol("text") .setOutputCol("result") + .setSeed(42L) .fit(docDF) model.transform(docDF).select("result", "expected").collect().foreach { diff --git a/mllib/src/test/scala/org/apache/spark/ml/impl/TreeTests.scala b/mllib/src/test/scala/org/apache/spark/ml/impl/TreeTests.scala index 1505ad872536..778abcba22c1 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/impl/TreeTests.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/impl/TreeTests.scala @@ -19,8 +19,7 @@ package org.apache.spark.ml.impl import scala.collection.JavaConverters._ -import org.scalatest.FunSuite - +import org.apache.spark.SparkFunSuite import org.apache.spark.api.java.JavaRDD import org.apache.spark.ml.attribute.{AttributeGroup, NominalAttribute, NumericAttribute} import org.apache.spark.ml.tree._ @@ -29,7 +28,7 @@ import org.apache.spark.rdd.RDD import org.apache.spark.sql.{SQLContext, DataFrame} -private[ml] object TreeTests extends FunSuite { +private[ml] object TreeTests extends SparkFunSuite { /** * Convert the given data to a DataFrame, and set the features and label metadata. diff --git a/mllib/src/test/scala/org/apache/spark/ml/param/ParamsSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/param/ParamsSuite.scala index b96874f3a882..050d4170ea01 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/param/ParamsSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/param/ParamsSuite.scala @@ -17,9 +17,9 @@ package org.apache.spark.ml.param -import org.scalatest.FunSuite +import org.apache.spark.SparkFunSuite -class ParamsSuite extends FunSuite { +class ParamsSuite extends SparkFunSuite { test("param") { val solver = new TestParams() @@ -27,7 +27,7 @@ class ParamsSuite extends FunSuite { import solver.{maxIter, inputCol} assert(maxIter.name === "maxIter") - assert(maxIter.doc === "max number of iterations (>= 0)") + assert(maxIter.doc === "maximum number of iterations (>= 0)") assert(maxIter.parent === uid) assert(maxIter.toString === s"${uid}__maxIter") assert(!maxIter.isValid(-1)) @@ -36,7 +36,7 @@ class ParamsSuite extends FunSuite { solver.setMaxIter(5) assert(solver.explainParam(maxIter) === - "maxIter: max number of iterations (>= 0) (default: 10, current: 5)") + "maxIter: maximum number of iterations (>= 0) (default: 10, current: 5)") assert(inputCol.toString === s"${uid}__inputCol") @@ -120,7 +120,7 @@ class ParamsSuite extends FunSuite { intercept[NoSuchElementException](solver.getInputCol) assert(solver.explainParam(maxIter) === - "maxIter: max number of iterations (>= 0) (default: 10, current: 100)") + "maxIter: maximum number of iterations (>= 0) (default: 10, current: 100)") assert(solver.explainParams() === Seq(inputCol, maxIter).map(solver.explainParam).mkString("\n")) @@ -135,7 +135,7 @@ class ParamsSuite extends FunSuite { intercept[IllegalArgumentException] { solver.validateParams() } - solver.validateParams(ParamMap(inputCol -> "input")) + solver.copy(ParamMap(inputCol -> "input")).validateParams() solver.setInputCol("input") assert(solver.isSet(inputCol)) assert(solver.isDefined(inputCol)) @@ -201,3 +201,31 @@ class ParamsSuite extends FunSuite { assert(inArray(1) && inArray(2) && !inArray(0)) } } + +object ParamsSuite extends SparkFunSuite { + + /** + * Checks common requirements for [[Params.params]]: + * - params are ordered by names + * - param parent has the same UID as the object's UID + * - param name is the same as the param method name + * - obj.copy should return the same type as the obj + */ + def checkParams(obj: Params): Unit = { + val clazz = obj.getClass + + val params = obj.params + val paramNames = params.map(_.name) + require(paramNames === paramNames.sorted, "params must be ordered by names") + params.foreach { p => + assert(p.parent === obj.uid) + assert(obj.getParam(p.name) === p) + // TODO: Check that setters return self, which needs special handling for generic types. + } + + val copyMethod = clazz.getMethod("copy", classOf[ParamMap]) + val copyReturnType = copyMethod.getReturnType + require(copyReturnType === obj.getClass, + s"${clazz.getName}.copy should return ${clazz.getName} instead of ${copyReturnType.getName}.") + } +} diff --git a/mllib/src/test/scala/org/apache/spark/ml/param/TestParams.scala b/mllib/src/test/scala/org/apache/spark/ml/param/TestParams.scala index a9e78366ad98..275924834453 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/param/TestParams.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/param/TestParams.scala @@ -38,7 +38,5 @@ class TestParams(override val uid: String) extends Params with HasMaxIter with H require(isDefined(inputCol)) } - override def copy(extra: ParamMap): TestParams = { - super.copy(extra).asInstanceOf[TestParams] - } + override def copy(extra: ParamMap): TestParams = defaultCopy(extra) } diff --git a/mllib/src/test/scala/org/apache/spark/ml/param/shared/SharedParamsSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/param/shared/SharedParamsSuite.scala new file mode 100644 index 000000000000..b3af81a3c60b --- /dev/null +++ b/mllib/src/test/scala/org/apache/spark/ml/param/shared/SharedParamsSuite.scala @@ -0,0 +1,36 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.ml.param.shared + +import org.apache.spark.SparkFunSuite +import org.apache.spark.ml.param.{ParamMap, Params} + +class SharedParamsSuite extends SparkFunSuite { + + test("outputCol") { + + class Obj(override val uid: String) extends Params with HasOutputCol { + override def copy(extra: ParamMap): Obj = defaultCopy(extra) + } + + val obj = new Obj("obj") + + assert(obj.hasDefault(obj.outputCol)) + assert(obj.getOrDefault(obj.outputCol) === "obj__output") + } +} diff --git a/mllib/src/test/scala/org/apache/spark/ml/recommendation/ALSSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/recommendation/ALSSuite.scala index fc7349330cf8..2e5cfe7027eb 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/recommendation/ALSSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/recommendation/ALSSuite.scala @@ -25,9 +25,8 @@ import scala.collection.mutable.ArrayBuffer import scala.language.existentials import com.github.fommil.netlib.BLAS.{getInstance => blas} -import org.scalatest.FunSuite -import org.apache.spark.{Logging, SparkException} +import org.apache.spark.{Logging, SparkException, SparkFunSuite} import org.apache.spark.ml.recommendation.ALS._ import org.apache.spark.mllib.linalg.Vectors import org.apache.spark.mllib.util.MLlibTestSparkContext @@ -36,16 +35,14 @@ import org.apache.spark.rdd.RDD import org.apache.spark.sql.{Row, SQLContext} import org.apache.spark.util.Utils -class ALSSuite extends FunSuite with MLlibTestSparkContext with Logging { +class ALSSuite extends SparkFunSuite with MLlibTestSparkContext with Logging { - private var sqlContext: SQLContext = _ private var tempDir: File = _ override def beforeAll(): Unit = { super.beforeAll() tempDir = Utils.createTempDir() sc.setCheckpointDir(tempDir.getAbsolutePath) - sqlContext = new SQLContext(sc) } override def afterAll(): Unit = { @@ -345,6 +342,7 @@ class ALSSuite extends FunSuite with MLlibTestSparkContext with Logging { .setImplicitPrefs(implicitPrefs) .setNumUserBlocks(numUserBlocks) .setNumItemBlocks(numItemBlocks) + .setSeed(0) val alpha = als.getAlpha val model = als.fit(training.toDF()) val predictions = model.transform(test.toDF()) @@ -425,17 +423,18 @@ class ALSSuite extends FunSuite with MLlibTestSparkContext with Logging { val (ratings, _) = genImplicitTestData(numUsers = 20, numItems = 40, rank = 2, noiseStd = 0.01) val longRatings = ratings.map(r => Rating(r.user.toLong, r.item.toLong, r.rating)) - val (longUserFactors, _) = ALS.train(longRatings, rank = 2, maxIter = 4) + val (longUserFactors, _) = ALS.train(longRatings, rank = 2, maxIter = 4, seed = 0) assert(longUserFactors.first()._1.getClass === classOf[Long]) val strRatings = ratings.map(r => Rating(r.user.toString, r.item.toString, r.rating)) - val (strUserFactors, _) = ALS.train(strRatings, rank = 2, maxIter = 4) + val (strUserFactors, _) = ALS.train(strRatings, rank = 2, maxIter = 4, seed = 0) assert(strUserFactors.first()._1.getClass === classOf[String]) } test("nonnegative constraint") { val (ratings, _) = genImplicitTestData(numUsers = 20, numItems = 40, rank = 2, noiseStd = 0.01) - val (userFactors, itemFactors) = ALS.train(ratings, rank = 2, maxIter = 4, nonnegative = true) + val (userFactors, itemFactors) = + ALS.train(ratings, rank = 2, maxIter = 4, nonnegative = true, seed = 0) def isNonnegative(factors: RDD[(Int, Array[Float])]): Boolean = { factors.values.map { _.forall(_ >= 0.0) }.reduce(_ && _) } @@ -459,7 +458,7 @@ class ALSSuite extends FunSuite with MLlibTestSparkContext with Logging { test("partitioner in returned factors") { val (ratings, _) = genImplicitTestData(numUsers = 20, numItems = 40, rank = 2, noiseStd = 0.01) val (userFactors, itemFactors) = ALS.train( - ratings, rank = 2, maxIter = 4, numUserBlocks = 3, numItemBlocks = 4) + ratings, rank = 2, maxIter = 4, numUserBlocks = 3, numItemBlocks = 4, seed = 0) for ((tpe, factors) <- Seq(("User", userFactors), ("Item", itemFactors))) { assert(userFactors.partitioner.isDefined, s"$tpe factors should have partitioner.") val part = userFactors.partitioner.get @@ -476,8 +475,8 @@ class ALSSuite extends FunSuite with MLlibTestSparkContext with Logging { test("als with large number of iterations") { val (ratings, _) = genExplicitTestData(numUsers = 4, numItems = 4, rank = 1) - ALS.train(ratings, rank = 1, maxIter = 50, numUserBlocks = 2, numItemBlocks = 2) - ALS.train( - ratings, rank = 1, maxIter = 50, numUserBlocks = 2, numItemBlocks = 2, implicitPrefs = true) + ALS.train(ratings, rank = 1, maxIter = 50, numUserBlocks = 2, numItemBlocks = 2, seed = 0) + ALS.train(ratings, rank = 1, maxIter = 50, numUserBlocks = 2, numItemBlocks = 2, + implicitPrefs = true, seed = 0) } } diff --git a/mllib/src/test/scala/org/apache/spark/ml/regression/DecisionTreeRegressorSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/regression/DecisionTreeRegressorSuite.scala index 1196a772dfdd..33aa9d0d6234 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/regression/DecisionTreeRegressorSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/regression/DecisionTreeRegressorSuite.scala @@ -17,8 +17,7 @@ package org.apache.spark.ml.regression -import org.scalatest.FunSuite - +import org.apache.spark.SparkFunSuite import org.apache.spark.ml.impl.TreeTests import org.apache.spark.mllib.regression.LabeledPoint import org.apache.spark.mllib.tree.{DecisionTree => OldDecisionTree, @@ -28,7 +27,7 @@ import org.apache.spark.rdd.RDD import org.apache.spark.sql.DataFrame -class DecisionTreeRegressorSuite extends FunSuite with MLlibTestSparkContext { +class DecisionTreeRegressorSuite extends SparkFunSuite with MLlibTestSparkContext { import DecisionTreeRegressorSuite.compareAPIs @@ -69,7 +68,7 @@ class DecisionTreeRegressorSuite extends FunSuite with MLlibTestSparkContext { // TODO: test("model save/load") SPARK-6725 } -private[ml] object DecisionTreeRegressorSuite extends FunSuite { +private[ml] object DecisionTreeRegressorSuite extends SparkFunSuite { /** * Train 2 decision trees on the given dataset, one using the old API and one using the new API. @@ -83,7 +82,7 @@ private[ml] object DecisionTreeRegressorSuite extends FunSuite { val oldTree = OldDecisionTree.train(data, oldStrategy) val newData: DataFrame = TreeTests.setMetadata(data, categoricalFeatures, numClasses = 0) val newTree = dt.fit(newData) - // Use parent, fittingParamMap from newTree since these are not checked anyways. + // Use parent from newTree since this is not checked anyways. val oldTreeAsNew = DecisionTreeRegressionModel.fromOld( oldTree, newTree.parent.asInstanceOf[DecisionTreeRegressor], categoricalFeatures) TreeTests.checkEqual(oldTreeAsNew, newTree) diff --git a/mllib/src/test/scala/org/apache/spark/ml/regression/GBTRegressorSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/regression/GBTRegressorSuite.scala index 40e7e3273e96..9682edcd9ba8 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/regression/GBTRegressorSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/regression/GBTRegressorSuite.scala @@ -17,21 +17,21 @@ package org.apache.spark.ml.regression -import org.scalatest.FunSuite - +import org.apache.spark.SparkFunSuite import org.apache.spark.ml.impl.TreeTests +import org.apache.spark.mllib.linalg.Vectors import org.apache.spark.mllib.regression.LabeledPoint import org.apache.spark.mllib.tree.{EnsembleTestHelper, GradientBoostedTrees => OldGBT} import org.apache.spark.mllib.tree.configuration.{Algo => OldAlgo} import org.apache.spark.mllib.util.MLlibTestSparkContext import org.apache.spark.rdd.RDD -import org.apache.spark.sql.DataFrame +import org.apache.spark.sql.{DataFrame, Row} /** * Test suite for [[GBTRegressor]]. */ -class GBTRegressorSuite extends FunSuite with MLlibTestSparkContext { +class GBTRegressorSuite extends SparkFunSuite with MLlibTestSparkContext { import GBTRegressorSuite.compareAPIs @@ -68,6 +68,26 @@ class GBTRegressorSuite extends FunSuite with MLlibTestSparkContext { } } + test("GBTRegressor behaves reasonably on toy data") { + val df = sqlContext.createDataFrame(Seq( + LabeledPoint(10, Vectors.dense(1, 2, 3, 4)), + LabeledPoint(-5, Vectors.dense(6, 3, 2, 1)), + LabeledPoint(11, Vectors.dense(2, 2, 3, 4)), + LabeledPoint(-6, Vectors.dense(6, 4, 2, 1)), + LabeledPoint(9, Vectors.dense(1, 2, 6, 4)), + LabeledPoint(-4, Vectors.dense(6, 3, 2, 2)) + )) + val gbt = new GBTRegressor() + .setMaxDepth(2) + .setMaxIter(2) + val model = gbt.fit(df) + val preds = model.transform(df) + val predictions = preds.select("prediction").map(_.getDouble(0)) + // Checks based on SPARK-8736 (to ensure it is not doing classification) + assert(predictions.max() > 2) + assert(predictions.min() < -1) + } + // TODO: Reinstate test once runWithValidation is implemented SPARK-7132 /* test("runWithValidation stops early and performs better on a validation dataset") { @@ -129,7 +149,7 @@ private object GBTRegressorSuite { val oldModel = oldGBT.run(data) val newData: DataFrame = TreeTests.setMetadata(data, categoricalFeatures, numClasses = 0) val newModel = gbt.fit(newData) - // Use parent, fittingParamMap from newTree since these are not checked anyways. + // Use parent from newTree since this is not checked anyways. val oldModelAsNew = GBTRegressionModel.fromOld( oldModel, newModel.parent.asInstanceOf[GBTRegressor], categoricalFeatures) TreeTests.checkEqual(oldModelAsNew, newModel) diff --git a/mllib/src/test/scala/org/apache/spark/ml/regression/LinearRegressionSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/regression/LinearRegressionSuite.scala index 80323ef5201a..5f39d44f3735 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/regression/LinearRegressionSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/regression/LinearRegressionSuite.scala @@ -17,54 +17,62 @@ package org.apache.spark.ml.regression -import org.scalatest.FunSuite - +import org.apache.spark.SparkFunSuite import org.apache.spark.mllib.linalg.DenseVector import org.apache.spark.mllib.util.{LinearDataGenerator, MLlibTestSparkContext} import org.apache.spark.mllib.util.TestingUtils._ -import org.apache.spark.sql.{Row, SQLContext, DataFrame} +import org.apache.spark.sql.{DataFrame, Row} -class LinearRegressionSuite extends FunSuite with MLlibTestSparkContext { +class LinearRegressionSuite extends SparkFunSuite with MLlibTestSparkContext { - @transient var sqlContext: SQLContext = _ @transient var dataset: DataFrame = _ + @transient var datasetWithoutIntercept: DataFrame = _ + + /* + In `LinearRegressionSuite`, we will make sure that the model trained by SparkML + is the same as the one trained by R's glmnet package. The following instruction + describes how to reproduce the data in R. - /** - * In `LinearRegressionSuite`, we will make sure that the model trained by SparkML - * is the same as the one trained by R's glmnet package. The following instruction - * describes how to reproduce the data in R. - * - * import org.apache.spark.mllib.util.LinearDataGenerator - * val data = - * sc.parallelize(LinearDataGenerator.generateLinearInput(6.3, Array(4.7, 7.2), 10000, 42), 2) - * data.map(x=> x.label + ", " + x.features(0) + ", " + x.features(1)).saveAsTextFile("path") + import org.apache.spark.mllib.util.LinearDataGenerator + val data = + sc.parallelize(LinearDataGenerator.generateLinearInput(6.3, Array(4.7, 7.2), + Array(0.9, -1.3), Array(0.7, 1.2), 10000, 42, 0.1), 2) + data.map(x=> x.label + ", " + x.features(0) + ", " + x.features(1)).coalesce(1) + .saveAsTextFile("path") */ override def beforeAll(): Unit = { super.beforeAll() - sqlContext = new SQLContext(sc) dataset = sqlContext.createDataFrame( sc.parallelize(LinearDataGenerator.generateLinearInput( 6.3, Array(4.7, 7.2), Array(0.9, -1.3), Array(0.7, 1.2), 10000, 42, 0.1), 2)) + /* + datasetWithoutIntercept is not needed for correctness testing but is useful for illustrating + training model without intercept + */ + datasetWithoutIntercept = sqlContext.createDataFrame( + sc.parallelize(LinearDataGenerator.generateLinearInput( + 0.0, Array(4.7, 7.2), Array(0.9, -1.3), Array(0.7, 1.2), 10000, 42, 0.1), 2)) + } test("linear regression with intercept without regularization") { val trainer = new LinearRegression val model = trainer.fit(dataset) - /** - * Using the following R code to load the data and train the model using glmnet package. - * - * library("glmnet") - * data <- read.csv("path", header=FALSE, stringsAsFactors=FALSE) - * features <- as.matrix(data.frame(as.numeric(data$V2), as.numeric(data$V3))) - * label <- as.numeric(data$V1) - * weights <- coef(glmnet(features, label, family="gaussian", alpha = 0, lambda = 0)) - * > weights - * 3 x 1 sparse Matrix of class "dgCMatrix" - * s0 - * (Intercept) 6.300528 - * as.numeric.data.V2. 4.701024 - * as.numeric.data.V3. 7.198257 + /* + Using the following R code to load the data and train the model using glmnet package. + + library("glmnet") + data <- read.csv("path", header=FALSE, stringsAsFactors=FALSE) + features <- as.matrix(data.frame(as.numeric(data$V2), as.numeric(data$V3))) + label <- as.numeric(data$V1) + weights <- coef(glmnet(features, label, family="gaussian", alpha = 0, lambda = 0)) + > weights + 3 x 1 sparse Matrix of class "dgCMatrix" + s0 + (Intercept) 6.300528 + as.numeric.data.V2. 4.701024 + as.numeric.data.V3. 7.198257 */ val interceptR = 6.298698 val weightsR = Array(4.700706, 7.199082) @@ -81,20 +89,56 @@ class LinearRegressionSuite extends FunSuite with MLlibTestSparkContext { } } + test("linear regression without intercept without regularization") { + val trainer = (new LinearRegression).setFitIntercept(false) + val model = trainer.fit(dataset) + val modelWithoutIntercept = trainer.fit(datasetWithoutIntercept) + + /* + weights <- coef(glmnet(features, label, family="gaussian", alpha = 0, lambda = 0, + intercept = FALSE)) + > weights + 3 x 1 sparse Matrix of class "dgCMatrix" + s0 + (Intercept) . + as.numeric.data.V2. 6.995908 + as.numeric.data.V3. 5.275131 + */ + val weightsR = Array(6.995908, 5.275131) + + assert(model.intercept ~== 0 relTol 1E-3) + assert(model.weights(0) ~== weightsR(0) relTol 1E-3) + assert(model.weights(1) ~== weightsR(1) relTol 1E-3) + /* + Then again with the data with no intercept: + > weightsWithoutIntercept + 3 x 1 sparse Matrix of class "dgCMatrix" + s0 + (Intercept) . + as.numeric.data3.V2. 4.70011 + as.numeric.data3.V3. 7.19943 + */ + val weightsWithoutInterceptR = Array(4.70011, 7.19943) + + assert(modelWithoutIntercept.intercept ~== 0 relTol 1E-3) + assert(modelWithoutIntercept.weights(0) ~== weightsWithoutInterceptR(0) relTol 1E-3) + assert(modelWithoutIntercept.weights(1) ~== weightsWithoutInterceptR(1) relTol 1E-3) + } + test("linear regression with intercept with L1 regularization") { val trainer = (new LinearRegression).setElasticNetParam(1.0).setRegParam(0.57) val model = trainer.fit(dataset) - /** - * weights <- coef(glmnet(features, label, family="gaussian", alpha = 1.0, lambda = 0.57)) - * > weights - * 3 x 1 sparse Matrix of class "dgCMatrix" - * s0 - * (Intercept) 6.311546 - * as.numeric.data.V2. 2.123522 - * as.numeric.data.V3. 4.605651 + /* + weights <- coef(glmnet(features, label, family="gaussian", alpha = 1.0, lambda = 0.57)) + > weights + 3 x 1 sparse Matrix of class "dgCMatrix" + s0 + (Intercept) 6.24300 + as.numeric.data.V2. 4.024821 + as.numeric.data.V3. 6.679841 */ - val interceptR = 6.243000 + val interceptR = 6.24300 val weightsR = Array(4.024821, 6.679841) assert(model.intercept ~== interceptR relTol 1E-3) @@ -109,18 +153,48 @@ class LinearRegressionSuite extends FunSuite with MLlibTestSparkContext { } } + test("linear regression without intercept with L1 regularization") { + val trainer = (new LinearRegression).setElasticNetParam(1.0).setRegParam(0.57) + .setFitIntercept(false) + val model = trainer.fit(dataset) + + /* + weights <- coef(glmnet(features, label, family="gaussian", alpha = 1.0, lambda = 0.57, + intercept=FALSE)) + > weights + 3 x 1 sparse Matrix of class "dgCMatrix" + s0 + (Intercept) . + as.numeric.data.V2. 6.299752 + as.numeric.data.V3. 4.772913 + */ + val interceptR = 0.0 + val weightsR = Array(6.299752, 4.772913) + + assert(model.intercept ~== interceptR relTol 1E-3) + assert(model.weights(0) ~== weightsR(0) relTol 1E-3) + assert(model.weights(1) ~== weightsR(1) relTol 1E-3) + + model.transform(dataset).select("features", "prediction").collect().foreach { + case Row(features: DenseVector, prediction1: Double) => + val prediction2 = + features(0) * model.weights(0) + features(1) * model.weights(1) + model.intercept + assert(prediction1 ~== prediction2 relTol 1E-5) + } + } + test("linear regression with intercept with L2 regularization") { val trainer = (new LinearRegression).setElasticNetParam(0.0).setRegParam(2.3) val model = trainer.fit(dataset) - /** - * weights <- coef(glmnet(features, label, family="gaussian", alpha = 0.0, lambda = 2.3)) - * > weights - * 3 x 1 sparse Matrix of class "dgCMatrix" - * s0 - * (Intercept) 6.328062 - * as.numeric.data.V2. 3.222034 - * as.numeric.data.V3. 4.926260 + /* + weights <- coef(glmnet(features, label, family="gaussian", alpha = 0.0, lambda = 2.3)) + > weights + 3 x 1 sparse Matrix of class "dgCMatrix" + s0 + (Intercept) 6.328062 + as.numeric.data.V2. 3.222034 + as.numeric.data.V3. 4.926260 */ val interceptR = 5.269376 val weightsR = Array(3.736216, 5.712356) @@ -137,18 +211,48 @@ class LinearRegressionSuite extends FunSuite with MLlibTestSparkContext { } } + test("linear regression without intercept with L2 regularization") { + val trainer = (new LinearRegression).setElasticNetParam(0.0).setRegParam(2.3) + .setFitIntercept(false) + val model = trainer.fit(dataset) + + /* + weights <- coef(glmnet(features, label, family="gaussian", alpha = 0.0, lambda = 2.3, + intercept = FALSE)) + > weights + 3 x 1 sparse Matrix of class "dgCMatrix" + s0 + (Intercept) . + as.numeric.data.V2. 5.522875 + as.numeric.data.V3. 4.214502 + */ + val interceptR = 0.0 + val weightsR = Array(5.522875, 4.214502) + + assert(model.intercept ~== interceptR relTol 1E-3) + assert(model.weights(0) ~== weightsR(0) relTol 1E-3) + assert(model.weights(1) ~== weightsR(1) relTol 1E-3) + + model.transform(dataset).select("features", "prediction").collect().foreach { + case Row(features: DenseVector, prediction1: Double) => + val prediction2 = + features(0) * model.weights(0) + features(1) * model.weights(1) + model.intercept + assert(prediction1 ~== prediction2 relTol 1E-5) + } + } + test("linear regression with intercept with ElasticNet regularization") { val trainer = (new LinearRegression).setElasticNetParam(0.3).setRegParam(1.6) val model = trainer.fit(dataset) - /** - * weights <- coef(glmnet(features, label, family="gaussian", alpha = 0.3, lambda = 1.6)) - * > weights - * 3 x 1 sparse Matrix of class "dgCMatrix" - * s0 - * (Intercept) 6.324108 - * as.numeric.data.V2. 3.168435 - * as.numeric.data.V3. 5.200403 + /* + weights <- coef(glmnet(features, label, family="gaussian", alpha = 0.3, lambda = 1.6)) + > weights + 3 x 1 sparse Matrix of class "dgCMatrix" + s0 + (Intercept) 6.324108 + as.numeric.data.V2. 3.168435 + as.numeric.data.V3. 5.200403 */ val interceptR = 5.696056 val weightsR = Array(3.670489, 6.001122) @@ -164,4 +268,34 @@ class LinearRegressionSuite extends FunSuite with MLlibTestSparkContext { assert(prediction1 ~== prediction2 relTol 1E-5) } } + + test("linear regression without intercept with ElasticNet regularization") { + val trainer = (new LinearRegression).setElasticNetParam(0.3).setRegParam(1.6) + .setFitIntercept(false) + val model = trainer.fit(dataset) + + /* + weights <- coef(glmnet(features, label, family="gaussian", alpha = 0.3, lambda = 1.6, + intercept=FALSE)) + > weights + 3 x 1 sparse Matrix of class "dgCMatrix" + s0 + (Intercept) . + as.numeric.dataM.V2. 5.673348 + as.numeric.dataM.V3. 4.322251 + */ + val interceptR = 0.0 + val weightsR = Array(5.673348, 4.322251) + + assert(model.intercept ~== interceptR relTol 1E-3) + assert(model.weights(0) ~== weightsR(0) relTol 1E-3) + assert(model.weights(1) ~== weightsR(1) relTol 1E-3) + + model.transform(dataset).select("features", "prediction").collect().foreach { + case Row(features: DenseVector, prediction1: Double) => + val prediction2 = + features(0) * model.weights(0) + features(1) * model.weights(1) + model.intercept + assert(prediction1 ~== prediction2 relTol 1E-5) + } + } } diff --git a/mllib/src/test/scala/org/apache/spark/ml/regression/RandomForestRegressorSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/regression/RandomForestRegressorSuite.scala index 3efffbb763b7..b24ecaa57c89 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/regression/RandomForestRegressorSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/regression/RandomForestRegressorSuite.scala @@ -17,8 +17,7 @@ package org.apache.spark.ml.regression -import org.scalatest.FunSuite - +import org.apache.spark.SparkFunSuite import org.apache.spark.ml.impl.TreeTests import org.apache.spark.mllib.regression.LabeledPoint import org.apache.spark.mllib.tree.{EnsembleTestHelper, RandomForest => OldRandomForest} @@ -31,7 +30,7 @@ import org.apache.spark.sql.DataFrame /** * Test suite for [[RandomForestRegressor]]. */ -class RandomForestRegressorSuite extends FunSuite with MLlibTestSparkContext { +class RandomForestRegressorSuite extends SparkFunSuite with MLlibTestSparkContext { import RandomForestRegressorSuite.compareAPIs @@ -98,7 +97,7 @@ class RandomForestRegressorSuite extends FunSuite with MLlibTestSparkContext { */ } -private object RandomForestRegressorSuite extends FunSuite { +private object RandomForestRegressorSuite extends SparkFunSuite { /** * Train 2 models on the given dataset, one using the old API and one using the new API. @@ -114,7 +113,7 @@ private object RandomForestRegressorSuite extends FunSuite { data, oldStrategy, rf.getNumTrees, rf.getFeatureSubsetStrategy, rf.getSeed.toInt) val newData: DataFrame = TreeTests.setMetadata(data, categoricalFeatures, numClasses = 0) val newModel = rf.fit(newData) - // Use parent, fittingParamMap from newTree since these are not checked anyways. + // Use parent from newTree since this is not checked anyways. val oldModelAsNew = RandomForestRegressionModel.fromOld( oldModel, newModel.parent.asInstanceOf[RandomForestRegressor], categoricalFeatures) TreeTests.checkEqual(oldModelAsNew, newModel) diff --git a/mllib/src/test/scala/org/apache/spark/ml/tuning/CrossValidatorSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/tuning/CrossValidatorSuite.scala index 05313d440fbf..db64511a7605 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/tuning/CrossValidatorSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/tuning/CrossValidatorSuite.scala @@ -17,15 +17,19 @@ package org.apache.spark.ml.tuning -import org.scalatest.FunSuite - +import org.apache.spark.SparkFunSuite +import org.apache.spark.ml.{Estimator, Model} import org.apache.spark.ml.classification.LogisticRegression -import org.apache.spark.ml.evaluation.BinaryClassificationEvaluator +import org.apache.spark.ml.evaluation.{BinaryClassificationEvaluator, Evaluator, RegressionEvaluator} +import org.apache.spark.ml.param.ParamMap +import org.apache.spark.ml.param.shared.HasInputCol +import org.apache.spark.ml.regression.LinearRegression import org.apache.spark.mllib.classification.LogisticRegressionSuite.generateLogisticInput -import org.apache.spark.mllib.util.MLlibTestSparkContext -import org.apache.spark.sql.{SQLContext, DataFrame} +import org.apache.spark.mllib.util.{LinearDataGenerator, MLlibTestSparkContext} +import org.apache.spark.sql.{DataFrame, SQLContext} +import org.apache.spark.sql.types.StructType -class CrossValidatorSuite extends FunSuite with MLlibTestSparkContext { +class CrossValidatorSuite extends SparkFunSuite with MLlibTestSparkContext { @transient var dataset: DataFrame = _ @@ -52,5 +56,90 @@ class CrossValidatorSuite extends FunSuite with MLlibTestSparkContext { val parent = cvModel.bestModel.parent.asInstanceOf[LogisticRegression] assert(parent.getRegParam === 0.001) assert(parent.getMaxIter === 10) + assert(cvModel.avgMetrics.length === lrParamMaps.length) + } + + test("cross validation with linear regression") { + val dataset = sqlContext.createDataFrame( + sc.parallelize(LinearDataGenerator.generateLinearInput( + 6.3, Array(4.7, 7.2), Array(0.9, -1.3), Array(0.7, 1.2), 100, 42, 0.1), 2)) + + val trainer = new LinearRegression + val lrParamMaps = new ParamGridBuilder() + .addGrid(trainer.regParam, Array(1000.0, 0.001)) + .addGrid(trainer.maxIter, Array(0, 10)) + .build() + val eval = new RegressionEvaluator() + val cv = new CrossValidator() + .setEstimator(trainer) + .setEstimatorParamMaps(lrParamMaps) + .setEvaluator(eval) + .setNumFolds(3) + val cvModel = cv.fit(dataset) + val parent = cvModel.bestModel.parent.asInstanceOf[LinearRegression] + assert(parent.getRegParam === 0.001) + assert(parent.getMaxIter === 10) + assert(cvModel.avgMetrics.length === lrParamMaps.length) + + eval.setMetricName("r2") + val cvModel2 = cv.fit(dataset) + val parent2 = cvModel2.bestModel.parent.asInstanceOf[LinearRegression] + assert(parent2.getRegParam === 0.001) + assert(parent2.getMaxIter === 10) + assert(cvModel2.avgMetrics.length === lrParamMaps.length) + } + + test("validateParams should check estimatorParamMaps") { + import CrossValidatorSuite._ + + val est = new MyEstimator("est") + val eval = new MyEvaluator + val paramMaps = new ParamGridBuilder() + .addGrid(est.inputCol, Array("input1", "input2")) + .build() + + val cv = new CrossValidator() + .setEstimator(est) + .setEstimatorParamMaps(paramMaps) + .setEvaluator(eval) + + cv.validateParams() // This should pass. + + val invalidParamMaps = paramMaps :+ ParamMap(est.inputCol -> "") + cv.setEstimatorParamMaps(invalidParamMaps) + intercept[IllegalArgumentException] { + cv.validateParams() + } + } +} + +object CrossValidatorSuite { + + abstract class MyModel extends Model[MyModel] + + class MyEstimator(override val uid: String) extends Estimator[MyModel] with HasInputCol { + + override def validateParams(): Unit = require($(inputCol).nonEmpty) + + override def fit(dataset: DataFrame): MyModel = { + throw new UnsupportedOperationException + } + + override def transformSchema(schema: StructType): StructType = { + throw new UnsupportedOperationException + } + + override def copy(extra: ParamMap): MyEstimator = defaultCopy(extra) + } + + class MyEvaluator extends Evaluator { + + override def evaluate(dataset: DataFrame): Double = { + throw new UnsupportedOperationException + } + + override val uid: String = "eval" + + override def copy(extra: ParamMap): MyEvaluator = defaultCopy(extra) } } diff --git a/mllib/src/test/scala/org/apache/spark/ml/tuning/ParamGridBuilderSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/tuning/ParamGridBuilderSuite.scala index 20aa100112bf..810b70049ec1 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/tuning/ParamGridBuilderSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/tuning/ParamGridBuilderSuite.scala @@ -19,11 +19,10 @@ package org.apache.spark.ml.tuning import scala.collection.mutable -import org.scalatest.FunSuite - +import org.apache.spark.SparkFunSuite import org.apache.spark.ml.param.{ParamMap, TestParams} -class ParamGridBuilderSuite extends FunSuite { +class ParamGridBuilderSuite extends SparkFunSuite { val solver = new TestParams() import solver.{inputCol, maxIter} diff --git a/mllib/src/test/scala/org/apache/spark/mllib/api/python/PythonMLLibAPISuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/api/python/PythonMLLibAPISuite.scala index a629dba8a426..59944416d96a 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/api/python/PythonMLLibAPISuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/api/python/PythonMLLibAPISuite.scala @@ -17,13 +17,12 @@ package org.apache.spark.mllib.api.python -import org.scalatest.FunSuite - +import org.apache.spark.SparkFunSuite import org.apache.spark.mllib.linalg.{DenseMatrix, Matrices, Vectors, SparseMatrix} import org.apache.spark.mllib.regression.LabeledPoint import org.apache.spark.mllib.recommendation.Rating -class PythonMLLibAPISuite extends FunSuite { +class PythonMLLibAPISuite extends SparkFunSuite { SerDe.initialize() @@ -84,7 +83,7 @@ class PythonMLLibAPISuite extends FunSuite { val smt = new SparseMatrix( 3, 3, Array(0, 2, 3, 5), Array(0, 2, 1, 0, 2), Array(0.9, 1.2, 3.4, 5.7, 8.9), - isTransposed=true) + isTransposed = true) val nsmt = SerDe.loads(SerDe.dumps(smt)).asInstanceOf[SparseMatrix] assert(smt.toArray === nsmt.toArray) } diff --git a/mllib/src/test/scala/org/apache/spark/mllib/classification/LogisticRegressionSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/classification/LogisticRegressionSuite.scala index 966811a5a326..2473510e1351 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/classification/LogisticRegressionSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/classification/LogisticRegressionSuite.scala @@ -21,9 +21,9 @@ import scala.collection.JavaConversions._ import scala.util.Random import scala.util.control.Breaks._ -import org.scalatest.FunSuite import org.scalatest.Matchers +import org.apache.spark.SparkFunSuite import org.apache.spark.mllib.linalg.{Vector, Vectors} import org.apache.spark.mllib.regression._ import org.apache.spark.mllib.util.{LocalClusterSparkContext, MLlibTestSparkContext} @@ -119,7 +119,7 @@ object LogisticRegressionSuite { } // Preventing the overflow when we compute the probability val maxMargin = margins.max - if (maxMargin > 0) for (i <-0 until nClasses) margins(i) -= maxMargin + if (maxMargin > 0) for (i <- 0 until nClasses) margins(i) -= maxMargin // Computing the probabilities for each class from the margins. val norm = { @@ -130,7 +130,7 @@ object LogisticRegressionSuite { } temp } - for (i <-0 until nClasses) probs(i) /= norm + for (i <- 0 until nClasses) probs(i) /= norm // Compute the cumulative probability so we can generate a random number and assign a label. for (i <- 1 until nClasses) probs(i) += probs(i - 1) @@ -169,7 +169,7 @@ object LogisticRegressionSuite { } -class LogisticRegressionSuite extends FunSuite with MLlibTestSparkContext with Matchers { +class LogisticRegressionSuite extends SparkFunSuite with MLlibTestSparkContext with Matchers { def validatePrediction( predictions: Seq[Double], input: Seq[LabeledPoint], @@ -196,6 +196,7 @@ class LogisticRegressionSuite extends FunSuite with MLlibTestSparkContext with M .setStepSize(10.0) .setRegParam(0.0) .setNumIterations(20) + .setConvergenceTol(0.0005) val model = lr.run(testRDD) @@ -541,7 +542,7 @@ class LogisticRegressionSuite extends FunSuite with MLlibTestSparkContext with M } -class LogisticRegressionClusterSuite extends FunSuite with LocalClusterSparkContext { +class LogisticRegressionClusterSuite extends SparkFunSuite with LocalClusterSparkContext { test("task size should be small in both training and prediction using SGD optimizer") { val m = 4 diff --git a/mllib/src/test/scala/org/apache/spark/mllib/classification/NaiveBayesSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/classification/NaiveBayesSuite.scala index 40a79a1f19bd..f7fc8730606a 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/classification/NaiveBayesSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/classification/NaiveBayesSuite.scala @@ -19,20 +19,19 @@ package org.apache.spark.mllib.classification import scala.util.Random -import breeze.linalg.{DenseMatrix => BDM, DenseVector => BDV, argmax => brzArgmax, sum => brzSum, Axis} +import breeze.linalg.{DenseMatrix => BDM, DenseVector => BDV, argmax => brzArgmax, sum => brzSum} import breeze.stats.distributions.{Multinomial => BrzMultinomial} -import org.scalatest.FunSuite - -import org.apache.spark.SparkException +import org.apache.spark.{SparkException, SparkFunSuite} import org.apache.spark.mllib.linalg.Vectors import org.apache.spark.mllib.regression.LabeledPoint import org.apache.spark.mllib.util.{LocalClusterSparkContext, MLlibTestSparkContext} import org.apache.spark.util.Utils - object NaiveBayesSuite { + import NaiveBayes.{Multinomial, Bernoulli} + private def calcLabel(p: Double, pi: Array[Double]): Int = { var sum = 0.0 for (j <- 0 until pi.length) { @@ -48,7 +47,7 @@ object NaiveBayesSuite { theta: Array[Array[Double]], // CXD nPoints: Int, seed: Int, - modelType: String = "Multinomial", + modelType: String = Multinomial, sample: Int = 10): Seq[LabeledPoint] = { val D = theta(0).length val rnd = new Random(seed) @@ -58,10 +57,10 @@ object NaiveBayesSuite { for (i <- 0 until nPoints) yield { val y = calcLabel(rnd.nextDouble(), _pi) val xi = modelType match { - case "Bernoulli" => Array.tabulate[Double] (D) { j => + case Bernoulli => Array.tabulate[Double] (D) { j => if (rnd.nextDouble () < _theta(y)(j) ) 1 else 0 } - case "Multinomial" => + case Multinomial => val mult = BrzMultinomial(BDV(_theta(y))) val emptyMap = (0 until D).map(x => (x, 0.0)).toMap val counts = emptyMap ++ mult.sample(sample).groupBy(x => x).map { @@ -70,7 +69,7 @@ object NaiveBayesSuite { counts.toArray.sortBy(_._1).map(_._2) case _ => // This should never happen. - throw new UnknownError(s"NaiveBayesSuite found unknown ModelType: $modelType") + throw new UnknownError(s"Invalid modelType: $modelType.") } LabeledPoint(y, Vectors.dense(xi)) @@ -79,16 +78,16 @@ object NaiveBayesSuite { /** Bernoulli NaiveBayes with binary labels, 3 features */ private val binaryBernoulliModel = new NaiveBayesModel(labels = Array(0.0, 1.0), - pi = Array(0.2, 0.8), theta = Array(Array(0.1, 0.3, 0.6), Array(0.2, 0.4, 0.4)), - "Bernoulli") + pi = Array(0.2, 0.8), theta = Array(Array(0.1, 0.3, 0.6), Array(0.2, 0.4, 0.4)), Bernoulli) /** Multinomial NaiveBayes with binary labels, 3 features */ private val binaryMultinomialModel = new NaiveBayesModel(labels = Array(0.0, 1.0), - pi = Array(0.2, 0.8), theta = Array(Array(0.1, 0.3, 0.6), Array(0.2, 0.4, 0.4)), - "Multinomial") + pi = Array(0.2, 0.8), theta = Array(Array(0.1, 0.3, 0.6), Array(0.2, 0.4, 0.4)), Multinomial) } -class NaiveBayesSuite extends FunSuite with MLlibTestSparkContext { +class NaiveBayesSuite extends SparkFunSuite with MLlibTestSparkContext { + + import NaiveBayes.{Multinomial, Bernoulli} def validatePrediction(predictions: Seq[Double], input: Seq[LabeledPoint]) { val numOfPredictions = predictions.zip(input).count { @@ -117,6 +116,11 @@ class NaiveBayesSuite extends FunSuite with MLlibTestSparkContext { } } + test("model types") { + assert(Multinomial === "multinomial") + assert(Bernoulli === "bernoulli") + } + test("get, set params") { val nb = new NaiveBayes() nb.setLambda(2.0) @@ -134,16 +138,15 @@ class NaiveBayesSuite extends FunSuite with MLlibTestSparkContext { Array(0.10, 0.10, 0.70, 0.10) // label 2 ).map(_.map(math.log)) - val testData = NaiveBayesSuite.generateNaiveBayesInput( - pi, theta, nPoints, 42, "Multinomial") + val testData = NaiveBayesSuite.generateNaiveBayesInput(pi, theta, nPoints, 42, Multinomial) val testRDD = sc.parallelize(testData, 2) testRDD.cache() - val model = NaiveBayes.train(testRDD, 1.0, "Multinomial") + val model = NaiveBayes.train(testRDD, 1.0, Multinomial) validateModelFit(pi, theta, model) val validationData = NaiveBayesSuite.generateNaiveBayesInput( - pi, theta, nPoints, 17, "Multinomial") + pi, theta, nPoints, 17, Multinomial) val validationRDD = sc.parallelize(validationData, 2) // Test prediction on RDD. @@ -159,19 +162,19 @@ class NaiveBayesSuite extends FunSuite with MLlibTestSparkContext { val theta = Array( Array(0.50, 0.02, 0.02, 0.02, 0.02, 0.02, 0.02, 0.02, 0.02, 0.02, 0.02, 0.40), // label 0 Array(0.02, 0.70, 0.10, 0.02, 0.02, 0.02, 0.02, 0.02, 0.02, 0.02, 0.02, 0.02), // label 1 - Array(0.02, 0.02, 0.60, 0.02, 0.02, 0.02, 0.02, 0.02, 0.02, 0.02, 0.02, 0.30) // label 2 + Array(0.02, 0.02, 0.60, 0.02, 0.02, 0.02, 0.02, 0.02, 0.02, 0.02, 0.02, 0.30) // label 2 ).map(_.map(math.log)) val testData = NaiveBayesSuite.generateNaiveBayesInput( - pi, theta, nPoints, 45, "Bernoulli") + pi, theta, nPoints, 45, Bernoulli) val testRDD = sc.parallelize(testData, 2) testRDD.cache() - val model = NaiveBayes.train(testRDD, 1.0, "Bernoulli") + val model = NaiveBayes.train(testRDD, 1.0, Bernoulli) validateModelFit(pi, theta, model) val validationData = NaiveBayesSuite.generateNaiveBayesInput( - pi, theta, nPoints, 20, "Bernoulli") + pi, theta, nPoints, 20, Bernoulli) val validationRDD = sc.parallelize(validationData, 2) // Test prediction on RDD. @@ -216,7 +219,7 @@ class NaiveBayesSuite extends FunSuite with MLlibTestSparkContext { LabeledPoint(1.0, Vectors.dense(0.0))) intercept[SparkException] { - NaiveBayes.train(sc.makeRDD(badTrain, 2), 1.0, "Bernoulli") + NaiveBayes.train(sc.makeRDD(badTrain, 2), 1.0, Bernoulli) } val okTrain = Seq( @@ -235,7 +238,7 @@ class NaiveBayesSuite extends FunSuite with MLlibTestSparkContext { Vectors.dense(1.0), Vectors.dense(0.0)) - val model = NaiveBayes.train(sc.makeRDD(okTrain, 2), 1.0, "Bernoulli") + val model = NaiveBayes.train(sc.makeRDD(okTrain, 2), 1.0, Bernoulli) intercept[SparkException] { model.predict(sc.makeRDD(badPredict, 2)).collect() } @@ -275,14 +278,14 @@ class NaiveBayesSuite extends FunSuite with MLlibTestSparkContext { assert(model.labels === sameModel.labels) assert(model.pi === sameModel.pi) assert(model.theta === sameModel.theta) - assert(model.modelType === "Multinomial") + assert(model.modelType === Multinomial) } finally { Utils.deleteRecursively(tempDir) } } } -class NaiveBayesClusterSuite extends FunSuite with LocalClusterSparkContext { +class NaiveBayesClusterSuite extends SparkFunSuite with LocalClusterSparkContext { test("task size should be small in both training and prediction") { val m = 10 diff --git a/mllib/src/test/scala/org/apache/spark/mllib/classification/SVMSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/classification/SVMSuite.scala index 6de098b383ba..b1d78cba9e3d 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/classification/SVMSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/classification/SVMSuite.scala @@ -21,9 +21,8 @@ import scala.collection.JavaConversions._ import scala.util.Random import org.jblas.DoubleMatrix -import org.scalatest.FunSuite -import org.apache.spark.SparkException +import org.apache.spark.{SparkException, SparkFunSuite} import org.apache.spark.mllib.linalg.Vectors import org.apache.spark.mllib.regression._ import org.apache.spark.mllib.util.{LocalClusterSparkContext, MLlibTestSparkContext} @@ -46,7 +45,7 @@ object SVMSuite { nPoints: Int, seed: Int): Seq[LabeledPoint] = { val rnd = new Random(seed) - val weightsMat = new DoubleMatrix(1, weights.length, weights:_*) + val weightsMat = new DoubleMatrix(1, weights.length, weights : _*) val x = Array.fill[Array[Double]](nPoints)( Array.fill[Double](weights.length)(rnd.nextDouble() * 2.0 - 1.0)) val y = x.map { xi => @@ -62,7 +61,7 @@ object SVMSuite { } -class SVMSuite extends FunSuite with MLlibTestSparkContext { +class SVMSuite extends SparkFunSuite with MLlibTestSparkContext { def validatePrediction(predictions: Seq[Double], input: Seq[LabeledPoint]) { val numOffPredictions = predictions.zip(input).count { case (prediction, expected) => @@ -91,7 +90,7 @@ class SVMSuite extends FunSuite with MLlibTestSparkContext { val model = svm.run(testRDD) val validationData = SVMSuite.generateSVMInput(A, Array[Double](B, C), nPoints, 17) - val validationRDD = sc.parallelize(validationData, 2) + val validationRDD = sc.parallelize(validationData, 2) // Test prediction on RDD. @@ -117,7 +116,7 @@ class SVMSuite extends FunSuite with MLlibTestSparkContext { val B = -1.5 val C = 1.0 - val testData = SVMSuite.generateSVMInput(A, Array[Double](B,C), nPoints, 42) + val testData = SVMSuite.generateSVMInput(A, Array[Double](B, C), nPoints, 42) val testRDD = sc.parallelize(testData, 2) testRDD.cache() @@ -127,8 +126,8 @@ class SVMSuite extends FunSuite with MLlibTestSparkContext { val model = svm.run(testRDD) - val validationData = SVMSuite.generateSVMInput(A, Array[Double](B,C), nPoints, 17) - val validationRDD = sc.parallelize(validationData, 2) + val validationData = SVMSuite.generateSVMInput(A, Array[Double](B, C), nPoints, 17) + val validationRDD = sc.parallelize(validationData, 2) // Test prediction on RDD. validatePrediction(model.predict(validationRDD.map(_.features)).collect(), validationData) @@ -145,7 +144,7 @@ class SVMSuite extends FunSuite with MLlibTestSparkContext { val B = -1.5 val C = 1.0 - val testData = SVMSuite.generateSVMInput(A, Array[Double](B,C), nPoints, 42) + val testData = SVMSuite.generateSVMInput(A, Array[Double](B, C), nPoints, 42) val initialB = -1.0 val initialC = -1.0 @@ -159,8 +158,8 @@ class SVMSuite extends FunSuite with MLlibTestSparkContext { val model = svm.run(testRDD, initialWeights) - val validationData = SVMSuite.generateSVMInput(A, Array[Double](B,C), nPoints, 17) - val validationRDD = sc.parallelize(validationData,2) + val validationData = SVMSuite.generateSVMInput(A, Array[Double](B, C), nPoints, 17) + val validationRDD = sc.parallelize(validationData, 2) // Test prediction on RDD. validatePrediction(model.predict(validationRDD.map(_.features)).collect(), validationData) @@ -177,7 +176,7 @@ class SVMSuite extends FunSuite with MLlibTestSparkContext { val B = -1.5 val C = 1.0 - val testData = SVMSuite.generateSVMInput(A, Array[Double](B,C), nPoints, 42) + val testData = SVMSuite.generateSVMInput(A, Array[Double](B, C), nPoints, 42) val testRDD = sc.parallelize(testData, 2) val testRDDInvalid = testRDD.map { lp => @@ -229,7 +228,7 @@ class SVMSuite extends FunSuite with MLlibTestSparkContext { } } -class SVMClusterSuite extends FunSuite with LocalClusterSparkContext { +class SVMClusterSuite extends SparkFunSuite with LocalClusterSparkContext { test("task size should be small in both training and prediction") { val m = 4 diff --git a/mllib/src/test/scala/org/apache/spark/mllib/classification/StreamingLogisticRegressionSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/classification/StreamingLogisticRegressionSuite.scala index 5683b55e8500..fd653296c9d9 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/classification/StreamingLogisticRegressionSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/classification/StreamingLogisticRegressionSuite.scala @@ -19,15 +19,14 @@ package org.apache.spark.mllib.classification import scala.collection.mutable.ArrayBuffer -import org.scalatest.FunSuite - +import org.apache.spark.SparkFunSuite import org.apache.spark.mllib.linalg.Vectors import org.apache.spark.mllib.regression.LabeledPoint import org.apache.spark.mllib.util.TestingUtils._ import org.apache.spark.streaming.dstream.DStream import org.apache.spark.streaming.TestSuiteBase -class StreamingLogisticRegressionSuite extends FunSuite with TestSuiteBase { +class StreamingLogisticRegressionSuite extends SparkFunSuite with TestSuiteBase { // use longer wait time to ensure job completion override def maxWaitTimeMillis: Int = 30000 @@ -159,4 +158,21 @@ class StreamingLogisticRegressionSuite extends FunSuite with TestSuiteBase { val error = output.map(batch => batch.map(p => math.abs(p._1 - p._2)).sum / nPoints).toList assert(error.head > 0.8 & error.last < 0.2) } + + // Test empty RDDs in a stream + test("handling empty RDDs in a stream") { + val model = new StreamingLogisticRegressionWithSGD() + .setInitialWeights(Vectors.dense(-0.1)) + .setStepSize(0.01) + .setNumIterations(10) + val numBatches = 10 + val emptyInput = Seq.empty[Seq[LabeledPoint]] + val ssc = setupStreams(emptyInput, + (inputDStream: DStream[LabeledPoint]) => { + model.trainOn(inputDStream) + model.predictOnValues(inputDStream.map(x => (x.label, x.features))) + } + ) + val output: Seq[Seq[(Double, Double)]] = runStreams(ssc, numBatches, numBatches) + } } diff --git a/mllib/src/test/scala/org/apache/spark/mllib/clustering/GaussianMixtureSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/clustering/GaussianMixtureSuite.scala index f356ffa3e3a2..b218d72f1268 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/clustering/GaussianMixtureSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/clustering/GaussianMixtureSuite.scala @@ -17,15 +17,14 @@ package org.apache.spark.mllib.clustering -import org.scalatest.FunSuite - +import org.apache.spark.SparkFunSuite import org.apache.spark.mllib.linalg.{Vectors, Matrices} import org.apache.spark.mllib.stat.distribution.MultivariateGaussian import org.apache.spark.mllib.util.MLlibTestSparkContext import org.apache.spark.mllib.util.TestingUtils._ import org.apache.spark.util.Utils -class GaussianMixtureSuite extends FunSuite with MLlibTestSparkContext { +class GaussianMixtureSuite extends SparkFunSuite with MLlibTestSparkContext { test("single cluster") { val data = sc.parallelize(Array( Vectors.dense(6.0, 9.0), @@ -47,7 +46,7 @@ class GaussianMixtureSuite extends FunSuite with MLlibTestSparkContext { } } - + test("two clusters") { val data = sc.parallelize(GaussianTestData.data) @@ -63,7 +62,7 @@ class GaussianMixtureSuite extends FunSuite with MLlibTestSparkContext { val Ew = Array(1.0 / 3.0, 2.0 / 3.0) val Emu = Array(Vectors.dense(-4.3673), Vectors.dense(5.1604)) val Esigma = Array(Matrices.dense(1, 1, Array(1.1098)), Matrices.dense(1, 1, Array(0.86644))) - + val gmm = new GaussianMixture() .setK(2) .setInitialModel(initialGmm) diff --git a/mllib/src/test/scala/org/apache/spark/mllib/clustering/KMeansSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/clustering/KMeansSuite.scala index 0f2b26d462ad..0dbbd7127444 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/clustering/KMeansSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/clustering/KMeansSuite.scala @@ -19,14 +19,13 @@ package org.apache.spark.mllib.clustering import scala.util.Random -import org.scalatest.FunSuite - +import org.apache.spark.SparkFunSuite import org.apache.spark.mllib.linalg.{DenseVector, SparseVector, Vector, Vectors} import org.apache.spark.mllib.util.{LocalClusterSparkContext, MLlibTestSparkContext} import org.apache.spark.mllib.util.TestingUtils._ import org.apache.spark.util.Utils -class KMeansSuite extends FunSuite with MLlibTestSparkContext { +class KMeansSuite extends SparkFunSuite with MLlibTestSparkContext { import org.apache.spark.mllib.clustering.KMeans.{K_MEANS_PARALLEL, RANDOM} @@ -75,7 +74,7 @@ class KMeansSuite extends FunSuite with MLlibTestSparkContext { val center = Vectors.dense(1.0, 2.0, 3.0) // Make sure code runs. - var model = KMeans.train(data, k=2, maxIterations=1) + var model = KMeans.train(data, k = 2, maxIterations = 1) assert(model.clusterCenters.size === 2) } @@ -87,7 +86,7 @@ class KMeansSuite extends FunSuite with MLlibTestSparkContext { 2) // Make sure code runs. - var model = KMeans.train(data, k=3, maxIterations=1) + var model = KMeans.train(data, k = 3, maxIterations = 1) assert(model.clusterCenters.size === 3) } @@ -281,7 +280,7 @@ class KMeansSuite extends FunSuite with MLlibTestSparkContext { } } -object KMeansSuite extends FunSuite { +object KMeansSuite extends SparkFunSuite { def createModel(dim: Int, k: Int, isSparse: Boolean): KMeansModel = { val singlePoint = isSparse match { case true => @@ -305,7 +304,7 @@ object KMeansSuite extends FunSuite { } } -class KMeansClusterSuite extends FunSuite with LocalClusterSparkContext { +class KMeansClusterSuite extends SparkFunSuite with LocalClusterSparkContext { test("task size should be small in both training and prediction") { val m = 4 diff --git a/mllib/src/test/scala/org/apache/spark/mllib/clustering/LDASuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/clustering/LDASuite.scala index d5b7d9633574..406affa25539 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/clustering/LDASuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/clustering/LDASuite.scala @@ -19,13 +19,12 @@ package org.apache.spark.mllib.clustering import breeze.linalg.{DenseMatrix => BDM} -import org.scalatest.FunSuite - +import org.apache.spark.SparkFunSuite import org.apache.spark.mllib.linalg.{Vector, DenseMatrix, Matrix, Vectors} import org.apache.spark.mllib.util.MLlibTestSparkContext import org.apache.spark.mllib.util.TestingUtils._ -class LDASuite extends FunSuite with MLlibTestSparkContext { +class LDASuite extends SparkFunSuite with MLlibTestSparkContext { import LDASuite._ diff --git a/mllib/src/test/scala/org/apache/spark/mllib/clustering/PowerIterationClusteringSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/clustering/PowerIterationClusteringSuite.scala index 6d6fe6fe46ba..19e65f1b53ab 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/clustering/PowerIterationClusteringSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/clustering/PowerIterationClusteringSuite.scala @@ -20,15 +20,13 @@ package org.apache.spark.mllib.clustering import scala.collection.mutable import scala.util.Random -import org.scalatest.FunSuite - -import org.apache.spark.SparkContext +import org.apache.spark.{SparkContext, SparkFunSuite} import org.apache.spark.graphx.{Edge, Graph} import org.apache.spark.mllib.util.MLlibTestSparkContext import org.apache.spark.mllib.util.TestingUtils._ import org.apache.spark.util.Utils -class PowerIterationClusteringSuite extends FunSuite with MLlibTestSparkContext { +class PowerIterationClusteringSuite extends SparkFunSuite with MLlibTestSparkContext { import org.apache.spark.mllib.clustering.PowerIterationClustering._ @@ -58,7 +56,7 @@ class PowerIterationClusteringSuite extends FunSuite with MLlibTestSparkContext predictions(a.cluster) += a.id } assert(predictions.toSet == Set((0 to 3).toSet, (4 to 15).toSet)) - + val model2 = new PowerIterationClustering() .setK(2) .setInitializationMode("degree") @@ -94,11 +92,13 @@ class PowerIterationClusteringSuite extends FunSuite with MLlibTestSparkContext */ val similarities = Seq[(Long, Long, Double)]( (0, 1, 1.0), (0, 2, 1.0), (0, 3, 1.0), (1, 2, 1.0), (2, 3, 1.0)) + // scalastyle:off val expected = Array( Array(0.0, 1.0/3.0, 1.0/3.0, 1.0/3.0), Array(1.0/2.0, 0.0, 1.0/2.0, 0.0), Array(1.0/3.0, 1.0/3.0, 0.0, 1.0/3.0), Array(1.0/2.0, 0.0, 1.0/2.0, 0.0)) + // scalastyle:on val w = normalize(sc.parallelize(similarities, 2)) w.edges.collect().foreach { case Edge(i, j, x) => assert(x ~== expected(i.toInt)(j.toInt) absTol 1e-14) @@ -128,7 +128,7 @@ class PowerIterationClusteringSuite extends FunSuite with MLlibTestSparkContext } } -object PowerIterationClusteringSuite extends FunSuite { +object PowerIterationClusteringSuite extends SparkFunSuite { def createModel(sc: SparkContext, k: Int, nPoints: Int): PowerIterationClusteringModel = { val assignments = sc.parallelize( (0 until nPoints).map(p => PowerIterationClustering.Assignment(p, Random.nextInt(k)))) diff --git a/mllib/src/test/scala/org/apache/spark/mllib/clustering/StreamingKMeansSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/clustering/StreamingKMeansSuite.scala index f90025d535e4..ac01622b8a08 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/clustering/StreamingKMeansSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/clustering/StreamingKMeansSuite.scala @@ -17,15 +17,14 @@ package org.apache.spark.mllib.clustering -import org.scalatest.FunSuite - +import org.apache.spark.SparkFunSuite import org.apache.spark.mllib.linalg.{Vector, Vectors} import org.apache.spark.mllib.util.TestingUtils._ import org.apache.spark.streaming.TestSuiteBase import org.apache.spark.streaming.dstream.DStream import org.apache.spark.util.random.XORShiftRandom -class StreamingKMeansSuite extends FunSuite with TestSuiteBase { +class StreamingKMeansSuite extends SparkFunSuite with TestSuiteBase { override def maxWaitTimeMillis: Int = 30000 @@ -133,6 +132,13 @@ class StreamingKMeansSuite extends FunSuite with TestSuiteBase { assert(math.abs(c1) ~== 0.8 absTol 0.6) } + test("SPARK-7946 setDecayFactor") { + val kMeans = new StreamingKMeans() + assert(kMeans.decayFactor === 1.0) + kMeans.setDecayFactor(2.0) + assert(kMeans.decayFactor === 2.0) + } + def StreamingKMeansDataGenerator( numPoints: Int, numBatches: Int, diff --git a/mllib/src/test/scala/org/apache/spark/mllib/evaluation/AreaUnderCurveSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/evaluation/AreaUnderCurveSuite.scala index 79847633ff0d..87ccc7eda44e 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/evaluation/AreaUnderCurveSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/evaluation/AreaUnderCurveSuite.scala @@ -17,12 +17,11 @@ package org.apache.spark.mllib.evaluation -import org.scalatest.FunSuite - +import org.apache.spark.SparkFunSuite import org.apache.spark.mllib.util.MLlibTestSparkContext import org.apache.spark.mllib.util.TestingUtils._ -class AreaUnderCurveSuite extends FunSuite with MLlibTestSparkContext { +class AreaUnderCurveSuite extends SparkFunSuite with MLlibTestSparkContext { test("auc computation") { val curve = Seq((0.0, 0.0), (1.0, 1.0), (2.0, 3.0), (3.0, 0.0)) val auc = 4.0 diff --git a/mllib/src/test/scala/org/apache/spark/mllib/evaluation/BinaryClassificationMetricsSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/evaluation/BinaryClassificationMetricsSuite.scala index e0224f960cc4..99d52fabc530 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/evaluation/BinaryClassificationMetricsSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/evaluation/BinaryClassificationMetricsSuite.scala @@ -17,12 +17,11 @@ package org.apache.spark.mllib.evaluation -import org.scalatest.FunSuite - +import org.apache.spark.SparkFunSuite import org.apache.spark.mllib.util.MLlibTestSparkContext import org.apache.spark.mllib.util.TestingUtils._ -class BinaryClassificationMetricsSuite extends FunSuite with MLlibTestSparkContext { +class BinaryClassificationMetricsSuite extends SparkFunSuite with MLlibTestSparkContext { private def areWithinEpsilon(x: (Double, Double)): Boolean = x._1 ~= (x._2) absTol 1E-5 diff --git a/mllib/src/test/scala/org/apache/spark/mllib/evaluation/MulticlassMetricsSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/evaluation/MulticlassMetricsSuite.scala index 7dc4f3cfbc4e..d55bc8c3ec09 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/evaluation/MulticlassMetricsSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/evaluation/MulticlassMetricsSuite.scala @@ -17,12 +17,11 @@ package org.apache.spark.mllib.evaluation -import org.scalatest.FunSuite - +import org.apache.spark.SparkFunSuite import org.apache.spark.mllib.linalg.Matrices import org.apache.spark.mllib.util.MLlibTestSparkContext -class MulticlassMetricsSuite extends FunSuite with MLlibTestSparkContext { +class MulticlassMetricsSuite extends SparkFunSuite with MLlibTestSparkContext { test("Multiclass evaluation metrics") { /* * Confusion matrix for 3-class classification with total 9 instances: diff --git a/mllib/src/test/scala/org/apache/spark/mllib/evaluation/MultilabelMetricsSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/evaluation/MultilabelMetricsSuite.scala index 2537dd62c92f..f3b19aeb42f8 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/evaluation/MultilabelMetricsSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/evaluation/MultilabelMetricsSuite.scala @@ -17,12 +17,11 @@ package org.apache.spark.mllib.evaluation -import org.scalatest.FunSuite - +import org.apache.spark.SparkFunSuite import org.apache.spark.mllib.util.MLlibTestSparkContext import org.apache.spark.rdd.RDD -class MultilabelMetricsSuite extends FunSuite with MLlibTestSparkContext { +class MultilabelMetricsSuite extends SparkFunSuite with MLlibTestSparkContext { test("Multilabel evaluation metrics") { /* * Documents true labels (5x class0, 3x class1, 4x class2): diff --git a/mllib/src/test/scala/org/apache/spark/mllib/evaluation/RankingMetricsSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/evaluation/RankingMetricsSuite.scala index 609eed983ff4..c0924a213a84 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/evaluation/RankingMetricsSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/evaluation/RankingMetricsSuite.scala @@ -17,12 +17,11 @@ package org.apache.spark.mllib.evaluation -import org.scalatest.FunSuite - +import org.apache.spark.SparkFunSuite import org.apache.spark.mllib.util.TestingUtils._ import org.apache.spark.mllib.util.MLlibTestSparkContext -class RankingMetricsSuite extends FunSuite with MLlibTestSparkContext { +class RankingMetricsSuite extends SparkFunSuite with MLlibTestSparkContext { test("Ranking metrics: map, ndcg") { val predictionAndLabels = sc.parallelize( Seq( diff --git a/mllib/src/test/scala/org/apache/spark/mllib/evaluation/RegressionMetricsSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/evaluation/RegressionMetricsSuite.scala index 670b4c34e609..9de2bdb6d724 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/evaluation/RegressionMetricsSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/evaluation/RegressionMetricsSuite.scala @@ -17,16 +17,15 @@ package org.apache.spark.mllib.evaluation -import org.scalatest.FunSuite - +import org.apache.spark.SparkFunSuite import org.apache.spark.mllib.util.MLlibTestSparkContext import org.apache.spark.mllib.util.TestingUtils._ -class RegressionMetricsSuite extends FunSuite with MLlibTestSparkContext { +class RegressionMetricsSuite extends SparkFunSuite with MLlibTestSparkContext { test("regression metrics") { val predictionAndObservations = sc.parallelize( - Seq((2.5,3.0),(0.0,-0.5),(2.0,2.0),(8.0,7.0)), 2) + Seq((2.5, 3.0), (0.0, -0.5), (2.0, 2.0), (8.0, 7.0)), 2) val metrics = new RegressionMetrics(predictionAndObservations) assert(metrics.explainedVariance ~== 0.95717 absTol 1E-5, "explained variance regression score mismatch") @@ -39,7 +38,7 @@ class RegressionMetricsSuite extends FunSuite with MLlibTestSparkContext { test("regression metrics with complete fitting") { val predictionAndObservations = sc.parallelize( - Seq((3.0,3.0),(0.0,0.0),(2.0,2.0),(8.0,8.0)), 2) + Seq((3.0, 3.0), (0.0, 0.0), (2.0, 2.0), (8.0, 8.0)), 2) val metrics = new RegressionMetrics(predictionAndObservations) assert(metrics.explainedVariance ~== 1.0 absTol 1E-5, "explained variance regression score mismatch") diff --git a/mllib/src/test/scala/org/apache/spark/mllib/feature/ChiSqSelectorSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/feature/ChiSqSelectorSuite.scala index 747f5914598e..889727fb5582 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/feature/ChiSqSelectorSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/feature/ChiSqSelectorSuite.scala @@ -17,13 +17,12 @@ package org.apache.spark.mllib.feature -import org.scalatest.FunSuite - +import org.apache.spark.SparkFunSuite import org.apache.spark.mllib.linalg.Vectors import org.apache.spark.mllib.regression.LabeledPoint import org.apache.spark.mllib.util.MLlibTestSparkContext -class ChiSqSelectorSuite extends FunSuite with MLlibTestSparkContext { +class ChiSqSelectorSuite extends SparkFunSuite with MLlibTestSparkContext { /* * Contingency tables diff --git a/mllib/src/test/scala/org/apache/spark/mllib/feature/ElementwiseProductSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/feature/ElementwiseProductSuite.scala index f3a482abda87..ccbf8a91cdd3 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/feature/ElementwiseProductSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/feature/ElementwiseProductSuite.scala @@ -17,13 +17,12 @@ package org.apache.spark.mllib.feature -import org.scalatest.FunSuite - +import org.apache.spark.SparkFunSuite import org.apache.spark.mllib.linalg.{DenseVector, SparseVector, Vectors} import org.apache.spark.mllib.util.MLlibTestSparkContext import org.apache.spark.mllib.util.TestingUtils._ -class ElementwiseProductSuite extends FunSuite with MLlibTestSparkContext { +class ElementwiseProductSuite extends SparkFunSuite with MLlibTestSparkContext { test("elementwise (hadamard) product should properly apply vector to dense data set") { val denseData = Array( diff --git a/mllib/src/test/scala/org/apache/spark/mllib/feature/HashingTFSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/feature/HashingTFSuite.scala index 0c4dfb7b97c7..cf279c02334e 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/feature/HashingTFSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/feature/HashingTFSuite.scala @@ -17,12 +17,11 @@ package org.apache.spark.mllib.feature -import org.scalatest.FunSuite - +import org.apache.spark.SparkFunSuite import org.apache.spark.mllib.linalg.Vectors import org.apache.spark.mllib.util.MLlibTestSparkContext -class HashingTFSuite extends FunSuite with MLlibTestSparkContext { +class HashingTFSuite extends SparkFunSuite with MLlibTestSparkContext { test("hashing tf on a single doc") { val hashingTF = new HashingTF(1000) diff --git a/mllib/src/test/scala/org/apache/spark/mllib/feature/IDFSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/feature/IDFSuite.scala index 0a5cad7caf8e..21163633051e 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/feature/IDFSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/feature/IDFSuite.scala @@ -17,13 +17,12 @@ package org.apache.spark.mllib.feature -import org.scalatest.FunSuite - +import org.apache.spark.SparkFunSuite import org.apache.spark.mllib.linalg.{DenseVector, SparseVector, Vectors, Vector} import org.apache.spark.mllib.util.MLlibTestSparkContext import org.apache.spark.mllib.util.TestingUtils._ -class IDFSuite extends FunSuite with MLlibTestSparkContext { +class IDFSuite extends SparkFunSuite with MLlibTestSparkContext { test("idf") { val n = 4 diff --git a/mllib/src/test/scala/org/apache/spark/mllib/feature/NormalizerSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/feature/NormalizerSuite.scala index 5c4af2b99e68..34122d6ed2e9 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/feature/NormalizerSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/feature/NormalizerSuite.scala @@ -17,15 +17,14 @@ package org.apache.spark.mllib.feature -import org.scalatest.FunSuite - import breeze.linalg.{norm => brzNorm} +import org.apache.spark.SparkFunSuite import org.apache.spark.mllib.linalg.{DenseVector, SparseVector, Vectors} import org.apache.spark.mllib.util.MLlibTestSparkContext import org.apache.spark.mllib.util.TestingUtils._ -class NormalizerSuite extends FunSuite with MLlibTestSparkContext { +class NormalizerSuite extends SparkFunSuite with MLlibTestSparkContext { val data = Array( Vectors.sparse(3, Seq((0, -2.0), (1, 2.3))), diff --git a/mllib/src/test/scala/org/apache/spark/mllib/feature/PCASuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/feature/PCASuite.scala index 758af588f1c6..e57f49191378 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/feature/PCASuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/feature/PCASuite.scala @@ -17,13 +17,12 @@ package org.apache.spark.mllib.feature -import org.scalatest.FunSuite - +import org.apache.spark.SparkFunSuite import org.apache.spark.mllib.linalg.Vectors import org.apache.spark.mllib.linalg.distributed.RowMatrix import org.apache.spark.mllib.util.MLlibTestSparkContext -class PCASuite extends FunSuite with MLlibTestSparkContext { +class PCASuite extends SparkFunSuite with MLlibTestSparkContext { private val data = Array( Vectors.sparse(5, Seq((1, 1.0), (3, 7.0))), diff --git a/mllib/src/test/scala/org/apache/spark/mllib/feature/StandardScalerSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/feature/StandardScalerSuite.scala index 7f94564b2a3a..6ab2fa677012 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/feature/StandardScalerSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/feature/StandardScalerSuite.scala @@ -17,15 +17,14 @@ package org.apache.spark.mllib.feature -import org.scalatest.FunSuite - +import org.apache.spark.SparkFunSuite import org.apache.spark.mllib.linalg.{DenseVector, SparseVector, Vector, Vectors} import org.apache.spark.mllib.util.MLlibTestSparkContext import org.apache.spark.mllib.util.TestingUtils._ import org.apache.spark.mllib.stat.{MultivariateStatisticalSummary, MultivariateOnlineSummarizer} import org.apache.spark.rdd.RDD -class StandardScalerSuite extends FunSuite with MLlibTestSparkContext { +class StandardScalerSuite extends SparkFunSuite with MLlibTestSparkContext { // When the input data is all constant, the variance is zero. The standardization against // zero variance is not well-defined, but we decide to just set it into zero here. @@ -360,7 +359,7 @@ class StandardScalerSuite extends FunSuite with MLlibTestSparkContext { } withClue("model needs std and mean vectors to be equal size when both are provided") { intercept[IllegalArgumentException] { - val model = new StandardScalerModel(Vectors.dense(0.0), Vectors.dense(0.0,1.0)) + val model = new StandardScalerModel(Vectors.dense(0.0), Vectors.dense(0.0, 1.0)) } } } diff --git a/mllib/src/test/scala/org/apache/spark/mllib/feature/Word2VecSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/feature/Word2VecSuite.scala index 98a98a7599bc..b6818369208d 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/feature/Word2VecSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/feature/Word2VecSuite.scala @@ -17,14 +17,13 @@ package org.apache.spark.mllib.feature -import org.scalatest.FunSuite - +import org.apache.spark.SparkFunSuite import org.apache.spark.mllib.util.MLlibTestSparkContext import org.apache.spark.mllib.util.TestingUtils._ import org.apache.spark.util.Utils -class Word2VecSuite extends FunSuite with MLlibTestSparkContext { +class Word2VecSuite extends SparkFunSuite with MLlibTestSparkContext { // TODO: add more tests diff --git a/mllib/src/test/scala/org/apache/spark/mllib/fpm/FPGrowthSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/fpm/FPGrowthSuite.scala index bd5b9cc3afa1..66ae3543ecc4 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/fpm/FPGrowthSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/fpm/FPGrowthSuite.scala @@ -16,11 +16,10 @@ */ package org.apache.spark.mllib.fpm -import org.scalatest.FunSuite - +import org.apache.spark.SparkFunSuite import org.apache.spark.mllib.util.MLlibTestSparkContext -class FPGrowthSuite extends FunSuite with MLlibTestSparkContext { +class FPGrowthSuite extends SparkFunSuite with MLlibTestSparkContext { test("FP-Growth using String type") { diff --git a/mllib/src/test/scala/org/apache/spark/mllib/fpm/FPTreeSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/fpm/FPTreeSuite.scala index 04017f67c311..a56d7b357921 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/fpm/FPTreeSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/fpm/FPTreeSuite.scala @@ -19,11 +19,10 @@ package org.apache.spark.mllib.fpm import scala.language.existentials -import org.scalatest.FunSuite - +import org.apache.spark.SparkFunSuite import org.apache.spark.mllib.util.MLlibTestSparkContext -class FPTreeSuite extends FunSuite with MLlibTestSparkContext { +class FPTreeSuite extends SparkFunSuite with MLlibTestSparkContext { test("add transaction") { val tree = new FPTree[String] diff --git a/mllib/src/test/scala/org/apache/spark/mllib/impl/PeriodicGraphCheckpointerSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/impl/PeriodicGraphCheckpointerSuite.scala index 699f009f0f2e..d34888af2d73 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/impl/PeriodicGraphCheckpointerSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/impl/PeriodicGraphCheckpointerSuite.scala @@ -17,18 +17,16 @@ package org.apache.spark.mllib.impl -import org.scalatest.FunSuite - import org.apache.hadoop.fs.{FileSystem, Path} -import org.apache.spark.SparkContext +import org.apache.spark.{SparkContext, SparkFunSuite} import org.apache.spark.graphx.{Edge, Graph} import org.apache.spark.mllib.util.MLlibTestSparkContext import org.apache.spark.storage.StorageLevel import org.apache.spark.util.Utils -class PeriodicGraphCheckpointerSuite extends FunSuite with MLlibTestSparkContext { +class PeriodicGraphCheckpointerSuite extends SparkFunSuite with MLlibTestSparkContext { import PeriodicGraphCheckpointerSuite._ diff --git a/mllib/src/test/scala/org/apache/spark/mllib/linalg/BLASSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/linalg/BLASSuite.scala index 002cb253862b..b0f3f71113c5 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/linalg/BLASSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/linalg/BLASSuite.scala @@ -17,12 +17,11 @@ package org.apache.spark.mllib.linalg -import org.scalatest.FunSuite - +import org.apache.spark.SparkFunSuite import org.apache.spark.mllib.util.TestingUtils._ import org.apache.spark.mllib.linalg.BLAS._ -class BLASSuite extends FunSuite { +class BLASSuite extends SparkFunSuite { test("copy") { val sx = Vectors.sparse(4, Array(0, 2), Array(1.0, -2.0)) @@ -140,7 +139,7 @@ class BLASSuite extends FunSuite { syr(alpha, x, dA) assert(dA ~== expected absTol 1e-15) - + val dB = new DenseMatrix(3, 4, Array(0.0, 1.2, 2.2, 3.1, 1.2, 3.2, 5.3, 4.6, 2.2, 5.3, 1.8, 3.0)) @@ -149,7 +148,7 @@ class BLASSuite extends FunSuite { syr(alpha, x, dB) } } - + val dC = new DenseMatrix(3, 3, Array(0.0, 1.2, 2.2, 1.2, 3.2, 5.3, 2.2, 5.3, 1.8)) @@ -158,7 +157,7 @@ class BLASSuite extends FunSuite { syr(alpha, x, dC) } } - + val y = new DenseVector(Array(0.0, 2.7, 3.5, 2.1, 1.5)) withClue("Size of vector must match the rank of matrix") { @@ -257,32 +256,96 @@ class BLASSuite extends FunSuite { new DenseMatrix(4, 3, Array(0.0, 1.0, 0.0, 0.0, 2.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 3.0)) val sA = new SparseMatrix(4, 3, Array(0, 1, 3, 4), Array(1, 0, 2, 3), Array(1.0, 2.0, 1.0, 3.0)) - val x = new DenseVector(Array(1.0, 2.0, 3.0)) + val dA2 = + new DenseMatrix(4, 3, Array(0.0, 2.0, 0.0, 1.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 3.0), true) + val sA2 = + new SparseMatrix(4, 3, Array(0, 1, 2, 3, 4), Array(1, 0, 1, 2), Array(2.0, 1.0, 1.0, 3.0), + true) + + val dx = new DenseVector(Array(1.0, 2.0, 3.0)) + val sx = dx.toSparse val expected = new DenseVector(Array(4.0, 1.0, 2.0, 9.0)) - assert(dA.multiply(x) ~== expected absTol 1e-15) - assert(sA.multiply(x) ~== expected absTol 1e-15) + assert(dA.multiply(dx) ~== expected absTol 1e-15) + assert(sA.multiply(dx) ~== expected absTol 1e-15) + assert(dA.multiply(sx) ~== expected absTol 1e-15) + assert(sA.multiply(sx) ~== expected absTol 1e-15) val y1 = new DenseVector(Array(1.0, 3.0, 1.0, 0.0)) val y2 = y1.copy val y3 = y1.copy val y4 = y1.copy + val y5 = y1.copy + val y6 = y1.copy + val y7 = y1.copy + val y8 = y1.copy + val y9 = y1.copy + val y10 = y1.copy + val y11 = y1.copy + val y12 = y1.copy + val y13 = y1.copy + val y14 = y1.copy + val y15 = y1.copy + val y16 = y1.copy + val expected2 = new DenseVector(Array(6.0, 7.0, 4.0, 9.0)) val expected3 = new DenseVector(Array(10.0, 8.0, 6.0, 18.0)) - gemv(1.0, dA, x, 2.0, y1) - gemv(1.0, sA, x, 2.0, y2) - gemv(2.0, dA, x, 2.0, y3) - gemv(2.0, sA, x, 2.0, y4) + gemv(1.0, dA, dx, 2.0, y1) + gemv(1.0, sA, dx, 2.0, y2) + gemv(1.0, dA, sx, 2.0, y3) + gemv(1.0, sA, sx, 2.0, y4) + + gemv(1.0, dA2, dx, 2.0, y5) + gemv(1.0, sA2, dx, 2.0, y6) + gemv(1.0, dA2, sx, 2.0, y7) + gemv(1.0, sA2, sx, 2.0, y8) + + gemv(2.0, dA, dx, 2.0, y9) + gemv(2.0, sA, dx, 2.0, y10) + gemv(2.0, dA, sx, 2.0, y11) + gemv(2.0, sA, sx, 2.0, y12) + + gemv(2.0, dA2, dx, 2.0, y13) + gemv(2.0, sA2, dx, 2.0, y14) + gemv(2.0, dA2, sx, 2.0, y15) + gemv(2.0, sA2, sx, 2.0, y16) + assert(y1 ~== expected2 absTol 1e-15) assert(y2 ~== expected2 absTol 1e-15) - assert(y3 ~== expected3 absTol 1e-15) - assert(y4 ~== expected3 absTol 1e-15) + assert(y3 ~== expected2 absTol 1e-15) + assert(y4 ~== expected2 absTol 1e-15) + + assert(y5 ~== expected2 absTol 1e-15) + assert(y6 ~== expected2 absTol 1e-15) + assert(y7 ~== expected2 absTol 1e-15) + assert(y8 ~== expected2 absTol 1e-15) + + assert(y9 ~== expected3 absTol 1e-15) + assert(y10 ~== expected3 absTol 1e-15) + assert(y11 ~== expected3 absTol 1e-15) + assert(y12 ~== expected3 absTol 1e-15) + + assert(y13 ~== expected3 absTol 1e-15) + assert(y14 ~== expected3 absTol 1e-15) + assert(y15 ~== expected3 absTol 1e-15) + assert(y16 ~== expected3 absTol 1e-15) + withClue("columns of A don't match the rows of B") { intercept[Exception] { - gemv(1.0, dA.transpose, x, 2.0, y1) + gemv(1.0, dA.transpose, dx, 2.0, y1) + } + intercept[Exception] { + gemv(1.0, sA.transpose, dx, 2.0, y1) + } + intercept[Exception] { + gemv(1.0, dA.transpose, sx, 2.0, y1) + } + intercept[Exception] { + gemv(1.0, sA.transpose, sx, 2.0, y1) } } + val dAT = new DenseMatrix(3, 4, Array(0.0, 2.0, 0.0, 1.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 3.0)) val sAT = @@ -291,7 +354,9 @@ class BLASSuite extends FunSuite { val dATT = dAT.transpose val sATT = sAT.transpose - assert(dATT.multiply(x) ~== expected absTol 1e-15) - assert(sATT.multiply(x) ~== expected absTol 1e-15) + assert(dATT.multiply(dx) ~== expected absTol 1e-15) + assert(sATT.multiply(dx) ~== expected absTol 1e-15) + assert(dATT.multiply(sx) ~== expected absTol 1e-15) + assert(sATT.multiply(sx) ~== expected absTol 1e-15) } } diff --git a/mllib/src/test/scala/org/apache/spark/mllib/linalg/BreezeMatrixConversionSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/linalg/BreezeMatrixConversionSuite.scala index 203103237397..dc04258e41d2 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/linalg/BreezeMatrixConversionSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/linalg/BreezeMatrixConversionSuite.scala @@ -17,11 +17,11 @@ package org.apache.spark.mllib.linalg -import org.scalatest.FunSuite - import breeze.linalg.{DenseMatrix => BDM, CSCMatrix => BSM} -class BreezeMatrixConversionSuite extends FunSuite { +import org.apache.spark.SparkFunSuite + +class BreezeMatrixConversionSuite extends SparkFunSuite { test("dense matrix to breeze") { val mat = Matrices.dense(3, 2, Array(0.0, 1.0, 2.0, 3.0, 4.0, 5.0)) val breeze = mat.toBreeze.asInstanceOf[BDM[Double]] diff --git a/mllib/src/test/scala/org/apache/spark/mllib/linalg/BreezeVectorConversionSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/linalg/BreezeVectorConversionSuite.scala index 8abdac72902c..3772c9235ad3 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/linalg/BreezeVectorConversionSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/linalg/BreezeVectorConversionSuite.scala @@ -17,14 +17,14 @@ package org.apache.spark.mllib.linalg -import org.scalatest.FunSuite - import breeze.linalg.{DenseVector => BDV, SparseVector => BSV} +import org.apache.spark.SparkFunSuite + /** * Test Breeze vector conversions. */ -class BreezeVectorConversionSuite extends FunSuite { +class BreezeVectorConversionSuite extends SparkFunSuite { val arr = Array(0.1, 0.2, 0.3, 0.4) val n = 20 diff --git a/mllib/src/test/scala/org/apache/spark/mllib/linalg/MatricesSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/linalg/MatricesSuite.scala index 86119ec38101..a270ba2562db 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/linalg/MatricesSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/linalg/MatricesSuite.scala @@ -20,13 +20,13 @@ package org.apache.spark.mllib.linalg import java.util.Random import org.mockito.Mockito.when -import org.scalatest.FunSuite import org.scalatest.mock.MockitoSugar._ import scala.collection.mutable.{Map => MutableMap} +import org.apache.spark.SparkFunSuite import org.apache.spark.mllib.util.TestingUtils._ -class MatricesSuite extends FunSuite { +class MatricesSuite extends SparkFunSuite { test("dense matrix construction") { val m = 3 val n = 2 @@ -455,4 +455,14 @@ class MatricesSuite extends FunSuite { lines = mat.toString(5, 100).lines.toArray assert(lines.size == 5 && lines.forall(_.size <= 100)) } + + test("numNonzeros and numActives") { + val dm1 = Matrices.dense(3, 2, Array(0, 0, -1, 1, 0, 1)) + assert(dm1.numNonzeros === 3) + assert(dm1.numActives === 6) + + val sm1 = Matrices.sparse(3, 2, Array(0, 2, 3), Array(0, 2, 1), Array(0.0, -1.2, 0.0)) + assert(sm1.numNonzeros === 1) + assert(sm1.numActives === 3) + } } diff --git a/mllib/src/test/scala/org/apache/spark/mllib/linalg/VectorsSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/linalg/VectorsSuite.scala index 24755e9ff46f..c4ae0a16f7c0 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/linalg/VectorsSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/linalg/VectorsSuite.scala @@ -20,12 +20,11 @@ package org.apache.spark.mllib.linalg import scala.util.Random import breeze.linalg.{DenseMatrix => BDM, squaredDistance => breezeSquaredDistance} -import org.scalatest.FunSuite -import org.apache.spark.SparkException +import org.apache.spark.{SparkException, SparkFunSuite} import org.apache.spark.mllib.util.TestingUtils._ -class VectorsSuite extends FunSuite { +class VectorsSuite extends SparkFunSuite { val arr = Array(0.1, 0.0, 0.3, 0.4) val n = 4 @@ -215,13 +214,13 @@ class VectorsSuite extends FunSuite { val squaredDist = breezeSquaredDistance(sparseVector1.toBreeze, sparseVector2.toBreeze) - // SparseVector vs. SparseVector - assert(Vectors.sqdist(sparseVector1, sparseVector2) ~== squaredDist relTol 1E-8) + // SparseVector vs. SparseVector + assert(Vectors.sqdist(sparseVector1, sparseVector2) ~== squaredDist relTol 1E-8) // DenseVector vs. SparseVector assert(Vectors.sqdist(denseVector1, sparseVector2) ~== squaredDist relTol 1E-8) // DenseVector vs. DenseVector assert(Vectors.sqdist(denseVector1, denseVector2) ~== squaredDist relTol 1E-8) - } + } } test("foreachActive") { diff --git a/mllib/src/test/scala/org/apache/spark/mllib/linalg/distributed/BlockMatrixSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/linalg/distributed/BlockMatrixSuite.scala index 949d1c993957..93fe04c139b9 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/linalg/distributed/BlockMatrixSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/linalg/distributed/BlockMatrixSuite.scala @@ -20,14 +20,13 @@ package org.apache.spark.mllib.linalg.distributed import java.{util => ju} import breeze.linalg.{DenseMatrix => BDM} -import org.scalatest.FunSuite -import org.apache.spark.SparkException +import org.apache.spark.{SparkException, SparkFunSuite} import org.apache.spark.mllib.linalg.{SparseMatrix, DenseMatrix, Matrices, Matrix} import org.apache.spark.mllib.util.MLlibTestSparkContext import org.apache.spark.mllib.util.TestingUtils._ -class BlockMatrixSuite extends FunSuite with MLlibTestSparkContext { +class BlockMatrixSuite extends SparkFunSuite with MLlibTestSparkContext { val m = 5 val n = 4 @@ -57,11 +56,13 @@ class BlockMatrixSuite extends FunSuite with MLlibTestSparkContext { val random = new ju.Random() // This should generate a 4x4 grid of 1x2 blocks. val part0 = GridPartitioner(4, 7, suggestedNumPartitions = 12) + // scalastyle:off val expected0 = Array( Array(0, 0, 4, 4, 8, 8, 12), Array(1, 1, 5, 5, 9, 9, 13), Array(2, 2, 6, 6, 10, 10, 14), Array(3, 3, 7, 7, 11, 11, 15)) + // scalastyle:on for (i <- 0 until 4; j <- 0 until 7) { assert(part0.getPartition((i, j)) === expected0(i)(j)) assert(part0.getPartition((i, j, random.nextInt())) === expected0(i)(j)) diff --git a/mllib/src/test/scala/org/apache/spark/mllib/linalg/distributed/CoordinateMatrixSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/linalg/distributed/CoordinateMatrixSuite.scala index 04b36a9ef999..f3728cd036a3 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/linalg/distributed/CoordinateMatrixSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/linalg/distributed/CoordinateMatrixSuite.scala @@ -17,14 +17,13 @@ package org.apache.spark.mllib.linalg.distributed -import org.scalatest.FunSuite - import breeze.linalg.{DenseMatrix => BDM} +import org.apache.spark.SparkFunSuite import org.apache.spark.mllib.util.MLlibTestSparkContext import org.apache.spark.mllib.linalg.Vectors -class CoordinateMatrixSuite extends FunSuite with MLlibTestSparkContext { +class CoordinateMatrixSuite extends SparkFunSuite with MLlibTestSparkContext { val m = 5 val n = 4 diff --git a/mllib/src/test/scala/org/apache/spark/mllib/linalg/distributed/IndexedRowMatrixSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/linalg/distributed/IndexedRowMatrixSuite.scala index 2ab53cc13db7..0ecb7a221a50 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/linalg/distributed/IndexedRowMatrixSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/linalg/distributed/IndexedRowMatrixSuite.scala @@ -17,15 +17,14 @@ package org.apache.spark.mllib.linalg.distributed -import org.scalatest.FunSuite - import breeze.linalg.{diag => brzDiag, DenseMatrix => BDM, DenseVector => BDV} +import org.apache.spark.SparkFunSuite import org.apache.spark.mllib.util.MLlibTestSparkContext import org.apache.spark.rdd.RDD import org.apache.spark.mllib.linalg.{Matrices, Vectors} -class IndexedRowMatrixSuite extends FunSuite with MLlibTestSparkContext { +class IndexedRowMatrixSuite extends SparkFunSuite with MLlibTestSparkContext { val m = 4 val n = 3 @@ -136,6 +135,17 @@ class IndexedRowMatrixSuite extends FunSuite with MLlibTestSparkContext { assert(closeToZero(U * brzDiag(s) * V.t - localA)) } + test("validate matrix sizes of svd") { + val k = 2 + val A = new IndexedRowMatrix(indexedRows) + val svd = A.computeSVD(k, computeU = true) + assert(svd.U.numRows() === m) + assert(svd.U.numCols() === k) + assert(svd.s.size === k) + assert(svd.V.numRows === n) + assert(svd.V.numCols === k) + } + test("validate k in svd") { val A = new IndexedRowMatrix(indexedRows) intercept[IllegalArgumentException] { diff --git a/mllib/src/test/scala/org/apache/spark/mllib/linalg/distributed/RowMatrixSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/linalg/distributed/RowMatrixSuite.scala index 27bb19f472e1..b6cb53d0c743 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/linalg/distributed/RowMatrixSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/linalg/distributed/RowMatrixSuite.scala @@ -20,12 +20,12 @@ package org.apache.spark.mllib.linalg.distributed import scala.util.Random import breeze.linalg.{DenseVector => BDV, DenseMatrix => BDM, norm => brzNorm, svd => brzSvd} -import org.scalatest.FunSuite +import org.apache.spark.SparkFunSuite import org.apache.spark.mllib.linalg.{Matrices, Vectors, Vector} import org.apache.spark.mllib.util.{LocalClusterSparkContext, MLlibTestSparkContext} -class RowMatrixSuite extends FunSuite with MLlibTestSparkContext { +class RowMatrixSuite extends SparkFunSuite with MLlibTestSparkContext { val m = 4 val n = 3 @@ -240,7 +240,7 @@ class RowMatrixSuite extends FunSuite with MLlibTestSparkContext { } } -class RowMatrixClusterSuite extends FunSuite with LocalClusterSparkContext { +class RowMatrixClusterSuite extends SparkFunSuite with LocalClusterSparkContext { var mat: RowMatrix = _ diff --git a/mllib/src/test/scala/org/apache/spark/mllib/optimization/GradientDescentSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/optimization/GradientDescentSuite.scala index 86481c6e6620..13b754a03943 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/optimization/GradientDescentSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/optimization/GradientDescentSuite.scala @@ -20,11 +20,12 @@ package org.apache.spark.mllib.optimization import scala.collection.JavaConversions._ import scala.util.Random -import org.scalatest.{FunSuite, Matchers} +import org.scalatest.Matchers +import org.apache.spark.SparkFunSuite import org.apache.spark.mllib.linalg.Vectors import org.apache.spark.mllib.regression._ -import org.apache.spark.mllib.util.{LocalClusterSparkContext, MLlibTestSparkContext} +import org.apache.spark.mllib.util.{MLUtils, LocalClusterSparkContext, MLlibTestSparkContext} import org.apache.spark.mllib.util.TestingUtils._ object GradientDescentSuite { @@ -42,7 +43,7 @@ object GradientDescentSuite { offset: Double, scale: Double, nPoints: Int, - seed: Int): Seq[LabeledPoint] = { + seed: Int): Seq[LabeledPoint] = { val rnd = new Random(seed) val x1 = Array.fill[Double](nPoints)(rnd.nextGaussian()) @@ -61,7 +62,7 @@ object GradientDescentSuite { } } -class GradientDescentSuite extends FunSuite with MLlibTestSparkContext with Matchers { +class GradientDescentSuite extends SparkFunSuite with MLlibTestSparkContext with Matchers { test("Assert the loss is decreasing.") { val nPoints = 10000 @@ -81,11 +82,11 @@ class GradientDescentSuite extends FunSuite with MLlibTestSparkContext with Matc // Add a extra variable consisting of all 1.0's for the intercept. val testData = GradientDescentSuite.generateGDInput(A, B, nPoints, 42) val data = testData.map { case LabeledPoint(label, features) => - label -> Vectors.dense(1.0 +: features.toArray) + label -> MLUtils.appendBias(features) } val dataRDD = sc.parallelize(data, 2).cache() - val initialWeightsWithIntercept = Vectors.dense(1.0 +: initialWeights.toArray) + val initialWeightsWithIntercept = Vectors.dense(initialWeights.toArray :+ 1.0) val (_, loss) = GradientDescent.runMiniBatchSGD( dataRDD, @@ -138,9 +139,48 @@ class GradientDescentSuite extends FunSuite with MLlibTestSparkContext with Matc "The different between newWeights with/without regularization " + "should be initialWeightsWithIntercept.") } + + test("iteration should end with convergence tolerance") { + val nPoints = 10000 + val A = 2.0 + val B = -1.5 + + val initialB = -1.0 + val initialWeights = Array(initialB) + + val gradient = new LogisticGradient() + val updater = new SimpleUpdater() + val stepSize = 1.0 + val numIterations = 10 + val regParam = 0 + val miniBatchFrac = 1.0 + val convergenceTolerance = 5.0e-1 + + // Add a extra variable consisting of all 1.0's for the intercept. + val testData = GradientDescentSuite.generateGDInput(A, B, nPoints, 42) + val data = testData.map { case LabeledPoint(label, features) => + label -> MLUtils.appendBias(features) + } + + val dataRDD = sc.parallelize(data, 2).cache() + val initialWeightsWithIntercept = Vectors.dense(initialWeights.toArray :+ 1.0) + + val (_, loss) = GradientDescent.runMiniBatchSGD( + dataRDD, + gradient, + updater, + stepSize, + numIterations, + regParam, + miniBatchFrac, + initialWeightsWithIntercept, + convergenceTolerance) + + assert(loss.length < numIterations, "convergenceTolerance failed to stop optimization early") + } } -class GradientDescentClusterSuite extends FunSuite with LocalClusterSparkContext { +class GradientDescentClusterSuite extends SparkFunSuite with LocalClusterSparkContext { test("task size should be small") { val m = 4 diff --git a/mllib/src/test/scala/org/apache/spark/mllib/optimization/LBFGSSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/optimization/LBFGSSuite.scala index c8f2adcf155a..75ae0eb32fb7 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/optimization/LBFGSSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/optimization/LBFGSSuite.scala @@ -19,14 +19,15 @@ package org.apache.spark.mllib.optimization import scala.util.Random -import org.scalatest.{FunSuite, Matchers} +import org.scalatest.Matchers +import org.apache.spark.SparkFunSuite import org.apache.spark.mllib.linalg.Vectors import org.apache.spark.mllib.regression.LabeledPoint import org.apache.spark.mllib.util.{LocalClusterSparkContext, MLlibTestSparkContext} import org.apache.spark.mllib.util.TestingUtils._ -class LBFGSSuite extends FunSuite with MLlibTestSparkContext with Matchers { +class LBFGSSuite extends SparkFunSuite with MLlibTestSparkContext with Matchers { val nPoints = 10000 val A = 2.0 @@ -121,7 +122,8 @@ class LBFGSSuite extends FunSuite with MLlibTestSparkContext with Matchers { numGDIterations, regParam, miniBatchFrac, - initialWeightsWithIntercept) + initialWeightsWithIntercept, + convergenceTol) assert(lossGD(0) ~= lossLBFGS(0) absTol 1E-5, "The first losses of LBFGS and GD should be the same.") @@ -220,7 +222,8 @@ class LBFGSSuite extends FunSuite with MLlibTestSparkContext with Matchers { numGDIterations, regParam, miniBatchFrac, - initialWeightsWithIntercept) + initialWeightsWithIntercept, + convergenceTol) // for class LBFGS and the optimize method, we only look at the weights assert( @@ -229,7 +232,7 @@ class LBFGSSuite extends FunSuite with MLlibTestSparkContext with Matchers { } } -class LBFGSClusterSuite extends FunSuite with LocalClusterSparkContext { +class LBFGSClusterSuite extends SparkFunSuite with LocalClusterSparkContext { test("task size should be small") { val m = 10 diff --git a/mllib/src/test/scala/org/apache/spark/mllib/optimization/NNLSSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/optimization/NNLSSuite.scala index 22855e4e8f24..d8f9b8c33963 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/optimization/NNLSSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/optimization/NNLSSuite.scala @@ -19,13 +19,12 @@ package org.apache.spark.mllib.optimization import scala.util.Random -import org.scalatest.FunSuite - import org.jblas.{DoubleMatrix, SimpleBlas} +import org.apache.spark.SparkFunSuite import org.apache.spark.mllib.util.TestingUtils._ -class NNLSSuite extends FunSuite { +class NNLSSuite extends SparkFunSuite { /** Generate an NNLS problem whose optimal solution is the all-ones vector. */ def genOnesData(n: Int, rand: Random): (DoubleMatrix, DoubleMatrix) = { val A = new DoubleMatrix(n, n, Array.fill(n*n)(rand.nextDouble()): _*) @@ -68,12 +67,14 @@ class NNLSSuite extends FunSuite { test("NNLS: nonnegativity constraint active") { val n = 5 + // scalastyle:off val ata = new DoubleMatrix(Array( Array( 4.377, -3.531, -1.306, -0.139, 3.418), Array(-3.531, 4.344, 0.934, 0.305, -2.140), Array(-1.306, 0.934, 2.644, -0.203, -0.170), Array(-0.139, 0.305, -0.203, 5.883, 1.428), Array( 3.418, -2.140, -0.170, 1.428, 4.684))) + // scalastyle:on val atb = new DoubleMatrix(Array(-1.632, 2.115, 1.094, -1.025, -0.636)) val goodx = Array(0.13025, 0.54506, 0.2874, 0.0, 0.028628) diff --git a/mllib/src/test/scala/org/apache/spark/mllib/pmml/export/BinaryClassificationPMMLModelExportSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/pmml/export/BinaryClassificationPMMLModelExportSuite.scala index 0b646cf1ce6c..4c6e76e47419 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/pmml/export/BinaryClassificationPMMLModelExportSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/pmml/export/BinaryClassificationPMMLModelExportSuite.scala @@ -19,13 +19,13 @@ package org.apache.spark.mllib.pmml.export import org.dmg.pmml.RegressionModel import org.dmg.pmml.RegressionNormalizationMethodType -import org.scalatest.FunSuite +import org.apache.spark.SparkFunSuite import org.apache.spark.mllib.classification.LogisticRegressionModel import org.apache.spark.mllib.classification.SVMModel import org.apache.spark.mllib.util.LinearDataGenerator -class BinaryClassificationPMMLModelExportSuite extends FunSuite { +class BinaryClassificationPMMLModelExportSuite extends SparkFunSuite { test("logistic regression PMML export") { val linearInput = LinearDataGenerator.generateLinearInput(3.0, Array(10.0, 10.0), 1, 17) @@ -53,13 +53,13 @@ class BinaryClassificationPMMLModelExportSuite extends FunSuite { // ensure logistic regression has normalization method set to LOGIT assert(pmmlRegressionModel.getNormalizationMethod() == RegressionNormalizationMethodType.LOGIT) } - + test("linear SVM PMML export") { val linearInput = LinearDataGenerator.generateLinearInput(3.0, Array(10.0, 10.0), 1, 17) val svmModel = new SVMModel(linearInput(0).features, linearInput(0).label) - + val svmModelExport = PMMLModelExportFactory.createPMMLModelExport(svmModel) - + // assert that the PMML format is as expected assert(svmModelExport.isInstanceOf[PMMLModelExport]) val pmml = svmModelExport.getPmml @@ -80,5 +80,5 @@ class BinaryClassificationPMMLModelExportSuite extends FunSuite { // ensure linear SVM has normalization method set to NONE assert(pmmlRegressionModel.getNormalizationMethod() == RegressionNormalizationMethodType.NONE) } - + } diff --git a/mllib/src/test/scala/org/apache/spark/mllib/pmml/export/GeneralizedLinearPMMLModelExportSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/pmml/export/GeneralizedLinearPMMLModelExportSuite.scala index f9afbd888dfc..1d3230948178 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/pmml/export/GeneralizedLinearPMMLModelExportSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/pmml/export/GeneralizedLinearPMMLModelExportSuite.scala @@ -18,12 +18,12 @@ package org.apache.spark.mllib.pmml.export import org.dmg.pmml.RegressionModel -import org.scalatest.FunSuite +import org.apache.spark.SparkFunSuite import org.apache.spark.mllib.regression.{LassoModel, LinearRegressionModel, RidgeRegressionModel} import org.apache.spark.mllib.util.LinearDataGenerator -class GeneralizedLinearPMMLModelExportSuite extends FunSuite { +class GeneralizedLinearPMMLModelExportSuite extends SparkFunSuite { test("linear regression PMML export") { val linearInput = LinearDataGenerator.generateLinearInput(3.0, Array(10.0, 10.0), 1, 17) diff --git a/mllib/src/test/scala/org/apache/spark/mllib/pmml/export/KMeansPMMLModelExportSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/pmml/export/KMeansPMMLModelExportSuite.scala index b985d0446d7b..b3f9750afa73 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/pmml/export/KMeansPMMLModelExportSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/pmml/export/KMeansPMMLModelExportSuite.scala @@ -18,12 +18,12 @@ package org.apache.spark.mllib.pmml.export import org.dmg.pmml.ClusteringModel -import org.scalatest.FunSuite +import org.apache.spark.SparkFunSuite import org.apache.spark.mllib.clustering.KMeansModel import org.apache.spark.mllib.linalg.Vectors -class KMeansPMMLModelExportSuite extends FunSuite { +class KMeansPMMLModelExportSuite extends SparkFunSuite { test("KMeansPMMLModelExport generate PMML format") { val clusterCenters = Array( @@ -45,5 +45,5 @@ class KMeansPMMLModelExportSuite extends FunSuite { val pmmlClusteringModel = pmml.getModels.get(0).asInstanceOf[ClusteringModel] assert(pmmlClusteringModel.getNumberOfClusters === clusterCenters.length) } - + } diff --git a/mllib/src/test/scala/org/apache/spark/mllib/pmml/export/PMMLModelExportFactorySuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/pmml/export/PMMLModelExportFactorySuite.scala index f28a4ac8ad01..af4945096175 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/pmml/export/PMMLModelExportFactorySuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/pmml/export/PMMLModelExportFactorySuite.scala @@ -17,15 +17,14 @@ package org.apache.spark.mllib.pmml.export -import org.scalatest.FunSuite - +import org.apache.spark.SparkFunSuite import org.apache.spark.mllib.classification.{LogisticRegressionModel, SVMModel} import org.apache.spark.mllib.clustering.KMeansModel import org.apache.spark.mllib.linalg.Vectors import org.apache.spark.mllib.regression.{LassoModel, LinearRegressionModel, RidgeRegressionModel} import org.apache.spark.mllib.util.LinearDataGenerator -class PMMLModelExportFactorySuite extends FunSuite { +class PMMLModelExportFactorySuite extends SparkFunSuite { test("PMMLModelExportFactory create KMeansPMMLModelExport when passing a KMeansModel") { val clusterCenters = Array( @@ -61,25 +60,25 @@ class PMMLModelExportFactorySuite extends FunSuite { test("PMMLModelExportFactory create BinaryClassificationPMMLModelExport " + "when passing a LogisticRegressionModel or SVMModel") { val linearInput = LinearDataGenerator.generateLinearInput(3.0, Array(10.0, 10.0), 1, 17) - + val logisticRegressionModel = new LogisticRegressionModel(linearInput(0).features, linearInput(0).label) val logisticRegressionModelExport = PMMLModelExportFactory.createPMMLModelExport(logisticRegressionModel) assert(logisticRegressionModelExport.isInstanceOf[BinaryClassificationPMMLModelExport]) - + val svmModel = new SVMModel(linearInput(0).features, linearInput(0).label) val svmModelExport = PMMLModelExportFactory.createPMMLModelExport(svmModel) assert(svmModelExport.isInstanceOf[BinaryClassificationPMMLModelExport]) } - + test("PMMLModelExportFactory throw IllegalArgumentException " + "when passing a Multinomial Logistic Regression") { /** 3 classes, 2 features */ val multiclassLogisticRegressionModel = new LogisticRegressionModel( - weights = Vectors.dense(0.1, 0.2, 0.3, 0.4), intercept = 1.0, + weights = Vectors.dense(0.1, 0.2, 0.3, 0.4), intercept = 1.0, numFeatures = 2, numClasses = 3) - + intercept[IllegalArgumentException] { PMMLModelExportFactory.createPMMLModelExport(multiclassLogisticRegressionModel) } diff --git a/mllib/src/test/scala/org/apache/spark/mllib/random/RandomDataGeneratorSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/random/RandomDataGeneratorSuite.scala index b792d819fdab..a5ca1518f82f 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/random/RandomDataGeneratorSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/random/RandomDataGeneratorSuite.scala @@ -19,12 +19,11 @@ package org.apache.spark.mllib.random import scala.math -import org.scalatest.FunSuite - +import org.apache.spark.SparkFunSuite import org.apache.spark.util.StatCounter // TODO update tests to use TestingUtils for floating point comparison after PR 1367 is merged -class RandomDataGeneratorSuite extends FunSuite { +class RandomDataGeneratorSuite extends SparkFunSuite { def apiChecks(gen: RandomDataGenerator[Double]) { // resetting seed should generate the same sequence of random numbers diff --git a/mllib/src/test/scala/org/apache/spark/mllib/random/RandomRDDsSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/random/RandomRDDsSuite.scala index 63f2ea916d45..413db2000d6d 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/random/RandomRDDsSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/random/RandomRDDsSuite.scala @@ -19,8 +19,7 @@ package org.apache.spark.mllib.random import scala.collection.mutable.ArrayBuffer -import org.scalatest.FunSuite - +import org.apache.spark.SparkFunSuite import org.apache.spark.SparkContext._ import org.apache.spark.mllib.linalg.Vector import org.apache.spark.mllib.rdd.{RandomRDDPartition, RandomRDD} @@ -34,7 +33,7 @@ import org.apache.spark.util.StatCounter * * TODO update tests to use TestingUtils for floating point comparison after PR 1367 is merged */ -class RandomRDDsSuite extends FunSuite with MLlibTestSparkContext with Serializable { +class RandomRDDsSuite extends SparkFunSuite with MLlibTestSparkContext with Serializable { def testGeneratedRDD(rdd: RDD[Double], expectedSize: Long, diff --git a/mllib/src/test/scala/org/apache/spark/mllib/rdd/MLPairRDDFunctionsSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/rdd/MLPairRDDFunctionsSuite.scala index 57216e8eb4a5..10f5a2be48f7 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/rdd/MLPairRDDFunctionsSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/rdd/MLPairRDDFunctionsSuite.scala @@ -17,12 +17,11 @@ package org.apache.spark.mllib.rdd -import org.scalatest.FunSuite - +import org.apache.spark.SparkFunSuite import org.apache.spark.mllib.util.MLlibTestSparkContext import org.apache.spark.mllib.rdd.MLPairRDDFunctions._ -class MLPairRDDFunctionsSuite extends FunSuite with MLlibTestSparkContext { +class MLPairRDDFunctionsSuite extends SparkFunSuite with MLlibTestSparkContext { test("topByKey") { val topMap = sc.parallelize(Array((1, 7), (1, 3), (1, 6), (1, 1), (1, 2), (3, 2), (3, 7), (5, 1), (3, 5)), 2) diff --git a/mllib/src/test/scala/org/apache/spark/mllib/rdd/RDDFunctionsSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/rdd/RDDFunctionsSuite.scala index 6d6c0aa5be81..bc6417261483 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/rdd/RDDFunctionsSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/rdd/RDDFunctionsSuite.scala @@ -17,12 +17,11 @@ package org.apache.spark.mllib.rdd -import org.scalatest.FunSuite - +import org.apache.spark.SparkFunSuite import org.apache.spark.mllib.util.MLlibTestSparkContext import org.apache.spark.mllib.rdd.RDDFunctions._ -class RDDFunctionsSuite extends FunSuite with MLlibTestSparkContext { +class RDDFunctionsSuite extends SparkFunSuite with MLlibTestSparkContext { test("sliding") { val data = 0 until 6 diff --git a/mllib/src/test/scala/org/apache/spark/mllib/recommendation/ALSSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/recommendation/ALSSuite.scala index b3798940ddc3..05b87728d6fd 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/recommendation/ALSSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/recommendation/ALSSuite.scala @@ -21,9 +21,9 @@ import scala.collection.JavaConversions._ import scala.math.abs import scala.util.Random -import org.scalatest.FunSuite import org.jblas.DoubleMatrix +import org.apache.spark.SparkFunSuite import org.apache.spark.mllib.util.MLlibTestSparkContext import org.apache.spark.storage.StorageLevel @@ -84,7 +84,7 @@ object ALSSuite { } -class ALSSuite extends FunSuite with MLlibTestSparkContext { +class ALSSuite extends SparkFunSuite with MLlibTestSparkContext { test("rank-1 matrices") { testALS(50, 100, 1, 15, 0.7, 0.3) diff --git a/mllib/src/test/scala/org/apache/spark/mllib/recommendation/MatrixFactorizationModelSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/recommendation/MatrixFactorizationModelSuite.scala index 2c92866f3893..2c8ed057a516 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/recommendation/MatrixFactorizationModelSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/recommendation/MatrixFactorizationModelSuite.scala @@ -17,14 +17,13 @@ package org.apache.spark.mllib.recommendation -import org.scalatest.FunSuite - +import org.apache.spark.SparkFunSuite import org.apache.spark.mllib.util.MLlibTestSparkContext import org.apache.spark.mllib.util.TestingUtils._ import org.apache.spark.rdd.RDD import org.apache.spark.util.Utils -class MatrixFactorizationModelSuite extends FunSuite with MLlibTestSparkContext { +class MatrixFactorizationModelSuite extends SparkFunSuite with MLlibTestSparkContext { val rank = 2 var userFeatures: RDD[(Int, Array[Double])] = _ diff --git a/mllib/src/test/scala/org/apache/spark/mllib/regression/IsotonicRegressionSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/regression/IsotonicRegressionSuite.scala index 3b38bdf5ef5e..ea4f2865757c 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/regression/IsotonicRegressionSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/regression/IsotonicRegressionSuite.scala @@ -17,13 +17,14 @@ package org.apache.spark.mllib.regression -import org.scalatest.{Matchers, FunSuite} +import org.scalatest.Matchers +import org.apache.spark.SparkFunSuite import org.apache.spark.mllib.util.MLlibTestSparkContext import org.apache.spark.mllib.util.TestingUtils._ import org.apache.spark.util.Utils -class IsotonicRegressionSuite extends FunSuite with MLlibTestSparkContext with Matchers { +class IsotonicRegressionSuite extends SparkFunSuite with MLlibTestSparkContext with Matchers { private def round(d: Double) = { math.round(d * 100).toDouble / 100 diff --git a/mllib/src/test/scala/org/apache/spark/mllib/regression/LabeledPointSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/regression/LabeledPointSuite.scala index 110c44a7193f..f8d0af8820e6 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/regression/LabeledPointSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/regression/LabeledPointSuite.scala @@ -17,11 +17,10 @@ package org.apache.spark.mllib.regression -import org.scalatest.FunSuite - +import org.apache.spark.SparkFunSuite import org.apache.spark.mllib.linalg.Vectors -class LabeledPointSuite extends FunSuite { +class LabeledPointSuite extends SparkFunSuite { test("parse labeled points") { val points = Seq( @@ -32,6 +31,11 @@ class LabeledPointSuite extends FunSuite { } } + test("parse labeled points with whitespaces") { + val point = LabeledPoint.parse("(0.0, [1.0, 2.0])") + assert(point === LabeledPoint(0.0, Vectors.dense(1.0, 2.0))) + } + test("parse labeled points with v0.9 format") { val point = LabeledPoint.parse("1.0,1.0 0.0 -2.0") assert(point === LabeledPoint(1.0, Vectors.dense(1.0, 0.0, -2.0))) diff --git a/mllib/src/test/scala/org/apache/spark/mllib/regression/LassoSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/regression/LassoSuite.scala index c9f5dc069ef2..39537e7bb4c7 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/regression/LassoSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/regression/LassoSuite.scala @@ -19,8 +19,7 @@ package org.apache.spark.mllib.regression import scala.util.Random -import org.scalatest.FunSuite - +import org.apache.spark.SparkFunSuite import org.apache.spark.mllib.linalg.Vectors import org.apache.spark.mllib.util.{LocalClusterSparkContext, LinearDataGenerator, MLlibTestSparkContext} @@ -32,7 +31,7 @@ private object LassoSuite { val model = new LassoModel(weights = Vectors.dense(0.1, 0.2, 0.3), intercept = 0.5) } -class LassoSuite extends FunSuite with MLlibTestSparkContext { +class LassoSuite extends SparkFunSuite with MLlibTestSparkContext { def validatePrediction(predictions: Seq[Double], input: Seq[LabeledPoint]) { val numOffPredictions = predictions.zip(input).count { case (prediction, expected) => @@ -67,11 +66,12 @@ class LassoSuite extends FunSuite with MLlibTestSparkContext { assert(weight1 >= -1.60 && weight1 <= -1.40, weight1 + " not in [-1.6, -1.4]") assert(weight2 >= -1.0e-3 && weight2 <= 1.0e-3, weight2 + " not in [-0.001, 0.001]") - val validationData = LinearDataGenerator.generateLinearInput(A, Array[Double](B,C), nPoints, 17) + val validationData = LinearDataGenerator + .generateLinearInput(A, Array[Double](B, C), nPoints, 17) .map { case LabeledPoint(label, features) => LabeledPoint(label, Vectors.dense(1.0 +: features.toArray)) } - val validationRDD = sc.parallelize(validationData, 2) + val validationRDD = sc.parallelize(validationData, 2) // Test prediction on RDD. validatePrediction(model.predict(validationRDD.map(_.features)).collect(), validationData) @@ -100,7 +100,7 @@ class LassoSuite extends FunSuite with MLlibTestSparkContext { val testRDD = sc.parallelize(testData, 2).cache() val ls = new LassoWithSGD() - ls.optimizer.setStepSize(1.0).setRegParam(0.01).setNumIterations(40) + ls.optimizer.setStepSize(1.0).setRegParam(0.01).setNumIterations(40).setConvergenceTol(0.0005) val model = ls.run(testRDD, initialWeights) val weight0 = model.weights(0) @@ -110,11 +110,12 @@ class LassoSuite extends FunSuite with MLlibTestSparkContext { assert(weight1 >= -1.60 && weight1 <= -1.40, weight1 + " not in [-1.6, -1.4]") assert(weight2 >= -1.0e-3 && weight2 <= 1.0e-3, weight2 + " not in [-0.001, 0.001]") - val validationData = LinearDataGenerator.generateLinearInput(A, Array[Double](B,C), nPoints, 17) + val validationData = LinearDataGenerator + .generateLinearInput(A, Array[Double](B, C), nPoints, 17) .map { case LabeledPoint(label, features) => LabeledPoint(label, Vectors.dense(1.0 +: features.toArray)) } - val validationRDD = sc.parallelize(validationData,2) + val validationRDD = sc.parallelize(validationData, 2) // Test prediction on RDD. validatePrediction(model.predict(validationRDD.map(_.features)).collect(), validationData) @@ -141,7 +142,7 @@ class LassoSuite extends FunSuite with MLlibTestSparkContext { } } -class LassoClusterSuite extends FunSuite with LocalClusterSparkContext { +class LassoClusterSuite extends SparkFunSuite with LocalClusterSparkContext { test("task size should be small in both training and prediction") { val m = 4 diff --git a/mllib/src/test/scala/org/apache/spark/mllib/regression/LinearRegressionSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/regression/LinearRegressionSuite.scala index 3781931c2f81..f88a1c33c9f7 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/regression/LinearRegressionSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/regression/LinearRegressionSuite.scala @@ -19,8 +19,7 @@ package org.apache.spark.mllib.regression import scala.util.Random -import org.scalatest.FunSuite - +import org.apache.spark.SparkFunSuite import org.apache.spark.mllib.linalg.Vectors import org.apache.spark.mllib.util.{LocalClusterSparkContext, LinearDataGenerator, MLlibTestSparkContext} @@ -32,7 +31,7 @@ private object LinearRegressionSuite { val model = new LinearRegressionModel(weights = Vectors.dense(0.1, 0.2, 0.3), intercept = 0.5) } -class LinearRegressionSuite extends FunSuite with MLlibTestSparkContext { +class LinearRegressionSuite extends SparkFunSuite with MLlibTestSparkContext { def validatePrediction(predictions: Seq[Double], input: Seq[LabeledPoint]) { val numOffPredictions = predictions.zip(input).count { case (prediction, expected) => @@ -150,7 +149,7 @@ class LinearRegressionSuite extends FunSuite with MLlibTestSparkContext { } } -class LinearRegressionClusterSuite extends FunSuite with LocalClusterSparkContext { +class LinearRegressionClusterSuite extends SparkFunSuite with LocalClusterSparkContext { test("task size should be small in both training and prediction") { val m = 4 diff --git a/mllib/src/test/scala/org/apache/spark/mllib/regression/RidgeRegressionSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/regression/RidgeRegressionSuite.scala index d6c93cc0e49c..7a781fee634c 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/regression/RidgeRegressionSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/regression/RidgeRegressionSuite.scala @@ -20,8 +20,8 @@ package org.apache.spark.mllib.regression import scala.util.Random import org.jblas.DoubleMatrix -import org.scalatest.FunSuite +import org.apache.spark.SparkFunSuite import org.apache.spark.mllib.linalg.Vectors import org.apache.spark.mllib.util.{LocalClusterSparkContext, LinearDataGenerator, MLlibTestSparkContext} @@ -33,7 +33,7 @@ private object RidgeRegressionSuite { val model = new RidgeRegressionModel(weights = Vectors.dense(0.1, 0.2, 0.3), intercept = 0.5) } -class RidgeRegressionSuite extends FunSuite with MLlibTestSparkContext { +class RidgeRegressionSuite extends SparkFunSuite with MLlibTestSparkContext { def predictionError(predictions: Seq[Double], input: Seq[LabeledPoint]): Double = { predictions.zip(input).map { case (prediction, expected) => @@ -101,7 +101,7 @@ class RidgeRegressionSuite extends FunSuite with MLlibTestSparkContext { } } -class RidgeRegressionClusterSuite extends FunSuite with LocalClusterSparkContext { +class RidgeRegressionClusterSuite extends SparkFunSuite with LocalClusterSparkContext { test("task size should be small in both training and prediction") { val m = 4 diff --git a/mllib/src/test/scala/org/apache/spark/mllib/regression/StreamingLinearRegressionSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/regression/StreamingLinearRegressionSuite.scala index 26604dbe6c1e..a2a4c5f6b8b7 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/regression/StreamingLinearRegressionSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/regression/StreamingLinearRegressionSuite.scala @@ -19,14 +19,13 @@ package org.apache.spark.mllib.regression import scala.collection.mutable.ArrayBuffer -import org.scalatest.FunSuite - +import org.apache.spark.SparkFunSuite import org.apache.spark.mllib.linalg.Vectors import org.apache.spark.mllib.util.LinearDataGenerator import org.apache.spark.streaming.dstream.DStream import org.apache.spark.streaming.TestSuiteBase -class StreamingLinearRegressionSuite extends FunSuite with TestSuiteBase { +class StreamingLinearRegressionSuite extends SparkFunSuite with TestSuiteBase { // use longer wait time to ensure job completion override def maxWaitTimeMillis: Int = 20000 @@ -54,6 +53,7 @@ class StreamingLinearRegressionSuite extends FunSuite with TestSuiteBase { .setInitialWeights(Vectors.dense(0.0, 0.0)) .setStepSize(0.2) .setNumIterations(25) + .setConvergenceTol(0.0001) // generate sequence of simulated data val numBatches = 10 @@ -167,4 +167,22 @@ class StreamingLinearRegressionSuite extends FunSuite with TestSuiteBase { val error = output.map(batch => batch.map(p => math.abs(p._1 - p._2)).sum / nPoints).toList assert((error.head - error.last) > 2) } + + // Test empty RDDs in a stream + test("handling empty RDDs in a stream") { + val model = new StreamingLinearRegressionWithSGD() + .setInitialWeights(Vectors.dense(0.0, 0.0)) + .setStepSize(0.2) + .setNumIterations(25) + val numBatches = 10 + val nPoints = 100 + val emptyInput = Seq.empty[Seq[LabeledPoint]] + val ssc = setupStreams(emptyInput, + (inputDStream: DStream[LabeledPoint]) => { + model.trainOn(inputDStream) + model.predictOnValues(inputDStream.map(x => (x.label, x.features))) + } + ) + val output: Seq[Seq[(Double, Double)]] = runStreams(ssc, numBatches, numBatches) + } } diff --git a/mllib/src/test/scala/org/apache/spark/mllib/stat/CorrelationSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/stat/CorrelationSuite.scala index d20a09b4b492..c292ced75e87 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/stat/CorrelationSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/stat/CorrelationSuite.scala @@ -17,16 +17,15 @@ package org.apache.spark.mllib.stat -import org.scalatest.FunSuite - import breeze.linalg.{DenseMatrix => BDM, Matrix => BM} +import org.apache.spark.SparkFunSuite import org.apache.spark.mllib.linalg.Vectors import org.apache.spark.mllib.stat.correlation.{Correlations, PearsonCorrelation, SpearmanCorrelation} import org.apache.spark.mllib.util.MLlibTestSparkContext -class CorrelationSuite extends FunSuite with MLlibTestSparkContext { +class CorrelationSuite extends SparkFunSuite with MLlibTestSparkContext { // test input data val xData = Array(1.0, 0.0, -2.0) @@ -96,11 +95,13 @@ class CorrelationSuite extends FunSuite with MLlibTestSparkContext { val X = sc.parallelize(data) val defaultMat = Statistics.corr(X) val pearsonMat = Statistics.corr(X, "pearson") + // scalastyle:off val expected = BDM( (1.00000000, 0.05564149, Double.NaN, 0.4004714), (0.05564149, 1.00000000, Double.NaN, 0.9135959), (Double.NaN, Double.NaN, 1.00000000, Double.NaN), - (0.40047142, 0.91359586, Double.NaN,1.0000000)) + (0.40047142, 0.91359586, Double.NaN, 1.0000000)) + // scalastyle:on assert(matrixApproxEqual(defaultMat.toBreeze, expected)) assert(matrixApproxEqual(pearsonMat.toBreeze, expected)) } @@ -108,11 +109,13 @@ class CorrelationSuite extends FunSuite with MLlibTestSparkContext { test("corr(X) spearman") { val X = sc.parallelize(data) val spearmanMat = Statistics.corr(X, "spearman") + // scalastyle:off val expected = BDM( (1.0000000, 0.1054093, Double.NaN, 0.4000000), (0.1054093, 1.0000000, Double.NaN, 0.9486833), (Double.NaN, Double.NaN, 1.00000000, Double.NaN), (0.4000000, 0.9486833, Double.NaN, 1.0000000)) + // scalastyle:on assert(matrixApproxEqual(spearmanMat.toBreeze, expected)) } diff --git a/mllib/src/test/scala/org/apache/spark/mllib/stat/HypothesisTestSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/stat/HypothesisTestSuite.scala index 15418e603596..b084a5fb4313 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/stat/HypothesisTestSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/stat/HypothesisTestSuite.scala @@ -19,16 +19,14 @@ package org.apache.spark.mllib.stat import java.util.Random -import org.scalatest.FunSuite - -import org.apache.spark.SparkException +import org.apache.spark.{SparkException, SparkFunSuite} import org.apache.spark.mllib.linalg.{DenseVector, Matrices, Vectors} import org.apache.spark.mllib.regression.LabeledPoint import org.apache.spark.mllib.stat.test.ChiSqTest import org.apache.spark.mllib.util.MLlibTestSparkContext import org.apache.spark.mllib.util.TestingUtils._ -class HypothesisTestSuite extends FunSuite with MLlibTestSparkContext { +class HypothesisTestSuite extends SparkFunSuite with MLlibTestSparkContext { test("chi squared pearson goodness of fit") { diff --git a/mllib/src/test/scala/org/apache/spark/mllib/stat/KernelDensitySuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/stat/KernelDensitySuite.scala index 16ecae23dd9d..5feccdf33681 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/stat/KernelDensitySuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/stat/KernelDensitySuite.scala @@ -17,31 +17,32 @@ package org.apache.spark.mllib.stat -import org.scalatest.FunSuite - import org.apache.commons.math3.distribution.NormalDistribution +import org.apache.spark.SparkFunSuite import org.apache.spark.mllib.util.MLlibTestSparkContext -class KernelDensitySuite extends FunSuite with MLlibTestSparkContext { +class KernelDensitySuite extends SparkFunSuite with MLlibTestSparkContext { test("kernel density single sample") { val rdd = sc.parallelize(Array(5.0)) val evaluationPoints = Array(5.0, 6.0) - val densities = KernelDensity.estimate(rdd, 3.0, evaluationPoints) + val densities = new KernelDensity().setSample(rdd).setBandwidth(3.0).estimate(evaluationPoints) val normal = new NormalDistribution(5.0, 3.0) val acceptableErr = 1e-6 - assert(densities(0) - normal.density(5.0) < acceptableErr) - assert(densities(0) - normal.density(6.0) < acceptableErr) + assert(math.abs(densities(0) - normal.density(5.0)) < acceptableErr) + assert(math.abs(densities(1) - normal.density(6.0)) < acceptableErr) } test("kernel density multiple samples") { val rdd = sc.parallelize(Array(5.0, 10.0)) val evaluationPoints = Array(5.0, 6.0) - val densities = KernelDensity.estimate(rdd, 3.0, evaluationPoints) + val densities = new KernelDensity().setSample(rdd).setBandwidth(3.0).estimate(evaluationPoints) val normal1 = new NormalDistribution(5.0, 3.0) val normal2 = new NormalDistribution(10.0, 3.0) val acceptableErr = 1e-6 - assert(densities(0) - (normal1.density(5.0) + normal2.density(5.0)) / 2 < acceptableErr) - assert(densities(0) - (normal1.density(6.0) + normal2.density(6.0)) / 2 < acceptableErr) + assert(math.abs( + densities(0) - (normal1.density(5.0) + normal2.density(5.0)) / 2) < acceptableErr) + assert(math.abs( + densities(1) - (normal1.density(6.0) + normal2.density(6.0)) / 2) < acceptableErr) } } diff --git a/mllib/src/test/scala/org/apache/spark/mllib/stat/MultivariateOnlineSummarizerSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/stat/MultivariateOnlineSummarizerSuite.scala index 23b0eec865de..07efde4f5e6d 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/stat/MultivariateOnlineSummarizerSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/stat/MultivariateOnlineSummarizerSuite.scala @@ -17,12 +17,11 @@ package org.apache.spark.mllib.stat -import org.scalatest.FunSuite - +import org.apache.spark.SparkFunSuite import org.apache.spark.mllib.linalg.Vectors import org.apache.spark.mllib.util.TestingUtils._ -class MultivariateOnlineSummarizerSuite extends FunSuite { +class MultivariateOnlineSummarizerSuite extends SparkFunSuite { test("basic error handing") { val summarizer = new MultivariateOnlineSummarizer diff --git a/mllib/src/test/scala/org/apache/spark/mllib/stat/distribution/MultivariateGaussianSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/stat/distribution/MultivariateGaussianSuite.scala index fac2498e4dcb..aa60deb665ae 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/stat/distribution/MultivariateGaussianSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/stat/distribution/MultivariateGaussianSuite.scala @@ -17,49 +17,48 @@ package org.apache.spark.mllib.stat.distribution -import org.scalatest.FunSuite - +import org.apache.spark.SparkFunSuite import org.apache.spark.mllib.linalg.{ Vectors, Matrices } import org.apache.spark.mllib.util.MLlibTestSparkContext import org.apache.spark.mllib.util.TestingUtils._ -class MultivariateGaussianSuite extends FunSuite with MLlibTestSparkContext { +class MultivariateGaussianSuite extends SparkFunSuite with MLlibTestSparkContext { test("univariate") { val x1 = Vectors.dense(0.0) val x2 = Vectors.dense(1.5) - + val mu = Vectors.dense(0.0) val sigma1 = Matrices.dense(1, 1, Array(1.0)) val dist1 = new MultivariateGaussian(mu, sigma1) assert(dist1.pdf(x1) ~== 0.39894 absTol 1E-5) assert(dist1.pdf(x2) ~== 0.12952 absTol 1E-5) - + val sigma2 = Matrices.dense(1, 1, Array(4.0)) val dist2 = new MultivariateGaussian(mu, sigma2) assert(dist2.pdf(x1) ~== 0.19947 absTol 1E-5) assert(dist2.pdf(x2) ~== 0.15057 absTol 1E-5) } - + test("multivariate") { val x1 = Vectors.dense(0.0, 0.0) val x2 = Vectors.dense(1.0, 1.0) - + val mu = Vectors.dense(0.0, 0.0) val sigma1 = Matrices.dense(2, 2, Array(1.0, 0.0, 0.0, 1.0)) val dist1 = new MultivariateGaussian(mu, sigma1) assert(dist1.pdf(x1) ~== 0.15915 absTol 1E-5) assert(dist1.pdf(x2) ~== 0.05855 absTol 1E-5) - + val sigma2 = Matrices.dense(2, 2, Array(4.0, -1.0, -1.0, 2.0)) val dist2 = new MultivariateGaussian(mu, sigma2) assert(dist2.pdf(x1) ~== 0.060155 absTol 1E-5) assert(dist2.pdf(x2) ~== 0.033971 absTol 1E-5) } - + test("multivariate degenerate") { val x1 = Vectors.dense(0.0, 0.0) val x2 = Vectors.dense(1.0, 1.0) - + val mu = Vectors.dense(0.0, 0.0) val sigma = Matrices.dense(2, 2, Array(1.0, 1.0, 1.0, 1.0)) val dist = new MultivariateGaussian(mu, sigma) diff --git a/mllib/src/test/scala/org/apache/spark/mllib/tree/DecisionTreeSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/tree/DecisionTreeSuite.scala index ce983eb27fa3..356d957f1590 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/tree/DecisionTreeSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/tree/DecisionTreeSuite.scala @@ -20,8 +20,7 @@ package org.apache.spark.mllib.tree import scala.collection.JavaConverters._ import scala.collection.mutable -import org.scalatest.FunSuite - +import org.apache.spark.SparkFunSuite import org.apache.spark.mllib.linalg.Vectors import org.apache.spark.mllib.regression.LabeledPoint import org.apache.spark.mllib.tree.configuration.Algo._ @@ -34,7 +33,7 @@ import org.apache.spark.mllib.util.MLlibTestSparkContext import org.apache.spark.util.Utils -class DecisionTreeSuite extends FunSuite with MLlibTestSparkContext { +class DecisionTreeSuite extends SparkFunSuite with MLlibTestSparkContext { ///////////////////////////////////////////////////////////////////////////// // Tests examining individual elements of training @@ -859,7 +858,7 @@ class DecisionTreeSuite extends FunSuite with MLlibTestSparkContext { } } -object DecisionTreeSuite extends FunSuite { +object DecisionTreeSuite extends SparkFunSuite { def validateClassifier( model: DecisionTreeModel, diff --git a/mllib/src/test/scala/org/apache/spark/mllib/tree/GradientBoostedTreesSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/tree/GradientBoostedTreesSuite.scala index 55b0bac7d49f..84dd3b342d4c 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/tree/GradientBoostedTreesSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/tree/GradientBoostedTreesSuite.scala @@ -17,8 +17,7 @@ package org.apache.spark.mllib.tree -import org.scalatest.FunSuite - +import org.apache.spark.SparkFunSuite import org.apache.spark.mllib.regression.LabeledPoint import org.apache.spark.mllib.tree.configuration.Algo._ import org.apache.spark.mllib.tree.configuration.{BoostingStrategy, Strategy} @@ -32,7 +31,7 @@ import org.apache.spark.util.Utils /** * Test suite for [[GradientBoostedTrees]]. */ -class GradientBoostedTreesSuite extends FunSuite with MLlibTestSparkContext { +class GradientBoostedTreesSuite extends SparkFunSuite with MLlibTestSparkContext { test("Regression with continuous features: SquaredError") { GradientBoostedTreesSuite.testCombinations.foreach { diff --git a/mllib/src/test/scala/org/apache/spark/mllib/tree/ImpuritySuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/tree/ImpuritySuite.scala index 92b498580af0..49aff21fe791 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/tree/ImpuritySuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/tree/ImpuritySuite.scala @@ -17,15 +17,14 @@ package org.apache.spark.mllib.tree -import org.scalatest.FunSuite - +import org.apache.spark.SparkFunSuite import org.apache.spark.mllib.tree.impurity.{EntropyAggregator, GiniAggregator} import org.apache.spark.mllib.util.MLlibTestSparkContext /** * Test suites for [[GiniAggregator]] and [[EntropyAggregator]]. */ -class ImpuritySuite extends FunSuite with MLlibTestSparkContext { +class ImpuritySuite extends SparkFunSuite with MLlibTestSparkContext { test("Gini impurity does not support negative labels") { val gini = new GiniAggregator(2) intercept[IllegalArgumentException] { diff --git a/mllib/src/test/scala/org/apache/spark/mllib/tree/RandomForestSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/tree/RandomForestSuite.scala index ee3bc9848686..e6df5d974bf3 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/tree/RandomForestSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/tree/RandomForestSuite.scala @@ -19,8 +19,7 @@ package org.apache.spark.mllib.tree import scala.collection.mutable -import org.scalatest.FunSuite - +import org.apache.spark.SparkFunSuite import org.apache.spark.mllib.linalg.Vectors import org.apache.spark.mllib.regression.LabeledPoint import org.apache.spark.mllib.tree.configuration.Algo._ @@ -35,7 +34,7 @@ import org.apache.spark.util.Utils /** * Test suite for [[RandomForest]]. */ -class RandomForestSuite extends FunSuite with MLlibTestSparkContext { +class RandomForestSuite extends SparkFunSuite with MLlibTestSparkContext { def binaryClassificationTestWithContinuousFeatures(strategy: Strategy) { val arr = EnsembleTestHelper.generateOrderedLabeledPoints(numFeatures = 50, 1000) val rdd = sc.parallelize(arr) @@ -196,7 +195,6 @@ class RandomForestSuite extends FunSuite with MLlibTestSparkContext { numClasses = 3, categoricalFeaturesInfo = categoricalFeaturesInfo) val model = RandomForest.trainClassifier(input, strategy, numTrees = 2, featureSubsetStrategy = "sqrt", seed = 12345) - EnsembleTestHelper.validateClassifier(model, arr, 1.0) } test("subsampling rate in RandomForest"){ diff --git a/mllib/src/test/scala/org/apache/spark/mllib/tree/impl/BaggedPointSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/tree/impl/BaggedPointSuite.scala index b184e936672c..9d756da41032 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/tree/impl/BaggedPointSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/tree/impl/BaggedPointSuite.scala @@ -17,15 +17,14 @@ package org.apache.spark.mllib.tree.impl -import org.scalatest.FunSuite - +import org.apache.spark.SparkFunSuite import org.apache.spark.mllib.tree.EnsembleTestHelper import org.apache.spark.mllib.util.MLlibTestSparkContext /** * Test suite for [[BaggedPoint]]. */ -class BaggedPointSuite extends FunSuite with MLlibTestSparkContext { +class BaggedPointSuite extends SparkFunSuite with MLlibTestSparkContext { test("BaggedPoint RDD: without subsampling") { val arr = EnsembleTestHelper.generateOrderedLabeledPoints(1, 1000) diff --git a/mllib/src/test/scala/org/apache/spark/mllib/util/MLUtilsSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/util/MLUtilsSuite.scala index 668fc1d43c5d..70219e9ad9d3 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/util/MLUtilsSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/util/MLUtilsSuite.scala @@ -21,19 +21,19 @@ import java.io.File import scala.io.Source -import org.scalatest.FunSuite - import breeze.linalg.{squaredDistance => breezeSquaredDistance} import com.google.common.base.Charsets import com.google.common.io.Files +import org.apache.spark.SparkException +import org.apache.spark.SparkFunSuite import org.apache.spark.mllib.linalg.{DenseVector, SparseVector, Vectors} import org.apache.spark.mllib.regression.LabeledPoint import org.apache.spark.mllib.util.MLUtils._ import org.apache.spark.mllib.util.TestingUtils._ import org.apache.spark.util.Utils -class MLUtilsSuite extends FunSuite with MLlibTestSparkContext { +class MLUtilsSuite extends SparkFunSuite with MLlibTestSparkContext { test("epsilon computation") { assert(1.0 + EPSILON > 1.0, s"EPSILON is too small: $EPSILON.") @@ -63,7 +63,7 @@ class MLUtilsSuite extends FunSuite with MLlibTestSparkContext { val fastSquaredDist3 = fastSquaredDistance(v2, norm2, v3, norm3, precision) assert((fastSquaredDist3 - squaredDist2) <= precision * squaredDist2, s"failed with m = $m") - if (m > 10) { + if (m > 10) { val v4 = Vectors.sparse(n, indices.slice(0, m - 10), indices.map(i => a(i) + 0.5).slice(0, m - 10)) val norm4 = Vectors.norm(v4, 2.0) @@ -109,6 +109,40 @@ class MLUtilsSuite extends FunSuite with MLlibTestSparkContext { Utils.deleteRecursively(tempDir) } + test("loadLibSVMFile throws IllegalArgumentException when indices is zero-based") { + val lines = + """ + |0 + |0 0:4.0 4:5.0 6:6.0 + """.stripMargin + val tempDir = Utils.createTempDir() + val file = new File(tempDir.getPath, "part-00000") + Files.write(lines, file, Charsets.US_ASCII) + val path = tempDir.toURI.toString + + intercept[SparkException] { + loadLibSVMFile(sc, path).collect() + } + Utils.deleteRecursively(tempDir) + } + + test("loadLibSVMFile throws IllegalArgumentException when indices is not in ascending order") { + val lines = + """ + |0 + |0 3:4.0 2:5.0 6:6.0 + """.stripMargin + val tempDir = Utils.createTempDir() + val file = new File(tempDir.getPath, "part-00000") + Files.write(lines, file, Charsets.US_ASCII) + val path = tempDir.toURI.toString + + intercept[SparkException] { + loadLibSVMFile(sc, path).collect() + } + Utils.deleteRecursively(tempDir) + } + test("saveAsLibSVMFile") { val examples = sc.parallelize(Seq( LabeledPoint(1.1, Vectors.sparse(3, Seq((0, 1.23), (2, 4.56)))), @@ -168,7 +202,7 @@ class MLUtilsSuite extends FunSuite with MLlibTestSparkContext { "Each training+validation set combined should contain all of the data.") } // K fold cross validation should only have each element in the validation set exactly once - assert(foldedRdds.map(_._2).reduce((x,y) => x.union(y)).collect().sorted === + assert(foldedRdds.map(_._2).reduce((x, y) => x.union(y)).collect().sorted === data.collect().sorted) } } diff --git a/mllib/src/test/scala/org/apache/spark/mllib/util/MLlibTestSparkContext.scala b/mllib/src/test/scala/org/apache/spark/mllib/util/MLlibTestSparkContext.scala index b658889476d3..5d1796ef6572 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/util/MLlibTestSparkContext.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/util/MLlibTestSparkContext.scala @@ -17,13 +17,14 @@ package org.apache.spark.mllib.util -import org.scalatest.Suite -import org.scalatest.BeforeAndAfterAll +import org.scalatest.{BeforeAndAfterAll, Suite} import org.apache.spark.{SparkConf, SparkContext} +import org.apache.spark.sql.SQLContext trait MLlibTestSparkContext extends BeforeAndAfterAll { self: Suite => @transient var sc: SparkContext = _ + @transient var sqlContext: SQLContext = _ override def beforeAll() { super.beforeAll() @@ -31,12 +32,15 @@ trait MLlibTestSparkContext extends BeforeAndAfterAll { self: Suite => .setMaster("local[2]") .setAppName("MLlibUnitTest") sc = new SparkContext(conf) + sqlContext = new SQLContext(sc) } override def afterAll() { + sqlContext = null if (sc != null) { sc.stop() } + sc = null super.afterAll() } } diff --git a/mllib/src/test/scala/org/apache/spark/mllib/util/NumericParserSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/util/NumericParserSuite.scala index f68fb95eac4e..fa4f74d71b7e 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/util/NumericParserSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/util/NumericParserSuite.scala @@ -17,11 +17,9 @@ package org.apache.spark.mllib.util -import org.scalatest.FunSuite +import org.apache.spark.{SparkException, SparkFunSuite} -import org.apache.spark.SparkException - -class NumericParserSuite extends FunSuite { +class NumericParserSuite extends SparkFunSuite { test("parser") { val s = "((1.0,2e3),-4,[5e-6,7.0E8],+9)" @@ -39,4 +37,11 @@ class NumericParserSuite extends FunSuite { } } } + + test("parser with whitespaces") { + val s = "(0.0, [1.0, 2.0])" + val parsed = NumericParser.parse(s).asInstanceOf[Seq[_]] + assert(parsed(0).asInstanceOf[Double] === 0.0) + assert(parsed(1).asInstanceOf[Array[Double]] === Array(1.0, 2.0)) + } } diff --git a/mllib/src/test/scala/org/apache/spark/mllib/util/TestingUtilsSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/util/TestingUtilsSuite.scala index 59e6c778806f..8f475f30249d 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/util/TestingUtilsSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/util/TestingUtilsSuite.scala @@ -17,12 +17,12 @@ package org.apache.spark.mllib.util +import org.apache.spark.SparkFunSuite import org.apache.spark.mllib.linalg.Vectors -import org.scalatest.FunSuite import org.apache.spark.mllib.util.TestingUtils._ import org.scalatest.exceptions.TestFailedException -class TestingUtilsSuite extends FunSuite { +class TestingUtilsSuite extends SparkFunSuite { test("Comparing doubles using relative error.") { diff --git a/network/common/pom.xml b/network/common/pom.xml index 0c3147761cfc..7dc3068ab8cb 100644 --- a/network/common/pom.xml +++ b/network/common/pom.xml @@ -22,7 +22,7 @@ org.apache.spark spark-parent_2.10 - 1.4.0-SNAPSHOT + 1.5.0-SNAPSHOT ../../pom.xml @@ -77,7 +77,7 @@ org.mockito - mockito-all + mockito-core test diff --git a/network/common/src/main/java/org/apache/spark/network/util/JavaUtils.java b/network/common/src/main/java/org/apache/spark/network/util/JavaUtils.java index 6b514aaa1290..7d27439cfde7 100644 --- a/network/common/src/main/java/org/apache/spark/network/util/JavaUtils.java +++ b/network/common/src/main/java/org/apache/spark/network/util/JavaUtils.java @@ -39,6 +39,12 @@ public class JavaUtils { private static final Logger logger = LoggerFactory.getLogger(JavaUtils.class); + /** + * Define a default value for driver memory here since this value is referenced across the code + * base and nearly all files already use Utils.scala + */ + public static final long DEFAULT_DRIVER_MEM_MB = 1024; + /** Closes the given object, ignoring IOExceptions. */ public static void closeQuietly(Closeable closeable) { try { diff --git a/network/shuffle/pom.xml b/network/shuffle/pom.xml index 7dc7c65825e3..532463e96fbb 100644 --- a/network/shuffle/pom.xml +++ b/network/shuffle/pom.xml @@ -22,7 +22,7 @@ org.apache.spark spark-parent_2.10 - 1.4.0-SNAPSHOT + 1.5.0-SNAPSHOT ../../pom.xml @@ -79,7 +79,7 @@ org.mockito - mockito-all + mockito-core test diff --git a/network/shuffle/src/main/java/org/apache/spark/network/shuffle/ExternalShuffleBlockResolver.java b/network/shuffle/src/main/java/org/apache/spark/network/shuffle/ExternalShuffleBlockResolver.java index dd08e24cade2..022ed88a1648 100644 --- a/network/shuffle/src/main/java/org/apache/spark/network/shuffle/ExternalShuffleBlockResolver.java +++ b/network/shuffle/src/main/java/org/apache/spark/network/shuffle/ExternalShuffleBlockResolver.java @@ -108,7 +108,8 @@ public ManagedBuffer getBlockData(String appId, String execId, String blockId) { if ("org.apache.spark.shuffle.hash.HashShuffleManager".equals(executor.shuffleManager)) { return getHashBasedShuffleBlockData(executor, blockId); - } else if ("org.apache.spark.shuffle.sort.SortShuffleManager".equals(executor.shuffleManager)) { + } else if ("org.apache.spark.shuffle.sort.SortShuffleManager".equals(executor.shuffleManager) + || "org.apache.spark.shuffle.unsafe.UnsafeShuffleManager".equals(executor.shuffleManager)) { return getSortBasedShuffleBlockData(executor, shuffleId, mapId, reduceId); } else { throw new UnsupportedOperationException( diff --git a/network/shuffle/src/main/java/org/apache/spark/network/shuffle/protocol/OpenBlocks.java b/network/shuffle/src/main/java/org/apache/spark/network/shuffle/protocol/OpenBlocks.java index 60485bace643..ce954b8a289e 100644 --- a/network/shuffle/src/main/java/org/apache/spark/network/shuffle/protocol/OpenBlocks.java +++ b/network/shuffle/src/main/java/org/apache/spark/network/shuffle/protocol/OpenBlocks.java @@ -24,6 +24,9 @@ import org.apache.spark.network.protocol.Encoders; +// Needed by ScalaDoc. See SPARK-7726 +import static org.apache.spark.network.shuffle.protocol.BlockTransferMessage.Type; + /** Request to read a set of blocks. Returns {@link StreamHandle}. */ public class OpenBlocks extends BlockTransferMessage { public final String appId; diff --git a/network/shuffle/src/main/java/org/apache/spark/network/shuffle/protocol/RegisterExecutor.java b/network/shuffle/src/main/java/org/apache/spark/network/shuffle/protocol/RegisterExecutor.java index 38acae3b31d6..cca8b17c4f12 100644 --- a/network/shuffle/src/main/java/org/apache/spark/network/shuffle/protocol/RegisterExecutor.java +++ b/network/shuffle/src/main/java/org/apache/spark/network/shuffle/protocol/RegisterExecutor.java @@ -22,6 +22,9 @@ import org.apache.spark.network.protocol.Encoders; +// Needed by ScalaDoc. See SPARK-7726 +import static org.apache.spark.network.shuffle.protocol.BlockTransferMessage.Type; + /** * Initial registration message between an executor and its local shuffle server. * Returns nothing (empty bye array). diff --git a/network/shuffle/src/main/java/org/apache/spark/network/shuffle/protocol/StreamHandle.java b/network/shuffle/src/main/java/org/apache/spark/network/shuffle/protocol/StreamHandle.java index 9a9220211a50..1915295aa6cc 100644 --- a/network/shuffle/src/main/java/org/apache/spark/network/shuffle/protocol/StreamHandle.java +++ b/network/shuffle/src/main/java/org/apache/spark/network/shuffle/protocol/StreamHandle.java @@ -20,6 +20,9 @@ import com.google.common.base.Objects; import io.netty.buffer.ByteBuf; +// Needed by ScalaDoc. See SPARK-7726 +import static org.apache.spark.network.shuffle.protocol.BlockTransferMessage.Type; + /** * Identifier for a fixed number of chunks to read from a stream created by an "open blocks" * message. This is used by {@link org.apache.spark.network.shuffle.OneForOneBlockFetcher}. diff --git a/network/shuffle/src/main/java/org/apache/spark/network/shuffle/protocol/UploadBlock.java b/network/shuffle/src/main/java/org/apache/spark/network/shuffle/protocol/UploadBlock.java index 2ff9aaa650f9..3caed59d508f 100644 --- a/network/shuffle/src/main/java/org/apache/spark/network/shuffle/protocol/UploadBlock.java +++ b/network/shuffle/src/main/java/org/apache/spark/network/shuffle/protocol/UploadBlock.java @@ -24,6 +24,9 @@ import org.apache.spark.network.protocol.Encoders; +// Needed by ScalaDoc. See SPARK-7726 +import static org.apache.spark.network.shuffle.protocol.BlockTransferMessage.Type; + /** Request to upload a block with a certain StorageLevel. Returns nothing (empty byte array). */ public class UploadBlock extends BlockTransferMessage { diff --git a/network/yarn/pom.xml b/network/yarn/pom.xml index 1e2e9c80af6c..a99f7c4392d3 100644 --- a/network/yarn/pom.xml +++ b/network/yarn/pom.xml @@ -22,7 +22,7 @@ org.apache.spark spark-parent_2.10 - 1.4.0-SNAPSHOT + 1.5.0-SNAPSHOT ../../pom.xml diff --git a/pom.xml b/pom.xml index 86aa0a9fa134..ffa96128a3d6 100644 --- a/pom.xml +++ b/pom.xml @@ -26,7 +26,7 @@ org.apache.spark spark-parent_2.10 - 1.4.0-SNAPSHOT + 1.5.0-SNAPSHOT pom Spark Project Parent POM http://spark.apache.org/ @@ -102,6 +102,7 @@ external/twitter external/flume external/flume-sink + external/flume-assembly external/mqtt external/zeromq examples @@ -114,11 +115,10 @@ UTF-8 UTF-8 - org.spark-project.akka - 2.3.4-spark - 1.6 + com.typesafe.akka + 2.3.11 + 1.7 spark - 2.0.1 0.21.1 shaded-protobuf 1.7.10 @@ -130,14 +130,15 @@ hbase 1.4.0 3.4.5 + 2.4.0 org.spark-project.hive 0.13.1a 0.13.1 10.10.1.1 - 1.6.0rc3 - 1.2.3 + 1.7.0 + 1.2.4 8.1.14.v20131031 3.0.0.v201112011016 0.5.0 @@ -147,8 +148,8 @@ 1.7.7 hadoop2 0.7.1 - 1.8.3 - 1.1.0 + 1.9.16 + 1.2.1 4.3.2 3.4.1 ${project.build.directory}/spark-test-classpath.txt @@ -156,7 +157,6 @@ 2.10 ${scala.version} org.scala-lang - 3.6.3 1.9.13 2.4.4 1.1.1.7 @@ -179,7 +179,7 @@ compile ${session.executionRootDirectory} @@ -249,7 +249,7 @@ mapr-repo MapR Repository - http://repository.mapr.com/maven + http://repository.mapr.com/maven/ true @@ -268,6 +268,18 @@ false + + + spark-1.4-staging + Spark 1.4 RC4 Staging Repository + https://repository.apache.org/content/repositories/orgapachespark-1112 + + true + + + false + + @@ -494,7 +506,7 @@ net.jpountz.lz4 lz4 - 1.2.0 + 1.3.0 com.clearspring.analytics @@ -575,7 +587,7 @@ io.netty netty-all - 4.0.23.Final + 4.0.28.Final org.apache.derby @@ -670,7 +682,7 @@ org.mockito - mockito-all + mockito-core 1.9.5 test @@ -707,7 +719,7 @@ org.apache.curator curator-recipes - 2.4.0 + ${curator.version} ${hadoop.deps.scope} @@ -716,6 +728,16 @@ + + org.apache.curator + curator-client + ${curator.version} + + + org.apache.curator + curator-framework + ${curator.version} + org.apache.hadoop hadoop-client @@ -726,6 +748,10 @@ asm asm + + org.codehaus.jackson + jackson-mapper-asl + org.ow2.asm asm @@ -738,6 +764,10 @@ commons-logging commons-logging + + org.mockito + mockito-all + org.mortbay.jetty servlet-api-2.5 @@ -1058,13 +1088,13 @@ - com.twitter + org.apache.parquet parquet-column ${parquet.version} ${parquet.deps.scope} - com.twitter + org.apache.parquet parquet-hadoop ${parquet.version} ${parquet.deps.scope} @@ -1194,15 +1224,6 @@ -target ${java.version} - - - - org.scalamacros - paradise_${scala.version} - ${scala.macros.version} - - @@ -1231,7 +1252,7 @@ **/*Suite.java ${project.build.directory}/surefire-reports - -Xmx3g -XX:MaxPermSize=${MaxPermGen} -XX:ReservedCodeCacheSize=512m + -Xmx3g -Xss4096k -XX:MaxPermSize=${MaxPermGen} -XX:ReservedCodeCacheSize=512m org.apache.maven.plugins @@ -1648,11 +1693,13 @@ hadoop-1 - 1.0.4 + 1.2.1 2.4.1 0.98.7-hadoop1 hadoop1 1.8.8 + org.spark-project.akka + 2.3.4-spark @@ -1679,6 +1726,17 @@ + + hadoop-2.6 + + 2.6.0 + 0.9.3 + 3.1.1 + 3.4.6 + 2.6.0 + + + yarn @@ -1709,7 +1767,7 @@ org.apache.curator curator-recipes - 2.4.0 + ${curator.version} org.apache.zookeeper @@ -1731,22 +1789,6 @@ sql/hive-thriftserver - - hive-0.12.0 - - 0.12.0-protobuf-2.5 - 0.12.0 - 10.4.2.0 - - - - hive-0.13.1 - - 0.13.1a - 0.13.1 - 10.10.1.1 - - scala-2.10 diff --git a/project/MimaBuild.scala b/project/MimaBuild.scala index dde92949fa17..f16bf989f200 100644 --- a/project/MimaBuild.scala +++ b/project/MimaBuild.scala @@ -91,7 +91,7 @@ object MimaBuild { def mimaSettings(sparkHome: File, projectRef: ProjectRef) = { val organization = "org.apache.spark" - val previousSparkVersion = "1.3.0" + val previousSparkVersion = "1.4.0" val fullId = "spark-" + projectRef.project + "_2.10" mimaDefaultSettings ++ Seq(previousArtifact := Some(organization % fullId % previousSparkVersion), diff --git a/project/MimaExcludes.scala b/project/MimaExcludes.scala index 487062a31f77..680b699e9e4a 100644 --- a/project/MimaExcludes.scala +++ b/project/MimaExcludes.scala @@ -34,10 +34,60 @@ import com.typesafe.tools.mima.core.ProblemFilters._ object MimaExcludes { def excludes(version: String) = version match { + case v if v.startsWith("1.5") => + Seq( + MimaBuild.excludeSparkPackage("deploy"), + // These are needed if checking against the sbt build, since they are part of + // the maven-generated artifacts in 1.3. + excludePackage("org.spark-project.jetty"), + MimaBuild.excludeSparkPackage("unused"), + // JavaRDDLike is not meant to be extended by user programs + ProblemFilters.exclude[MissingMethodProblem]( + "org.apache.spark.api.java.JavaRDDLike.partitioner"), + // Modification of private static method + ProblemFilters.exclude[IncompatibleMethTypeProblem]( + "org.apache.spark.streaming.kafka.KafkaUtils.org$apache$spark$streaming$kafka$KafkaUtils$$leadersForRanges"), + // Mima false positive (was a private[spark] class) + ProblemFilters.exclude[MissingClassProblem]( + "org.apache.spark.util.collection.PairIterator"), + // Removing a testing method from a private class + ProblemFilters.exclude[MissingMethodProblem]( + "org.apache.spark.streaming.kafka.KafkaTestUtils.waitUntilLeaderOffset"), + // While private MiMa is still not happy about the changes, + ProblemFilters.exclude[MissingMethodProblem]( + "org.apache.spark.ml.regression.LeastSquaresAggregator.this"), + ProblemFilters.exclude[MissingMethodProblem]( + "org.apache.spark.ml.regression.LeastSquaresCostFun.this"), + // SQL execution is considered private. + excludePackage("org.apache.spark.sql.execution"), + // NanoTime and CatalystTimestampConverter is only used inside catalyst, + // not needed anymore + ProblemFilters.exclude[MissingClassProblem]( + "org.apache.spark.sql.parquet.timestamp.NanoTime"), + ProblemFilters.exclude[MissingClassProblem]( + "org.apache.spark.sql.parquet.timestamp.NanoTime$"), + ProblemFilters.exclude[MissingClassProblem]( + "org.apache.spark.sql.parquet.CatalystTimestampConverter"), + ProblemFilters.exclude[MissingClassProblem]( + "org.apache.spark.sql.parquet.CatalystTimestampConverter$"), + // SPARK-6777 Implements backwards compatibility rules in CatalystSchemaConverter + ProblemFilters.exclude[MissingClassProblem]( + "org.apache.spark.sql.parquet.ParquetTypeInfo"), + ProblemFilters.exclude[MissingClassProblem]( + "org.apache.spark.sql.parquet.ParquetTypeInfo$") + ) ++ Seq( + // SPARK-8479 Add numNonzeros and numActives to Matrix. + ProblemFilters.exclude[MissingMethodProblem]( + "org.apache.spark.mllib.linalg.Matrix.numNonzeros"), + ProblemFilters.exclude[MissingMethodProblem]( + "org.apache.spark.mllib.linalg.Matrix.numActives") + ) case v if v.startsWith("1.4") => Seq( MimaBuild.excludeSparkPackage("deploy"), MimaBuild.excludeSparkPackage("ml"), + // SPARK-7910 Adding a method to get the partioner to JavaRDD, + ProblemFilters.exclude[MissingMethodProblem]("org.apache.spark.api.java.JavaRDDLike.partitioner"), // SPARK-5922 Adding a generalized diff(other: RDD[(VertexId, VD)]) to VertexRDD ProblemFilters.exclude[MissingMethodProblem]("org.apache.spark.graphx.VertexRDD.diff"), // These are needed if checking against the sbt build, since they are part of @@ -87,7 +137,14 @@ object MimaExcludes { ProblemFilters.exclude[MissingMethodProblem]( "org.apache.spark.mllib.linalg.Vector.toSparse"), ProblemFilters.exclude[MissingMethodProblem]( - "org.apache.spark.mllib.linalg.Vector.numActives") + "org.apache.spark.mllib.linalg.Vector.numActives"), + // SPARK-7681 add SparseVector support for gemv + ProblemFilters.exclude[MissingMethodProblem]( + "org.apache.spark.mllib.linalg.Matrix.multiply"), + ProblemFilters.exclude[MissingMethodProblem]( + "org.apache.spark.mllib.linalg.DenseMatrix.multiply"), + ProblemFilters.exclude[MissingMethodProblem]( + "org.apache.spark.mllib.linalg.SparseMatrix.multiply") ) ++ Seq( // Execution should never be included as its always internal. MimaBuild.excludeSparkPackage("sql.execution"), @@ -126,7 +183,10 @@ object MimaExcludes { "org.apache.spark.sql.parquet.TestGroupWriteSupport"), ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.CachedData"), ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.CachedData$"), - ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.CacheManager") + ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.CacheManager"), + // TODO: Remove the following rule once ParquetTest has been moved to src/test. + ProblemFilters.exclude[MissingClassProblem]( + "org.apache.spark.sql.parquet.ParquetTest") ) ++ Seq( // SPARK-7530 Added StreamingContext.getState() ProblemFilters.exclude[MissingMethodProblem]( @@ -137,6 +197,14 @@ object MimaExcludes { // implementing this interface in Java. Note that ShuffleWriter is private[spark]. ProblemFilters.exclude[IncompatibleTemplateDefProblem]( "org.apache.spark.shuffle.ShuffleWriter") + ) ++ Seq( + // SPARK-6888 make jdbc driver handling user definable + // This patch renames some classes to API friendly names. + ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.jdbc.DriverQuirks$"), + ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.jdbc.DriverQuirks"), + ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.jdbc.PostgresQuirks"), + ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.jdbc.NoQuirks"), + ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.jdbc.MySQLQuirks") ) case v if v.startsWith("1.3") => diff --git a/project/SparkBuild.scala b/project/SparkBuild.scala index 1b87e4e98bd8..3408c6d51ed4 100644 --- a/project/SparkBuild.scala +++ b/project/SparkBuild.scala @@ -23,11 +23,12 @@ import scala.collection.JavaConversions._ import sbt._ import sbt.Classpaths.publishTask import sbt.Keys._ -import sbtunidoc.Plugin.genjavadocSettings import sbtunidoc.Plugin.UnidocKeys.unidocGenjavadocVersion import com.typesafe.sbt.pom.{loadEffectivePom, PomBuild, SbtPomKeys} import net.virtualvoid.sbt.graph.Plugin.graphSettings +import spray.revolver.RevolverPlugin._ + object BuildCommons { private val buildLocation = file(".").getAbsoluteFile.getParentFile @@ -44,14 +45,16 @@ object BuildCommons { sparkKinesisAsl) = Seq("yarn", "yarn-stable", "java8-tests", "ganglia-lgpl", "kinesis-asl").map(ProjectRef(buildLocation, _)) - val assemblyProjects@Seq(assembly, examples, networkYarn, streamingKafkaAssembly) = - Seq("assembly", "examples", "network-yarn", "streaming-kafka-assembly") + val assemblyProjects@Seq(assembly, examples, networkYarn, streamingFlumeAssembly, streamingKafkaAssembly) = + Seq("assembly", "examples", "network-yarn", "streaming-flume-assembly", "streaming-kafka-assembly") .map(ProjectRef(buildLocation, _)) val tools = ProjectRef(buildLocation, "tools") // Root project. val spark = ProjectRef(buildLocation, "spark") val sparkHome = buildLocation + + val testTempDir = s"$sparkHome/target/tmp" } object SparkBuild extends PomBuild { @@ -118,7 +121,12 @@ object SparkBuild extends PomBuild { lazy val MavenCompile = config("m2r") extend(Compile) lazy val publishLocalBoth = TaskKey[Unit]("publish-local", "publish local for m2 and ivy") - lazy val sharedSettings = graphSettings ++ genjavadocSettings ++ Seq ( + lazy val sparkGenjavadocSettings: Seq[sbt.Def.Setting[_]] = Seq( + libraryDependencies += compilerPlugin( + "org.spark-project" %% "genjavadoc-plugin" % unidocGenjavadocVersion.value cross CrossVersion.full), + scalacOptions <+= target.map(t => "-P:genjavadoc:out=" + (t / "java"))) + + lazy val sharedSettings = graphSettings ++ sparkGenjavadocSettings ++ Seq ( javaHome := sys.env.get("JAVA_HOME") .orElse(sys.props.get("java.home").map { p => new File(p).getParentFile().getAbsolutePath() }) .map(file), @@ -126,7 +134,7 @@ object SparkBuild extends PomBuild { retrieveManaged := true, retrievePattern := "[type]s/[artifact](-[revision])(-[classifier]).[ext]", publishMavenStyle := true, - unidocGenjavadocVersion := "0.8", + unidocGenjavadocVersion := "0.9-spark0", resolvers += Resolver.mavenLocal, otherResolvers <<= SbtPomKeys.mvnLocalRepository(dotM2 => Seq(Resolver.file("dotM2", dotM2))), @@ -140,7 +148,9 @@ object SparkBuild extends PomBuild { javacOptions in (Compile, doc) ++= { val Array(major, minor, _) = System.getProperty("java.version").split("\\.", 3) if (major.toInt >= 1 && minor.toInt >= 8) Seq("-Xdoclint:all", "-Xdoclint:-missing") else Seq.empty - } + }, + + javacOptions in Compile ++= Seq("-encoding", "UTF-8") ) def enable(settings: Seq[Setting[_]])(projectRef: ProjectRef) = { @@ -151,14 +161,13 @@ object SparkBuild extends PomBuild { // Note ordering of these settings matter. /* Enable shared settings on all projects */ (allProjects ++ optionallyEnabledProjects ++ assemblyProjects ++ Seq(spark, tools)) - .foreach(enable(sharedSettings ++ ExludedDependencies.settings)) + .foreach(enable(sharedSettings ++ ExcludedDependencies.settings ++ Revolver.settings)) /* Enable tests settings for all projects except examples, assembly and tools */ (allProjects ++ optionallyEnabledProjects).foreach(enable(TestSettings.settings)) - // TODO: remove launcher from this list after 1.4.0 allProjects.filterNot(x => Seq(spark, hive, hiveThriftServer, catalyst, repl, - networkCommon, networkShuffle, networkYarn, launcher, unsafe).contains(x)).foreach { + networkCommon, networkShuffle, networkYarn, unsafe).contains(x)).foreach { x => enable(MimaBuild.mimaSettings(sparkHome, x))(x) } @@ -174,9 +183,6 @@ object SparkBuild extends PomBuild { /* Enable unidoc only for the root spark project */ enable(Unidoc.settings)(spark) - /* Catalyst macro settings */ - enable(Catalyst.settings)(catalyst) - /* Spark SQL Core console settings */ enable(SQL.settings)(sql) @@ -200,7 +206,7 @@ object SparkBuild extends PomBuild { fork := true, outputStrategy in run := Some (StdoutOutput), - javaOptions ++= Seq("-Xmx2G", "-XX:MaxPermSize=1g"), + javaOptions ++= Seq("-Xmx2G", "-XX:MaxPermSize=256m"), sparkShell := { (runMain in Compile).toTask(" org.apache.spark.repl.Main -usejavacp").value @@ -240,7 +246,7 @@ object Flume { This excludes library dependencies in sbt, which are specified in maven but are not needed by sbt build. */ -object ExludedDependencies { +object ExcludedDependencies { lazy val settings = Seq( libraryDependencies ~= { libs => libs.filterNot(_.name == "groovy-all") } ) @@ -271,14 +277,6 @@ object OldDeps { ) } -object Catalyst { - lazy val settings = Seq( - addCompilerPlugin("org.scalamacros" % "paradise" % "2.0.1" cross CrossVersion.full), - // Quasiquotes break compiling scala doc... - // TODO: Investigate fixing this. - sources in (Compile, doc) ~= (_ filter (_.getName contains "codegen"))) -} - object SQL { lazy val settings = Seq( initialCommands in console := @@ -301,7 +299,7 @@ object SQL { object Hive { lazy val settings = Seq( - javaOptions += "-XX:MaxPermSize=1g", + javaOptions += "-XX:MaxPermSize=256m", // Specially disable assertions since some Hive tests fail them javaOptions in Test := (javaOptions in Test).value.filterNot(_ == "-ea"), // Multiple queries rely on the TestHive singleton. See comments there for more details. @@ -324,6 +322,7 @@ object Hive { |import org.apache.spark.sql.functions._ |import org.apache.spark.sql.hive._ |import org.apache.spark.sql.hive.test.TestHive._ + |import org.apache.spark.sql.hive.test.TestHive.implicits._ |import org.apache.spark.sql.types._""".stripMargin, cleanupCommands in console := "sparkContext.stop()", // Some of our log4j jars make it impossible to submit jobs from this JVM to Hive Map/Reduce @@ -348,7 +347,7 @@ object Assembly { .getOrElse(SbtPomKeys.effectivePom.value.getProperties.get("hadoop.version").asInstanceOf[String]) }, jarName in assembly <<= (version, moduleName, hadoopVersion) map { (v, mName, hv) => - if (mName.contains("streaming-kafka-assembly")) { + if (mName.contains("streaming-flume-assembly") || mName.contains("streaming-kafka-assembly")) { // This must match the same name used in maven (see external/kafka-assembly/pom.xml) s"${mName}-${v}.jar" } else { @@ -502,6 +501,7 @@ object TestSettings { "SPARK_DIST_CLASSPATH" -> (fullClasspath in Test).value.files.map(_.getAbsolutePath).mkString(":").stripSuffix(":"), "JAVA_HOME" -> sys.env.get("JAVA_HOME").getOrElse(sys.props("java.home"))), + javaOptions in Test += s"-Djava.io.tmpdir=$testTempDir", javaOptions in Test += "-Dspark.test.home=" + sparkHome, javaOptions in Test += "-Dspark.testing=1", javaOptions in Test += "-Dspark.port.maxRetries=100", @@ -510,10 +510,11 @@ object TestSettings { javaOptions in Test += "-Dspark.driver.allowMultipleContexts=true", javaOptions in Test += "-Dspark.unsafe.exceptionOnMemoryLeak=true", javaOptions in Test += "-Dsun.io.serialization.extendedDebugInfo=true", + javaOptions in Test += "-Dderby.system.durability=test", javaOptions in Test ++= System.getProperties.filter(_._1 startsWith "spark") .map { case (k,v) => s"-D$k=$v" }.toSeq, javaOptions in Test += "-ea", - javaOptions in Test ++= "-Xmx3g -XX:PermSize=128M -XX:MaxNewSize=256m -XX:MaxPermSize=1g" + javaOptions in Test ++= "-Xmx3g -Xss4096k -XX:PermSize=128M -XX:MaxNewSize=256m -XX:MaxPermSize=1g" .split(" ").toSeq, javaOptions += "-Xmx3g", // Show full stack trace and duration in test cases. @@ -523,6 +524,13 @@ object TestSettings { libraryDependencies += "com.novocode" % "junit-interface" % "0.9" % "test", // Only allow one test at a time, even across projects, since they run in the same JVM parallelExecution in Test := false, + // Make sure the test temp directory exists. + resourceGenerators in Test <+= resourceManaged in Test map { outDir: File => + if (!new File(testTempDir).isDirectory()) { + require(new File(testTempDir).mkdirs()) + } + Seq[File]() + }, concurrentRestrictions in Global += Tags.limit(Tags.Test, 1), // Remove certain packages from Scaladoc scalacOptions in (Compile, doc) := Seq( diff --git a/project/plugins.sbt b/project/plugins.sbt index 7096b0d3ee7d..51820460ca1a 100644 --- a/project/plugins.sbt +++ b/project/plugins.sbt @@ -25,10 +25,12 @@ addSbtPlugin("com.typesafe" % "sbt-mima-plugin" % "0.1.6") addSbtPlugin("com.alpinenow" % "junit_xml_listener" % "0.5.1") -addSbtPlugin("com.eed3si9n" % "sbt-unidoc" % "0.3.1") +addSbtPlugin("com.eed3si9n" % "sbt-unidoc" % "0.3.3") addSbtPlugin("com.cavorite" % "sbt-avro" % "0.3.2") +addSbtPlugin("io.spray" % "sbt-revolver" % "0.7.2") + libraryDependencies += "org.ow2.asm" % "asm" % "5.0.3" libraryDependencies += "org.ow2.asm" % "asm-commons" % "5.0.3" diff --git a/python/pyspark/accumulators.py b/python/pyspark/accumulators.py index 0d21a132048a..6ef8cf53cc74 100644 --- a/python/pyspark/accumulators.py +++ b/python/pyspark/accumulators.py @@ -261,3 +261,9 @@ def _start_update_server(): thread.daemon = True thread.start() return server + +if __name__ == "__main__": + import doctest + (failure_count, test_count) = doctest.testmod() + if failure_count: + exit(-1) diff --git a/python/pyspark/broadcast.py b/python/pyspark/broadcast.py index 3de4615428bb..663c9abe0881 100644 --- a/python/pyspark/broadcast.py +++ b/python/pyspark/broadcast.py @@ -115,4 +115,6 @@ def __reduce__(self): if __name__ == "__main__": import doctest - doctest.testmod() + (failure_count, test_count) = doctest.testmod() + if failure_count: + exit(-1) diff --git a/python/pyspark/context.py b/python/pyspark/context.py index 31992795a9e4..d7466729b8f3 100644 --- a/python/pyspark/context.py +++ b/python/pyspark/context.py @@ -173,6 +173,7 @@ def _do_init(self, master, appName, sparkHome, pyFiles, environment, batchSize, self._jvm.PythonAccumulatorParam(host, port)) self.pythonExec = os.environ.get("PYSPARK_PYTHON", 'python') + self.pythonVer = "%d.%d" % sys.version_info[:2] # Broadcast's __reduce__ method stores Broadcast instances here. # This allows other code to determine which Broadcast instances have @@ -290,6 +291,26 @@ def version(self): """ return self._jsc.version() + @property + @ignore_unicode_prefix + def applicationId(self): + """ + A unique identifier for the Spark application. + Its format depends on the scheduler implementation. + (i.e. + in case of local spark app something like 'local-1433865536131' + in case of YARN something like 'application_1433865536131_34483' + ) + >>> sc.applicationId # doctest: +ELLIPSIS + u'local-...' + """ + return self._jsc.sc().applicationId() + + @property + def startTime(self): + """Return the epoch time when the Spark Context was started.""" + return self._jsc.startTime() + @property def defaultParallelism(self): """ @@ -318,6 +339,38 @@ def stop(self): with SparkContext._lock: SparkContext._active_spark_context = None + def emptyRDD(self): + """ + Create an RDD that has no partitions or elements. + """ + return RDD(self._jsc.emptyRDD(), self, NoOpSerializer()) + + def range(self, start, end=None, step=1, numSlices=None): + """ + Create a new RDD of int containing elements from `start` to `end` + (exclusive), increased by `step` every element. Can be called the same + way as python's built-in range() function. If called with a single argument, + the argument is interpreted as `end`, and `start` is set to 0. + + :param start: the start value + :param end: the end value (exclusive) + :param step: the incremental step (default: 1) + :param numSlices: the number of partitions of the new RDD + :return: An RDD of int + + >>> sc.range(5).collect() + [0, 1, 2, 3, 4] + >>> sc.range(2, 4).collect() + [2, 3] + >>> sc.range(1, 7, 2).collect() + [1, 3, 5] + """ + if end is None: + end = start + start = 0 + + return self.parallelize(xrange(start, end, step), numSlices) + def parallelize(self, c, numSlices=None): """ Distribute a local Python collection to form an RDD. Using xrange diff --git a/python/pyspark/heapq3.py b/python/pyspark/heapq3.py index 4ef2afe03544..b27e91a4cc25 100644 --- a/python/pyspark/heapq3.py +++ b/python/pyspark/heapq3.py @@ -883,6 +883,7 @@ def nlargest(n, iterable, key=None): if __name__ == "__main__": - import doctest - print(doctest.testmod()) + (failure_count, test_count) = doctest.testmod() + if failure_count: + exit(-1) diff --git a/python/pyspark/java_gateway.py b/python/pyspark/java_gateway.py index 3cee4ea6e3a3..90cd342a6cf7 100644 --- a/python/pyspark/java_gateway.py +++ b/python/pyspark/java_gateway.py @@ -51,6 +51,8 @@ def launch_gateway(): on_windows = platform.system() == "Windows" script = "./bin/spark-submit.cmd" if on_windows else "./bin/spark-submit" submit_args = os.environ.get("PYSPARK_SUBMIT_ARGS", "pyspark-shell") + if os.environ.get("SPARK_TESTING"): + submit_args = "--conf spark.ui.enabled=false " + submit_args command = [os.path.join(SPARK_HOME, script)] + shlex.split(submit_args) # Start a socket that will be used by PythonGatewayServer to communicate its port to us diff --git a/python/pyspark/ml/__init__.py b/python/pyspark/ml/__init__.py index da793d9db7f9..327a11b14b5a 100644 --- a/python/pyspark/ml/__init__.py +++ b/python/pyspark/ml/__init__.py @@ -15,6 +15,6 @@ # limitations under the License. # -from pyspark.ml.pipeline import Transformer, Estimator, Model, Pipeline, PipelineModel, Evaluator +from pyspark.ml.pipeline import Transformer, Estimator, Model, Pipeline, PipelineModel -__all__ = ["Transformer", "Estimator", "Model", "Pipeline", "PipelineModel", "Evaluator"] +__all__ = ["Transformer", "Estimator", "Model", "Pipeline", "PipelineModel"] diff --git a/python/pyspark/ml/classification.py b/python/pyspark/ml/classification.py index 1411d3fd9c56..7abbde8b260e 100644 --- a/python/pyspark/ml/classification.py +++ b/python/pyspark/ml/classification.py @@ -55,7 +55,7 @@ class LogisticRegression(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPredicti ... TypeError: Method setParams forces keyword arguments. """ - _java_class = "org.apache.spark.ml.classification.LogisticRegression" + # a placeholder to make it appear in the generated doc elasticNetParam = \ Param(Params._dummy(), "elasticNetParam", @@ -75,6 +75,8 @@ def __init__(self, featuresCol="features", labelCol="label", predictionCol="pred threshold=0.5, probabilityCol="probability") """ super(LogisticRegression, self).__init__() + self._java_obj = self._new_java_obj( + "org.apache.spark.ml.classification.LogisticRegression", self.uid) #: param for the ElasticNet mixing parameter, in range [0, 1]. For alpha = 0, the penalty # is an L2 penalty. For alpha = 1, it is an L1 penalty. self.elasticNetParam = \ @@ -111,7 +113,7 @@ def setElasticNetParam(self, value): """ Sets the value of :py:attr:`elasticNetParam`. """ - self.paramMap[self.elasticNetParam] = value + self._paramMap[self.elasticNetParam] = value return self def getElasticNetParam(self): @@ -124,7 +126,7 @@ def setFitIntercept(self, value): """ Sets the value of :py:attr:`fitIntercept`. """ - self.paramMap[self.fitIntercept] = value + self._paramMap[self.fitIntercept] = value return self def getFitIntercept(self): @@ -137,7 +139,7 @@ def setThreshold(self, value): """ Sets the value of :py:attr:`threshold`. """ - self.paramMap[self.threshold] = value + self._paramMap[self.threshold] = value return self def getThreshold(self): @@ -208,7 +210,6 @@ class DecisionTreeClassifier(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPred 1.0 """ - _java_class = "org.apache.spark.ml.classification.DecisionTreeClassifier" # a placeholder to make it appear in the generated doc impurity = Param(Params._dummy(), "impurity", "Criterion used for information gain calculation (case-insensitive). " + @@ -224,6 +225,8 @@ def __init__(self, featuresCol="features", labelCol="label", predictionCol="pred maxMemoryInMB=256, cacheNodeIds=False, checkpointInterval=10, impurity="gini") """ super(DecisionTreeClassifier, self).__init__() + self._java_obj = self._new_java_obj( + "org.apache.spark.ml.classification.DecisionTreeClassifier", self.uid) #: param for Criterion used for information gain calculation (case-insensitive). self.impurity = \ Param(self, "impurity", @@ -256,7 +259,7 @@ def setImpurity(self, value): """ Sets the value of :py:attr:`impurity`. """ - self.paramMap[self.impurity] = value + self._paramMap[self.impurity] = value return self def getImpurity(self): @@ -289,7 +292,7 @@ class RandomForestClassifier(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPred >>> stringIndexer = StringIndexer(inputCol="label", outputCol="indexed") >>> si_model = stringIndexer.fit(df) >>> td = si_model.transform(df) - >>> rf = RandomForestClassifier(numTrees=2, maxDepth=2, labelCol="indexed") + >>> rf = RandomForestClassifier(numTrees=2, maxDepth=2, labelCol="indexed", seed=42) >>> model = rf.fit(td) >>> test0 = sqlContext.createDataFrame([(Vectors.dense(-1.0),)], ["features"]) >>> model.transform(test0).head().prediction @@ -299,7 +302,6 @@ class RandomForestClassifier(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPred 1.0 """ - _java_class = "org.apache.spark.ml.classification.RandomForestClassifier" # a placeholder to make it appear in the generated doc impurity = Param(Params._dummy(), "impurity", "Criterion used for information gain calculation (case-insensitive). " + @@ -317,14 +319,16 @@ class RandomForestClassifier(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPred def __init__(self, featuresCol="features", labelCol="label", predictionCol="prediction", maxDepth=5, maxBins=32, minInstancesPerNode=1, minInfoGain=0.0, maxMemoryInMB=256, cacheNodeIds=False, checkpointInterval=10, impurity="gini", - numTrees=20, featureSubsetStrategy="auto", seed=42): + numTrees=20, featureSubsetStrategy="auto", seed=None): """ __init__(self, featuresCol="features", labelCol="label", predictionCol="prediction", \ maxDepth=5, maxBins=32, minInstancesPerNode=1, minInfoGain=0.0, \ maxMemoryInMB=256, cacheNodeIds=False, checkpointInterval=10, impurity="gini", \ - numTrees=20, featureSubsetStrategy="auto", seed=42) + numTrees=20, featureSubsetStrategy="auto", seed=None) """ super(RandomForestClassifier, self).__init__() + self._java_obj = self._new_java_obj( + "org.apache.spark.ml.classification.RandomForestClassifier", self.uid) #: param for Criterion used for information gain calculation (case-insensitive). self.impurity = \ Param(self, "impurity", @@ -343,7 +347,7 @@ def __init__(self, featuresCol="features", labelCol="label", predictionCol="pred "The number of features to consider for splits at each tree node. Supported " + "options: " + ", ".join(RandomForestParams.supportedFeatureSubsetStrategies)) self._setDefault(maxDepth=5, maxBins=32, minInstancesPerNode=1, minInfoGain=0.0, - maxMemoryInMB=256, cacheNodeIds=False, checkpointInterval=10, seed=42, + maxMemoryInMB=256, cacheNodeIds=False, checkpointInterval=10, seed=None, impurity="gini", numTrees=20, featureSubsetStrategy="auto") kwargs = self.__init__._input_kwargs self.setParams(**kwargs) @@ -351,12 +355,12 @@ def __init__(self, featuresCol="features", labelCol="label", predictionCol="pred @keyword_only def setParams(self, featuresCol="features", labelCol="label", predictionCol="prediction", maxDepth=5, maxBins=32, minInstancesPerNode=1, minInfoGain=0.0, - maxMemoryInMB=256, cacheNodeIds=False, checkpointInterval=10, seed=42, + maxMemoryInMB=256, cacheNodeIds=False, checkpointInterval=10, seed=None, impurity="gini", numTrees=20, featureSubsetStrategy="auto"): """ setParams(self, featuresCol="features", labelCol="label", predictionCol="prediction", \ maxDepth=5, maxBins=32, minInstancesPerNode=1, minInfoGain=0.0, \ - maxMemoryInMB=256, cacheNodeIds=False, checkpointInterval=10, seed=42, \ + maxMemoryInMB=256, cacheNodeIds=False, checkpointInterval=10, seed=None, \ impurity="gini", numTrees=20, featureSubsetStrategy="auto") Sets params for linear classification. """ @@ -370,7 +374,7 @@ def setImpurity(self, value): """ Sets the value of :py:attr:`impurity`. """ - self.paramMap[self.impurity] = value + self._paramMap[self.impurity] = value return self def getImpurity(self): @@ -383,7 +387,7 @@ def setSubsamplingRate(self, value): """ Sets the value of :py:attr:`subsamplingRate`. """ - self.paramMap[self.subsamplingRate] = value + self._paramMap[self.subsamplingRate] = value return self def getSubsamplingRate(self): @@ -396,7 +400,7 @@ def setNumTrees(self, value): """ Sets the value of :py:attr:`numTrees`. """ - self.paramMap[self.numTrees] = value + self._paramMap[self.numTrees] = value return self def getNumTrees(self): @@ -409,7 +413,7 @@ def setFeatureSubsetStrategy(self, value): """ Sets the value of :py:attr:`featureSubsetStrategy`. """ - self.paramMap[self.featureSubsetStrategy] = value + self._paramMap[self.featureSubsetStrategy] = value return self def getFeatureSubsetStrategy(self): @@ -452,7 +456,6 @@ class GBTClassifier(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPredictionCol 1.0 """ - _java_class = "org.apache.spark.ml.classification.GBTClassifier" # a placeholder to make it appear in the generated doc lossType = Param(Params._dummy(), "lossType", "Loss function which GBT tries to minimize (case-insensitive). " + @@ -476,6 +479,8 @@ def __init__(self, featuresCol="features", labelCol="label", predictionCol="pred lossType="logistic", maxIter=20, stepSize=0.1) """ super(GBTClassifier, self).__init__() + self._java_obj = self._new_java_obj( + "org.apache.spark.ml.classification.GBTClassifier", self.uid) #: param for Loss function which GBT tries to minimize (case-insensitive). self.lossType = Param(self, "lossType", "Loss function which GBT tries to minimize (case-insensitive). " + @@ -517,7 +522,7 @@ def setLossType(self, value): """ Sets the value of :py:attr:`lossType`. """ - self.paramMap[self.lossType] = value + self._paramMap[self.lossType] = value return self def getLossType(self): @@ -530,7 +535,7 @@ def setSubsamplingRate(self, value): """ Sets the value of :py:attr:`subsamplingRate`. """ - self.paramMap[self.subsamplingRate] = value + self._paramMap[self.subsamplingRate] = value return self def getSubsamplingRate(self): @@ -543,7 +548,7 @@ def setStepSize(self, value): """ Sets the value of :py:attr:`stepSize`. """ - self.paramMap[self.stepSize] = value + self._paramMap[self.stepSize] = value return self def getStepSize(self): diff --git a/python/pyspark/ml/evaluation.py b/python/pyspark/ml/evaluation.py index 02020ebff94c..595593a7f2cd 100644 --- a/python/pyspark/ml/evaluation.py +++ b/python/pyspark/ml/evaluation.py @@ -15,13 +15,72 @@ # limitations under the License. # -from pyspark.ml.wrapper import JavaEvaluator +from abc import abstractmethod, ABCMeta + +from pyspark.ml.wrapper import JavaWrapper from pyspark.ml.param import Param, Params -from pyspark.ml.param.shared import HasLabelCol, HasRawPredictionCol +from pyspark.ml.param.shared import HasLabelCol, HasPredictionCol, HasRawPredictionCol from pyspark.ml.util import keyword_only from pyspark.mllib.common import inherit_doc -__all__ = ['BinaryClassificationEvaluator'] +__all__ = ['Evaluator', 'BinaryClassificationEvaluator', 'RegressionEvaluator'] + + +@inherit_doc +class Evaluator(Params): + """ + Base class for evaluators that compute metrics from predictions. + """ + + __metaclass__ = ABCMeta + + @abstractmethod + def _evaluate(self, dataset): + """ + Evaluates the output. + + :param dataset: a dataset that contains labels/observations and + predictions + :return: metric + """ + raise NotImplementedError() + + def evaluate(self, dataset, params={}): + """ + Evaluates the output with optional parameters. + + :param dataset: a dataset that contains labels/observations and + predictions + :param params: an optional param map that overrides embedded + params + :return: metric + """ + if isinstance(params, dict): + if params: + return self.copy(params)._evaluate(dataset) + else: + return self._evaluate(dataset) + else: + raise ValueError("Params must be a param map but got %s." % type(params)) + + +@inherit_doc +class JavaEvaluator(Evaluator, JavaWrapper): + """ + Base class for :py:class:`Evaluator`s that wrap Java/Scala + implementations. + """ + + __metaclass__ = ABCMeta + + def _evaluate(self, dataset): + """ + Evaluates the output. + :param dataset: a dataset that contains labels/observations and predictions. + :return: evaluation metric + """ + self._transfer_params_to_java() + return self._java_obj.evaluate(dataset._jdf) @inherit_doc @@ -42,8 +101,6 @@ class BinaryClassificationEvaluator(JavaEvaluator, HasLabelCol, HasRawPrediction 0.83... """ - _java_class = "org.apache.spark.ml.evaluation.BinaryClassificationEvaluator" - # a placeholder to make it appear in the generated doc metricName = Param(Params._dummy(), "metricName", "metric name in evaluation (areaUnderROC|areaUnderPR)") @@ -56,6 +113,8 @@ def __init__(self, rawPredictionCol="rawPrediction", labelCol="label", metricName="areaUnderROC") """ super(BinaryClassificationEvaluator, self).__init__() + self._java_obj = self._new_java_obj( + "org.apache.spark.ml.evaluation.BinaryClassificationEvaluator", self.uid) #: param for metric name in evaluation (areaUnderROC|areaUnderPR) self.metricName = Param(self, "metricName", "metric name in evaluation (areaUnderROC|areaUnderPR)") @@ -68,7 +127,7 @@ def setMetricName(self, value): """ Sets the value of :py:attr:`metricName`. """ - self.paramMap[self.metricName] = value + self._paramMap[self.metricName] = value return self def getMetricName(self): @@ -89,6 +148,72 @@ def setParams(self, rawPredictionCol="rawPrediction", labelCol="label", return self._set(**kwargs) +@inherit_doc +class RegressionEvaluator(JavaEvaluator, HasLabelCol, HasPredictionCol): + """ + Evaluator for Regression, which expects two input + columns: prediction and label. + + >>> scoreAndLabels = [(-28.98343821, -27.0), (20.21491975, 21.5), + ... (-25.98418959, -22.0), (30.69731842, 33.0), (74.69283752, 71.0)] + >>> dataset = sqlContext.createDataFrame(scoreAndLabels, ["raw", "label"]) + ... + >>> evaluator = RegressionEvaluator(predictionCol="raw") + >>> evaluator.evaluate(dataset) + -2.842... + >>> evaluator.evaluate(dataset, {evaluator.metricName: "r2"}) + 0.993... + >>> evaluator.evaluate(dataset, {evaluator.metricName: "mae"}) + -2.649... + """ + # Because we will maximize evaluation value (ref: `CrossValidator`), + # when we evaluate a metric that is needed to minimize (e.g., `"rmse"`, `"mse"`, `"mae"`), + # we take and output the negative of this metric. + metricName = Param(Params._dummy(), "metricName", + "metric name in evaluation (mse|rmse|r2|mae)") + + @keyword_only + def __init__(self, predictionCol="prediction", labelCol="label", + metricName="rmse"): + """ + __init__(self, predictionCol="prediction", labelCol="label", \ + metricName="rmse") + """ + super(RegressionEvaluator, self).__init__() + self._java_obj = self._new_java_obj( + "org.apache.spark.ml.evaluation.RegressionEvaluator", self.uid) + #: param for metric name in evaluation (mse|rmse|r2|mae) + self.metricName = Param(self, "metricName", + "metric name in evaluation (mse|rmse|r2|mae)") + self._setDefault(predictionCol="prediction", labelCol="label", + metricName="rmse") + kwargs = self.__init__._input_kwargs + self._set(**kwargs) + + def setMetricName(self, value): + """ + Sets the value of :py:attr:`metricName`. + """ + self._paramMap[self.metricName] = value + return self + + def getMetricName(self): + """ + Gets the value of metricName or its default value. + """ + return self.getOrDefault(self.metricName) + + @keyword_only + def setParams(self, predictionCol="prediction", labelCol="label", + metricName="rmse"): + """ + setParams(self, predictionCol="prediction", labelCol="label", \ + metricName="rmse") + Sets params for regression evaluator. + """ + kwargs = self.setParams._input_kwargs + return self._set(**kwargs) + if __name__ == "__main__": import doctest from pyspark.context import SparkContext diff --git a/python/pyspark/ml/feature.py b/python/pyspark/ml/feature.py index 58e22190c7c3..8804dace849b 100644 --- a/python/pyspark/ml/feature.py +++ b/python/pyspark/ml/feature.py @@ -21,7 +21,7 @@ from pyspark.ml.wrapper import JavaEstimator, JavaModel, JavaTransformer from pyspark.mllib.common import inherit_doc -__all__ = ['Binarizer', 'HashingTF', 'IDF', 'IDFModel', 'Normalizer', 'OneHotEncoder', +__all__ = ['Binarizer', 'HashingTF', 'IDF', 'IDFModel', 'NGram', 'Normalizer', 'OneHotEncoder', 'PolynomialExpansion', 'RegexTokenizer', 'StandardScaler', 'StandardScalerModel', 'StringIndexer', 'StringIndexerModel', 'Tokenizer', 'VectorAssembler', 'VectorIndexer', 'Word2Vec', 'Word2VecModel'] @@ -43,7 +43,6 @@ class Binarizer(JavaTransformer, HasInputCol, HasOutputCol): 1.0 """ - _java_class = "org.apache.spark.ml.feature.Binarizer" # a placeholder to make it appear in the generated doc threshold = Param(Params._dummy(), "threshold", "threshold in binary classification prediction, in range [0, 1]") @@ -54,6 +53,7 @@ def __init__(self, threshold=0.0, inputCol=None, outputCol=None): __init__(self, threshold=0.0, inputCol=None, outputCol=None) """ super(Binarizer, self).__init__() + self._java_obj = self._new_java_obj("org.apache.spark.ml.feature.Binarizer", self.uid) self.threshold = Param(self, "threshold", "threshold in binary classification prediction, in range [0, 1]") self._setDefault(threshold=0.0) @@ -73,7 +73,7 @@ def setThreshold(self, value): """ Sets the value of :py:attr:`threshold`. """ - self.paramMap[self.threshold] = value + self._paramMap[self.threshold] = value return self def getThreshold(self): @@ -104,7 +104,6 @@ class Bucketizer(JavaTransformer, HasInputCol, HasOutputCol): 0.0 """ - _java_class = "org.apache.spark.ml.feature.Bucketizer" # a placeholder to make it appear in the generated doc splits = \ Param(Params._dummy(), "splits", @@ -121,6 +120,7 @@ def __init__(self, splits=None, inputCol=None, outputCol=None): __init__(self, splits=None, inputCol=None, outputCol=None) """ super(Bucketizer, self).__init__() + self._java_obj = self._new_java_obj("org.apache.spark.ml.feature.Bucketizer", self.uid) #: param for Splitting points for mapping continuous features into buckets. With n+1 splits, # there are n buckets. A bucket defined by splits x,y holds values in the range [x,y) # except the last bucket, which also includes y. The splits should be strictly increasing. @@ -150,7 +150,7 @@ def setSplits(self, value): """ Sets the value of :py:attr:`splits`. """ - self.paramMap[self.splits] = value + self._paramMap[self.splits] = value return self def getSplits(self): @@ -177,14 +177,13 @@ class HashingTF(JavaTransformer, HasInputCol, HasOutputCol, HasNumFeatures): SparseVector(5, {2: 1.0, 3: 1.0, 4: 1.0}) """ - _java_class = "org.apache.spark.ml.feature.HashingTF" - @keyword_only def __init__(self, numFeatures=1 << 18, inputCol=None, outputCol=None): """ __init__(self, numFeatures=1 << 18, inputCol=None, outputCol=None) """ super(HashingTF, self).__init__() + self._java_obj = self._new_java_obj("org.apache.spark.ml.feature.HashingTF", self.uid) self._setDefault(numFeatures=1 << 18) kwargs = self.__init__._input_kwargs self.setParams(**kwargs) @@ -217,8 +216,6 @@ class IDF(JavaEstimator, HasInputCol, HasOutputCol): DenseVector([0.2877, 0.0]) """ - _java_class = "org.apache.spark.ml.feature.IDF" - # a placeholder to make it appear in the generated doc minDocFreq = Param(Params._dummy(), "minDocFreq", "minimum of documents in which a term should appear for filtering") @@ -229,6 +226,7 @@ def __init__(self, minDocFreq=0, inputCol=None, outputCol=None): __init__(self, minDocFreq=0, inputCol=None, outputCol=None) """ super(IDF, self).__init__() + self._java_obj = self._new_java_obj("org.apache.spark.ml.feature.IDF", self.uid) self.minDocFreq = Param(self, "minDocFreq", "minimum of documents in which a term should appear for filtering") self._setDefault(minDocFreq=0) @@ -248,7 +246,7 @@ def setMinDocFreq(self, value): """ Sets the value of :py:attr:`minDocFreq`. """ - self.paramMap[self.minDocFreq] = value + self._paramMap[self.minDocFreq] = value return self def getMinDocFreq(self): @@ -257,6 +255,9 @@ def getMinDocFreq(self): """ return self.getOrDefault(self.minDocFreq) + def _create_model(self, java_model): + return IDFModel(java_model) + class IDFModel(JavaModel): """ @@ -264,6 +265,75 @@ class IDFModel(JavaModel): """ +@inherit_doc +@ignore_unicode_prefix +class NGram(JavaTransformer, HasInputCol, HasOutputCol): + """ + A feature transformer that converts the input array of strings into an array of n-grams. Null + values in the input array are ignored. + It returns an array of n-grams where each n-gram is represented by a space-separated string of + words. + When the input is empty, an empty array is returned. + When the input array length is less than n (number of elements per n-gram), no n-grams are + returned. + + >>> df = sqlContext.createDataFrame([Row(inputTokens=["a", "b", "c", "d", "e"])]) + >>> ngram = NGram(n=2, inputCol="inputTokens", outputCol="nGrams") + >>> ngram.transform(df).head() + Row(inputTokens=[u'a', u'b', u'c', u'd', u'e'], nGrams=[u'a b', u'b c', u'c d', u'd e']) + >>> # Change n-gram length + >>> ngram.setParams(n=4).transform(df).head() + Row(inputTokens=[u'a', u'b', u'c', u'd', u'e'], nGrams=[u'a b c d', u'b c d e']) + >>> # Temporarily modify output column. + >>> ngram.transform(df, {ngram.outputCol: "output"}).head() + Row(inputTokens=[u'a', u'b', u'c', u'd', u'e'], output=[u'a b c d', u'b c d e']) + >>> ngram.transform(df).head() + Row(inputTokens=[u'a', u'b', u'c', u'd', u'e'], nGrams=[u'a b c d', u'b c d e']) + >>> # Must use keyword arguments to specify params. + >>> ngram.setParams("text") + Traceback (most recent call last): + ... + TypeError: Method setParams forces keyword arguments. + """ + + # a placeholder to make it appear in the generated doc + n = Param(Params._dummy(), "n", "number of elements per n-gram (>=1)") + + @keyword_only + def __init__(self, n=2, inputCol=None, outputCol=None): + """ + __init__(self, n=2, inputCol=None, outputCol=None) + """ + super(NGram, self).__init__() + self._java_obj = self._new_java_obj("org.apache.spark.ml.feature.NGram", self.uid) + self.n = Param(self, "n", "number of elements per n-gram (>=1)") + self._setDefault(n=2) + kwargs = self.__init__._input_kwargs + self.setParams(**kwargs) + + @keyword_only + def setParams(self, n=2, inputCol=None, outputCol=None): + """ + setParams(self, n=2, inputCol=None, outputCol=None) + Sets params for this NGram. + """ + kwargs = self.setParams._input_kwargs + return self._set(**kwargs) + + def setN(self, value): + """ + Sets the value of :py:attr:`n`. + """ + self._paramMap[self.n] = value + return self + + def getN(self): + """ + Gets the value of n or its default value. + """ + return self.getOrDefault(self.n) + + @inherit_doc class Normalizer(JavaTransformer, HasInputCol, HasOutputCol): """ @@ -285,14 +355,13 @@ class Normalizer(JavaTransformer, HasInputCol, HasOutputCol): # a placeholder to make it appear in the generated doc p = Param(Params._dummy(), "p", "the p norm value.") - _java_class = "org.apache.spark.ml.feature.Normalizer" - @keyword_only def __init__(self, p=2.0, inputCol=None, outputCol=None): """ __init__(self, p=2.0, inputCol=None, outputCol=None) """ super(Normalizer, self).__init__() + self._java_obj = self._new_java_obj("org.apache.spark.ml.feature.Normalizer", self.uid) self.p = Param(self, "p", "the p norm value.") self._setDefault(p=2.0) kwargs = self.__init__._input_kwargs @@ -311,7 +380,7 @@ def setP(self, value): """ Sets the value of :py:attr:`p`. """ - self.paramMap[self.p] = value + self._paramMap[self.p] = value return self def getP(self): @@ -324,66 +393,73 @@ def getP(self): @inherit_doc class OneHotEncoder(JavaTransformer, HasInputCol, HasOutputCol): """ - A one-hot encoder that maps a column of label indices to a column of binary vectors, with - at most a single one-value. By default, the binary vector has an element for each category, so - with 5 categories, an input value of 2.0 would map to an output vector of - (0.0, 0.0, 1.0, 0.0, 0.0). If includeFirst is set to false, the first category is omitted, so - the output vector for the previous example would be (0.0, 1.0, 0.0, 0.0) and an input value - of 0.0 would map to a vector of all zeros. Including the first category makes the vector columns - linearly dependent because they sum up to one. - - TODO: This method requires the use of StringIndexer first. Decouple them. + A one-hot encoder that maps a column of category indices to a + column of binary vectors, with at most a single one-value per row + that indicates the input category index. + For example with 5 categories, an input value of 2.0 would map to + an output vector of `[0.0, 0.0, 1.0, 0.0]`. + The last category is not included by default (configurable via + :py:attr:`dropLast`) because it makes the vector entries sum up to + one, and hence linearly dependent. + So an input value of 4.0 maps to `[0.0, 0.0, 0.0, 0.0]`. + Note that this is different from scikit-learn's OneHotEncoder, + which keeps all categories. + The output vectors are sparse. + + .. seealso:: + + :py:class:`StringIndexer` for converting categorical values into + category indices >>> stringIndexer = StringIndexer(inputCol="label", outputCol="indexed") >>> model = stringIndexer.fit(stringIndDf) >>> td = model.transform(stringIndDf) - >>> encoder = OneHotEncoder(includeFirst=False, inputCol="indexed", outputCol="features") + >>> encoder = OneHotEncoder(inputCol="indexed", outputCol="features") >>> encoder.transform(td).head().features - SparseVector(2, {}) + SparseVector(2, {0: 1.0}) >>> encoder.setParams(outputCol="freqs").transform(td).head().freqs - SparseVector(2, {}) - >>> params = {encoder.includeFirst: True, encoder.outputCol: "test"} + SparseVector(2, {0: 1.0}) + >>> params = {encoder.dropLast: False, encoder.outputCol: "test"} >>> encoder.transform(td, params).head().test SparseVector(3, {0: 1.0}) """ - _java_class = "org.apache.spark.ml.feature.OneHotEncoder" - # a placeholder to make it appear in the generated doc - includeFirst = Param(Params._dummy(), "includeFirst", "include first category") + dropLast = Param(Params._dummy(), "dropLast", "whether to drop the last category") @keyword_only - def __init__(self, includeFirst=True, inputCol=None, outputCol=None): + def __init__(self, dropLast=True, inputCol=None, outputCol=None): """ __init__(self, includeFirst=True, inputCol=None, outputCol=None) """ super(OneHotEncoder, self).__init__() - self.includeFirst = Param(self, "includeFirst", "include first category") - self._setDefault(includeFirst=True) + self._java_obj = self._new_java_obj("org.apache.spark.ml.feature.OneHotEncoder", self.uid) + self.dropLast = Param(self, "dropLast", "whether to drop the last category") + self._setDefault(dropLast=True) kwargs = self.__init__._input_kwargs self.setParams(**kwargs) @keyword_only - def setParams(self, includeFirst=True, inputCol=None, outputCol=None): + def setParams(self, dropLast=True, inputCol=None, outputCol=None): """ - setParams(self, includeFirst=True, inputCol=None, outputCol=None) + setParams(self, dropLast=True, inputCol=None, outputCol=None) Sets params for this OneHotEncoder. """ kwargs = self.setParams._input_kwargs return self._set(**kwargs) - def setIncludeFirst(self, value): + def setDropLast(self, value): """ - Sets the value of :py:attr:`includeFirst`. + Sets the value of :py:attr:`dropLast`. """ - self.paramMap[self.includeFirst] = value + self._paramMap[self.dropLast] = value return self - def getIncludeFirst(self): + def getDropLast(self): """ - Gets the value of includeFirst or its default value. + Gets the value of dropLast or its default value. """ - return self.getOrDefault(self.includeFirst) + return self.getOrDefault(self.dropLast) @inherit_doc @@ -404,8 +480,6 @@ class PolynomialExpansion(JavaTransformer, HasInputCol, HasOutputCol): DenseVector([0.5, 0.25, 2.0, 1.0, 4.0]) """ - _java_class = "org.apache.spark.ml.feature.PolynomialExpansion" - # a placeholder to make it appear in the generated doc degree = Param(Params._dummy(), "degree", "the polynomial degree to expand (>= 1)") @@ -415,6 +489,8 @@ def __init__(self, degree=2, inputCol=None, outputCol=None): __init__(self, degree=2, inputCol=None, outputCol=None) """ super(PolynomialExpansion, self).__init__() + self._java_obj = self._new_java_obj( + "org.apache.spark.ml.feature.PolynomialExpansion", self.uid) self.degree = Param(self, "degree", "the polynomial degree to expand (>= 1)") self._setDefault(degree=2) kwargs = self.__init__._input_kwargs @@ -433,7 +509,7 @@ def setDegree(self, value): """ Sets the value of :py:attr:`degree`. """ - self.paramMap[self.degree] = value + self._paramMap[self.degree] = value return self def getDegree(self): @@ -447,23 +523,25 @@ def getDegree(self): @ignore_unicode_prefix class RegexTokenizer(JavaTransformer, HasInputCol, HasOutputCol): """ - A regex based tokenizer that extracts tokens either by repeatedly matching the regex(default) - or using it to split the text (set matching to false). Optional parameters also allow filtering - tokens using a minimal length. + A regex based tokenizer that extracts tokens either by using the + provided regex pattern (in Java dialect) to split the text + (default) or repeatedly matching the regex (if gaps is true). + Optional parameters also allow filtering tokens using a minimal + length. It returns an array of strings that can be empty. - >>> df = sqlContext.createDataFrame([("a b c",)], ["text"]) + >>> df = sqlContext.createDataFrame([("a b c",)], ["text"]) >>> reTokenizer = RegexTokenizer(inputCol="text", outputCol="words") >>> reTokenizer.transform(df).head() - Row(text=u'a b c', words=[u'a', u'b', u'c']) + Row(text=u'a b c', words=[u'a', u'b', u'c']) >>> # Change a parameter. >>> reTokenizer.setParams(outputCol="tokens").transform(df).head() - Row(text=u'a b c', tokens=[u'a', u'b', u'c']) + Row(text=u'a b c', tokens=[u'a', u'b', u'c']) >>> # Temporarily modify a parameter. >>> reTokenizer.transform(df, {reTokenizer.outputCol: "words"}).head() - Row(text=u'a b c', words=[u'a', u'b', u'c']) + Row(text=u'a b c', words=[u'a', u'b', u'c']) >>> reTokenizer.transform(df).head() - Row(text=u'a b c', tokens=[u'a', u'b', u'c']) + Row(text=u'a b c', tokens=[u'a', u'b', u'c']) >>> # Must use keyword arguments to specify params. >>> reTokenizer.setParams("text") Traceback (most recent call last): @@ -471,33 +549,29 @@ class RegexTokenizer(JavaTransformer, HasInputCol, HasOutputCol): TypeError: Method setParams forces keyword arguments. """ - _java_class = "org.apache.spark.ml.feature.RegexTokenizer" # a placeholder to make it appear in the generated doc minTokenLength = Param(Params._dummy(), "minTokenLength", "minimum token length (>= 0)") - gaps = Param(Params._dummy(), "gaps", "Set regex to match gaps or tokens") - pattern = Param(Params._dummy(), "pattern", "regex pattern used for tokenizing") + gaps = Param(Params._dummy(), "gaps", "whether regex splits on gaps (True) or matches tokens") + pattern = Param(Params._dummy(), "pattern", "regex pattern (Java dialect) used for tokenizing") @keyword_only - def __init__(self, minTokenLength=1, gaps=False, pattern="\\p{L}+|[^\\p{L}\\s]+", - inputCol=None, outputCol=None): + def __init__(self, minTokenLength=1, gaps=True, pattern="\\s+", inputCol=None, outputCol=None): """ - __init__(self, minTokenLength=1, gaps=False, pattern="\\p{L}+|[^\\p{L}\\s]+", \ - inputCol=None, outputCol=None) + __init__(self, minTokenLength=1, gaps=True, pattern="\\s+", inputCol=None, outputCol=None) """ super(RegexTokenizer, self).__init__() - self.minTokenLength = Param(self, "minLength", "minimum token length (>= 0)") - self.gaps = Param(self, "gaps", "Set regex to match gaps or tokens") - self.pattern = Param(self, "pattern", "regex pattern used for tokenizing") - self._setDefault(minTokenLength=1, gaps=False, pattern="\\p{L}+|[^\\p{L}\\s]+") + self._java_obj = self._new_java_obj("org.apache.spark.ml.feature.RegexTokenizer", self.uid) + self.minTokenLength = Param(self, "minTokenLength", "minimum token length (>= 0)") + self.gaps = Param(self, "gaps", "whether regex splits on gaps (True) or matches tokens") + self.pattern = Param(self, "pattern", "regex pattern (Java dialect) used for tokenizing") + self._setDefault(minTokenLength=1, gaps=True, pattern="\\s+") kwargs = self.__init__._input_kwargs self.setParams(**kwargs) @keyword_only - def setParams(self, minTokenLength=1, gaps=False, pattern="\\p{L}+|[^\\p{L}\\s]+", - inputCol=None, outputCol=None): + def setParams(self, minTokenLength=1, gaps=True, pattern="\\s+", inputCol=None, outputCol=None): """ - setParams(self, minTokenLength=1, gaps=False, pattern="\\p{L}+|[^\\p{L}\\s]+", \ - inputCol="input", outputCol="output") + setParams(self, minTokenLength=1, gaps=True, pattern="\\s+", inputCol=None, outputCol=None) Sets params for this RegexTokenizer. """ kwargs = self.setParams._input_kwargs @@ -507,7 +581,7 @@ def setMinTokenLength(self, value): """ Sets the value of :py:attr:`minTokenLength`. """ - self.paramMap[self.minTokenLength] = value + self._paramMap[self.minTokenLength] = value return self def getMinTokenLength(self): @@ -520,7 +594,7 @@ def setGaps(self, value): """ Sets the value of :py:attr:`gaps`. """ - self.paramMap[self.gaps] = value + self._paramMap[self.gaps] = value return self def getGaps(self): @@ -533,7 +607,7 @@ def setPattern(self, value): """ Sets the value of :py:attr:`pattern`. """ - self.paramMap[self.pattern] = value + self._paramMap[self.pattern] = value return self def getPattern(self): @@ -557,8 +631,6 @@ class StandardScaler(JavaEstimator, HasInputCol, HasOutputCol): DenseVector([1.4142]) """ - _java_class = "org.apache.spark.ml.feature.StandardScaler" - # a placeholder to make it appear in the generated doc withMean = Param(Params._dummy(), "withMean", "Center data with mean") withStd = Param(Params._dummy(), "withStd", "Scale to unit standard deviation") @@ -569,6 +641,7 @@ def __init__(self, withMean=False, withStd=True, inputCol=None, outputCol=None): __init__(self, withMean=False, withStd=True, inputCol=None, outputCol=None) """ super(StandardScaler, self).__init__() + self._java_obj = self._new_java_obj("org.apache.spark.ml.feature.StandardScaler", self.uid) self.withMean = Param(self, "withMean", "Center data with mean") self.withStd = Param(self, "withStd", "Scale to unit standard deviation") self._setDefault(withMean=False, withStd=True) @@ -588,7 +661,7 @@ def setWithMean(self, value): """ Sets the value of :py:attr:`withMean`. """ - self.paramMap[self.withMean] = value + self._paramMap[self.withMean] = value return self def getWithMean(self): @@ -601,7 +674,7 @@ def setWithStd(self, value): """ Sets the value of :py:attr:`withStd`. """ - self.paramMap[self.withStd] = value + self._paramMap[self.withStd] = value return self def getWithStd(self): @@ -610,6 +683,9 @@ def getWithStd(self): """ return self.getOrDefault(self.withStd) + def _create_model(self, java_model): + return StandardScalerModel(java_model) + class StandardScalerModel(JavaModel): """ @@ -633,14 +709,13 @@ class StringIndexer(JavaEstimator, HasInputCol, HasOutputCol): [(0, 0.0), (1, 2.0), (2, 1.0), (3, 0.0), (4, 0.0), (5, 1.0)] """ - _java_class = "org.apache.spark.ml.feature.StringIndexer" - @keyword_only def __init__(self, inputCol=None, outputCol=None): """ __init__(self, inputCol=None, outputCol=None) """ super(StringIndexer, self).__init__() + self._java_obj = self._new_java_obj("org.apache.spark.ml.feature.StringIndexer", self.uid) kwargs = self.__init__._input_kwargs self.setParams(**kwargs) @@ -653,6 +728,9 @@ def setParams(self, inputCol=None, outputCol=None): kwargs = self.setParams._input_kwargs return self._set(**kwargs) + def _create_model(self, java_model): + return StringIndexerModel(java_model) + class StringIndexerModel(JavaModel): """ @@ -686,14 +764,13 @@ class Tokenizer(JavaTransformer, HasInputCol, HasOutputCol): TypeError: Method setParams forces keyword arguments. """ - _java_class = "org.apache.spark.ml.feature.Tokenizer" - @keyword_only def __init__(self, inputCol=None, outputCol=None): """ __init__(self, inputCol=None, outputCol=None) """ super(Tokenizer, self).__init__() + self._java_obj = self._new_java_obj("org.apache.spark.ml.feature.Tokenizer", self.uid) kwargs = self.__init__._input_kwargs self.setParams(**kwargs) @@ -723,14 +800,13 @@ class VectorAssembler(JavaTransformer, HasInputCols, HasOutputCol): DenseVector([0.0, 1.0]) """ - _java_class = "org.apache.spark.ml.feature.VectorAssembler" - @keyword_only def __init__(self, inputCols=None, outputCol=None): """ __init__(self, inputCols=None, outputCol=None) """ super(VectorAssembler, self).__init__() + self._java_obj = self._new_java_obj("org.apache.spark.ml.feature.VectorAssembler", self.uid) kwargs = self.__init__._input_kwargs self.setParams(**kwargs) @@ -797,7 +873,6 @@ class VectorIndexer(JavaEstimator, HasInputCol, HasOutputCol): DenseVector([1.0, 0.0]) """ - _java_class = "org.apache.spark.ml.feature.VectorIndexer" # a placeholder to make it appear in the generated doc maxCategories = Param(Params._dummy(), "maxCategories", "Threshold for the number of values a categorical feature can take " + @@ -810,6 +885,7 @@ def __init__(self, maxCategories=20, inputCol=None, outputCol=None): __init__(self, maxCategories=20, inputCol=None, outputCol=None) """ super(VectorIndexer, self).__init__() + self._java_obj = self._new_java_obj("org.apache.spark.ml.feature.VectorIndexer", self.uid) self.maxCategories = Param(self, "maxCategories", "Threshold for the number of values a categorical feature " + "can take (>= 2). If a feature is found to have " + @@ -831,7 +907,7 @@ def setMaxCategories(self, value): """ Sets the value of :py:attr:`maxCategories`. """ - self.paramMap[self.maxCategories] = value + self._paramMap[self.maxCategories] = value return self def getMaxCategories(self): @@ -840,6 +916,15 @@ def getMaxCategories(self): """ return self.getOrDefault(self.maxCategories) + def _create_model(self, java_model): + return VectorIndexerModel(java_model) + + +class VectorIndexerModel(JavaModel): + """ + Model fitted by VectorIndexer. + """ + @inherit_doc @ignore_unicode_prefix @@ -855,7 +940,6 @@ class Word2Vec(JavaEstimator, HasStepSize, HasMaxIter, HasSeed, HasInputCol, Has DenseVector([-0.0422, -0.5138, -0.2546, 0.6885, 0.276]) """ - _java_class = "org.apache.spark.ml.feature.Word2Vec" # a placeholder to make it appear in the generated doc vectorSize = Param(Params._dummy(), "vectorSize", "the dimension of codes after transforming from words") @@ -867,12 +951,13 @@ class Word2Vec(JavaEstimator, HasStepSize, HasMaxIter, HasSeed, HasInputCol, Has @keyword_only def __init__(self, vectorSize=100, minCount=5, numPartitions=1, stepSize=0.025, maxIter=1, - seed=42, inputCol=None, outputCol=None): + seed=None, inputCol=None, outputCol=None): """ __init__(self, vectorSize=100, minCount=5, numPartitions=1, stepSize=0.025, maxIter=1, \ - seed=42, inputCol=None, outputCol=None) + seed=None, inputCol=None, outputCol=None) """ super(Word2Vec, self).__init__() + self._java_obj = self._new_java_obj("org.apache.spark.ml.feature.Word2Vec", self.uid) self.vectorSize = Param(self, "vectorSize", "the dimension of codes after transforming from words") self.numPartitions = Param(self, "numPartitions", @@ -881,15 +966,15 @@ def __init__(self, vectorSize=100, minCount=5, numPartitions=1, stepSize=0.025, "the minimum number of times a token must appear to be included " + "in the word2vec model's vocabulary") self._setDefault(vectorSize=100, minCount=5, numPartitions=1, stepSize=0.025, maxIter=1, - seed=42) + seed=None) kwargs = self.__init__._input_kwargs self.setParams(**kwargs) @keyword_only def setParams(self, vectorSize=100, minCount=5, numPartitions=1, stepSize=0.025, maxIter=1, - seed=42, inputCol=None, outputCol=None): + seed=None, inputCol=None, outputCol=None): """ - setParams(self, minCount=5, numPartitions=1, stepSize=0.025, maxIter=1, seed=42, \ + setParams(self, minCount=5, numPartitions=1, stepSize=0.025, maxIter=1, seed=None, \ inputCol=None, outputCol=None) Sets params for this Word2Vec. """ @@ -900,7 +985,7 @@ def setVectorSize(self, value): """ Sets the value of :py:attr:`vectorSize`. """ - self.paramMap[self.vectorSize] = value + self._paramMap[self.vectorSize] = value return self def getVectorSize(self): @@ -913,7 +998,7 @@ def setNumPartitions(self, value): """ Sets the value of :py:attr:`numPartitions`. """ - self.paramMap[self.numPartitions] = value + self._paramMap[self.numPartitions] = value return self def getNumPartitions(self): @@ -926,7 +1011,7 @@ def setMinCount(self, value): """ Sets the value of :py:attr:`minCount`. """ - self.paramMap[self.minCount] = value + self._paramMap[self.minCount] = value return self def getMinCount(self): @@ -935,6 +1020,9 @@ def getMinCount(self): """ return self.getOrDefault(self.minCount) + def _create_model(self, java_model): + return Word2VecModel(java_model) + class Word2VecModel(JavaModel): """ diff --git a/python/pyspark/ml/param/__init__.py b/python/pyspark/ml/param/__init__.py index 49c20b4cf70c..7845536161e0 100644 --- a/python/pyspark/ml/param/__init__.py +++ b/python/pyspark/ml/param/__init__.py @@ -16,6 +16,7 @@ # from abc import ABCMeta +import copy from pyspark.ml.util import Identifiable @@ -29,9 +30,9 @@ class Param(object): """ def __init__(self, parent, name, doc): - if not isinstance(parent, Params): - raise TypeError("Parent must be a Params but got type %s." % type(parent)) - self.parent = parent + if not isinstance(parent, Identifiable): + raise TypeError("Parent must be an Identifiable but got type %s." % type(parent)) + self.parent = parent.uid self.name = str(name) self.doc = str(doc) @@ -41,6 +42,15 @@ def __str__(self): def __repr__(self): return "Param(parent=%r, name=%r, doc=%r)" % (self.parent, self.name, self.doc) + def __hash__(self): + return hash(str(self)) + + def __eq__(self, other): + if isinstance(other, Param): + return self.parent == other.parent and self.name == other.name + else: + return False + class Params(Identifiable): """ @@ -51,10 +61,13 @@ class Params(Identifiable): __metaclass__ = ABCMeta #: internal param map for user-supplied values param map - paramMap = {} + _paramMap = {} #: internal param map for default values - defaultParamMap = {} + _defaultParamMap = {} + + #: value returned by :py:func:`params` + _params = None @property def params(self): @@ -63,10 +76,12 @@ def params(self): uses :py:func:`dir` to get all attributes of type :py:class:`Param`. """ - return list(filter(lambda attr: isinstance(attr, Param), - [getattr(self, x) for x in dir(self) if x != "params"])) + if self._params is None: + self._params = list(filter(lambda attr: isinstance(attr, Param), + [getattr(self, x) for x in dir(self) if x != "params"])) + return self._params - def _explain(self, param): + def explainParam(self, param): """ Explains a single param and returns its name, doc, and optional default value and user-supplied value in a string. @@ -74,10 +89,10 @@ def _explain(self, param): param = self._resolveParam(param) values = [] if self.isDefined(param): - if param in self.defaultParamMap: - values.append("default: %s" % self.defaultParamMap[param]) - if param in self.paramMap: - values.append("current: %s" % self.paramMap[param]) + if param in self._defaultParamMap: + values.append("default: %s" % self._defaultParamMap[param]) + if param in self._paramMap: + values.append("current: %s" % self._paramMap[param]) else: values.append("undefined") valueStr = "(" + ", ".join(values) + ")" @@ -88,7 +103,7 @@ def explainParams(self): Returns the documentation of all params with their optionally default values and user-supplied values. """ - return "\n".join([self._explain(param) for param in self.params]) + return "\n".join([self.explainParam(param) for param in self.params]) def getParam(self, paramName): """ @@ -105,56 +120,76 @@ def isSet(self, param): Checks whether a param is explicitly set by user. """ param = self._resolveParam(param) - return param in self.paramMap + return param in self._paramMap def hasDefault(self, param): """ Checks whether a param has a default value. """ param = self._resolveParam(param) - return param in self.defaultParamMap + return param in self._defaultParamMap def isDefined(self, param): """ - Checks whether a param is explicitly set by user or has a default value. + Checks whether a param is explicitly set by user or has + a default value. """ return self.isSet(param) or self.hasDefault(param) + def hasParam(self, paramName): + """ + Tests whether this instance contains a param with a given + (string) name. + """ + param = self._resolveParam(paramName) + return param in self.params + def getOrDefault(self, param): """ Gets the value of a param in the user-supplied param map or its - default value. Raises an error if either is set. + default value. Raises an error if neither is set. """ - if isinstance(param, Param): - if param in self.paramMap: - return self.paramMap[param] - else: - return self.defaultParamMap[param] - elif isinstance(param, str): - return self.getOrDefault(self.getParam(param)) + param = self._resolveParam(param) + if param in self._paramMap: + return self._paramMap[param] else: - raise KeyError("Cannot recognize %r as a param." % param) + return self._defaultParamMap[param] - def extractParamMap(self, extraParamMap={}): + def extractParamMap(self, extra={}): """ Extracts the embedded default param values and user-supplied values, and then merges them with extra values from input into a flat param map, where the latter value is used if there exist conflicts, i.e., with ordering: default param values < - user-supplied values < extraParamMap. - :param extraParamMap: extra param values + user-supplied values < extra. + :param extra: extra param values :return: merged param map """ - paramMap = self.defaultParamMap.copy() - paramMap.update(self.paramMap) - paramMap.update(extraParamMap) + paramMap = self._defaultParamMap.copy() + paramMap.update(self._paramMap) + paramMap.update(extra) return paramMap + def copy(self, extra={}): + """ + Creates a copy of this instance with the same uid and some + extra params. The default implementation creates a + shallow copy using :py:func:`copy.copy`, and then copies the + embedded and extra parameters over and returns the copy. + Subclasses should override this method if the default approach + is not sufficient. + :param extra: Extra parameters to copy to the new instance + :return: Copy of this instance + """ + that = copy.copy(self) + that._paramMap = self.extractParamMap(extra) + return that + def _shouldOwn(self, param): """ Validates that the input param belongs to this Params instance. """ - if param.parent is not self: + if not (self.uid == param.parent and self.hasParam(param.name)): raise ValueError("Param %r does not belong to %r." % (param, self)) def _resolveParam(self, param): @@ -175,7 +210,8 @@ def _resolveParam(self, param): @staticmethod def _dummy(): """ - Returns a dummy Params instance used as a placeholder to generate docs. + Returns a dummy Params instance used as a placeholder to + generate docs. """ dummy = Params() dummy.uid = "undefined" @@ -186,7 +222,7 @@ def _set(self, **kwargs): Sets user-supplied params. """ for param, value in kwargs.items(): - self.paramMap[getattr(self, param)] = value + self._paramMap[getattr(self, param)] = value return self def _setDefault(self, **kwargs): @@ -194,5 +230,19 @@ def _setDefault(self, **kwargs): Sets default params. """ for param, value in kwargs.items(): - self.defaultParamMap[getattr(self, param)] = value + self._defaultParamMap[getattr(self, param)] = value return self + + def _copyValues(self, to, extra={}): + """ + Copies param values from this instance to another instance for + params shared by them. + :param to: the target instance + :param extra: extra params to be copied + :return: the target instance with param values copied + """ + paramMap = self.extractParamMap(extra) + for p in self.params: + if p in paramMap and to.hasParam(p.name): + to._set(**{p.name: paramMap[p]}) + return to diff --git a/python/pyspark/ml/param/_shared_params_code_gen.py b/python/pyspark/ml/param/_shared_params_code_gen.py index 6fa9b8c2cf36..69efc424ec4e 100644 --- a/python/pyspark/ml/param/_shared_params_code_gen.py +++ b/python/pyspark/ml/param/_shared_params_code_gen.py @@ -56,9 +56,10 @@ def _gen_param_header(name, doc, defaultValueStr): def __init__(self): super(Has$Name, self).__init__() #: param for $doc - self.$name = Param(self, "$name", "$doc") - if $defaultValueStr is not None: - self._setDefault($name=$defaultValueStr)''' + self.$name = Param(self, "$name", "$doc")''' + if defaultValueStr is not None: + template += ''' + self._setDefault($name=$defaultValueStr)''' Name = name[0].upper() + name[1:] return template \ @@ -83,7 +84,7 @@ def set$Name(self, value): """ Sets the value of :py:attr:`$name`. """ - self.paramMap[self.$name] = value + self._paramMap[self.$name] = value return self def get$Name(self): @@ -115,10 +116,10 @@ def get$Name(self): ("rawPredictionCol", "raw prediction (a.k.a. confidence) column name", "'rawPrediction'"), ("inputCol", "input column name", None), ("inputCols", "input column names", None), - ("outputCol", "output column name", None), + ("outputCol", "output column name", "self.uid + '__output'"), ("numFeatures", "number of features", None), ("checkpointInterval", "checkpoint interval (>= 1)", None), - ("seed", "random seed", None), + ("seed", "random seed", "hash(type(self).__name__)"), ("tol", "the convergence tolerance for iterative algorithms", None), ("stepSize", "Step size to be used for each iteration of optimization.", None)] code = [] diff --git a/python/pyspark/ml/param/shared.py b/python/pyspark/ml/param/shared.py index b116f05a068d..bc088e4c29e2 100644 --- a/python/pyspark/ml/param/shared.py +++ b/python/pyspark/ml/param/shared.py @@ -32,14 +32,12 @@ def __init__(self): super(HasMaxIter, self).__init__() #: param for max number of iterations (>= 0) self.maxIter = Param(self, "maxIter", "max number of iterations (>= 0)") - if None is not None: - self._setDefault(maxIter=None) def setMaxIter(self, value): """ Sets the value of :py:attr:`maxIter`. """ - self.paramMap[self.maxIter] = value + self._paramMap[self.maxIter] = value return self def getMaxIter(self): @@ -61,14 +59,12 @@ def __init__(self): super(HasRegParam, self).__init__() #: param for regularization parameter (>= 0) self.regParam = Param(self, "regParam", "regularization parameter (>= 0)") - if None is not None: - self._setDefault(regParam=None) def setRegParam(self, value): """ Sets the value of :py:attr:`regParam`. """ - self.paramMap[self.regParam] = value + self._paramMap[self.regParam] = value return self def getRegParam(self): @@ -90,14 +86,13 @@ def __init__(self): super(HasFeaturesCol, self).__init__() #: param for features column name self.featuresCol = Param(self, "featuresCol", "features column name") - if 'features' is not None: - self._setDefault(featuresCol='features') + self._setDefault(featuresCol='features') def setFeaturesCol(self, value): """ Sets the value of :py:attr:`featuresCol`. """ - self.paramMap[self.featuresCol] = value + self._paramMap[self.featuresCol] = value return self def getFeaturesCol(self): @@ -119,14 +114,13 @@ def __init__(self): super(HasLabelCol, self).__init__() #: param for label column name self.labelCol = Param(self, "labelCol", "label column name") - if 'label' is not None: - self._setDefault(labelCol='label') + self._setDefault(labelCol='label') def setLabelCol(self, value): """ Sets the value of :py:attr:`labelCol`. """ - self.paramMap[self.labelCol] = value + self._paramMap[self.labelCol] = value return self def getLabelCol(self): @@ -148,14 +142,13 @@ def __init__(self): super(HasPredictionCol, self).__init__() #: param for prediction column name self.predictionCol = Param(self, "predictionCol", "prediction column name") - if 'prediction' is not None: - self._setDefault(predictionCol='prediction') + self._setDefault(predictionCol='prediction') def setPredictionCol(self, value): """ Sets the value of :py:attr:`predictionCol`. """ - self.paramMap[self.predictionCol] = value + self._paramMap[self.predictionCol] = value return self def getPredictionCol(self): @@ -177,14 +170,13 @@ def __init__(self): super(HasProbabilityCol, self).__init__() #: param for Column name for predicted class conditional probabilities. Note: Not all models output well-calibrated probability estimates! These probabilities should be treated as confidences, not precise probabilities. self.probabilityCol = Param(self, "probabilityCol", "Column name for predicted class conditional probabilities. Note: Not all models output well-calibrated probability estimates! These probabilities should be treated as confidences, not precise probabilities.") - if 'probability' is not None: - self._setDefault(probabilityCol='probability') + self._setDefault(probabilityCol='probability') def setProbabilityCol(self, value): """ Sets the value of :py:attr:`probabilityCol`. """ - self.paramMap[self.probabilityCol] = value + self._paramMap[self.probabilityCol] = value return self def getProbabilityCol(self): @@ -206,14 +198,13 @@ def __init__(self): super(HasRawPredictionCol, self).__init__() #: param for raw prediction (a.k.a. confidence) column name self.rawPredictionCol = Param(self, "rawPredictionCol", "raw prediction (a.k.a. confidence) column name") - if 'rawPrediction' is not None: - self._setDefault(rawPredictionCol='rawPrediction') + self._setDefault(rawPredictionCol='rawPrediction') def setRawPredictionCol(self, value): """ Sets the value of :py:attr:`rawPredictionCol`. """ - self.paramMap[self.rawPredictionCol] = value + self._paramMap[self.rawPredictionCol] = value return self def getRawPredictionCol(self): @@ -235,14 +226,12 @@ def __init__(self): super(HasInputCol, self).__init__() #: param for input column name self.inputCol = Param(self, "inputCol", "input column name") - if None is not None: - self._setDefault(inputCol=None) def setInputCol(self, value): """ Sets the value of :py:attr:`inputCol`. """ - self.paramMap[self.inputCol] = value + self._paramMap[self.inputCol] = value return self def getInputCol(self): @@ -264,14 +253,12 @@ def __init__(self): super(HasInputCols, self).__init__() #: param for input column names self.inputCols = Param(self, "inputCols", "input column names") - if None is not None: - self._setDefault(inputCols=None) def setInputCols(self, value): """ Sets the value of :py:attr:`inputCols`. """ - self.paramMap[self.inputCols] = value + self._paramMap[self.inputCols] = value return self def getInputCols(self): @@ -293,14 +280,13 @@ def __init__(self): super(HasOutputCol, self).__init__() #: param for output column name self.outputCol = Param(self, "outputCol", "output column name") - if None is not None: - self._setDefault(outputCol=None) + self._setDefault(outputCol=self.uid + '__output') def setOutputCol(self, value): """ Sets the value of :py:attr:`outputCol`. """ - self.paramMap[self.outputCol] = value + self._paramMap[self.outputCol] = value return self def getOutputCol(self): @@ -322,14 +308,12 @@ def __init__(self): super(HasNumFeatures, self).__init__() #: param for number of features self.numFeatures = Param(self, "numFeatures", "number of features") - if None is not None: - self._setDefault(numFeatures=None) def setNumFeatures(self, value): """ Sets the value of :py:attr:`numFeatures`. """ - self.paramMap[self.numFeatures] = value + self._paramMap[self.numFeatures] = value return self def getNumFeatures(self): @@ -351,14 +335,12 @@ def __init__(self): super(HasCheckpointInterval, self).__init__() #: param for checkpoint interval (>= 1) self.checkpointInterval = Param(self, "checkpointInterval", "checkpoint interval (>= 1)") - if None is not None: - self._setDefault(checkpointInterval=None) def setCheckpointInterval(self, value): """ Sets the value of :py:attr:`checkpointInterval`. """ - self.paramMap[self.checkpointInterval] = value + self._paramMap[self.checkpointInterval] = value return self def getCheckpointInterval(self): @@ -380,14 +362,13 @@ def __init__(self): super(HasSeed, self).__init__() #: param for random seed self.seed = Param(self, "seed", "random seed") - if None is not None: - self._setDefault(seed=None) + self._setDefault(seed=hash(type(self).__name__)) def setSeed(self, value): """ Sets the value of :py:attr:`seed`. """ - self.paramMap[self.seed] = value + self._paramMap[self.seed] = value return self def getSeed(self): @@ -409,14 +390,12 @@ def __init__(self): super(HasTol, self).__init__() #: param for the convergence tolerance for iterative algorithms self.tol = Param(self, "tol", "the convergence tolerance for iterative algorithms") - if None is not None: - self._setDefault(tol=None) def setTol(self, value): """ Sets the value of :py:attr:`tol`. """ - self.paramMap[self.tol] = value + self._paramMap[self.tol] = value return self def getTol(self): @@ -438,14 +417,12 @@ def __init__(self): super(HasStepSize, self).__init__() #: param for Step size to be used for each iteration of optimization. self.stepSize = Param(self, "stepSize", "Step size to be used for each iteration of optimization.") - if None is not None: - self._setDefault(stepSize=None) def setStepSize(self, value): """ Sets the value of :py:attr:`stepSize`. """ - self.paramMap[self.stepSize] = value + self._paramMap[self.stepSize] = value return self def getStepSize(self): @@ -467,6 +444,7 @@ class DecisionTreeParams(Params): minInfoGain = Param(Params._dummy(), "minInfoGain", "Minimum information gain for a split to be considered at a tree node.") maxMemoryInMB = Param(Params._dummy(), "maxMemoryInMB", "Maximum memory in MB allocated to histogram aggregation.") cacheNodeIds = Param(Params._dummy(), "cacheNodeIds", "If false, the algorithm will pass trees to executors to match instances with nodes. If true, the algorithm will cache node IDs for each instance. Caching can speed up training of deeper trees.") + def __init__(self): super(DecisionTreeParams, self).__init__() @@ -482,12 +460,12 @@ def __init__(self): self.maxMemoryInMB = Param(self, "maxMemoryInMB", "Maximum memory in MB allocated to histogram aggregation.") #: param for If false, the algorithm will pass trees to executors to match instances with nodes. If true, the algorithm will cache node IDs for each instance. Caching can speed up training of deeper trees. self.cacheNodeIds = Param(self, "cacheNodeIds", "If false, the algorithm will pass trees to executors to match instances with nodes. If true, the algorithm will cache node IDs for each instance. Caching can speed up training of deeper trees.") - + def setMaxDepth(self, value): """ Sets the value of :py:attr:`maxDepth`. """ - self.paramMap[self.maxDepth] = value + self._paramMap[self.maxDepth] = value return self def getMaxDepth(self): @@ -500,7 +478,7 @@ def setMaxBins(self, value): """ Sets the value of :py:attr:`maxBins`. """ - self.paramMap[self.maxBins] = value + self._paramMap[self.maxBins] = value return self def getMaxBins(self): @@ -513,7 +491,7 @@ def setMinInstancesPerNode(self, value): """ Sets the value of :py:attr:`minInstancesPerNode`. """ - self.paramMap[self.minInstancesPerNode] = value + self._paramMap[self.minInstancesPerNode] = value return self def getMinInstancesPerNode(self): @@ -526,7 +504,7 @@ def setMinInfoGain(self, value): """ Sets the value of :py:attr:`minInfoGain`. """ - self.paramMap[self.minInfoGain] = value + self._paramMap[self.minInfoGain] = value return self def getMinInfoGain(self): @@ -539,7 +517,7 @@ def setMaxMemoryInMB(self, value): """ Sets the value of :py:attr:`maxMemoryInMB`. """ - self.paramMap[self.maxMemoryInMB] = value + self._paramMap[self.maxMemoryInMB] = value return self def getMaxMemoryInMB(self): @@ -552,7 +530,7 @@ def setCacheNodeIds(self, value): """ Sets the value of :py:attr:`cacheNodeIds`. """ - self.paramMap[self.cacheNodeIds] = value + self._paramMap[self.cacheNodeIds] = value return self def getCacheNodeIds(self): diff --git a/python/pyspark/ml/pipeline.py b/python/pyspark/ml/pipeline.py index a328bcf84a2e..9889f56cac9e 100644 --- a/python/pyspark/ml/pipeline.py +++ b/python/pyspark/ml/pipeline.py @@ -31,18 +31,42 @@ class Estimator(Params): __metaclass__ = ABCMeta @abstractmethod - def fit(self, dataset, params={}): + def _fit(self, dataset): """ - Fits a model to the input dataset with optional parameters. + Fits a model to the input dataset. This is called by the + default implementation of fit. :param dataset: input dataset, which is an instance of :py:class:`pyspark.sql.DataFrame` - :param params: an optional param map that overwrites embedded - params :returns: fitted model """ raise NotImplementedError() + def fit(self, dataset, params=None): + """ + Fits a model to the input dataset with optional parameters. + + :param dataset: input dataset, which is an instance of + :py:class:`pyspark.sql.DataFrame` + :param params: an optional param map that overrides embedded + params. If a list/tuple of param maps is given, + this calls fit on each param map and returns a + list of models. + :returns: fitted model(s) + """ + if params is None: + params = dict() + if isinstance(params, (list, tuple)): + return [self.fit(dataset, paramMap) for paramMap in params] + elif isinstance(params, dict): + if params: + return self.copy(params)._fit(dataset) + else: + return self._fit(dataset) + else: + raise ValueError("Params must be either a param map or a list/tuple of param maps, " + "but got %s." % type(params)) + @inherit_doc class Transformer(Params): @@ -54,18 +78,36 @@ class Transformer(Params): __metaclass__ = ABCMeta @abstractmethod - def transform(self, dataset, params={}): + def _transform(self, dataset): """ Transforms the input dataset with optional parameters. :param dataset: input dataset, which is an instance of :py:class:`pyspark.sql.DataFrame` - :param params: an optional param map that overwrites embedded - params :returns: transformed dataset """ raise NotImplementedError() + def transform(self, dataset, params=None): + """ + Transforms the input dataset with optional parameters. + + :param dataset: input dataset, which is an instance of + :py:class:`pyspark.sql.DataFrame` + :param params: an optional param map that overrides embedded + params. + :returns: transformed dataset + """ + if params is None: + params = dict() + if isinstance(params, dict): + if params: + return self.copy(params,)._transform(dataset) + else: + return self._transform(dataset) + else: + raise ValueError("Params must be either a param map but got %s." % type(params)) + @inherit_doc class Model(Transformer): @@ -97,10 +139,12 @@ class Pipeline(Estimator): """ @keyword_only - def __init__(self, stages=[]): + def __init__(self, stages=None): """ __init__(self, stages=[]) """ + if stages is None: + stages = [] super(Pipeline, self).__init__() #: Param for pipeline stages. self.stages = Param(self, "stages", "pipeline stages") @@ -113,28 +157,29 @@ def setStages(self, value): :param value: a list of transformers or estimators :return: the pipeline instance """ - self.paramMap[self.stages] = value + self._paramMap[self.stages] = value return self def getStages(self): """ Get pipeline stages. """ - if self.stages in self.paramMap: - return self.paramMap[self.stages] + if self.stages in self._paramMap: + return self._paramMap[self.stages] @keyword_only - def setParams(self, stages=[]): + def setParams(self, stages=None): """ setParams(self, stages=[]) Sets params for Pipeline. """ + if stages is None: + stages = [] kwargs = self.setParams._input_kwargs return self._set(**kwargs) - def fit(self, dataset, params={}): - paramMap = self.extractParamMap(params) - stages = paramMap[self.stages] + def _fit(self, dataset): + stages = self.getStages() for stage in stages: if not (isinstance(stage, Estimator) or isinstance(stage, Transformer)): raise TypeError( @@ -148,16 +193,23 @@ def fit(self, dataset, params={}): if i <= indexOfLastEstimator: if isinstance(stage, Transformer): transformers.append(stage) - dataset = stage.transform(dataset, paramMap) + dataset = stage.transform(dataset) else: # must be an Estimator - model = stage.fit(dataset, paramMap) + model = stage.fit(dataset) transformers.append(model) if i < indexOfLastEstimator: - dataset = model.transform(dataset, paramMap) + dataset = model.transform(dataset) else: transformers.append(stage) return PipelineModel(transformers) + def copy(self, extra=None): + if extra is None: + extra = dict() + that = Params.copy(self, extra) + stages = [stage.copy(extra) for stage in that.getStages()] + return that.setStages(stages) + @inherit_doc class PipelineModel(Model): @@ -165,33 +217,17 @@ class PipelineModel(Model): Represents a compiled pipeline with transformers and fitted models. """ - def __init__(self, transformers): + def __init__(self, stages): super(PipelineModel, self).__init__() - self.transformers = transformers + self.stages = stages - def transform(self, dataset, params={}): - paramMap = self.extractParamMap(params) - for t in self.transformers: - dataset = t.transform(dataset, paramMap) + def _transform(self, dataset): + for t in self.stages: + dataset = t.transform(dataset) return dataset - -class Evaluator(Params): - """ - Base class for evaluators that compute metrics from predictions. - """ - - __metaclass__ = ABCMeta - - @abstractmethod - def evaluate(self, dataset, params={}): - """ - Evaluates the output. - - :param dataset: a dataset that contains labels/observations and - predictions - :param params: an optional param map that overrides embedded - params - :return: metric - """ - raise NotImplementedError() + def copy(self, extra=None): + if extra is None: + extra = dict() + stages = [stage.copy(extra) for stage in self.stages] + return PipelineModel(stages) diff --git a/python/pyspark/ml/recommendation.py b/python/pyspark/ml/recommendation.py index b2439cbd9652..b06099ac0aee 100644 --- a/python/pyspark/ml/recommendation.py +++ b/python/pyspark/ml/recommendation.py @@ -63,8 +63,15 @@ class ALS(JavaEstimator, HasCheckpointInterval, HasMaxIter, HasPredictionCol, Ha indicated user preferences rather than explicit ratings given to items. + >>> df = sqlContext.createDataFrame( + ... [(0, 0, 4.0), (0, 1, 2.0), (1, 1, 3.0), (1, 2, 4.0), (2, 1, 1.0), (2, 2, 5.0)], + ... ["user", "item", "rating"]) >>> als = ALS(rank=10, maxIter=5) >>> model = als.fit(df) + >>> model.rank + 10 + >>> model.userFactors.orderBy("id").collect() + [Row(id=0, features=[...]), Row(id=1, ...), Row(id=2, ...)] >>> test = sqlContext.createDataFrame([(0, 2), (1, 0), (2, 0)], ["user", "item"]) >>> predictions = sorted(model.transform(test).collect(), key=lambda r: r[0]) >>> predictions[0] @@ -74,7 +81,7 @@ class ALS(JavaEstimator, HasCheckpointInterval, HasMaxIter, HasPredictionCol, Ha >>> predictions[2] Row(user=2, item=0, prediction=-1.15...) """ - _java_class = "org.apache.spark.ml.recommendation.ALS" + # a placeholder to make it appear in the generated doc rank = Param(Params._dummy(), "rank", "rank of the factorization") numUserBlocks = Param(Params._dummy(), "numUserBlocks", "number of user blocks") @@ -89,14 +96,15 @@ class ALS(JavaEstimator, HasCheckpointInterval, HasMaxIter, HasPredictionCol, Ha @keyword_only def __init__(self, rank=10, maxIter=10, regParam=0.1, numUserBlocks=10, numItemBlocks=10, - implicitPrefs=False, alpha=1.0, userCol="user", itemCol="item", seed=0, + implicitPrefs=False, alpha=1.0, userCol="user", itemCol="item", seed=None, ratingCol="rating", nonnegative=False, checkpointInterval=10): """ __init__(self, rank=10, maxIter=10, regParam=0.1, numUserBlocks=10, numItemBlocks=10, \ - implicitPrefs=false, alpha=1.0, userCol="user", itemCol="item", seed=0, \ + implicitPrefs=false, alpha=1.0, userCol="user", itemCol="item", seed=None, \ ratingCol="rating", nonnegative=false, checkpointInterval=10) """ super(ALS, self).__init__() + self._java_obj = self._new_java_obj("org.apache.spark.ml.recommendation.ALS", self.uid) self.rank = Param(self, "rank", "rank of the factorization") self.numUserBlocks = Param(self, "numUserBlocks", "number of user blocks") self.numItemBlocks = Param(self, "numItemBlocks", "number of item blocks") @@ -108,18 +116,18 @@ def __init__(self, rank=10, maxIter=10, regParam=0.1, numUserBlocks=10, numItemB self.nonnegative = Param(self, "nonnegative", "whether to use nonnegative constraint for least squares") self._setDefault(rank=10, maxIter=10, regParam=0.1, numUserBlocks=10, numItemBlocks=10, - implicitPrefs=False, alpha=1.0, userCol="user", itemCol="item", seed=0, + implicitPrefs=False, alpha=1.0, userCol="user", itemCol="item", seed=None, ratingCol="rating", nonnegative=False, checkpointInterval=10) kwargs = self.__init__._input_kwargs self.setParams(**kwargs) @keyword_only def setParams(self, rank=10, maxIter=10, regParam=0.1, numUserBlocks=10, numItemBlocks=10, - implicitPrefs=False, alpha=1.0, userCol="user", itemCol="item", seed=0, + implicitPrefs=False, alpha=1.0, userCol="user", itemCol="item", seed=None, ratingCol="rating", nonnegative=False, checkpointInterval=10): """ setParams(self, rank=10, maxIter=10, regParam=0.1, numUserBlocks=10, numItemBlocks=10, \ - implicitPrefs=False, alpha=1.0, userCol="user", itemCol="item", seed=0, \ + implicitPrefs=False, alpha=1.0, userCol="user", itemCol="item", seed=None, \ ratingCol="rating", nonnegative=False, checkpointInterval=10) Sets params for ALS. """ @@ -133,7 +141,7 @@ def setRank(self, value): """ Sets the value of :py:attr:`rank`. """ - self.paramMap[self.rank] = value + self._paramMap[self.rank] = value return self def getRank(self): @@ -146,7 +154,7 @@ def setNumUserBlocks(self, value): """ Sets the value of :py:attr:`numUserBlocks`. """ - self.paramMap[self.numUserBlocks] = value + self._paramMap[self.numUserBlocks] = value return self def getNumUserBlocks(self): @@ -159,7 +167,7 @@ def setNumItemBlocks(self, value): """ Sets the value of :py:attr:`numItemBlocks`. """ - self.paramMap[self.numItemBlocks] = value + self._paramMap[self.numItemBlocks] = value return self def getNumItemBlocks(self): @@ -172,14 +180,14 @@ def setNumBlocks(self, value): """ Sets both :py:attr:`numUserBlocks` and :py:attr:`numItemBlocks` to the specific value. """ - self.paramMap[self.numUserBlocks] = value - self.paramMap[self.numItemBlocks] = value + self._paramMap[self.numUserBlocks] = value + self._paramMap[self.numItemBlocks] = value def setImplicitPrefs(self, value): """ Sets the value of :py:attr:`implicitPrefs`. """ - self.paramMap[self.implicitPrefs] = value + self._paramMap[self.implicitPrefs] = value return self def getImplicitPrefs(self): @@ -192,7 +200,7 @@ def setAlpha(self, value): """ Sets the value of :py:attr:`alpha`. """ - self.paramMap[self.alpha] = value + self._paramMap[self.alpha] = value return self def getAlpha(self): @@ -205,7 +213,7 @@ def setUserCol(self, value): """ Sets the value of :py:attr:`userCol`. """ - self.paramMap[self.userCol] = value + self._paramMap[self.userCol] = value return self def getUserCol(self): @@ -218,7 +226,7 @@ def setItemCol(self, value): """ Sets the value of :py:attr:`itemCol`. """ - self.paramMap[self.itemCol] = value + self._paramMap[self.itemCol] = value return self def getItemCol(self): @@ -231,7 +239,7 @@ def setRatingCol(self, value): """ Sets the value of :py:attr:`ratingCol`. """ - self.paramMap[self.ratingCol] = value + self._paramMap[self.ratingCol] = value return self def getRatingCol(self): @@ -244,7 +252,7 @@ def setNonnegative(self, value): """ Sets the value of :py:attr:`nonnegative`. """ - self.paramMap[self.nonnegative] = value + self._paramMap[self.nonnegative] = value return self def getNonnegative(self): @@ -259,6 +267,27 @@ class ALSModel(JavaModel): Model fitted by ALS. """ + @property + def rank(self): + """rank of the matrix factorization model""" + return self._call_java("rank") + + @property + def userFactors(self): + """ + a DataFrame that stores user factors in two columns: `id` and + `features` + """ + return self._call_java("userFactors") + + @property + def itemFactors(self): + """ + a DataFrame that stores item factors in two columns: `id` and + `features` + """ + return self._call_java("itemFactors") + if __name__ == "__main__": import doctest @@ -271,8 +300,6 @@ class ALSModel(JavaModel): sqlContext = SQLContext(sc) globs['sc'] = sc globs['sqlContext'] = sqlContext - globs['df'] = sqlContext.createDataFrame([(0, 0, 4.0), (0, 1, 2.0), (1, 1, 3.0), (1, 2, 4.0), - (2, 1, 1.0), (2, 2, 5.0)], ["user", "item", "rating"]) (failure_count, test_count) = doctest.testmod(globs=globs, optionflags=doctest.ELLIPSIS) sc.stop() if failure_count: diff --git a/python/pyspark/ml/regression.py b/python/pyspark/ml/regression.py index ef77e1932718..b139e27372d8 100644 --- a/python/pyspark/ml/regression.py +++ b/python/pyspark/ml/regression.py @@ -62,7 +62,7 @@ class LinearRegression(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPrediction ... TypeError: Method setParams forces keyword arguments. """ - _java_class = "org.apache.spark.ml.regression.LinearRegression" + # a placeholder to make it appear in the generated doc elasticNetParam = \ Param(Params._dummy(), "elasticNetParam", @@ -77,6 +77,8 @@ def __init__(self, featuresCol="features", labelCol="label", predictionCol="pred maxIter=100, regParam=0.0, elasticNetParam=0.0, tol=1e-6) """ super(LinearRegression, self).__init__() + self._java_obj = self._new_java_obj( + "org.apache.spark.ml.regression.LinearRegression", self.uid) #: param for the ElasticNet mixing parameter, in range [0, 1]. For alpha = 0, the penalty # is an L2 penalty. For alpha = 1, it is an L1 penalty. self.elasticNetParam = \ @@ -105,7 +107,7 @@ def setElasticNetParam(self, value): """ Sets the value of :py:attr:`elasticNetParam`. """ - self.paramMap[self.elasticNetParam] = value + self._paramMap[self.elasticNetParam] = value return self def getElasticNetParam(self): @@ -178,7 +180,6 @@ class DecisionTreeRegressor(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPredi 1.0 """ - _java_class = "org.apache.spark.ml.regression.DecisionTreeRegressor" # a placeholder to make it appear in the generated doc impurity = Param(Params._dummy(), "impurity", "Criterion used for information gain calculation (case-insensitive). " + @@ -194,6 +195,8 @@ def __init__(self, featuresCol="features", labelCol="label", predictionCol="pred maxMemoryInMB=256, cacheNodeIds=False, checkpointInterval=10, impurity="variance") """ super(DecisionTreeRegressor, self).__init__() + self._java_obj = self._new_java_obj( + "org.apache.spark.ml.regression.DecisionTreeRegressor", self.uid) #: param for Criterion used for information gain calculation (case-insensitive). self.impurity = \ Param(self, "impurity", @@ -226,7 +229,7 @@ def setImpurity(self, value): """ Sets the value of :py:attr:`impurity`. """ - self.paramMap[self.impurity] = value + self._paramMap[self.impurity] = value return self def getImpurity(self): @@ -254,7 +257,7 @@ class RandomForestRegressor(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPredi >>> df = sqlContext.createDataFrame([ ... (1.0, Vectors.dense(1.0)), ... (0.0, Vectors.sparse(1, [], []))], ["label", "features"]) - >>> rf = RandomForestRegressor(numTrees=2, maxDepth=2) + >>> rf = RandomForestRegressor(numTrees=2, maxDepth=2, seed=42) >>> model = rf.fit(df) >>> test0 = sqlContext.createDataFrame([(Vectors.dense(-1.0),)], ["features"]) >>> model.transform(test0).head().prediction @@ -264,7 +267,6 @@ class RandomForestRegressor(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPredi 0.5 """ - _java_class = "org.apache.spark.ml.regression.RandomForestRegressor" # a placeholder to make it appear in the generated doc impurity = Param(Params._dummy(), "impurity", "Criterion used for information gain calculation (case-insensitive). " + @@ -282,14 +284,17 @@ class RandomForestRegressor(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPredi def __init__(self, featuresCol="features", labelCol="label", predictionCol="prediction", maxDepth=5, maxBins=32, minInstancesPerNode=1, minInfoGain=0.0, maxMemoryInMB=256, cacheNodeIds=False, checkpointInterval=10, impurity="variance", - numTrees=20, featureSubsetStrategy="auto", seed=42): + numTrees=20, featureSubsetStrategy="auto", seed=None): """ __init__(self, featuresCol="features", labelCol="label", predictionCol="prediction", \ maxDepth=5, maxBins=32, minInstancesPerNode=1, minInfoGain=0.0, \ maxMemoryInMB=256, cacheNodeIds=False, checkpointInterval=10, \ - impurity="variance", numTrees=20, featureSubsetStrategy="auto", seed=42) + impurity="variance", numTrees=20, \ + featureSubsetStrategy="auto", seed=None) """ super(RandomForestRegressor, self).__init__() + self._java_obj = self._new_java_obj( + "org.apache.spark.ml.regression.RandomForestRegressor", self.uid) #: param for Criterion used for information gain calculation (case-insensitive). self.impurity = \ Param(self, "impurity", @@ -308,7 +313,7 @@ def __init__(self, featuresCol="features", labelCol="label", predictionCol="pred "The number of features to consider for splits at each tree node. Supported " + "options: " + ", ".join(RandomForestParams.supportedFeatureSubsetStrategies)) self._setDefault(maxDepth=5, maxBins=32, minInstancesPerNode=1, minInfoGain=0.0, - maxMemoryInMB=256, cacheNodeIds=False, checkpointInterval=10, seed=42, + maxMemoryInMB=256, cacheNodeIds=False, checkpointInterval=10, seed=None, impurity="variance", numTrees=20, featureSubsetStrategy="auto") kwargs = self.__init__._input_kwargs self.setParams(**kwargs) @@ -316,12 +321,12 @@ def __init__(self, featuresCol="features", labelCol="label", predictionCol="pred @keyword_only def setParams(self, featuresCol="features", labelCol="label", predictionCol="prediction", maxDepth=5, maxBins=32, minInstancesPerNode=1, minInfoGain=0.0, - maxMemoryInMB=256, cacheNodeIds=False, checkpointInterval=10, seed=42, + maxMemoryInMB=256, cacheNodeIds=False, checkpointInterval=10, seed=None, impurity="variance", numTrees=20, featureSubsetStrategy="auto"): """ setParams(self, featuresCol="features", labelCol="label", predictionCol="prediction", \ maxDepth=5, maxBins=32, minInstancesPerNode=1, minInfoGain=0.0, \ - maxMemoryInMB=256, cacheNodeIds=False, checkpointInterval=10, seed=42, \ + maxMemoryInMB=256, cacheNodeIds=False, checkpointInterval=10, seed=None, \ impurity="variance", numTrees=20, featureSubsetStrategy="auto") Sets params for linear regression. """ @@ -335,7 +340,7 @@ def setImpurity(self, value): """ Sets the value of :py:attr:`impurity`. """ - self.paramMap[self.impurity] = value + self._paramMap[self.impurity] = value return self def getImpurity(self): @@ -348,7 +353,7 @@ def setSubsamplingRate(self, value): """ Sets the value of :py:attr:`subsamplingRate`. """ - self.paramMap[self.subsamplingRate] = value + self._paramMap[self.subsamplingRate] = value return self def getSubsamplingRate(self): @@ -361,7 +366,7 @@ def setNumTrees(self, value): """ Sets the value of :py:attr:`numTrees`. """ - self.paramMap[self.numTrees] = value + self._paramMap[self.numTrees] = value return self def getNumTrees(self): @@ -374,7 +379,7 @@ def setFeatureSubsetStrategy(self, value): """ Sets the value of :py:attr:`featureSubsetStrategy`. """ - self.paramMap[self.featureSubsetStrategy] = value + self._paramMap[self.featureSubsetStrategy] = value return self def getFeatureSubsetStrategy(self): @@ -412,7 +417,6 @@ class GBTRegressor(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPredictionCol, 1.0 """ - _java_class = "org.apache.spark.ml.regression.GBTRegressor" # a placeholder to make it appear in the generated doc lossType = Param(Params._dummy(), "lossType", "Loss function which GBT tries to minimize (case-insensitive). " + @@ -436,6 +440,7 @@ def __init__(self, featuresCol="features", labelCol="label", predictionCol="pred lossType="squared", maxIter=20, stepSize=0.1) """ super(GBTRegressor, self).__init__() + self._java_obj = self._new_java_obj("org.apache.spark.ml.regression.GBTRegressor", self.uid) #: param for Loss function which GBT tries to minimize (case-insensitive). self.lossType = Param(self, "lossType", "Loss function which GBT tries to minimize (case-insensitive). " + @@ -477,7 +482,7 @@ def setLossType(self, value): """ Sets the value of :py:attr:`lossType`. """ - self.paramMap[self.lossType] = value + self._paramMap[self.lossType] = value return self def getLossType(self): @@ -490,7 +495,7 @@ def setSubsamplingRate(self, value): """ Sets the value of :py:attr:`subsamplingRate`. """ - self.paramMap[self.subsamplingRate] = value + self._paramMap[self.subsamplingRate] = value return self def getSubsamplingRate(self): @@ -503,7 +508,7 @@ def setStepSize(self, value): """ Sets the value of :py:attr:`stepSize`. """ - self.paramMap[self.stepSize] = value + self._paramMap[self.stepSize] = value return self def getStepSize(self): diff --git a/python/pyspark/ml/tests.py b/python/pyspark/ml/tests.py index ba6478dcd58a..c151d21fd661 100644 --- a/python/pyspark/ml/tests.py +++ b/python/pyspark/ml/tests.py @@ -31,10 +31,13 @@ import unittest from pyspark.tests import ReusedPySparkTestCase as PySparkTestCase -from pyspark.sql import DataFrame -from pyspark.ml.param import Param -from pyspark.ml.param.shared import HasMaxIter, HasInputCol -from pyspark.ml.pipeline import Estimator, Model, Pipeline, Transformer +from pyspark.sql import DataFrame, SQLContext +from pyspark.ml.param import Param, Params +from pyspark.ml.param.shared import HasMaxIter, HasInputCol, HasSeed +from pyspark.ml.util import keyword_only +from pyspark.ml import Estimator, Model, Pipeline, Transformer +from pyspark.ml.feature import * +from pyspark.mllib.linalg import DenseVector class MockDataset(DataFrame): @@ -43,44 +46,43 @@ def __init__(self): self.index = 0 -class MockTransformer(Transformer): +class HasFake(Params): + + def __init__(self): + super(HasFake, self).__init__() + self.fake = Param(self, "fake", "fake param") + + def getFake(self): + return self.getOrDefault(self.fake) + + +class MockTransformer(Transformer, HasFake): def __init__(self): super(MockTransformer, self).__init__() - self.fake = Param(self, "fake", "fake") self.dataset_index = None - self.fake_param_value = None - def transform(self, dataset, params={}): + def _transform(self, dataset): self.dataset_index = dataset.index - if self.fake in params: - self.fake_param_value = params[self.fake] dataset.index += 1 return dataset -class MockEstimator(Estimator): +class MockEstimator(Estimator, HasFake): def __init__(self): super(MockEstimator, self).__init__() - self.fake = Param(self, "fake", "fake") self.dataset_index = None - self.fake_param_value = None - self.model = None - def fit(self, dataset, params={}): + def _fit(self, dataset): self.dataset_index = dataset.index - if self.fake in params: - self.fake_param_value = params[self.fake] model = MockModel() - self.model = model + self._copyValues(model) return model -class MockModel(MockTransformer, Model): - - def __init__(self): - super(MockModel, self).__init__() +class MockModel(MockTransformer, Model, HasFake): + pass class PipelineTests(PySparkTestCase): @@ -91,19 +93,17 @@ def test_pipeline(self): transformer1 = MockTransformer() estimator2 = MockEstimator() transformer3 = MockTransformer() - pipeline = Pipeline() \ - .setStages([estimator0, transformer1, estimator2, transformer3]) + pipeline = Pipeline(stages=[estimator0, transformer1, estimator2, transformer3]) pipeline_model = pipeline.fit(dataset, {estimator0.fake: 0, transformer1.fake: 1}) - self.assertEqual(0, estimator0.dataset_index) - self.assertEqual(0, estimator0.fake_param_value) - model0 = estimator0.model + model0, transformer1, model2, transformer3 = pipeline_model.stages self.assertEqual(0, model0.dataset_index) + self.assertEqual(0, model0.getFake()) self.assertEqual(1, transformer1.dataset_index) - self.assertEqual(1, transformer1.fake_param_value) - self.assertEqual(2, estimator2.dataset_index) - model2 = estimator2.model - self.assertIsNone(model2.dataset_index, "The model produced by the last estimator should " - "not be called during fit.") + self.assertEqual(1, transformer1.getFake()) + self.assertEqual(2, dataset.index) + self.assertIsNone(model2.dataset_index, "The last model shouldn't be called in fit.") + self.assertIsNone(transformer3.dataset_index, + "The last transformer shouldn't be called in fit.") dataset = pipeline_model.transform(dataset) self.assertEqual(2, model0.dataset_index) self.assertEqual(3, transformer1.dataset_index) @@ -112,14 +112,46 @@ def test_pipeline(self): self.assertEqual(6, dataset.index) -class TestParams(HasMaxIter, HasInputCol): +class TestParams(HasMaxIter, HasInputCol, HasSeed): """ - A subclass of Params mixed with HasMaxIter and HasInputCol. + A subclass of Params mixed with HasMaxIter, HasInputCol and HasSeed. """ - - def __init__(self): + @keyword_only + def __init__(self, seed=None): super(TestParams, self).__init__() self._setDefault(maxIter=10) + kwargs = self.__init__._input_kwargs + self.setParams(**kwargs) + + @keyword_only + def setParams(self, seed=None): + """ + setParams(self, seed=None) + Sets params for this test. + """ + kwargs = self.setParams._input_kwargs + return self._set(**kwargs) + + +class OtherTestParams(HasMaxIter, HasInputCol, HasSeed): + """ + A subclass of Params mixed with HasMaxIter, HasInputCol and HasSeed. + """ + @keyword_only + def __init__(self, seed=None): + super(OtherTestParams, self).__init__() + self._setDefault(maxIter=10) + kwargs = self.__init__._input_kwargs + self.setParams(**kwargs) + + @keyword_only + def setParams(self, seed=None): + """ + setParams(self, seed=None) + Sets params for this test. + """ + kwargs = self.setParams._input_kwargs + return self._set(**kwargs) class ParamTests(PySparkTestCase): @@ -129,16 +161,18 @@ def test_param(self): maxIter = testParams.maxIter self.assertEqual(maxIter.name, "maxIter") self.assertEqual(maxIter.doc, "max number of iterations (>= 0)") - self.assertTrue(maxIter.parent is testParams) + self.assertTrue(maxIter.parent == testParams.uid) def test_params(self): testParams = TestParams() maxIter = testParams.maxIter inputCol = testParams.inputCol + seed = testParams.seed params = testParams.params - self.assertEqual(params, [inputCol, maxIter]) + self.assertEqual(params, [inputCol, maxIter, seed]) + self.assertTrue(testParams.hasParam(maxIter)) self.assertTrue(testParams.hasDefault(maxIter)) self.assertFalse(testParams.isSet(maxIter)) self.assertTrue(testParams.isDefined(maxIter)) @@ -147,16 +181,87 @@ def test_params(self): self.assertTrue(testParams.isSet(maxIter)) self.assertEquals(testParams.getMaxIter(), 100) + self.assertTrue(testParams.hasParam(inputCol)) self.assertFalse(testParams.hasDefault(inputCol)) self.assertFalse(testParams.isSet(inputCol)) self.assertFalse(testParams.isDefined(inputCol)) with self.assertRaises(KeyError): testParams.getInputCol() + # Since the default is normally random, set it to a known number for debug str + testParams._setDefault(seed=41) + testParams.setSeed(43) + self.assertEquals( testParams.explainParams(), "\n".join(["inputCol: input column name (undefined)", - "maxIter: max number of iterations (>= 0) (default: 10, current: 100)"])) + "maxIter: max number of iterations (>= 0) (default: 10, current: 100)", + "seed: random seed (default: 41, current: 43)"])) + + def test_hasseed(self): + noSeedSpecd = TestParams() + withSeedSpecd = TestParams(seed=42) + other = OtherTestParams() + # Check that we no longer use 42 as the magic number + self.assertNotEqual(noSeedSpecd.getSeed(), 42) + origSeed = noSeedSpecd.getSeed() + # Check that we only compute the seed once + self.assertEqual(noSeedSpecd.getSeed(), origSeed) + # Check that a specified seed is honored + self.assertEqual(withSeedSpecd.getSeed(), 42) + # Check that a different class has a different seed + self.assertNotEqual(other.getSeed(), noSeedSpecd.getSeed()) + + +class FeatureTests(PySparkTestCase): + + def test_binarizer(self): + b0 = Binarizer() + self.assertListEqual(b0.params, [b0.inputCol, b0.outputCol, b0.threshold]) + self.assertTrue(all([~b0.isSet(p) for p in b0.params])) + self.assertTrue(b0.hasDefault(b0.threshold)) + self.assertEqual(b0.getThreshold(), 0.0) + b0.setParams(inputCol="input", outputCol="output").setThreshold(1.0) + self.assertTrue(all([b0.isSet(p) for p in b0.params])) + self.assertEqual(b0.getThreshold(), 1.0) + self.assertEqual(b0.getInputCol(), "input") + self.assertEqual(b0.getOutputCol(), "output") + + b0c = b0.copy({b0.threshold: 2.0}) + self.assertEqual(b0c.uid, b0.uid) + self.assertListEqual(b0c.params, b0.params) + self.assertEqual(b0c.getThreshold(), 2.0) + + b1 = Binarizer(threshold=2.0, inputCol="input", outputCol="output") + self.assertNotEqual(b1.uid, b0.uid) + self.assertEqual(b1.getThreshold(), 2.0) + self.assertEqual(b1.getInputCol(), "input") + self.assertEqual(b1.getOutputCol(), "output") + + def test_idf(self): + sqlContext = SQLContext(self.sc) + dataset = sqlContext.createDataFrame([ + (DenseVector([1.0, 2.0]),), + (DenseVector([0.0, 1.0]),), + (DenseVector([3.0, 0.2]),)], ["tf"]) + idf0 = IDF(inputCol="tf") + self.assertListEqual(idf0.params, [idf0.inputCol, idf0.minDocFreq, idf0.outputCol]) + idf0m = idf0.fit(dataset, {idf0.outputCol: "idf"}) + self.assertEqual(idf0m.uid, idf0.uid, + "Model should inherit the UID from its parent estimator.") + output = idf0m.transform(dataset) + self.assertIsNotNone(output.head().idf) + + def test_ngram(self): + sqlContext = SQLContext(self.sc) + dataset = sqlContext.createDataFrame([ + ([["a", "b", "c", "d", "e"]])], ["input"]) + ngram0 = NGram(n=4, inputCol="input", outputCol="output") + self.assertEqual(ngram0.getN(), 4) + self.assertEqual(ngram0.getInputCol(), "input") + self.assertEqual(ngram0.getOutputCol(), "output") + transformedDF = ngram0.transform(dataset) + self.assertEquals(transformedDF.head().output, ["a b c d", "b c d e"]) if __name__ == "__main__": diff --git a/python/pyspark/ml/tuning.py b/python/pyspark/ml/tuning.py index 86f4dc7368be..0bf988fd72f1 100644 --- a/python/pyspark/ml/tuning.py +++ b/python/pyspark/ml/tuning.py @@ -91,20 +91,19 @@ class CrossValidator(Estimator): >>> from pyspark.ml.evaluation import BinaryClassificationEvaluator >>> from pyspark.mllib.linalg import Vectors >>> dataset = sqlContext.createDataFrame( - ... [(Vectors.dense([0.0, 1.0]), 0.0), - ... (Vectors.dense([1.0, 2.0]), 1.0), - ... (Vectors.dense([0.55, 3.0]), 0.0), - ... (Vectors.dense([0.45, 4.0]), 1.0), - ... (Vectors.dense([0.51, 5.0]), 1.0)] * 10, + ... [(Vectors.dense([0.0]), 0.0), + ... (Vectors.dense([0.4]), 1.0), + ... (Vectors.dense([0.5]), 0.0), + ... (Vectors.dense([0.6]), 1.0), + ... (Vectors.dense([1.0]), 1.0)] * 10, ... ["features", "label"]) >>> lr = LogisticRegression() - >>> grid = ParamGridBuilder().addGrid(lr.maxIter, [0, 1, 5]).build() + >>> grid = ParamGridBuilder().addGrid(lr.maxIter, [0, 1]).build() >>> evaluator = BinaryClassificationEvaluator() >>> cv = CrossValidator(estimator=lr, estimatorParamMaps=grid, evaluator=evaluator) - >>> # SPARK-7432: The following test is flaky. - >>> # cvModel = cv.fit(dataset) - >>> # expected = lr.fit(dataset, {lr.maxIter: 5}).transform(dataset) - >>> # cvModel.transform(dataset).collect() == expected.collect() + >>> cvModel = cv.fit(dataset) + >>> evaluator.evaluate(cvModel.transform(dataset)) + 0.8333... """ # a placeholder to make it appear in the generated doc @@ -155,7 +154,7 @@ def setEstimator(self, value): """ Sets the value of :py:attr:`estimator`. """ - self.paramMap[self.estimator] = value + self._paramMap[self.estimator] = value return self def getEstimator(self): @@ -168,7 +167,7 @@ def setEstimatorParamMaps(self, value): """ Sets the value of :py:attr:`estimatorParamMaps`. """ - self.paramMap[self.estimatorParamMaps] = value + self._paramMap[self.estimatorParamMaps] = value return self def getEstimatorParamMaps(self): @@ -181,7 +180,7 @@ def setEvaluator(self, value): """ Sets the value of :py:attr:`evaluator`. """ - self.paramMap[self.evaluator] = value + self._paramMap[self.evaluator] = value return self def getEvaluator(self): @@ -194,7 +193,7 @@ def setNumFolds(self, value): """ Sets the value of :py:attr:`numFolds`. """ - self.paramMap[self.numFolds] = value + self._paramMap[self.numFolds] = value return self def getNumFolds(self): @@ -203,13 +202,12 @@ def getNumFolds(self): """ return self.getOrDefault(self.numFolds) - def fit(self, dataset, params={}): - paramMap = self.extractParamMap(params) - est = paramMap[self.estimator] - epm = paramMap[self.estimatorParamMaps] + def _fit(self, dataset): + est = self.getOrDefault(self.estimator) + epm = self.getOrDefault(self.estimatorParamMaps) numModels = len(epm) - eva = paramMap[self.evaluator] - nFolds = paramMap[self.numFolds] + eva = self.getOrDefault(self.evaluator) + nFolds = self.getOrDefault(self.numFolds) h = 1.0 / nFolds randCol = self.uid + "_rand" df = dataset.select("*", rand(0).alias(randCol)) @@ -229,6 +227,15 @@ def fit(self, dataset, params={}): bestModel = est.fit(dataset, epm[bestIndex]) return CrossValidatorModel(bestModel) + def copy(self, extra={}): + newCV = Params.copy(self, extra) + if self.isSet(self.estimator): + newCV.setEstimator(self.getEstimator().copy(extra)) + # estimatorParamMaps remain the same + if self.isSet(self.evaluator): + newCV.setEvaluator(self.getEvaluator().copy(extra)) + return newCV + class CrossValidatorModel(Model): """ @@ -240,8 +247,19 @@ def __init__(self, bestModel): #: best model from cross validation self.bestModel = bestModel - def transform(self, dataset, params={}): - return self.bestModel.transform(dataset, params) + def _transform(self, dataset): + return self.bestModel.transform(dataset) + + def copy(self, extra={}): + """ + Creates a copy of this instance with a randomly generated uid + and some extra params. This copies the underlying bestModel, + creates a deep copy of the embedded paramMap, and + copies the embedded and extra parameters over. + :param extra: Extra parameters to copy to the new instance + :return: Copy of this instance + """ + return CrossValidatorModel(self.bestModel.copy(extra)) if __name__ == "__main__": diff --git a/python/pyspark/ml/util.py b/python/pyspark/ml/util.py index d3cb100a9efa..cee9d67b0532 100644 --- a/python/pyspark/ml/util.py +++ b/python/pyspark/ml/util.py @@ -39,9 +39,16 @@ class Identifiable(object): """ def __init__(self): - #: A unique id for the object. The default implementation - #: concatenates the class name, "_", and 8 random hex chars. - self.uid = type(self).__name__ + "_" + uuid.uuid4().hex[:8] + #: A unique id for the object. + self.uid = self._randomUID() def __repr__(self): return self.uid + + @classmethod + def _randomUID(cls): + """ + Generate a unique id for the object. The default implementation + concatenates the class name, "_", and 12 random hex chars. + """ + return cls.__name__ + "_" + uuid.uuid4().hex[12:] diff --git a/python/pyspark/ml/wrapper.py b/python/pyspark/ml/wrapper.py index dda6c6aba304..253705bde913 100644 --- a/python/pyspark/ml/wrapper.py +++ b/python/pyspark/ml/wrapper.py @@ -20,7 +20,7 @@ from pyspark import SparkContext from pyspark.sql import DataFrame from pyspark.ml.param import Params -from pyspark.ml.pipeline import Estimator, Transformer, Evaluator, Model +from pyspark.ml.pipeline import Estimator, Transformer, Model from pyspark.mllib.common import inherit_doc, _java2py, _py2java @@ -45,46 +45,61 @@ class JavaWrapper(Params): __metaclass__ = ABCMeta - #: Fully-qualified class name of the wrapped Java component. - _java_class = None + #: The wrapped Java companion object. Subclasses should initialize + #: it properly. The param values in the Java object should be + #: synced with the Python wrapper in fit/transform/evaluate/copy. + _java_obj = None - def _java_obj(self): + @staticmethod + def _new_java_obj(java_class, *args): """ - Returns or creates a Java object. + Construct a new Java object. """ + sc = SparkContext._active_spark_context java_obj = _jvm() - for name in self._java_class.split("."): + for name in java_class.split("."): java_obj = getattr(java_obj, name) - return java_obj() + java_args = [_py2java(sc, arg) for arg in args] + return java_obj(*java_args) - def _transfer_params_to_java(self, params, java_obj): + def _make_java_param_pair(self, param, value): """ - Transforms the embedded params and additional params to the - input Java object. - :param params: additional params (overwriting embedded values) - :param java_obj: Java object to receive the params + Makes a Java parm pair. + """ + sc = SparkContext._active_spark_context + param = self._resolveParam(param) + java_param = self._java_obj.getParam(param.name) + java_value = _py2java(sc, value) + return java_param.w(java_value) + + def _transfer_params_to_java(self): + """ + Transforms the embedded params to the companion Java object. """ - paramMap = self.extractParamMap(params) + paramMap = self.extractParamMap() for param in self.params: if param in paramMap: - value = paramMap[param] - java_param = java_obj.getParam(param.name) - java_obj.set(java_param.w(value)) + pair = self._make_java_param_pair(param, paramMap[param]) + self._java_obj.set(pair) - def _empty_java_param_map(self): + def _transfer_params_from_java(self): + """ + Transforms the embedded params from the companion Java object. + """ + sc = SparkContext._active_spark_context + for param in self.params: + if self._java_obj.hasParam(param.name): + java_param = self._java_obj.getParam(param.name) + value = _java2py(sc, self._java_obj.getOrDefault(java_param)) + self._paramMap[param] = value + + @staticmethod + def _empty_java_param_map(): """ Returns an empty Java ParamMap reference. """ return _jvm().org.apache.spark.ml.param.ParamMap() - def _create_java_param_map(self, params, java_obj): - paramMap = self._empty_java_param_map() - for param, value in params.items(): - if param.parent is self: - java_param = java_obj.getParam(param.name) - paramMap.put(java_param.w(value)) - return paramMap - @inherit_doc class JavaEstimator(Estimator, JavaWrapper): @@ -99,9 +114,9 @@ def _create_model(self, java_model): """ Creates a model from the input Java model reference. """ - return JavaModel(java_model) + raise NotImplementedError() - def _fit_java(self, dataset, params={}): + def _fit_java(self, dataset): """ Fits a Java model to the input dataset. :param dataset: input dataset, which is an instance of @@ -109,12 +124,11 @@ def _fit_java(self, dataset, params={}): :param params: additional params (overwriting embedded values) :return: fitted Java model """ - java_obj = self._java_obj() - self._transfer_params_to_java(params, java_obj) - return java_obj.fit(dataset._jdf, self._empty_java_param_map()) + self._transfer_params_to_java() + return self._java_obj.fit(dataset._jdf) - def fit(self, dataset, params={}): - java_model = self._fit_java(dataset, params) + def _fit(self, dataset): + java_model = self._fit_java(dataset) return self._create_model(java_model) @@ -127,45 +141,49 @@ class JavaTransformer(Transformer, JavaWrapper): __metaclass__ = ABCMeta - def transform(self, dataset, params={}): - java_obj = self._java_obj() - self._transfer_params_to_java(params, java_obj) - return DataFrame(java_obj.transform(dataset._jdf), dataset.sql_ctx) + def _transform(self, dataset): + self._transfer_params_to_java() + return DataFrame(self._java_obj.transform(dataset._jdf), dataset.sql_ctx) @inherit_doc class JavaModel(Model, JavaTransformer): """ Base class for :py:class:`Model`s that wrap Java/Scala - implementations. + implementations. Subclasses should inherit this class before + param mix-ins, because this sets the UID from the Java model. """ __metaclass__ = ABCMeta def __init__(self, java_model): - super(JavaTransformer, self).__init__() - self._java_model = java_model + """ + Initialize this instance with a Java model object. + Subclasses should call this constructor, initialize params, + and then call _transformer_params_from_java. + """ + super(JavaModel, self).__init__() + self._java_obj = java_model + self.uid = java_model.uid() - def _java_obj(self): - return self._java_model + def copy(self, extra=None): + """ + Creates a copy of this instance with the same uid and some + extra params. This implementation first calls Params.copy and + then make a copy of the companion Java model with extra params. + So both the Python wrapper and the Java model get copied. + :param extra: Extra parameters to copy to the new instance + :return: Copy of this instance + """ + if extra is None: + extra = dict() + that = super(JavaModel, self).copy(extra) + that._java_obj = self._java_obj.copy(self._empty_java_param_map()) + that._transfer_params_to_java() + return that def _call_java(self, name, *args): - m = getattr(self._java_model, name) + m = getattr(self._java_obj, name) sc = SparkContext._active_spark_context java_args = [_py2java(sc, arg) for arg in args] return _java2py(sc, m(*java_args)) - - -@inherit_doc -class JavaEvaluator(Evaluator, JavaWrapper): - """ - Base class for :py:class:`Evaluator`s that wrap Java/Scala - implementations. - """ - - __metaclass__ = ABCMeta - - def evaluate(self, dataset, params={}): - java_obj = self._java_obj() - self._transfer_params_to_java(params, java_obj) - return java_obj.evaluate(dataset._jdf, self._empty_java_param_map()) diff --git a/python/pyspark/mllib/__init__.py b/python/pyspark/mllib/__init__.py index 07507b2ad0d0..acba3a717d21 100644 --- a/python/pyspark/mllib/__init__.py +++ b/python/pyspark/mllib/__init__.py @@ -23,16 +23,10 @@ # MLlib currently needs NumPy 1.4+, so complain if lower import numpy -if numpy.version.version < '1.4': + +ver = [int(x) for x in numpy.version.version.split('.')[:2]] +if ver < [1, 4]: raise Exception("MLlib requires NumPy 1.4+") __all__ = ['classification', 'clustering', 'feature', 'fpm', 'linalg', 'random', 'recommendation', 'regression', 'stat', 'tree', 'util'] - -import sys -from . import rand as random -modname = __name__ + '.random' -random.__name__ = modname -random.RandomRDDs.__module__ = modname -sys.modules[modname] = random -del modname, sys diff --git a/python/pyspark/mllib/classification.py b/python/pyspark/mllib/classification.py index a70c664a71fd..8f27c446a66e 100644 --- a/python/pyspark/mllib/classification.py +++ b/python/pyspark/mllib/classification.py @@ -21,20 +21,24 @@ from numpy import array from pyspark import RDD +from pyspark.streaming import DStream from pyspark.mllib.common import callMLlibFunc, _py2java, _java2py from pyspark.mllib.linalg import DenseVector, SparseVector, _convert_to_vector -from pyspark.mllib.regression import LabeledPoint, LinearModel, _regression_train_wrapper +from pyspark.mllib.regression import ( + LabeledPoint, LinearModel, _regression_train_wrapper, + StreamingLinearAlgorithm) from pyspark.mllib.util import Saveable, Loader, inherit_doc __all__ = ['LogisticRegressionModel', 'LogisticRegressionWithSGD', 'LogisticRegressionWithLBFGS', - 'SVMModel', 'SVMWithSGD', 'NaiveBayesModel', 'NaiveBayes'] + 'SVMModel', 'SVMWithSGD', 'NaiveBayesModel', 'NaiveBayes', + 'StreamingLogisticRegressionWithSGD'] class LinearClassificationModel(LinearModel): """ - A private abstract class representing a multiclass classification model. - The categories are represented by int values: 0, 1, 2, etc. + A private abstract class representing a multiclass classification + model. The categories are represented by int values: 0, 1, 2, etc. """ def __init__(self, weights, intercept): super(LinearClassificationModel, self).__init__(weights, intercept) @@ -44,10 +48,11 @@ def setThreshold(self, value): """ .. note:: Experimental - Sets the threshold that separates positive predictions from negative - predictions. An example with prediction score greater than or equal - to this threshold is identified as an positive, and negative otherwise. - It is used for binary classification only. + Sets the threshold that separates positive predictions from + negative predictions. An example with prediction score greater + than or equal to this threshold is identified as an positive, + and negative otherwise. It is used for binary classification + only. """ self._threshold = value @@ -56,8 +61,9 @@ def threshold(self): """ .. note:: Experimental - Returns the threshold (if any) used for converting raw prediction scores - into 0/1 predictions. It is used for binary classification only. + Returns the threshold (if any) used for converting raw + prediction scores into 0/1 predictions. It is used for + binary classification only. """ return self._threshold @@ -65,22 +71,35 @@ def clearThreshold(self): """ .. note:: Experimental - Clears the threshold so that `predict` will output raw prediction scores. - It is used for binary classification only. + Clears the threshold so that `predict` will output raw + prediction scores. It is used for binary classification only. """ self._threshold = None def predict(self, test): """ - Predict values for a single data point or an RDD of points using - the model trained. + Predict values for a single data point or an RDD of points + using the model trained. """ raise NotImplementedError class LogisticRegressionModel(LinearClassificationModel): - """A linear binary classification model derived from logistic regression. + """ + Classification model trained using Multinomial/Binary Logistic + Regression. + + :param weights: Weights computed for every feature. + :param intercept: Intercept computed for this model. (Only used + in Binary Logistic Regression. In Multinomial Logistic + Regression, the intercepts will not be a single value, + so the intercepts will be part of the weights.) + :param numFeatures: the dimension of the features. + :param numClasses: the number of possible outcomes for k classes + classification problem in Multinomial Logistic Regression. + By default, it is binary logistic regression so numClasses + will be set to 2. >>> data = [ ... LabeledPoint(0.0, [0.0, 1.0]), @@ -120,8 +139,9 @@ class LogisticRegressionModel(LinearClassificationModel): 1 >>> sameModel.predict(SparseVector(2, {0: 1.0})) 0 + >>> from shutil import rmtree >>> try: - ... os.removedirs(path) + ... rmtree(path) ... except: ... pass >>> multi_class_data = [ @@ -161,8 +181,8 @@ def numClasses(self): def predict(self, x): """ - Predict values for a single data point or an RDD of points using - the model trained. + Predict values for a single data point or an RDD of points + using the model trained. """ if isinstance(x, RDD): return x.map(lambda v: self.predict(v)) @@ -225,16 +245,19 @@ def train(cls, data, iterations=100, step=1.0, miniBatchFraction=1.0, """ Train a logistic regression model on the given data. - :param data: The training data, an RDD of LabeledPoint. - :param iterations: The number of iterations (default: 100). + :param data: The training data, an RDD of + LabeledPoint. + :param iterations: The number of iterations + (default: 100). :param step: The step parameter used in SGD (default: 1.0). - :param miniBatchFraction: Fraction of data to be used for each SGD - iteration. + :param miniBatchFraction: Fraction of data to be used for each + SGD iteration (default: 1.0). :param initialWeights: The initial weights (default: None). - :param regParam: The regularizer parameter (default: 0.01). - :param regType: The type of regularizer used for training - our model. + :param regParam: The regularizer parameter + (default: 0.01). + :param regType: The type of regularizer used for + training our model. :Allowed values: - "l1" for using L1 regularization @@ -243,13 +266,14 @@ def train(cls, data, iterations=100, step=1.0, miniBatchFraction=1.0, (default: "l2") - :param intercept: Boolean parameter which indicates the use - or not of the augmented representation for - training data (i.e. whether bias features - are activated or not). - :param validateData: Boolean parameter which indicates if the - algorithm should validate data before training. - (default: True) + :param intercept: Boolean parameter which indicates the + use or not of the augmented representation + for training data (i.e. whether bias + features are activated or not, + default: False). + :param validateData: Boolean parameter which indicates if + the algorithm should validate data + before training. (default: True) """ def train(rdd, i): return callMLlibFunc("trainLogisticRegressionModelWithSGD", rdd, int(iterations), @@ -267,12 +291,15 @@ def train(cls, data, iterations=100, initialWeights=None, regParam=0.01, regType """ Train a logistic regression model on the given data. - :param data: The training data, an RDD of LabeledPoint. - :param iterations: The number of iterations (default: 100). + :param data: The training data, an RDD of + LabeledPoint. + :param iterations: The number of iterations + (default: 100). :param initialWeights: The initial weights (default: None). - :param regParam: The regularizer parameter (default: 0.01). - :param regType: The type of regularizer used for training - our model. + :param regParam: The regularizer parameter + (default: 0.01). + :param regType: The type of regularizer used for + training our model. :Allowed values: - "l1" for using L1 regularization @@ -281,19 +308,21 @@ def train(cls, data, iterations=100, initialWeights=None, regParam=0.01, regType (default: "l2") - :param intercept: Boolean parameter which indicates the use - or not of the augmented representation for - training data (i.e. whether bias features - are activated or not). - :param corrections: The number of corrections used in the LBFGS - update (default: 10). - :param tolerance: The convergence tolerance of iterations for - L-BFGS (default: 1e-4). + :param intercept: Boolean parameter which indicates the + use or not of the augmented representation + for training data (i.e. whether bias + features are activated or not, + default: False). + :param corrections: The number of corrections used in the + LBFGS update (default: 10). + :param tolerance: The convergence tolerance of iterations + for L-BFGS (default: 1e-4). :param validateData: Boolean parameter which indicates if the - algorithm should validate data before training. - (default: True) - :param numClasses: The number of classes (i.e., outcomes) a label can take - in Multinomial Logistic Regression (default: 2). + algorithm should validate data before + training. (default: True) + :param numClasses: The number of classes (i.e., outcomes) a + label can take in Multinomial Logistic + Regression (default: 2). >>> data = [ ... LabeledPoint(0.0, [0.0, 1.0]), @@ -323,7 +352,11 @@ def train(rdd, i): class SVMModel(LinearClassificationModel): - """A support vector machine. + """ + Model for Support Vector Machines (SVMs). + + :param weights: Weights computed for every feature. + :param intercept: Intercept computed for this model. >>> data = [ ... LabeledPoint(0.0, [0.0]), @@ -359,8 +392,9 @@ class SVMModel(LinearClassificationModel): 1 >>> sameModel.predict(SparseVector(2, {0: -1.0})) 0 + >>> from shutil import rmtree >>> try: - ... os.removedirs(path) + ... rmtree(path) ... except: ... pass """ @@ -370,8 +404,8 @@ def __init__(self, weights, intercept): def predict(self, x): """ - Predict values for a single data point or an RDD of points using - the model trained. + Predict values for a single data point or an RDD of points + using the model trained. """ if isinstance(x, RDD): return x.map(lambda v: self.predict(v)) @@ -409,16 +443,19 @@ def train(cls, data, iterations=100, step=1.0, regParam=0.01, """ Train a support vector machine on the given data. - :param data: The training data, an RDD of LabeledPoint. - :param iterations: The number of iterations (default: 100). + :param data: The training data, an RDD of + LabeledPoint. + :param iterations: The number of iterations + (default: 100). :param step: The step parameter used in SGD (default: 1.0). - :param regParam: The regularizer parameter (default: 0.01). - :param miniBatchFraction: Fraction of data to be used for each SGD - iteration. + :param regParam: The regularizer parameter + (default: 0.01). + :param miniBatchFraction: Fraction of data to be used for each + SGD iteration (default: 1.0). :param initialWeights: The initial weights (default: None). - :param regType: The type of regularizer used for training - our model. + :param regType: The type of regularizer used for + training our model. :Allowed values: - "l1" for using L1 regularization @@ -427,13 +464,14 @@ def train(cls, data, iterations=100, step=1.0, regParam=0.01, (default: "l2") - :param intercept: Boolean parameter which indicates the use - or not of the augmented representation for - training data (i.e. whether bias features - are activated or not). - :param validateData: Boolean parameter which indicates if the - algorithm should validate data before training. - (default: True) + :param intercept: Boolean parameter which indicates the + use or not of the augmented representation + for training data (i.e. whether bias + features are activated or not, + default: False). + :param validateData: Boolean parameter which indicates if + the algorithm should validate data + before training. (default: True) """ def train(rdd, i): return callMLlibFunc("trainSVMModelWithSGD", rdd, int(iterations), float(step), @@ -449,9 +487,11 @@ class NaiveBayesModel(Saveable, Loader): """ Model for Naive Bayes classifiers. - Contains two parameters: - - pi: vector of logs of class priors (dimension C) - - theta: matrix of logs of class conditional probabilities (CxD) + :param labels: list of labels. + :param pi: log of class priors, whose dimension is C, + number of labels. + :param theta: log of class conditional probabilities, whose + dimension is C-by-D, where D is number of features. >>> data = [ ... LabeledPoint(0.0, [0.0, 0.0]), @@ -481,8 +521,9 @@ class NaiveBayesModel(Saveable, Loader): >>> sameModel = NaiveBayesModel.load(sc, path) >>> sameModel.predict(SparseVector(2, {0: 1.0})) == model.predict(SparseVector(2, {0: 1.0})) True + >>> from shutil import rmtree >>> try: - ... os.removedirs(path) + ... rmtree(path) ... except OSError: ... pass """ @@ -493,7 +534,10 @@ def __init__(self, labels, pi, theta): self.theta = theta def predict(self, x): - """Return the most likely class for a data vector or an RDD of vectors""" + """ + Return the most likely class for a data vector + or an RDD of vectors + """ if isinstance(x, RDD): return x.map(lambda v: self.predict(v)) x = _convert_to_vector(x) @@ -523,24 +567,76 @@ class NaiveBayes(object): @classmethod def train(cls, data, lambda_=1.0): """ - Train a Naive Bayes model given an RDD of (label, features) vectors. + Train a Naive Bayes model given an RDD of (label, features) + vectors. - This is the Multinomial NB (U{http://tinyurl.com/lsdw6p}) which can - handle all kinds of discrete data. For example, by converting - documents into TF-IDF vectors, it can be used for document - classification. By making every vector a 0-1 vector, it can also be - used as Bernoulli NB (U{http://tinyurl.com/p7c96j6}). + This is the Multinomial NB (U{http://tinyurl.com/lsdw6p}) which + can handle all kinds of discrete data. For example, by + converting documents into TF-IDF vectors, it can be used for + document classification. By making every vector a 0-1 vector, + it can also be used as Bernoulli NB (U{http://tinyurl.com/p7c96j6}). + The input feature values must be nonnegative. :param data: RDD of LabeledPoint. - :param lambda_: The smoothing parameter + :param lambda_: The smoothing parameter (default: 1.0). """ first = data.first() if not isinstance(first, LabeledPoint): raise ValueError("`data` should be an RDD of LabeledPoint") - labels, pi, theta = callMLlibFunc("trainNaiveBayes", data, lambda_) + labels, pi, theta = callMLlibFunc("trainNaiveBayesModel", data, lambda_) return NaiveBayesModel(labels.toArray(), pi.toArray(), numpy.array(theta)) +@inherit_doc +class StreamingLogisticRegressionWithSGD(StreamingLinearAlgorithm): + """ + Run LogisticRegression with SGD on a batch of data. + + The weights obtained at the end of training a stream are used as initial + weights for the next batch. + + :param stepSize: Step size for each iteration of gradient descent. + :param numIterations: Number of iterations run for each batch of data. + :param miniBatchFraction: Fraction of data on which SGD is run for each + iteration. + :param regParam: L2 Regularization parameter. + """ + def __init__(self, stepSize=0.1, numIterations=50, miniBatchFraction=1.0, regParam=0.01): + self.stepSize = stepSize + self.numIterations = numIterations + self.regParam = regParam + self.miniBatchFraction = miniBatchFraction + self._model = None + super(StreamingLogisticRegressionWithSGD, self).__init__( + model=self._model) + + def setInitialWeights(self, initialWeights): + """ + Set the initial value of weights. + + This must be set before running trainOn and predictOn. + """ + initialWeights = _convert_to_vector(initialWeights) + + # LogisticRegressionWithSGD does only binary classification. + self._model = LogisticRegressionModel( + initialWeights, 0, initialWeights.size, 2) + return self + + def trainOn(self, dstream): + """Train the model on the incoming dstream.""" + self._validate(dstream) + + def update(rdd): + # LogisticRegressionWithSGD.train raises an error for an empty RDD. + if not rdd.isEmpty(): + self._model = LogisticRegressionWithSGD.train( + rdd, self.numIterations, self.stepSize, + self.miniBatchFraction, self._model.weights) + + dstream.foreachRDD(update) + + def _test(): import doctest from pyspark import SparkContext diff --git a/python/pyspark/mllib/clustering.py b/python/pyspark/mllib/clustering.py index b55583f82223..a3eab635282f 100644 --- a/python/pyspark/mllib/clustering.py +++ b/python/pyspark/mllib/clustering.py @@ -21,16 +21,23 @@ if sys.version > '3': xrange = range -from numpy import array +from math import exp, log + +from numpy import array, random, tile + +from collections import namedtuple -from pyspark import RDD from pyspark import SparkContext -from pyspark.mllib.common import callMLlibFunc, callJavaFunc, _py2java, _java2py -from pyspark.mllib.linalg import SparseVector, _convert_to_vector +from pyspark.rdd import RDD, ignore_unicode_prefix +from pyspark.mllib.common import JavaModelWrapper, callMLlibFunc, callJavaFunc, _py2java, _java2py +from pyspark.mllib.linalg import SparseVector, _convert_to_vector, DenseVector from pyspark.mllib.stat.distribution import MultivariateGaussian -from pyspark.mllib.util import Saveable, Loader, inherit_doc +from pyspark.mllib.util import Saveable, Loader, inherit_doc, JavaLoader, JavaSaveable +from pyspark.streaming import DStream -__all__ = ['KMeansModel', 'KMeans', 'GaussianMixtureModel', 'GaussianMixture'] +__all__ = ['KMeansModel', 'KMeans', 'GaussianMixtureModel', 'GaussianMixture', + 'PowerIterationClusteringModel', 'PowerIterationClustering', + 'StreamingKMeans', 'StreamingKMeansModel'] @inherit_doc @@ -75,8 +82,9 @@ class KMeansModel(Saveable, Loader): >>> sameModel = KMeansModel.load(sc, path) >>> sameModel.predict(sparse_data[0]) == model.predict(sparse_data[0]) True + >>> from shutil import rmtree >>> try: - ... os.removedirs(path) + ... rmtree(path) ... except OSError: ... pass """ @@ -98,6 +106,9 @@ def predict(self, x): """Find the cluster to which x belongs in this model.""" best = 0 best_distance = float("inf") + if isinstance(x, RDD): + return x.map(self.predict) + x = _convert_to_vector(x) for i in xrange(len(self.centers)): distance = x.squared_distance(self.centers[i]) @@ -257,16 +268,293 @@ def train(cls, rdd, k, convergenceTol=1e-3, maxIterations=100, seed=None, initia initialModelWeights = initialModel.weights initialModelMu = [initialModel.gaussians[i].mu for i in range(initialModel.k)] initialModelSigma = [initialModel.gaussians[i].sigma for i in range(initialModel.k)] - weight, mu, sigma = callMLlibFunc("trainGaussianMixture", rdd.map(_convert_to_vector), k, - convergenceTol, maxIterations, seed, initialModelWeights, - initialModelMu, initialModelSigma) + weight, mu, sigma = callMLlibFunc("trainGaussianMixtureModel", rdd.map(_convert_to_vector), + k, convergenceTol, maxIterations, seed, + initialModelWeights, initialModelMu, initialModelSigma) mvg_obj = [MultivariateGaussian(mu[i], sigma[i]) for i in range(k)] return GaussianMixtureModel(weight, mvg_obj) +class PowerIterationClusteringModel(JavaModelWrapper, JavaSaveable, JavaLoader): + + """ + .. note:: Experimental + + Model produced by [[PowerIterationClustering]]. + + >>> data = [(0, 1, 1.0), (0, 2, 1.0), (1, 3, 1.0), (2, 3, 1.0), + ... (0, 3, 1.0), (1, 2, 1.0), (0, 4, 0.1)] + >>> rdd = sc.parallelize(data, 2) + >>> model = PowerIterationClustering.train(rdd, 2, 100) + >>> model.k + 2 + >>> import os, tempfile + >>> path = tempfile.mkdtemp() + >>> model.save(sc, path) + >>> sameModel = PowerIterationClusteringModel.load(sc, path) + >>> sameModel.k + 2 + >>> from shutil import rmtree + >>> try: + ... rmtree(path) + ... except OSError: + ... pass + """ + + @property + def k(self): + """ + Returns the number of clusters. + """ + return self.call("k") + + def assignments(self): + """ + Returns the cluster assignments of this model. + """ + return self.call("getAssignments").map( + lambda x: (PowerIterationClustering.Assignment(*x))) + + @classmethod + def load(cls, sc, path): + model = cls._load_java(sc, path) + wrapper = sc._jvm.PowerIterationClusteringModelWrapper(model) + return PowerIterationClusteringModel(wrapper) + + +class PowerIterationClustering(object): + """ + .. note:: Experimental + + Power Iteration Clustering (PIC), a scalable graph clustering algorithm + developed by [[http://www.icml2010.org/papers/387.pdf Lin and Cohen]]. + From the abstract: PIC finds a very low-dimensional embedding of a + dataset using truncated power iteration on a normalized pair-wise + similarity matrix of the data. + """ + + @classmethod + def train(cls, rdd, k, maxIterations=100, initMode="random"): + """ + :param rdd: an RDD of (i, j, s,,ij,,) tuples representing the + affinity matrix, which is the matrix A in the PIC paper. + The similarity s,,ij,, must be nonnegative. + This is a symmetric matrix and hence s,,ij,, = s,,ji,,. + For any (i, j) with nonzero similarity, there should be + either (i, j, s,,ij,,) or (j, i, s,,ji,,) in the input. + Tuples with i = j are ignored, because we assume + s,,ij,, = 0.0. + :param k: Number of clusters. + :param maxIterations: Maximum number of iterations of the + PIC algorithm. + :param initMode: Initialization mode. + """ + model = callMLlibFunc("trainPowerIterationClusteringModel", + rdd.map(_convert_to_vector), int(k), int(maxIterations), initMode) + return PowerIterationClusteringModel(model) + + class Assignment(namedtuple("Assignment", ["id", "cluster"])): + """ + Represents an (id, cluster) tuple. + """ + + +class StreamingKMeansModel(KMeansModel): + """ + .. note:: Experimental + + Clustering model which can perform an online update of the centroids. + + The update formula for each centroid is given by + + * c_t+1 = ((c_t * n_t * a) + (x_t * m_t)) / (n_t + m_t) + * n_t+1 = n_t * a + m_t + + where + + * c_t: Centroid at the n_th iteration. + * n_t: Number of samples (or) weights associated with the centroid + at the n_th iteration. + * x_t: Centroid of the new data closest to c_t. + * m_t: Number of samples (or) weights of the new data closest to c_t + * c_t+1: New centroid. + * n_t+1: New number of weights. + * a: Decay Factor, which gives the forgetfulness. + + Note that if a is set to 1, it is the weighted mean of the previous + and new data. If it set to zero, the old centroids are completely + forgotten. + + :param clusterCenters: Initial cluster centers. + :param clusterWeights: List of weights assigned to each cluster. + + >>> initCenters = [[0.0, 0.0], [1.0, 1.0]] + >>> initWeights = [1.0, 1.0] + >>> stkm = StreamingKMeansModel(initCenters, initWeights) + >>> data = sc.parallelize([[-0.1, -0.1], [0.1, 0.1], + ... [0.9, 0.9], [1.1, 1.1]]) + >>> stkm = stkm.update(data, 1.0, u"batches") + >>> stkm.centers + array([[ 0., 0.], + [ 1., 1.]]) + >>> stkm.predict([-0.1, -0.1]) + 0 + >>> stkm.predict([0.9, 0.9]) + 1 + >>> stkm.clusterWeights + [3.0, 3.0] + >>> decayFactor = 0.0 + >>> data = sc.parallelize([DenseVector([1.5, 1.5]), DenseVector([0.2, 0.2])]) + >>> stkm = stkm.update(data, 0.0, u"batches") + >>> stkm.centers + array([[ 0.2, 0.2], + [ 1.5, 1.5]]) + >>> stkm.clusterWeights + [1.0, 1.0] + >>> stkm.predict([0.2, 0.2]) + 0 + >>> stkm.predict([1.5, 1.5]) + 1 + """ + def __init__(self, clusterCenters, clusterWeights): + super(StreamingKMeansModel, self).__init__(centers=clusterCenters) + self._clusterWeights = list(clusterWeights) + + @property + def clusterWeights(self): + """Return the cluster weights.""" + return self._clusterWeights + + @ignore_unicode_prefix + def update(self, data, decayFactor, timeUnit): + """Update the centroids, according to data + + :param data: Should be a RDD that represents the new data. + :param decayFactor: forgetfulness of the previous centroids. + :param timeUnit: Can be "batches" or "points". If points, then the + decay factor is raised to the power of number of new + points and if batches, it is used as it is. + """ + if not isinstance(data, RDD): + raise TypeError("Data should be of an RDD, got %s." % type(data)) + data = data.map(_convert_to_vector) + decayFactor = float(decayFactor) + if timeUnit not in ["batches", "points"]: + raise ValueError( + "timeUnit should be 'batches' or 'points', got %s." % timeUnit) + vectorCenters = [_convert_to_vector(center) for center in self.centers] + updatedModel = callMLlibFunc( + "updateStreamingKMeansModel", vectorCenters, self._clusterWeights, + data, decayFactor, timeUnit) + self.centers = array(updatedModel[0]) + self._clusterWeights = list(updatedModel[1]) + return self + + +class StreamingKMeans(object): + """ + .. note:: Experimental + + Provides methods to set k, decayFactor, timeUnit to configure the + KMeans algorithm for fitting and predicting on incoming dstreams. + More details on how the centroids are updated are provided under the + docs of StreamingKMeansModel. + + :param k: int, number of clusters + :param decayFactor: float, forgetfulness of the previous centroids. + :param timeUnit: can be "batches" or "points". If points, then the + decayfactor is raised to the power of no. of new points. + """ + def __init__(self, k=2, decayFactor=1.0, timeUnit="batches"): + self._k = k + self._decayFactor = decayFactor + if timeUnit not in ["batches", "points"]: + raise ValueError( + "timeUnit should be 'batches' or 'points', got %s." % timeUnit) + self._timeUnit = timeUnit + self._model = None + + def latestModel(self): + """Return the latest model""" + return self._model + + def _validate(self, dstream): + if self._model is None: + raise ValueError( + "Initial centers should be set either by setInitialCenters " + "or setRandomCenters.") + if not isinstance(dstream, DStream): + raise TypeError( + "Expected dstream to be of type DStream, " + "got type %s" % type(dstream)) + + def setK(self, k): + """Set number of clusters.""" + self._k = k + return self + + def setDecayFactor(self, decayFactor): + """Set decay factor.""" + self._decayFactor = decayFactor + return self + + def setHalfLife(self, halfLife, timeUnit): + """ + Set number of batches after which the centroids of that + particular batch has half the weightage. + """ + self._timeUnit = timeUnit + self._decayFactor = exp(log(0.5) / halfLife) + return self + + def setInitialCenters(self, centers, weights): + """ + Set initial centers. Should be set before calling trainOn. + """ + self._model = StreamingKMeansModel(centers, weights) + return self + + def setRandomCenters(self, dim, weight, seed): + """ + Set the initial centres to be random samples from + a gaussian population with constant weights. + """ + rng = random.RandomState(seed) + clusterCenters = rng.randn(self._k, dim) + clusterWeights = tile(weight, self._k) + self._model = StreamingKMeansModel(clusterCenters, clusterWeights) + return self + + def trainOn(self, dstream): + """Train the model on the incoming dstream.""" + self._validate(dstream) + + def update(rdd): + self._model.update(rdd, self._decayFactor, self._timeUnit) + + dstream.foreachRDD(update) + + def predictOn(self, dstream): + """ + Make predictions on a dstream. + Returns a transformed dstream object + """ + self._validate(dstream) + return dstream.map(lambda x: self._model.predict(x)) + + def predictOnValues(self, dstream): + """ + Make predictions on a keyed dstream. + Returns a transformed dstream object. + """ + self._validate(dstream) + return dstream.mapValues(lambda x: self._model.predict(x)) + + def _test(): import doctest - globs = globals().copy() + import pyspark.mllib.clustering + globs = pyspark.mllib.clustering.__dict__.copy() globs['sc'] = SparkContext('local[4]', 'PythonTest', batchSize=2) (failure_count, test_count) = doctest.testmod(globs=globs, optionflags=doctest.ELLIPSIS) globs['sc'].stop() diff --git a/python/pyspark/mllib/common.py b/python/pyspark/mllib/common.py index ba6058978880..855e85f57155 100644 --- a/python/pyspark/mllib/common.py +++ b/python/pyspark/mllib/common.py @@ -27,7 +27,7 @@ from pyspark import RDD, SparkContext from pyspark.serializers import PickleSerializer, AutoBatchedSerializer - +from pyspark.sql import DataFrame, SQLContext # Hack for support float('inf') in Py4j _old_smart_decode = py4j.protocol.smart_decode @@ -99,6 +99,9 @@ def _java2py(sc, r, encoding="bytes"): jrdd = sc._jvm.SerDe.javaToPython(r) return RDD(jrdd, sc) + if clsName == 'DataFrame': + return DataFrame(r, SQLContext(sc)) + if clsName in _picklable_classes: r = sc._jvm.SerDe.dumps(r) elif isinstance(r, (JavaArray, JavaList)): diff --git a/python/pyspark/mllib/evaluation.py b/python/pyspark/mllib/evaluation.py index 4c777f2180dc..c5cf3a4e7ff2 100644 --- a/python/pyspark/mllib/evaluation.py +++ b/python/pyspark/mllib/evaluation.py @@ -27,6 +27,8 @@ class BinaryClassificationMetrics(JavaModelWrapper): """ Evaluator for binary classification. + :param scoreAndLabels: an RDD of (score, label) pairs + >>> scoreAndLabels = sc.parallelize([ ... (0.1, 0.0), (0.1, 1.0), (0.4, 0.0), (0.6, 0.0), (0.6, 1.0), (0.6, 1.0), (0.8, 1.0)], 2) >>> metrics = BinaryClassificationMetrics(scoreAndLabels) @@ -38,9 +40,6 @@ class BinaryClassificationMetrics(JavaModelWrapper): """ def __init__(self, scoreAndLabels): - """ - :param scoreAndLabels: an RDD of (score, label) pairs - """ sc = scoreAndLabels.ctx sql_ctx = SQLContext(sc) df = sql_ctx.createDataFrame(scoreAndLabels, schema=StructType([ @@ -76,6 +75,9 @@ class RegressionMetrics(JavaModelWrapper): """ Evaluator for regression. + :param predictionAndObservations: an RDD of (prediction, + observation) pairs. + >>> predictionAndObservations = sc.parallelize([ ... (2.5, 3.0), (0.0, -0.5), (2.0, 2.0), (8.0, 7.0)]) >>> metrics = RegressionMetrics(predictionAndObservations) @@ -92,9 +94,6 @@ class RegressionMetrics(JavaModelWrapper): """ def __init__(self, predictionAndObservations): - """ - :param predictionAndObservations: an RDD of (prediction, observation) pairs. - """ sc = predictionAndObservations.ctx sql_ctx = SQLContext(sc) df = sql_ctx.createDataFrame(predictionAndObservations, schema=StructType([ @@ -148,6 +147,8 @@ class MulticlassMetrics(JavaModelWrapper): """ Evaluator for multiclass classification. + :param predictionAndLabels an RDD of (prediction, label) pairs. + >>> predictionAndLabels = sc.parallelize([(0.0, 0.0), (0.0, 1.0), (0.0, 0.0), ... (1.0, 0.0), (1.0, 1.0), (1.0, 1.0), (1.0, 1.0), (2.0, 2.0), (2.0, 0.0)]) >>> metrics = MulticlassMetrics(predictionAndLabels) @@ -176,9 +177,6 @@ class MulticlassMetrics(JavaModelWrapper): """ def __init__(self, predictionAndLabels): - """ - :param predictionAndLabels an RDD of (prediction, label) pairs. - """ sc = predictionAndLabels.ctx sql_ctx = SQLContext(sc) df = sql_ctx.createDataFrame(predictionAndLabels, schema=StructType([ @@ -277,6 +275,9 @@ class RankingMetrics(JavaModelWrapper): """ Evaluator for ranking algorithms. + :param predictionAndLabels: an RDD of (predicted ranking, + ground truth set) pairs. + >>> predictionAndLabels = sc.parallelize([ ... ([1, 6, 2, 7, 8, 3, 9, 10, 4, 5], [1, 2, 3, 4, 5]), ... ([4, 1, 5, 6, 2, 7, 3, 8, 9, 10], [1, 2, 3]), @@ -298,9 +299,6 @@ class RankingMetrics(JavaModelWrapper): """ def __init__(self, predictionAndLabels): - """ - :param predictionAndLabels: an RDD of (predicted ranking, ground truth set) pairs. - """ sc = predictionAndLabels.ctx sql_ctx = SQLContext(sc) df = sql_ctx.createDataFrame(predictionAndLabels, @@ -334,16 +332,136 @@ def ndcgAt(self, k): """ Compute the average NDCG value of all the queries, truncated at ranking position k. The discounted cumulative gain at position k is computed as: - sum,,i=1,,^k^ (2^{relevance of ''i''th item}^ - 1) / log(i + 1), + sum,,i=1,,^k^ (2^{relevance of ''i''th item}^ - 1) / log(i + 1), and the NDCG is obtained by dividing the DCG value on the ground truth set. In the current implementation, the relevance value is binary. - - If a query has an empty ground truth set, zero will be used as ndcg together with + If a query has an empty ground truth set, zero will be used as NDCG together with a log warning. """ return self.call("ndcgAt", int(k)) +class MultilabelMetrics(JavaModelWrapper): + """ + Evaluator for multilabel classification. + + :param predictionAndLabels: an RDD of (predictions, labels) pairs, + both are non-null Arrays, each with + unique elements. + + >>> predictionAndLabels = sc.parallelize([([0.0, 1.0], [0.0, 2.0]), ([0.0, 2.0], [0.0, 1.0]), + ... ([], [0.0]), ([2.0], [2.0]), ([2.0, 0.0], [2.0, 0.0]), + ... ([0.0, 1.0, 2.0], [0.0, 1.0]), ([1.0], [1.0, 2.0])]) + >>> metrics = MultilabelMetrics(predictionAndLabels) + >>> metrics.precision(0.0) + 1.0 + >>> metrics.recall(1.0) + 0.66... + >>> metrics.f1Measure(2.0) + 0.5 + >>> metrics.precision() + 0.66... + >>> metrics.recall() + 0.64... + >>> metrics.f1Measure() + 0.63... + >>> metrics.microPrecision + 0.72... + >>> metrics.microRecall + 0.66... + >>> metrics.microF1Measure + 0.69... + >>> metrics.hammingLoss + 0.33... + >>> metrics.subsetAccuracy + 0.28... + >>> metrics.accuracy + 0.54... + """ + + def __init__(self, predictionAndLabels): + sc = predictionAndLabels.ctx + sql_ctx = SQLContext(sc) + df = sql_ctx.createDataFrame(predictionAndLabels, + schema=sql_ctx._inferSchema(predictionAndLabels)) + java_class = sc._jvm.org.apache.spark.mllib.evaluation.MultilabelMetrics + java_model = java_class(df._jdf) + super(MultilabelMetrics, self).__init__(java_model) + + def precision(self, label=None): + """ + Returns precision or precision for a given label (category) if specified. + """ + if label is None: + return self.call("precision") + else: + return self.call("precision", float(label)) + + def recall(self, label=None): + """ + Returns recall or recall for a given label (category) if specified. + """ + if label is None: + return self.call("recall") + else: + return self.call("recall", float(label)) + + def f1Measure(self, label=None): + """ + Returns f1Measure or f1Measure for a given label (category) if specified. + """ + if label is None: + return self.call("f1Measure") + else: + return self.call("f1Measure", float(label)) + + @property + def microPrecision(self): + """ + Returns micro-averaged label-based precision. + (equals to micro-averaged document-based precision) + """ + return self.call("microPrecision") + + @property + def microRecall(self): + """ + Returns micro-averaged label-based recall. + (equals to micro-averaged document-based recall) + """ + return self.call("microRecall") + + @property + def microF1Measure(self): + """ + Returns micro-averaged label-based f1-measure. + (equals to micro-averaged document-based f1-measure) + """ + return self.call("microF1Measure") + + @property + def hammingLoss(self): + """ + Returns Hamming-loss. + """ + return self.call("hammingLoss") + + @property + def subsetAccuracy(self): + """ + Returns subset accuracy. + (for equal sets of labels) + """ + return self.call("subsetAccuracy") + + @property + def accuracy(self): + """ + Returns accuracy. + """ + return self.call("accuracy") + + def _test(): import doctest from pyspark import SparkContext diff --git a/python/pyspark/mllib/feature.py b/python/pyspark/mllib/feature.py index aac305db6c19..f921e3ad1a31 100644 --- a/python/pyspark/mllib/feature.py +++ b/python/pyspark/mllib/feature.py @@ -33,12 +33,14 @@ from pyspark import SparkContext from pyspark.rdd import RDD, ignore_unicode_prefix from pyspark.mllib.common import callMLlibFunc, JavaModelWrapper -from pyspark.mllib.linalg import Vectors, DenseVector, SparseVector, _convert_to_vector +from pyspark.mllib.linalg import ( + Vector, Vectors, DenseVector, SparseVector, _convert_to_vector) from pyspark.mllib.regression import LabeledPoint +from pyspark.mllib.util import JavaLoader, JavaSaveable __all__ = ['Normalizer', 'StandardScalerModel', 'StandardScaler', 'HashingTF', 'IDFModel', 'IDF', 'Word2Vec', 'Word2VecModel', - 'ChiSqSelector', 'ChiSqSelectorModel'] + 'ChiSqSelector', 'ChiSqSelectorModel', 'ElementwiseProduct'] class VectorTransformer(object): @@ -68,6 +70,8 @@ class Normalizer(VectorTransformer): For `p` = float('inf'), max(abs(vector)) will be used as norm for normalization. + :param p: Normalization in L^p^ space, p = 2 by default. + >>> v = Vectors.dense(range(3)) >>> nor = Normalizer(1) >>> nor.transform(v) @@ -82,9 +86,6 @@ class Normalizer(VectorTransformer): DenseVector([0.0, 0.5, 1.0]) """ def __init__(self, p=2.0): - """ - :param p: Normalization in L^p^ space, p = 2 by default. - """ assert p >= 1.0, "p should be greater than 1.0" self.p = float(p) @@ -94,7 +95,7 @@ def transform(self, vector): :param vector: vector or RDD of vector to be normalized. :return: normalized vector. If the norm of the input is zero, it - will return the input vector. + will return the input vector. """ sc = SparkContext._active_spark_context assert sc is not None, "SparkContext should be initialized first" @@ -111,6 +112,15 @@ class JavaVectorTransformer(JavaModelWrapper, VectorTransformer): """ def transform(self, vector): + """ + Applies transformation on a vector or an RDD[Vector]. + + Note: In Python, transform cannot currently be used within + an RDD transformation or action. + Call transform directly on the RDD instead. + + :param vector: Vector or RDD of Vector to be transformed. + """ if isinstance(vector, RDD): vector = vector.map(_convert_to_vector) else: @@ -164,6 +174,13 @@ class StandardScaler(object): variance using column summary statistics on the samples in the training set. + :param withMean: False by default. Centers the data with mean + before scaling. It will build a dense output, so this + does not work on sparse input and will raise an + exception. + :param withStd: True by default. Scales the data to unit + standard deviation. + >>> vs = [Vectors.dense([-2.0, 2.3, 0]), Vectors.dense([3.8, 0.0, 1.9])] >>> dataset = sc.parallelize(vs) >>> standardizer = StandardScaler(True, True) @@ -174,14 +191,6 @@ class StandardScaler(object): DenseVector([0.7071, -0.7071, 0.7071]) """ def __init__(self, withMean=False, withStd=True): - """ - :param withMean: False by default. Centers the data with mean - before scaling. It will build a dense output, so this - does not work on sparse input and will raise an - exception. - :param withStd: True by default. Scales the data to unit - standard deviation. - """ if not (withMean or withStd): warnings.warn("Both withMean and withStd are false. The model does nothing.") self.withMean = withMean @@ -192,8 +201,8 @@ def fit(self, dataset): Computes the mean and variance and stores as a model to be used for later scaling. - :param data: The data used to compute the mean and variance - to build the transformation model. + :param dataset: The data used to compute the mean and variance + to build the transformation model. :return: a StandardScalarModel """ dataset = dataset.map(_convert_to_vector) @@ -223,6 +232,8 @@ class ChiSqSelector(object): Creates a ChiSquared feature selector. + :param numTopFeatures: number of features that selector will select. + >>> data = [ ... LabeledPoint(0.0, SparseVector(3, {0: 8.0, 1: 7.0})), ... LabeledPoint(1.0, SparseVector(3, {1: 9.0, 2: 6.0})), @@ -236,9 +247,6 @@ class ChiSqSelector(object): DenseVector([5.0]) """ def __init__(self, numTopFeatures): - """ - :param numTopFeatures: number of features that selector will select. - """ self.numTopFeatures = int(numTopFeatures) def fit(self, data): @@ -246,14 +254,49 @@ def fit(self, data): Returns a ChiSquared feature selector. :param data: an `RDD[LabeledPoint]` containing the labeled dataset - with categorical features. Real-valued features will be - treated as categorical for each distinct value. - Apply feature discretizer before using this function. + with categorical features. Real-valued features will be + treated as categorical for each distinct value. + Apply feature discretizer before using this function. """ jmodel = callMLlibFunc("fitChiSqSelector", self.numTopFeatures, data) return ChiSqSelectorModel(jmodel) +class PCAModel(JavaVectorTransformer): + """ + Model fitted by [[PCA]] that can project vectors to a low-dimensional space using PCA. + """ + + +class PCA(object): + """ + A feature transformer that projects vectors to a low-dimensional space using PCA. + + >>> data = [Vectors.sparse(5, [(1, 1.0), (3, 7.0)]), + ... Vectors.dense([2.0, 0.0, 3.0, 4.0, 5.0]), + ... Vectors.dense([4.0, 0.0, 0.0, 6.0, 7.0])] + >>> model = PCA(2).fit(sc.parallelize(data)) + >>> pcArray = model.transform(Vectors.sparse(5, [(1, 1.0), (3, 7.0)])).toArray() + >>> pcArray[0] + 1.648... + >>> pcArray[1] + -4.013... + """ + def __init__(self, k): + """ + :param k: number of principal components. + """ + self.k = int(k) + + def fit(self, data): + """ + Computes a [[PCAModel]] that contains the principal components of the input vectors. + :param data: source vectors + """ + jmodel = callMLlibFunc("fitPCA", self.k, data) + return PCAModel(jmodel) + + class HashingTF(object): """ .. note:: Experimental @@ -263,15 +306,14 @@ class HashingTF(object): Note: the terms must be hashable (can not be dict/set/list...). + :param numFeatures: number of features (default: 2^20) + >>> htf = HashingTF(100) >>> doc = "a a b b c d".split(" ") >>> htf.transform(doc) SparseVector(100, {...}) """ def __init__(self, numFeatures=1 << 20): - """ - :param numFeatures: number of features (default: 2^20) - """ self.numFeatures = numFeatures def indexOf(self, term): @@ -311,13 +353,9 @@ def transform(self, x): Call transform directly on the RDD instead. :param x: an RDD of term frequency vectors or a term frequency - vector + vector :return: an RDD of TF-IDF vectors or a TF-IDF vector """ - if isinstance(x, RDD): - return JavaVectorTransformer.transform(self, x) - - x = _convert_to_vector(x) return JavaVectorTransformer.transform(self, x) def idf(self): @@ -342,6 +380,9 @@ class IDF(object): `minDocFreq`). For terms that are not in at least `minDocFreq` documents, the IDF is found as 0, resulting in TF-IDFs of 0. + :param minDocFreq: minimum of documents in which a term + should appear for filtering + >>> n = 4 >>> freqs = [Vectors.sparse(n, (1, 3), (1.0, 2.0)), ... Vectors.dense([0.0, 1.0, 2.0, 3.0]), @@ -362,10 +403,6 @@ class IDF(object): SparseVector(4, {1: 0.0, 3: 0.5754}) """ def __init__(self, minDocFreq=0): - """ - :param minDocFreq: minimum of documents in which a term - should appear for filtering - """ self.minDocFreq = minDocFreq def fit(self, dataset): @@ -380,7 +417,7 @@ def fit(self, dataset): return IDFModel(jmodel) -class Word2VecModel(JavaVectorTransformer): +class Word2VecModel(JavaVectorTransformer, JavaSaveable, JavaLoader): """ class for Word2Vec model """ @@ -419,6 +456,12 @@ def getVectors(self): """ return self.call("getVectors") + @classmethod + def load(cls, sc, path): + jmodel = sc._jvm.org.apache.spark.mllib.feature \ + .Word2VecModel.load(sc._jsc.sc(), path) + return Word2VecModel(jmodel) + @ignore_unicode_prefix class Word2Vec(object): @@ -452,6 +495,18 @@ class Word2Vec(object): >>> syms = model.findSynonyms(vec, 2) >>> [s[0] for s in syms] [u'b', u'c'] + + >>> import os, tempfile + >>> path = tempfile.mkdtemp() + >>> model.save(sc, path) + >>> sameModel = Word2VecModel.load(sc, path) + >>> model.transform("a") == sameModel.transform("a") + True + >>> from shutil import rmtree + >>> try: + ... rmtree(path) + ... except OSError: + ... pass """ def __init__(self): """ @@ -518,13 +573,45 @@ def fit(self, data): """ if not isinstance(data, RDD): raise TypeError("data should be an RDD of list of string") - jmodel = callMLlibFunc("trainWord2Vec", data, int(self.vectorSize), + jmodel = callMLlibFunc("trainWord2VecModel", data, int(self.vectorSize), float(self.learningRate), int(self.numPartitions), int(self.numIterations), int(self.seed), int(self.minCount)) return Word2VecModel(jmodel) +class ElementwiseProduct(VectorTransformer): + """ + .. note:: Experimental + + Scales each column of the vector, with the supplied weight vector. + i.e the elementwise product. + + >>> weight = Vectors.dense([1.0, 2.0, 3.0]) + >>> eprod = ElementwiseProduct(weight) + >>> a = Vectors.dense([2.0, 1.0, 3.0]) + >>> eprod.transform(a) + DenseVector([2.0, 2.0, 9.0]) + >>> b = Vectors.dense([9.0, 3.0, 4.0]) + >>> rdd = sc.parallelize([a, b]) + >>> eprod.transform(rdd).collect() + [DenseVector([2.0, 2.0, 9.0]), DenseVector([9.0, 6.0, 12.0])] + """ + def __init__(self, scalingVector): + self.scalingVector = _convert_to_vector(scalingVector) + + def transform(self, vector): + """ + Computes the Hadamard product of the vector. + """ + if isinstance(vector, RDD): + vector = vector.map(_convert_to_vector) + + else: + vector = _convert_to_vector(vector) + return callMLlibFunc("elementwiseProductVector", self.scalingVector, vector) + + def _test(): import doctest from pyspark import SparkContext diff --git a/python/pyspark/mllib/fpm.py b/python/pyspark/mllib/fpm.py index d8df02bdbaba..bdc4a132b1b1 100644 --- a/python/pyspark/mllib/fpm.py +++ b/python/pyspark/mllib/fpm.py @@ -61,12 +61,12 @@ class FPGrowth(object): def train(cls, data, minSupport=0.3, numPartitions=-1): """ Computes an FP-Growth model that contains frequent itemsets. - :param data: The input data set, each element - contains a transaction. - :param minSupport: The minimal support level - (default: `0.3`). - :param numPartitions: The number of partitions used by parallel - FP-growth (default: same as input data). + + :param data: The input data set, each element contains a + transaction. + :param minSupport: The minimal support level (default: `0.3`). + :param numPartitions: The number of partitions used by + parallel FP-growth (default: same as input data). """ model = callMLlibFunc("trainFPGrowthModel", data, float(minSupport), int(numPartitions)) return FPGrowthModel(model) diff --git a/python/pyspark/mllib/linalg.py b/python/pyspark/mllib/linalg.py index 23d1a79ffe51..9959a01cce7e 100644 --- a/python/pyspark/mllib/linalg.py +++ b/python/pyspark/mllib/linalg.py @@ -36,7 +36,7 @@ import numpy as np from pyspark.sql.types import UserDefinedType, StructField, StructType, ArrayType, DoubleType, \ - IntegerType, ByteType + IntegerType, ByteType, BooleanType __all__ = ['Vector', 'DenseVector', 'SparseVector', 'Vectors', @@ -163,6 +163,59 @@ def simpleString(self): return "vector" +class MatrixUDT(UserDefinedType): + """ + SQL user-defined type (UDT) for Matrix. + """ + + @classmethod + def sqlType(cls): + return StructType([ + StructField("type", ByteType(), False), + StructField("numRows", IntegerType(), False), + StructField("numCols", IntegerType(), False), + StructField("colPtrs", ArrayType(IntegerType(), False), True), + StructField("rowIndices", ArrayType(IntegerType(), False), True), + StructField("values", ArrayType(DoubleType(), False), True), + StructField("isTransposed", BooleanType(), False)]) + + @classmethod + def module(cls): + return "pyspark.mllib.linalg" + + @classmethod + def scalaUDT(cls): + return "org.apache.spark.mllib.linalg.MatrixUDT" + + def serialize(self, obj): + if isinstance(obj, SparseMatrix): + colPtrs = [int(i) for i in obj.colPtrs] + rowIndices = [int(i) for i in obj.rowIndices] + values = [float(v) for v in obj.values] + return (0, obj.numRows, obj.numCols, colPtrs, + rowIndices, values, bool(obj.isTransposed)) + elif isinstance(obj, DenseMatrix): + values = [float(v) for v in obj.values] + return (1, obj.numRows, obj.numCols, None, None, values, + bool(obj.isTransposed)) + else: + raise TypeError("cannot serialize type %r" % (type(obj))) + + def deserialize(self, datum): + assert len(datum) == 7, \ + "MatrixUDT.deserialize given row with length %d but requires 7" % len(datum) + tpe = datum[0] + if tpe == 0: + return SparseMatrix(*datum[1:]) + elif tpe == 1: + return DenseMatrix(datum[1], datum[2], datum[5], datum[6]) + else: + raise ValueError("do not recognize type %r" % tpe) + + def simpleString(self): + return "matrix" + + class Vector(object): __UDT__ = VectorUDT() @@ -524,22 +577,19 @@ def dot(self, other): ... AssertionError: dimension mismatch """ - if type(other) == np.ndarray: - if other.ndim == 2: - results = [self.dot(other[:, i]) for i in xrange(other.shape[1])] - return np.array(results) - elif other.ndim > 2: + + if isinstance(other, np.ndarray): + if other.ndim not in [2, 1]: raise ValueError("Cannot call dot with %d-dimensional array" % other.ndim) + assert len(self) == other.shape[0], "dimension mismatch" + return np.dot(self.values, other[self.indices]) assert len(self) == _vector_size(other), "dimension mismatch" - if type(other) in (np.ndarray, array.array, DenseVector): - result = 0.0 - for i in xrange(len(self.indices)): - result += self.values[i] * other[self.indices[i]] - return result + if isinstance(other, DenseVector): + return np.dot(other.array[self.indices], self.values) - elif type(other) is SparseVector: + elif isinstance(other, SparseVector): result = 0.0 i, j = 0, 0 while i < len(self.indices) and j < len(other.indices): @@ -582,22 +632,23 @@ def squared_distance(self, other): AssertionError: dimension mismatch """ assert len(self) == _vector_size(other), "dimension mismatch" - if type(other) in (list, array.array, DenseVector, np.array, np.ndarray): - if type(other) is np.array and other.ndim != 1: + + if isinstance(other, np.ndarray) or isinstance(other, DenseVector): + if isinstance(other, np.ndarray) and other.ndim != 1: raise Exception("Cannot call squared_distance with %d-dimensional array" % other.ndim) - result = 0.0 - j = 0 # index into our own array - for i in xrange(len(other)): - if j < len(self.indices) and self.indices[j] == i: - diff = self.values[j] - other[i] - result += diff * diff - j += 1 - else: - result += other[i] * other[i] + if isinstance(other, DenseVector): + other = other.array + sparse_ind = np.zeros(other.size, dtype=bool) + sparse_ind[self.indices] = True + dist = other[sparse_ind] - self.values + result = np.dot(dist, dist) + + other_ind = other[~sparse_ind] + result += np.dot(other_ind, other_ind) return result - elif type(other) is SparseVector: + elif isinstance(other, SparseVector): result = 0.0 i, j = 0, 0 while i < len(self.indices) and j < len(other.indices): @@ -781,10 +832,12 @@ def zeros(size): class Matrix(object): + + __UDT__ = MatrixUDT() + """ Represents a local matrix. """ - def __init__(self, numRows, numCols, isTransposed=False): self.numRows = numRows self.numCols = numCols diff --git a/python/pyspark/mllib/rand.py b/python/pyspark/mllib/random.py similarity index 100% rename from python/pyspark/mllib/rand.py rename to python/pyspark/mllib/random.py diff --git a/python/pyspark/mllib/recommendation.py b/python/pyspark/mllib/recommendation.py index 9c4647ddfdcf..506ca2151cce 100644 --- a/python/pyspark/mllib/recommendation.py +++ b/python/pyspark/mllib/recommendation.py @@ -106,8 +106,9 @@ class MatrixFactorizationModel(JavaModelWrapper, JavaSaveable, JavaLoader): 0.4... >>> sameModel.predictAll(testset).collect() [Rating(... + >>> from shutil import rmtree >>> try: - ... os.removedirs(path) + ... rmtree(path) ... except OSError: ... pass """ diff --git a/python/pyspark/mllib/regression.py b/python/pyspark/mllib/regression.py index 41bde2ce3e60..8e90adee5f4c 100644 --- a/python/pyspark/mllib/regression.py +++ b/python/pyspark/mllib/regression.py @@ -19,6 +19,7 @@ from numpy import array from pyspark import RDD +from pyspark.streaming.dstream import DStream from pyspark.mllib.common import callMLlibFunc, _py2java, _java2py, inherit_doc from pyspark.mllib.linalg import SparseVector, Vectors, _convert_to_vector from pyspark.mllib.util import Saveable, Loader @@ -33,12 +34,12 @@ class LabeledPoint(object): """ - The features and labels of a data point. + Class that represents the features and labels of a data point. :param label: Label for this data point. :param features: Vector of features for this point (NumPy array, - list, pyspark.mllib.linalg.SparseVector, or scipy.sparse - column matrix) + list, pyspark.mllib.linalg.SparseVector, or scipy.sparse + column matrix) Note: 'label' and 'features' are accessible as class attributes. """ @@ -59,7 +60,12 @@ def __repr__(self): class LinearModel(object): - """A linear model that has a vector of coefficients and an intercept.""" + """ + A linear model that has a vector of coefficients and an intercept. + + :param weights: Weights computed for every feature. + :param intercept: Intercept computed for this model. + """ def __init__(self, weights, intercept): self._coeff = _convert_to_vector(weights) @@ -128,10 +134,11 @@ class LinearRegressionModel(LinearRegressionModelBase): True >>> abs(sameModel.predict(SparseVector(1, {0: 1.0})) - 1) < 0.5 True + >>> from shutil import rmtree >>> try: - ... os.removedirs(path) + ... rmtree(path) ... except: - ... pass + ... pass >>> data = [ ... LabeledPoint(0.0, SparseVector(1, {0: 0.0})), ... LabeledPoint(1.0, SparseVector(1, {0: 1.0})), @@ -193,18 +200,28 @@ def train(cls, data, iterations=100, step=1.0, miniBatchFraction=1.0, initialWeights=None, regParam=0.0, regType=None, intercept=False, validateData=True): """ - Train a linear regression model on the given data. - - :param data: The training data. - :param iterations: The number of iterations (default: 100). + Train a linear regression model using Stochastic Gradient + Descent (SGD). + This solves the least squares regression formulation + f(weights) = 1/n ||A weights-y||^2^ + (which is the mean squared error). + Here the data matrix has n rows, and the input RDD holds the + set of rows of A, each with its corresponding right hand side + label y. See also the documentation for the precise formulation. + + :param data: The training data, an RDD of + LabeledPoint. + :param iterations: The number of iterations + (default: 100). :param step: The step parameter used in SGD (default: 1.0). - :param miniBatchFraction: Fraction of data to be used for each SGD - iteration. + :param miniBatchFraction: Fraction of data to be used for each + SGD iteration (default: 1.0). :param initialWeights: The initial weights (default: None). - :param regParam: The regularizer parameter (default: 0.0). - :param regType: The type of regularizer used for training - our model. + :param regParam: The regularizer parameter + (default: 0.0). + :param regType: The type of regularizer used for + training our model. :Allowed values: - "l1" for using L1 regularization (lasso), @@ -213,13 +230,14 @@ def train(cls, data, iterations=100, step=1.0, miniBatchFraction=1.0, (default: None) - :param intercept: Boolean parameter which indicates the use - or not of the augmented representation for - training data (i.e. whether bias features - are activated or not). (default: False) - :param validateData: Boolean parameter which indicates if the - algorithm should validate data before training. - (default: True) + :param intercept: Boolean parameter which indicates the + use or not of the augmented representation + for training data (i.e. whether bias + features are activated or not, + default: False). + :param validateData: Boolean parameter which indicates if + the algorithm should validate data + before training. (default: True) """ def train(rdd, i): return callMLlibFunc("trainLinearRegressionModelWithSGD", rdd, int(iterations), @@ -232,8 +250,8 @@ def train(rdd, i): @inherit_doc class LassoModel(LinearRegressionModelBase): - """A linear regression model derived from a least-squares fit with an - l_1 penalty term. + """A linear regression model derived from a least-squares fit with + an l_1 penalty term. >>> from pyspark.mllib.regression import LabeledPoint >>> data = [ @@ -259,8 +277,9 @@ class LassoModel(LinearRegressionModelBase): True >>> abs(sameModel.predict(SparseVector(1, {0: 1.0})) - 1) < 0.5 True + >>> from shutil import rmtree >>> try: - ... os.removedirs(path) + ... rmtree(path) ... except: ... pass >>> data = [ @@ -304,7 +323,36 @@ class LassoWithSGD(object): def train(cls, data, iterations=100, step=1.0, regParam=0.01, miniBatchFraction=1.0, initialWeights=None, intercept=False, validateData=True): - """Train a Lasso regression model on the given data.""" + """ + Train a regression model with L1-regularization using + Stochastic Gradient Descent. + This solves the l1-regularized least squares regression + formulation + f(weights) = 1/2n ||A weights-y||^2^ + regParam ||weights||_1 + Here the data matrix has n rows, and the input RDD holds the + set of rows of A, each with its corresponding right hand side + label y. See also the documentation for the precise formulation. + + :param data: The training data, an RDD of + LabeledPoint. + :param iterations: The number of iterations + (default: 100). + :param step: The step parameter used in SGD + (default: 1.0). + :param regParam: The regularizer parameter + (default: 0.01). + :param miniBatchFraction: Fraction of data to be used for each + SGD iteration (default: 1.0). + :param initialWeights: The initial weights (default: None). + :param intercept: Boolean parameter which indicates the + use or not of the augmented representation + for training data (i.e. whether bias + features are activated or not, + default: False). + :param validateData: Boolean parameter which indicates if + the algorithm should validate data + before training. (default: True) + """ def train(rdd, i): return callMLlibFunc("trainLassoModelWithSGD", rdd, int(iterations), float(step), float(regParam), float(miniBatchFraction), i, bool(intercept), @@ -316,8 +364,8 @@ def train(rdd, i): @inherit_doc class RidgeRegressionModel(LinearRegressionModelBase): - """A linear regression model derived from a least-squares fit with an - l_2 penalty term. + """A linear regression model derived from a least-squares fit with + an l_2 penalty term. >>> from pyspark.mllib.regression import LabeledPoint >>> data = [ @@ -344,8 +392,9 @@ class RidgeRegressionModel(LinearRegressionModelBase): True >>> abs(sameModel.predict(SparseVector(1, {0: 1.0})) - 1) < 0.5 True + >>> from shutil import rmtree >>> try: - ... os.removedirs(path) + ... rmtree(path) ... except: ... pass >>> data = [ @@ -389,7 +438,36 @@ class RidgeRegressionWithSGD(object): def train(cls, data, iterations=100, step=1.0, regParam=0.01, miniBatchFraction=1.0, initialWeights=None, intercept=False, validateData=True): - """Train a ridge regression model on the given data.""" + """ + Train a regression model with L2-regularization using + Stochastic Gradient Descent. + This solves the l2-regularized least squares regression + formulation + f(weights) = 1/2n ||A weights-y||^2^ + regParam/2 ||weights||^2^ + Here the data matrix has n rows, and the input RDD holds the + set of rows of A, each with its corresponding right hand side + label y. See also the documentation for the precise formulation. + + :param data: The training data, an RDD of + LabeledPoint. + :param iterations: The number of iterations + (default: 100). + :param step: The step parameter used in SGD + (default: 1.0). + :param regParam: The regularizer parameter + (default: 0.01). + :param miniBatchFraction: Fraction of data to be used for each + SGD iteration (default: 1.0). + :param initialWeights: The initial weights (default: None). + :param intercept: Boolean parameter which indicates the + use or not of the augmented representation + for training data (i.e. whether bias + features are activated or not, + default: False). + :param validateData: Boolean parameter which indicates if + the algorithm should validate data + before training. (default: True) + """ def train(rdd, i): return callMLlibFunc("trainRidgeModelWithSGD", rdd, int(iterations), float(step), float(regParam), float(miniBatchFraction), i, bool(intercept), @@ -400,7 +478,15 @@ def train(rdd, i): class IsotonicRegressionModel(Saveable, Loader): - """Regression model for isotonic regression. + """ + Regression model for isotonic regression. + + :param boundaries: Array of boundaries for which predictions are + known. Boundaries must be sorted in increasing order. + :param predictions: Array of predictions associated to the + boundaries at the same index. Results of isotonic + regression and therefore monotone. + :param isotonic: indicates whether this is isotonic or antitonic. >>> data = [(1, 0, 1), (2, 1, 1), (3, 2, 1), (1, 3, 1), (6, 4, 1), (17, 5, 1), (16, 6, 1)] >>> irm = IsotonicRegression.train(sc.parallelize(data)) @@ -418,8 +504,9 @@ class IsotonicRegressionModel(Saveable, Loader): 2.0 >>> sameModel.predict(5) 16.5 + >>> from shutil import rmtree >>> try: - ... os.removedirs(path) + ... rmtree(path) ... except OSError: ... pass """ @@ -430,6 +517,25 @@ def __init__(self, boundaries, predictions, isotonic): self.isotonic = isotonic def predict(self, x): + """ + Predict labels for provided features. + Using a piecewise linear function. + 1) If x exactly matches a boundary then associated prediction + is returned. In case there are multiple predictions with the + same boundary then one of them is returned. Which one is + undefined (same as java.util.Arrays.binarySearch). + 2) If x is lower or higher than all boundaries then first or + last prediction is returned respectively. In case there are + multiple predictions with the same boundary then the lowest + or highest is returned respectively. + 3) If x falls between two values in boundary array then + prediction is treated as piecewise linear function and + interpolated value is returned. In case there are multiple + values with the same boundary then the same rules as in 2) + are used. + + :param x: Feature or RDD of Features to be labeled. + """ if isinstance(x, RDD): return x.map(lambda v: self.predict(v)) return np.interp(x, self.boundaries, self.predictions) @@ -451,20 +557,109 @@ def load(cls, sc, path): class IsotonicRegression(object): - """ - Run IsotonicRegression algorithm to obtain isotonic regression model. - :param data: RDD of (label, feature, weight) tuples. - :param isotonic: Whether this is isotonic or antitonic. - """ @classmethod def train(cls, data, isotonic=True): - """Train a isotonic regression model on the given data.""" + """ + Train a isotonic regression model on the given data. + + :param data: RDD of (label, feature, weight) tuples. + :param isotonic: Whether this is isotonic or antitonic. + """ boundaries, predictions = callMLlibFunc("trainIsotonicRegressionModel", data.map(_convert_to_vector), bool(isotonic)) return IsotonicRegressionModel(boundaries.toArray(), predictions.toArray(), isotonic) +class StreamingLinearAlgorithm(object): + """ + Base class that has to be inherited by any StreamingLinearAlgorithm. + + Prevents reimplementation of methods predictOn and predictOnValues. + """ + def __init__(self, model): + self._model = model + + def latestModel(self): + """ + Returns the latest model. + """ + return self._model + + def _validate(self, dstream): + if not isinstance(dstream, DStream): + raise TypeError( + "dstream should be a DStream object, got %s" % type(dstream)) + if not self._model: + raise ValueError( + "Model must be intialized using setInitialWeights") + + def predictOn(self, dstream): + """ + Make predictions on a dstream. + + :return: Transformed dstream object. + """ + self._validate(dstream) + return dstream.map(lambda x: self._model.predict(x)) + + def predictOnValues(self, dstream): + """ + Make predictions on a keyed dstream. + + :return: Transformed dstream object. + """ + self._validate(dstream) + return dstream.mapValues(lambda x: self._model.predict(x)) + + +@inherit_doc +class StreamingLinearRegressionWithSGD(StreamingLinearAlgorithm): + """ + Run LinearRegression with SGD on a batch of data. + + The problem minimized is (1 / n_samples) * (y - weights'X)**2. + After training on a batch of data, the weights obtained at the end of + training are used as initial weights for the next batch. + + :param: stepSize Step size for each iteration of gradient descent. + :param: numIterations Total number of iterations run. + :param: miniBatchFraction Fraction of data on which SGD is run for each + iteration. + """ + def __init__(self, stepSize=0.1, numIterations=50, miniBatchFraction=1.0): + self.stepSize = stepSize + self.numIterations = numIterations + self.miniBatchFraction = miniBatchFraction + self._model = None + super(StreamingLinearRegressionWithSGD, self).__init__( + model=self._model) + + def setInitialWeights(self, initialWeights): + """ + Set the initial value of weights. + + This must be set before running trainOn and predictOn + """ + initialWeights = _convert_to_vector(initialWeights) + self._model = LinearRegressionModel(initialWeights, 0) + return self + + def trainOn(self, dstream): + """Train the model on the incoming dstream.""" + self._validate(dstream) + + def update(rdd): + # LinearRegressionWithSGD.train raises an error for an empty RDD. + if not rdd.isEmpty(): + self._model = LinearRegressionWithSGD.train( + rdd, self.numIterations, self.stepSize, + self.miniBatchFraction, self._model.weights, + self._model.intercept) + + dstream.foreachRDD(update) + + def _test(): import doctest from pyspark import SparkContext diff --git a/python/pyspark/mllib/stat/KernelDensity.py b/python/pyspark/mllib/stat/KernelDensity.py new file mode 100644 index 000000000000..7da921976d4d --- /dev/null +++ b/python/pyspark/mllib/stat/KernelDensity.py @@ -0,0 +1,61 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +import sys + +if sys.version > '3': + xrange = range + +import numpy as np + +from pyspark.mllib.common import callMLlibFunc +from pyspark.rdd import RDD + + +class KernelDensity(object): + """ + .. note:: Experimental + + Estimate probability density at required points given a RDD of samples + from the population. + + >>> kd = KernelDensity() + >>> sample = sc.parallelize([0.0, 1.0]) + >>> kd.setSample(sample) + >>> kd.estimate([0.0, 1.0]) + array([ 0.12938758, 0.12938758]) + """ + def __init__(self): + self._bandwidth = 1.0 + self._sample = None + + def setBandwidth(self, bandwidth): + """Set bandwidth of each sample. Defaults to 1.0""" + self._bandwidth = bandwidth + + def setSample(self, sample): + """Set sample points from the population. Should be a RDD""" + if not isinstance(sample, RDD): + raise TypeError("samples should be a RDD, received %s" % type(sample)) + self._sample = sample + + def estimate(self, points): + """Estimate the probability density at points""" + points = list(points) + densities = callMLlibFunc( + "estimateKernelDensity", self._sample, self._bandwidth, points) + return np.asarray(densities) diff --git a/python/pyspark/mllib/stat/__init__.py b/python/pyspark/mllib/stat/__init__.py index e3e128513e0d..c8a721d3fe41 100644 --- a/python/pyspark/mllib/stat/__init__.py +++ b/python/pyspark/mllib/stat/__init__.py @@ -22,6 +22,7 @@ from pyspark.mllib.stat._statistics import * from pyspark.mllib.stat.distribution import MultivariateGaussian from pyspark.mllib.stat.test import ChiSqTestResult +from pyspark.mllib.stat.KernelDensity import KernelDensity __all__ = ["Statistics", "MultivariateStatisticalSummary", "ChiSqTestResult", - "MultivariateGaussian"] + "MultivariateGaussian", "KernelDensity"] diff --git a/python/pyspark/mllib/tests.py b/python/pyspark/mllib/tests.py index 36a4c7a5408c..d9f9874d50c1 100644 --- a/python/pyspark/mllib/tests.py +++ b/python/pyspark/mllib/tests.py @@ -23,8 +23,13 @@ import sys import tempfile import array as pyarray +from time import time, sleep +from shutil import rmtree + +from numpy import ( + array, array_equal, zeros, inf, random, exp, dot, all, mean, abs) +from numpy import sum as array_sum -from numpy import array, array_equal, zeros, inf from py4j.protocol import Py4JJavaError if sys.version_info[:2] <= (2, 6): @@ -38,16 +43,22 @@ from pyspark import SparkContext from pyspark.mllib.common import _to_java_object_rdd +from pyspark.mllib.clustering import StreamingKMeans, StreamingKMeansModel from pyspark.mllib.linalg import Vector, SparseVector, DenseVector, VectorUDT, _convert_to_vector,\ - DenseMatrix, SparseMatrix, Vectors, Matrices -from pyspark.mllib.regression import LabeledPoint + DenseMatrix, SparseMatrix, Vectors, Matrices, MatrixUDT +from pyspark.mllib.classification import StreamingLogisticRegressionWithSGD +from pyspark.mllib.regression import LabeledPoint, StreamingLinearRegressionWithSGD from pyspark.mllib.random import RandomRDDs from pyspark.mllib.stat import Statistics from pyspark.mllib.feature import Word2Vec from pyspark.mllib.feature import IDF -from pyspark.mllib.feature import StandardScaler +from pyspark.mllib.feature import StandardScaler, ElementwiseProduct +from pyspark.mllib.util import LinearDataGenerator +from pyspark.mllib.util import MLUtils from pyspark.serializers import PickleSerializer +from pyspark.streaming import StreamingContext from pyspark.sql import SQLContext +from pyspark.streaming import StreamingContext _have_scipy = False try: @@ -66,6 +77,20 @@ def setUp(self): self.sc = sc +class MLLibStreamingTestCase(unittest.TestCase): + def setUp(self): + self.sc = sc + self.ssc = StreamingContext(self.sc, 1.0) + + def tearDown(self): + self.ssc.stop(False) + + @staticmethod + def _ssc_wait(start_time, end_time, sleep_time): + while time() - start_time < end_time: + sleep(0.01) + + def _squared_distance(a, b): if isinstance(a, Vector): return a.squared_distance(b) @@ -104,17 +129,22 @@ def test_dot(self): [1., 2., 3., 4.], [1., 2., 3., 4.], [1., 2., 3., 4.]]) + arr = pyarray.array('d', [0, 1, 2, 3]) self.assertEquals(10.0, sv.dot(dv)) self.assertTrue(array_equal(array([3., 6., 9., 12.]), sv.dot(mat))) self.assertEquals(30.0, dv.dot(dv)) self.assertTrue(array_equal(array([10., 20., 30., 40.]), dv.dot(mat))) self.assertEquals(30.0, lst.dot(dv)) self.assertTrue(array_equal(array([10., 20., 30., 40.]), lst.dot(mat))) + self.assertEquals(7.0, sv.dot(arr)) def test_squared_distance(self): sv = SparseVector(4, {1: 1, 3: 2}) dv = DenseVector(array([1., 2., 3., 4.])) lst = DenseVector([4, 3, 2, 1]) + lst1 = [4, 3, 2, 1] + arr = pyarray.array('d', [0, 2, 1, 3]) + narr = array([0, 2, 1, 3]) self.assertEquals(15.0, _squared_distance(sv, dv)) self.assertEquals(25.0, _squared_distance(sv, lst)) self.assertEquals(20.0, _squared_distance(dv, lst)) @@ -124,6 +154,9 @@ def test_squared_distance(self): self.assertEquals(0.0, _squared_distance(sv, sv)) self.assertEquals(0.0, _squared_distance(dv, dv)) self.assertEquals(0.0, _squared_distance(lst, lst)) + self.assertEquals(25.0, _squared_distance(sv, lst1)) + self.assertEquals(3.0, _squared_distance(sv, arr)) + self.assertEquals(3.0, _squared_distance(sv, narr)) def test_conversion(self): # numpy arrays should be automatically upcast to float64 @@ -379,7 +412,7 @@ def test_classification(self): self.assertEqual(same_gbt_model.toDebugString(), gbt_model.toDebugString()) try: - os.removedirs(temp_dir) + rmtree(temp_dir) except OSError: pass @@ -443,6 +476,13 @@ def test_regression(self): except ValueError: self.fail() + # Verify that maxBins is being passed through + GradientBoostedTrees.trainRegressor( + rdd, categoricalFeaturesInfo=categoricalFeaturesInfo, numIterations=4, maxBins=32) + with self.assertRaises(Exception) as cm: + GradientBoostedTrees.trainRegressor( + rdd, categoricalFeaturesInfo=categoricalFeaturesInfo, numIterations=4, maxBins=1) + class StatTests(MLlibTestCase): # SPARK-4023 @@ -507,6 +547,38 @@ def test_infer_schema(self): raise TypeError("expecting a vector but got %r of type %r" % (v, type(v))) +class MatrixUDTTests(MLlibTestCase): + + dm1 = DenseMatrix(3, 2, [0, 1, 4, 5, 9, 10]) + dm2 = DenseMatrix(3, 2, [0, 1, 4, 5, 9, 10], isTransposed=True) + sm1 = SparseMatrix(1, 1, [0, 1], [0], [2.0]) + sm2 = SparseMatrix(2, 1, [0, 0, 1], [0], [5.0], isTransposed=True) + udt = MatrixUDT() + + def test_json_schema(self): + self.assertEqual(MatrixUDT.fromJson(self.udt.jsonValue()), self.udt) + + def test_serialization(self): + for m in [self.dm1, self.dm2, self.sm1, self.sm2]: + self.assertEqual(m, self.udt.deserialize(self.udt.serialize(m))) + + def test_infer_schema(self): + sqlCtx = SQLContext(self.sc) + rdd = self.sc.parallelize([("dense", self.dm1), ("sparse", self.sm1)]) + df = rdd.toDF() + schema = df.schema + self.assertTrue(schema.fields[1].dataType, self.udt) + matrices = df.map(lambda x: x._2).collect() + self.assertEqual(len(matrices), 2) + for m in matrices: + if isinstance(m, DenseMatrix): + self.assertTrue(m, self.dm1) + elif isinstance(m, SparseMatrix): + self.assertTrue(m, self.sm1) + else: + raise ValueError("Expected a matrix but got type %r" % type(m)) + + @unittest.skipIf(not _have_scipy, "SciPy not installed") class SciPyTests(MLlibTestCase): @@ -818,6 +890,457 @@ def test_model_transform(self): self.assertEqual(model.transform([1.0, 2.0, 3.0]), DenseVector([1.0, 2.0, 3.0])) +class ElementwiseProductTests(MLlibTestCase): + def test_model_transform(self): + weight = Vectors.dense([3, 2, 1]) + + densevec = Vectors.dense([4, 5, 6]) + sparsevec = Vectors.sparse(3, [0], [1]) + eprod = ElementwiseProduct(weight) + self.assertEqual(eprod.transform(densevec), DenseVector([12, 10, 6])) + self.assertEqual( + eprod.transform(sparsevec), SparseVector(3, [0], [3])) + + +class StreamingKMeansTest(MLLibStreamingTestCase): + def test_model_params(self): + """Test that the model params are set correctly""" + stkm = StreamingKMeans() + stkm.setK(5).setDecayFactor(0.0) + self.assertEquals(stkm._k, 5) + self.assertEquals(stkm._decayFactor, 0.0) + + # Model not set yet. + self.assertIsNone(stkm.latestModel()) + self.assertRaises(ValueError, stkm.trainOn, [0.0, 1.0]) + + stkm.setInitialCenters( + centers=[[0.0, 0.0], [1.0, 1.0]], weights=[1.0, 1.0]) + self.assertEquals( + stkm.latestModel().centers, [[0.0, 0.0], [1.0, 1.0]]) + self.assertEquals(stkm.latestModel().clusterWeights, [1.0, 1.0]) + + def test_accuracy_for_single_center(self): + """Test that parameters obtained are correct for a single center.""" + centers, batches = self.streamingKMeansDataGenerator( + batches=5, numPoints=5, k=1, d=5, r=0.1, seed=0) + stkm = StreamingKMeans(1) + stkm.setInitialCenters([[0., 0., 0., 0., 0.]], [0.]) + input_stream = self.ssc.queueStream( + [self.sc.parallelize(batch, 1) for batch in batches]) + stkm.trainOn(input_stream) + + t = time() + self.ssc.start() + self._ssc_wait(t, 10.0, 0.01) + self.assertEquals(stkm.latestModel().clusterWeights, [25.0]) + realCenters = array_sum(array(centers), axis=0) + for i in range(5): + modelCenters = stkm.latestModel().centers[0][i] + self.assertAlmostEqual(centers[0][i], modelCenters, 1) + self.assertAlmostEqual(realCenters[i], modelCenters, 1) + + def streamingKMeansDataGenerator(self, batches, numPoints, + k, d, r, seed, centers=None): + rng = random.RandomState(seed) + + # Generate centers. + centers = [rng.randn(d) for i in range(k)] + + return centers, [[Vectors.dense(centers[j % k] + r * rng.randn(d)) + for j in range(numPoints)] + for i in range(batches)] + + def test_trainOn_model(self): + """Test the model on toy data with four clusters.""" + stkm = StreamingKMeans() + initCenters = [[1.0, 1.0], [-1.0, 1.0], [-1.0, -1.0], [1.0, -1.0]] + stkm.setInitialCenters( + centers=initCenters, weights=[1.0, 1.0, 1.0, 1.0]) + + # Create a toy dataset by setting a tiny offest for each point. + offsets = [[0, 0.1], [0, -0.1], [0.1, 0], [-0.1, 0]] + batches = [] + for offset in offsets: + batches.append([[offset[0] + center[0], offset[1] + center[1]] + for center in initCenters]) + + batches = [self.sc.parallelize(batch, 1) for batch in batches] + input_stream = self.ssc.queueStream(batches) + stkm.trainOn(input_stream) + t = time() + self.ssc.start() + + # Give enough time to train the model. + self._ssc_wait(t, 6.0, 0.01) + finalModel = stkm.latestModel() + self.assertTrue(all(finalModel.centers == array(initCenters))) + self.assertEquals(finalModel.clusterWeights, [5.0, 5.0, 5.0, 5.0]) + + def test_predictOn_model(self): + """Test that the model predicts correctly on toy data.""" + stkm = StreamingKMeans() + stkm._model = StreamingKMeansModel( + clusterCenters=[[1.0, 1.0], [-1.0, 1.0], [-1.0, -1.0], [1.0, -1.0]], + clusterWeights=[1.0, 1.0, 1.0, 1.0]) + + predict_data = [[[1.5, 1.5]], [[-1.5, 1.5]], [[-1.5, -1.5]], [[1.5, -1.5]]] + predict_data = [sc.parallelize(batch, 1) for batch in predict_data] + predict_stream = self.ssc.queueStream(predict_data) + predict_val = stkm.predictOn(predict_stream) + + result = [] + + def update(rdd): + rdd_collect = rdd.collect() + if rdd_collect: + result.append(rdd_collect) + + predict_val.foreachRDD(update) + t = time() + self.ssc.start() + self._ssc_wait(t, 6.0, 0.01) + self.assertEquals(result, [[0], [1], [2], [3]]) + + def test_trainOn_predictOn(self): + """Test that prediction happens on the updated model.""" + stkm = StreamingKMeans(decayFactor=0.0, k=2) + stkm.setInitialCenters([[0.0], [1.0]], [1.0, 1.0]) + + # Since decay factor is set to zero, once the first batch + # is passed the clusterCenters are updated to [-0.5, 0.7] + # which causes 0.2 & 0.3 to be classified as 1, even though the + # classification based in the initial model would have been 0 + # proving that the model is updated. + batches = [[[-0.5], [0.6], [0.8]], [[0.2], [-0.1], [0.3]]] + batches = [sc.parallelize(batch) for batch in batches] + input_stream = self.ssc.queueStream(batches) + predict_results = [] + + def collect(rdd): + rdd_collect = rdd.collect() + if rdd_collect: + predict_results.append(rdd_collect) + + stkm.trainOn(input_stream) + predict_stream = stkm.predictOn(input_stream) + predict_stream.foreachRDD(collect) + + t = time() + self.ssc.start() + self._ssc_wait(t, 6.0, 0.01) + self.assertEqual(predict_results, [[0, 1, 1], [1, 0, 1]]) + + +class LinearDataGeneratorTests(MLlibTestCase): + def test_dim(self): + linear_data = LinearDataGenerator.generateLinearInput( + intercept=0.0, weights=[0.0, 0.0, 0.0], + xMean=[0.0, 0.0, 0.0], xVariance=[0.33, 0.33, 0.33], + nPoints=4, seed=0, eps=0.1) + self.assertEqual(len(linear_data), 4) + for point in linear_data: + self.assertEqual(len(point.features), 3) + + linear_data = LinearDataGenerator.generateLinearRDD( + sc=sc, nexamples=6, nfeatures=2, eps=0.1, + nParts=2, intercept=0.0).collect() + self.assertEqual(len(linear_data), 6) + for point in linear_data: + self.assertEqual(len(point.features), 2) + + +class StreamingLogisticRegressionWithSGDTests(MLLibStreamingTestCase): + + @staticmethod + def generateLogisticInput(offset, scale, nPoints, seed): + """ + Generate 1 / (1 + exp(-x * scale + offset)) + + where, + x is randomnly distributed and the threshold + and labels for each sample in x is obtained from a random uniform + distribution. + """ + rng = random.RandomState(seed) + x = rng.randn(nPoints) + sigmoid = 1. / (1 + exp(-(dot(x, scale) + offset))) + y_p = rng.rand(nPoints) + cut_off = y_p <= sigmoid + y_p[cut_off] = 1.0 + y_p[~cut_off] = 0.0 + return [ + LabeledPoint(y_p[i], Vectors.dense([x[i]])) + for i in range(nPoints)] + + def test_parameter_accuracy(self): + """ + Test that the final value of weights is close to the desired value. + """ + input_batches = [ + self.sc.parallelize(self.generateLogisticInput(0, 1.5, 100, 42 + i)) + for i in range(20)] + input_stream = self.ssc.queueStream(input_batches) + + slr = StreamingLogisticRegressionWithSGD( + stepSize=0.2, numIterations=25) + slr.setInitialWeights([0.0]) + slr.trainOn(input_stream) + + t = time() + self.ssc.start() + self._ssc_wait(t, 20.0, 0.01) + rel = (1.5 - slr.latestModel().weights.array[0]) / 1.5 + self.assertAlmostEqual(rel, 0.1, 1) + + def test_convergence(self): + """ + Test that weights converge to the required value on toy data. + """ + input_batches = [ + self.sc.parallelize(self.generateLogisticInput(0, 1.5, 100, 42 + i)) + for i in range(20)] + input_stream = self.ssc.queueStream(input_batches) + models = [] + + slr = StreamingLogisticRegressionWithSGD( + stepSize=0.2, numIterations=25) + slr.setInitialWeights([0.0]) + slr.trainOn(input_stream) + input_stream.foreachRDD( + lambda x: models.append(slr.latestModel().weights[0])) + + t = time() + self.ssc.start() + self._ssc_wait(t, 15.0, 0.01) + t_models = array(models) + diff = t_models[1:] - t_models[:-1] + + # Test that weights improve with a small tolerance, + self.assertTrue(all(diff >= -0.1)) + self.assertTrue(array_sum(diff > 0) > 1) + + @staticmethod + def calculate_accuracy_error(true, predicted): + return sum(abs(array(true) - array(predicted))) / len(true) + + def test_predictions(self): + """Test predicted values on a toy model.""" + input_batches = [] + for i in range(20): + batch = self.sc.parallelize( + self.generateLogisticInput(0, 1.5, 100, 42 + i)) + input_batches.append(batch.map(lambda x: (x.label, x.features))) + input_stream = self.ssc.queueStream(input_batches) + + slr = StreamingLogisticRegressionWithSGD( + stepSize=0.2, numIterations=25) + slr.setInitialWeights([1.5]) + predict_stream = slr.predictOnValues(input_stream) + true_predicted = [] + predict_stream.foreachRDD(lambda x: true_predicted.append(x.collect())) + t = time() + self.ssc.start() + self._ssc_wait(t, 5.0, 0.01) + + # Test that the accuracy error is no more than 0.4 on each batch. + for batch in true_predicted: + true, predicted = zip(*batch) + self.assertTrue( + self.calculate_accuracy_error(true, predicted) < 0.4) + + def test_training_and_prediction(self): + """Test that the model improves on toy data with no. of batches""" + input_batches = [ + self.sc.parallelize(self.generateLogisticInput(0, 1.5, 100, 42 + i)) + for i in range(20)] + predict_batches = [ + b.map(lambda lp: (lp.label, lp.features)) for b in input_batches] + + slr = StreamingLogisticRegressionWithSGD( + stepSize=0.01, numIterations=25) + slr.setInitialWeights([-0.1]) + errors = [] + + def collect_errors(rdd): + true, predicted = zip(*rdd.collect()) + errors.append(self.calculate_accuracy_error(true, predicted)) + + true_predicted = [] + input_stream = self.ssc.queueStream(input_batches) + predict_stream = self.ssc.queueStream(predict_batches) + slr.trainOn(input_stream) + ps = slr.predictOnValues(predict_stream) + ps.foreachRDD(lambda x: collect_errors(x)) + + t = time() + self.ssc.start() + self._ssc_wait(t, 20.0, 0.01) + + # Test that the improvement in error is atleast 0.3 + self.assertTrue(errors[1] - errors[-1] > 0.3) + + +class StreamingLinearRegressionWithTests(MLLibStreamingTestCase): + + def assertArrayAlmostEqual(self, array1, array2, dec): + for i, j in array1, array2: + self.assertAlmostEqual(i, j, dec) + + def test_parameter_accuracy(self): + """Test that coefs are predicted accurately by fitting on toy data.""" + + # Test that fitting (10*X1 + 10*X2), (X1, X2) gives coefficients + # (10, 10) + slr = StreamingLinearRegressionWithSGD(stepSize=0.2, numIterations=25) + slr.setInitialWeights([0.0, 0.0]) + xMean = [0.0, 0.0] + xVariance = [1.0 / 3.0, 1.0 / 3.0] + + # Create ten batches with 100 sample points in each. + batches = [] + for i in range(10): + batch = LinearDataGenerator.generateLinearInput( + 0.0, [10.0, 10.0], xMean, xVariance, 100, 42 + i, 0.1) + batches.append(sc.parallelize(batch)) + + input_stream = self.ssc.queueStream(batches) + t = time() + slr.trainOn(input_stream) + self.ssc.start() + self._ssc_wait(t, 10, 0.01) + self.assertArrayAlmostEqual( + slr.latestModel().weights.array, [10., 10.], 1) + self.assertAlmostEqual(slr.latestModel().intercept, 0.0, 1) + + def test_parameter_convergence(self): + """Test that the model parameters improve with streaming data.""" + slr = StreamingLinearRegressionWithSGD(stepSize=0.2, numIterations=25) + slr.setInitialWeights([0.0]) + + # Create ten batches with 100 sample points in each. + batches = [] + for i in range(10): + batch = LinearDataGenerator.generateLinearInput( + 0.0, [10.0], [0.0], [1.0 / 3.0], 100, 42 + i, 0.1) + batches.append(sc.parallelize(batch)) + + model_weights = [] + input_stream = self.ssc.queueStream(batches) + input_stream.foreachRDD( + lambda x: model_weights.append(slr.latestModel().weights[0])) + t = time() + slr.trainOn(input_stream) + self.ssc.start() + self._ssc_wait(t, 10, 0.01) + + model_weights = array(model_weights) + diff = model_weights[1:] - model_weights[:-1] + self.assertTrue(all(diff >= -0.1)) + + def test_prediction(self): + """Test prediction on a model with weights already set.""" + # Create a model with initial Weights equal to coefs + slr = StreamingLinearRegressionWithSGD(stepSize=0.2, numIterations=25) + slr.setInitialWeights([10.0, 10.0]) + + # Create ten batches with 100 sample points in each. + batches = [] + for i in range(10): + batch = LinearDataGenerator.generateLinearInput( + 0.0, [10.0, 10.0], [0.0, 0.0], [1.0 / 3.0, 1.0 / 3.0], + 100, 42 + i, 0.1) + batches.append( + sc.parallelize(batch).map(lambda lp: (lp.label, lp.features))) + + input_stream = self.ssc.queueStream(batches) + t = time() + output_stream = slr.predictOnValues(input_stream) + samples = [] + output_stream.foreachRDD(lambda x: samples.append(x.collect())) + + self.ssc.start() + self._ssc_wait(t, 5, 0.01) + + # Test that mean absolute error on each batch is less than 0.1 + for batch in samples: + true, predicted = zip(*batch) + self.assertTrue(mean(abs(array(true) - array(predicted))) < 0.1) + + def test_train_prediction(self): + """Test that error on test data improves as model is trained.""" + slr = StreamingLinearRegressionWithSGD(stepSize=0.2, numIterations=25) + slr.setInitialWeights([0.0]) + + # Create ten batches with 100 sample points in each. + batches = [] + for i in range(10): + batch = LinearDataGenerator.generateLinearInput( + 0.0, [10.0], [0.0], [1.0 / 3.0], 100, 42 + i, 0.1) + batches.append(sc.parallelize(batch)) + + predict_batches = [ + b.map(lambda lp: (lp.label, lp.features)) for b in batches] + mean_absolute_errors = [] + + def func(rdd): + true, predicted = zip(*rdd.collect()) + mean_absolute_errors.append(mean(abs(true) - abs(predicted))) + + model_weights = [] + input_stream = self.ssc.queueStream(batches) + output_stream = self.ssc.queueStream(predict_batches) + t = time() + slr.trainOn(input_stream) + output_stream = slr.predictOnValues(output_stream) + output_stream.foreachRDD(func) + self.ssc.start() + self._ssc_wait(t, 10, 0.01) + self.assertTrue(mean_absolute_errors[1] - mean_absolute_errors[-1] > 2) + + +class MLUtilsTests(MLlibTestCase): + def test_append_bias(self): + data = [2.0, 2.0, 2.0] + ret = MLUtils.appendBias(data) + self.assertEqual(ret[3], 1.0) + self.assertEqual(type(ret), DenseVector) + + def test_append_bias_with_vector(self): + data = Vectors.dense([2.0, 2.0, 2.0]) + ret = MLUtils.appendBias(data) + self.assertEqual(ret[3], 1.0) + self.assertEqual(type(ret), DenseVector) + + def test_append_bias_with_sp_vector(self): + data = Vectors.sparse(3, {0: 2.0, 2: 2.0}) + expected = Vectors.sparse(4, {0: 2.0, 2: 2.0, 3: 1.0}) + # Returned value must be SparseVector + ret = MLUtils.appendBias(data) + self.assertEqual(ret, expected) + self.assertEqual(type(ret), SparseVector) + + def test_load_vectors(self): + import shutil + data = [ + [1.0, 2.0, 3.0], + [1.0, 2.0, 3.0] + ] + temp_dir = tempfile.mkdtemp() + load_vectors_path = os.path.join(temp_dir, "test_load_vectors") + try: + self.sc.parallelize(data).saveAsTextFile(load_vectors_path) + ret_rdd = MLUtils.loadVectors(self.sc, load_vectors_path) + ret = ret_rdd.collect() + self.assertEqual(len(ret), 2) + self.assertEqual(ret[0], DenseVector([1.0, 2.0, 3.0])) + self.assertEqual(ret[1], DenseVector([1.0, 2.0, 3.0])) + except: + self.fail() + finally: + shutil.rmtree(load_vectors_path) + + if __name__ == "__main__": if not _have_scipy: print("NOTE: Skipping SciPy tests as it does not seem to be installed") diff --git a/python/pyspark/mllib/tree.py b/python/pyspark/mllib/tree.py index cfcbea573fd2..372b86a7c95d 100644 --- a/python/pyspark/mllib/tree.py +++ b/python/pyspark/mllib/tree.py @@ -299,7 +299,7 @@ def trainClassifier(cls, data, numClasses, categoricalFeaturesInfo, numTrees, 1 internal node + 2 leaf nodes. (default: 4) :param maxBins: maximum number of bins used for splitting features - (default: 100) + (default: 32) :param seed: Random seed for bootstrapping and choosing feature subsets. :return: RandomForestModel that can be used for prediction @@ -377,7 +377,7 @@ def trainRegressor(cls, data, categoricalFeaturesInfo, numTrees, featureSubsetSt 1 leaf node; depth 1 means 1 internal node + 2 leaf nodes. (default: 4) :param maxBins: maximum number of bins used for splitting - features (default: 100) + features (default: 32) :param seed: Random seed for bootstrapping and choosing feature subsets. :return: RandomForestModel that can be used for prediction @@ -435,16 +435,17 @@ class GradientBoostedTrees(object): @classmethod def _train(cls, data, algo, categoricalFeaturesInfo, - loss, numIterations, learningRate, maxDepth): + loss, numIterations, learningRate, maxDepth, maxBins): first = data.first() assert isinstance(first, LabeledPoint), "the data should be RDD of LabeledPoint" model = callMLlibFunc("trainGradientBoostedTreesModel", data, algo, categoricalFeaturesInfo, - loss, numIterations, learningRate, maxDepth) + loss, numIterations, learningRate, maxDepth, maxBins) return GradientBoostedTreesModel(model) @classmethod def trainClassifier(cls, data, categoricalFeaturesInfo, - loss="logLoss", numIterations=100, learningRate=0.1, maxDepth=3): + loss="logLoss", numIterations=100, learningRate=0.1, maxDepth=3, + maxBins=32): """ Method to train a gradient-boosted trees model for classification. @@ -467,6 +468,8 @@ def trainClassifier(cls, data, categoricalFeaturesInfo, :param maxDepth: Maximum depth of the tree. E.g., depth 0 means 1 leaf node; depth 1 means 1 internal node + 2 leaf nodes. (default: 3) + :param maxBins: maximum number of bins used for splitting + features (default: 32) DecisionTree requires maxBins >= max categories :return: GradientBoostedTreesModel that can be used for prediction @@ -499,11 +502,12 @@ def trainClassifier(cls, data, categoricalFeaturesInfo, [1.0, 0.0] """ return cls._train(data, "classification", categoricalFeaturesInfo, - loss, numIterations, learningRate, maxDepth) + loss, numIterations, learningRate, maxDepth, maxBins) @classmethod def trainRegressor(cls, data, categoricalFeaturesInfo, - loss="leastSquaresError", numIterations=100, learningRate=0.1, maxDepth=3): + loss="leastSquaresError", numIterations=100, learningRate=0.1, maxDepth=3, + maxBins=32): """ Method to train a gradient-boosted trees model for regression. @@ -522,6 +526,8 @@ def trainRegressor(cls, data, categoricalFeaturesInfo, contribution of each estimator. The learning rate should be between in the interval (0, 1]. (default: 0.1) + :param maxBins: maximum number of bins used for splitting + features (default: 32) DecisionTree requires maxBins >= max categories :param maxDepth: Maximum depth of the tree. E.g., depth 0 means 1 leaf node; depth 1 means 1 internal node + 2 leaf nodes. (default: 3) @@ -556,7 +562,7 @@ def trainRegressor(cls, data, categoricalFeaturesInfo, [1.0, 0.0] """ return cls._train(data, "regression", categoricalFeaturesInfo, - loss, numIterations, learningRate, maxDepth) + loss, numIterations, learningRate, maxDepth, maxBins) def _test(): diff --git a/python/pyspark/mllib/util.py b/python/pyspark/mllib/util.py index 16a90db146ef..875d3b2d642c 100644 --- a/python/pyspark/mllib/util.py +++ b/python/pyspark/mllib/util.py @@ -169,6 +169,28 @@ def loadLabeledPoints(sc, path, minPartitions=None): minPartitions = minPartitions or min(sc.defaultParallelism, 2) return callMLlibFunc("loadLabeledPoints", sc, path, minPartitions) + @staticmethod + def appendBias(data): + """ + Returns a new vector with `1.0` (bias) appended to + the end of the input vector. + """ + vec = _convert_to_vector(data) + if isinstance(vec, SparseVector): + newIndices = np.append(vec.indices, len(vec)) + newValues = np.append(vec.values, 1.0) + return SparseVector(len(vec) + 1, newIndices, newValues) + else: + return _convert_to_vector(np.append(vec.toArray(), 1.0)) + + @staticmethod + def loadVectors(sc, path): + """ + Loads vectors saved using `RDD[Vector].saveAsTextFile` + with the default number of partitions. + """ + return callMLlibFunc("loadVectors", sc, path) + class Saveable(object): """ @@ -257,6 +279,41 @@ def load(cls, sc, path): return cls(java_model) +class LinearDataGenerator(object): + """Utils for generating linear data""" + + @staticmethod + def generateLinearInput(intercept, weights, xMean, xVariance, + nPoints, seed, eps): + """ + :param: intercept bias factor, the term c in X'w + c + :param: weights feature vector, the term w in X'w + c + :param: xMean Point around which the data X is centered. + :param: xVariance Variance of the given data + :param: nPoints Number of points to be generated + :param: seed Random Seed + :param: eps Used to scale the noise. If eps is set high, + the amount of gaussian noise added is more. + Returns a list of LabeledPoints of length nPoints + """ + weights = [float(weight) for weight in weights] + xMean = [float(mean) for mean in xMean] + xVariance = [float(var) for var in xVariance] + return list(callMLlibFunc( + "generateLinearInputWrapper", float(intercept), weights, xMean, + xVariance, int(nPoints), int(seed), float(eps))) + + @staticmethod + def generateLinearRDD(sc, nexamples, nfeatures, eps, + nParts=2, intercept=0.0): + """ + Generate a RDD of LabeledPoints. + """ + return callMLlibFunc( + "generateLinearRDDWrapper", sc, int(nexamples), int(nfeatures), + float(eps), int(nParts), float(intercept)) + + def _test(): import doctest from pyspark.context import SparkContext diff --git a/python/pyspark/profiler.py b/python/pyspark/profiler.py index d18daaabfcb3..44d17bd62947 100644 --- a/python/pyspark/profiler.py +++ b/python/pyspark/profiler.py @@ -90,9 +90,11 @@ class Profiler(object): >>> sc = SparkContext('local', 'test', conf=conf, profiler_cls=MyCustomProfiler) >>> sc.parallelize(range(1000)).map(lambda x: 2 * x).take(10) [0, 2, 4, 6, 8, 10, 12, 14, 16, 18] + >>> sc.parallelize(range(1000)).count() + 1000 >>> sc.show_profiles() My custom profiles for RDD:1 - My custom profiles for RDD:2 + My custom profiles for RDD:3 >>> sc.stop() """ @@ -169,4 +171,6 @@ def stats(self): if __name__ == "__main__": import doctest - doctest.testmod() + (failure_count, test_count) = doctest.testmod() + if failure_count: + exit(-1) diff --git a/python/pyspark/rdd.py b/python/pyspark/rdd.py index 545c5ad20cb9..79dafb0a4ef2 100644 --- a/python/pyspark/rdd.py +++ b/python/pyspark/rdd.py @@ -121,10 +121,23 @@ def _parse_memory(s): def _load_from_socket(port, serializer): - sock = socket.socket() - sock.settimeout(3) + sock = None + # Support for both IPv4 and IPv6. + # On most of IPv6-ready systems, IPv6 will take precedence. + for res in socket.getaddrinfo("localhost", port, socket.AF_UNSPEC, socket.SOCK_STREAM): + af, socktype, proto, canonname, sa = res + sock = socket.socket(af, socktype, proto) + try: + sock.settimeout(3) + sock.connect(sa) + except socket.error: + sock.close() + sock = None + continue + break + if not sock: + raise Exception("could not open socket") try: - sock.connect(("localhost", port)) rf = sock.makefile("rb", 65536) for item in serializer.load_stream(rf): yield item @@ -813,13 +826,21 @@ def op(x, y): def fold(self, zeroValue, op): """ Aggregate the elements of each partition, and then the results for all - the partitions, using a given associative function and a neutral "zero - value." + the partitions, using a given associative and commutative function and + a neutral "zero value." The function C{op(t1, t2)} is allowed to modify C{t1} and return it as its result value to avoid object allocation; however, it should not modify C{t2}. + This behaves somewhat differently from fold operations implemented + for non-distributed collections in functional languages like Scala. + This fold operation may be applied to partitions individually, and then + fold those results into the final result, rather than apply the fold + to each element sequentially in some defined ordering. For functions + that are not commutative, the result may differ from that of a fold + applied to a non-distributed collection. + >>> from operator import add >>> sc.parallelize([1, 2, 3, 4, 5]).fold(0, add) 15 @@ -952,7 +973,7 @@ def sum(self): >>> sc.parallelize([1.0, 2.0, 3.0]).sum() 6.0 """ - return self.mapPartitions(lambda x: [sum(x)]).reduce(operator.add) + return self.mapPartitions(lambda x: [sum(x)]).fold(0, operator.add) def count(self): """ @@ -2190,7 +2211,7 @@ def sumApprox(self, timeout, confidence=0.95): >>> rdd = sc.parallelize(range(1000), 10) >>> r = sum(range(1000)) - >>> (rdd.sumApprox(1000) - r) / r < 0.05 + >>> abs(rdd.sumApprox(1000) - r) / r < 0.05 True """ jrdd = self.mapPartitions(lambda it: [float(sum(it))])._to_java_object_rdd() @@ -2207,7 +2228,7 @@ def meanApprox(self, timeout, confidence=0.95): >>> rdd = sc.parallelize(range(1000), 10) >>> r = sum(range(1000)) / 1000.0 - >>> (rdd.meanApprox(1000) - r) / r < 0.05 + >>> abs(rdd.meanApprox(1000) - r) / r < 0.05 True """ jrdd = self.map(float)._to_java_object_rdd() @@ -2260,7 +2281,7 @@ def toLocalIterator(self): def _prepare_for_python_RDD(sc, command, obj=None): # the serialized command will be compressed by broadcast ser = CloudPickleSerializer() - pickled_command = ser.dumps((command, sys.version_info[:2])) + pickled_command = ser.dumps(command) if len(pickled_command) > (1 << 20): # 1M # The broadcast will have same life cycle as created PythonRDD broadcast = sc.broadcast(pickled_command) @@ -2344,7 +2365,7 @@ def _jrdd(self): python_rdd = self.ctx._jvm.PythonRDD(self._prev_jrdd.rdd(), bytearray(pickled_cmd), env, includes, self.preservesPartitioning, - self.ctx.pythonExec, + self.ctx.pythonExec, self.ctx.pythonVer, bvars, self.ctx._javaAccumulator) self._jrdd_val = python_rdd.asJavaRDD() diff --git a/python/pyspark/serializers.py b/python/pyspark/serializers.py index d8cdcda3a378..411b4dbf481f 100644 --- a/python/pyspark/serializers.py +++ b/python/pyspark/serializers.py @@ -44,8 +44,8 @@ >>> rdd.glom().collect() [[0, 1, 2, 3], [4, 5, 6, 7], [8, 9, 10, 11], [12, 13, 14, 15]] ->>> rdd._jrdd.count() -8L +>>> int(rdd._jrdd.count()) +8 >>> sc.stop() """ @@ -272,7 +272,7 @@ def dump_stream(self, iterator, stream): if size < best: batch *= 2 elif size > best * 10 and batch > 1: - batch /= 2 + batch //= 2 def __repr__(self): return "AutoBatchedSerializer(%s)" % self.serializer @@ -556,4 +556,6 @@ def write_with_length(obj, stream): if __name__ == '__main__': import doctest - doctest.testmod() + (failure_count, test_count) = doctest.testmod() + if failure_count: + exit(-1) diff --git a/python/pyspark/shuffle.py b/python/pyspark/shuffle.py index 1d0b16cade8b..8fb71bac64a5 100644 --- a/python/pyspark/shuffle.py +++ b/python/pyspark/shuffle.py @@ -362,7 +362,7 @@ def _spill(self): self.spills += 1 gc.collect() # release the memory as much as possible - MemoryBytesSpilled += (used_memory - get_used_memory()) << 20 + MemoryBytesSpilled += max(used_memory - get_used_memory(), 0) << 20 def items(self): """ Return all merged items as iterator """ @@ -486,7 +486,7 @@ def sorted(self, iterator, key=None, reverse=False): goes above the limit. """ global MemoryBytesSpilled, DiskBytesSpilled - batch, limit = 100, self.memory_limit + batch, limit = 100, self._next_limit() chunks, current_chunk = [], [] iterator = iter(iterator) while True: @@ -512,10 +512,7 @@ def load(f): f.close() chunks.append(load(open(path, 'rb'))) current_chunk = [] - gc.collect() - batch //= 2 - limit = self._next_limit() - MemoryBytesSpilled += (used_memory - get_used_memory()) << 20 + MemoryBytesSpilled += max(used_memory - get_used_memory(), 0) << 20 DiskBytesSpilled += os.path.getsize(path) os.unlink(path) # data will be deleted after close @@ -630,7 +627,7 @@ def _spill(self): self.values = [] gc.collect() DiskBytesSpilled += self._file.tell() - pos - MemoryBytesSpilled += (used_memory - get_used_memory()) << 20 + MemoryBytesSpilled += max(used_memory - get_used_memory(), 0) << 20 class ExternalListOfList(ExternalList): @@ -794,7 +791,7 @@ def _spill(self): self.spills += 1 gc.collect() # release the memory as much as possible - MemoryBytesSpilled += (used_memory - get_used_memory()) << 20 + MemoryBytesSpilled += max(used_memory - get_used_memory(), 0) << 20 def _merged_items(self, index): size = sum(os.path.getsize(os.path.join(self._get_spill_dir(j), str(index))) @@ -841,4 +838,6 @@ def load_partition(j): if __name__ == "__main__": import doctest - doctest.testmod() + (failure_count, test_count) = doctest.testmod() + if failure_count: + exit(-1) diff --git a/python/pyspark/sql/__init__.py b/python/pyspark/sql/__init__.py index 7192c89b3dc7..ad9c891ba1c0 100644 --- a/python/pyspark/sql/__init__.py +++ b/python/pyspark/sql/__init__.py @@ -18,47 +18,58 @@ """ Important classes of Spark SQL and DataFrames: - - L{SQLContext} + - :class:`pyspark.sql.SQLContext` Main entry point for :class:`DataFrame` and SQL functionality. - - L{DataFrame} + - :class:`pyspark.sql.DataFrame` A distributed collection of data grouped into named columns. - - L{Column} + - :class:`pyspark.sql.Column` A column expression in a :class:`DataFrame`. - - L{Row} + - :class:`pyspark.sql.Row` A row of data in a :class:`DataFrame`. - - L{HiveContext} + - :class:`pyspark.sql.HiveContext` Main entry point for accessing data stored in Apache Hive. - - L{GroupedData} + - :class:`pyspark.sql.GroupedData` Aggregation methods, returned by :func:`DataFrame.groupBy`. - - L{DataFrameNaFunctions} + - :class:`pyspark.sql.DataFrameNaFunctions` Methods for handling missing data (null values). - - L{DataFrameStatFunctions} + - :class:`pyspark.sql.DataFrameStatFunctions` Methods for statistics functionality. - - L{functions} + - :class:`pyspark.sql.functions` List of built-in functions available for :class:`DataFrame`. - - L{types} + - :class:`pyspark.sql.types` List of data types available. + - :class:`pyspark.sql.Window` + For working with window functions. """ from __future__ import absolute_import -# fix the module name conflict for Python 3+ -import sys -from . import _types as types -modname = __name__ + '.types' -types.__name__ = modname -# update the __module__ for all objects, make them picklable -for v in types.__dict__.values(): - if hasattr(v, "__module__") and v.__module__.endswith('._types'): - v.__module__ = modname -sys.modules[modname] = types -del modname, sys + +def since(version): + """ + A decorator that annotates a function to append the version of Spark the function was added. + """ + import re + indent_p = re.compile(r'\n( +)') + + def deco(f): + indents = indent_p.findall(f.__doc__) + indent = ' ' * (min(len(m) for m in indents) if indents else 0) + f.__doc__ = f.__doc__.rstrip() + "\n\n%s.. versionadded:: %s" % (indent, version) + return f + return deco + from pyspark.sql.types import Row from pyspark.sql.context import SQLContext, HiveContext -from pyspark.sql.dataframe import DataFrame, GroupedData, Column, SchemaRDD, DataFrameNaFunctions -from pyspark.sql.dataframe import DataFrameStatFunctions +from pyspark.sql.column import Column +from pyspark.sql.dataframe import DataFrame, SchemaRDD, DataFrameNaFunctions, DataFrameStatFunctions +from pyspark.sql.group import GroupedData +from pyspark.sql.readwriter import DataFrameReader, DataFrameWriter +from pyspark.sql.window import Window, WindowSpec + __all__ = [ 'SQLContext', 'HiveContext', 'DataFrame', 'GroupedData', 'Column', 'Row', - 'DataFrameNaFunctions', 'DataFrameStatFunctions' + 'DataFrameNaFunctions', 'DataFrameStatFunctions', 'Window', 'WindowSpec', + 'DataFrameReader', 'DataFrameWriter' ] diff --git a/python/pyspark/sql/column.py b/python/pyspark/sql/column.py new file mode 100644 index 000000000000..0a85da7443d3 --- /dev/null +++ b/python/pyspark/sql/column.py @@ -0,0 +1,430 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +import sys + +if sys.version >= '3': + basestring = str + long = int + +from pyspark.context import SparkContext +from pyspark.rdd import ignore_unicode_prefix +from pyspark.sql import since +from pyspark.sql.types import * + +__all__ = ["DataFrame", "Column", "SchemaRDD", "DataFrameNaFunctions", + "DataFrameStatFunctions"] + + +def _create_column_from_literal(literal): + sc = SparkContext._active_spark_context + return sc._jvm.functions.lit(literal) + + +def _create_column_from_name(name): + sc = SparkContext._active_spark_context + return sc._jvm.functions.col(name) + + +def _to_java_column(col): + if isinstance(col, Column): + jcol = col._jc + else: + jcol = _create_column_from_name(col) + return jcol + + +def _to_seq(sc, cols, converter=None): + """ + Convert a list of Column (or names) into a JVM Seq of Column. + + An optional `converter` could be used to convert items in `cols` + into JVM Column objects. + """ + if converter: + cols = [converter(c) for c in cols] + return sc._jvm.PythonUtils.toSeq(cols) + + +def _unary_op(name, doc="unary operator"): + """ Create a method for given unary operator """ + def _(self): + jc = getattr(self._jc, name)() + return Column(jc) + _.__doc__ = doc + return _ + + +def _func_op(name, doc=''): + def _(self): + sc = SparkContext._active_spark_context + jc = getattr(sc._jvm.functions, name)(self._jc) + return Column(jc) + _.__doc__ = doc + return _ + + +def _bin_op(name, doc="binary operator"): + """ Create a method for given binary operator + """ + def _(self, other): + jc = other._jc if isinstance(other, Column) else other + njc = getattr(self._jc, name)(jc) + return Column(njc) + _.__doc__ = doc + return _ + + +def _reverse_op(name, doc="binary operator"): + """ Create a method for binary operator (this object is on right side) + """ + def _(self, other): + jother = _create_column_from_literal(other) + jc = getattr(jother, name)(self._jc) + return Column(jc) + _.__doc__ = doc + return _ + + +class Column(object): + + """ + A column in a DataFrame. + + :class:`Column` instances can be created by:: + + # 1. Select a column out of a DataFrame + + df.colName + df["colName"] + + # 2. Create from an expression + df.colName + 1 + 1 / df.colName + + .. note:: Experimental + + .. versionadded:: 1.3 + """ + + def __init__(self, jc): + self._jc = jc + + # arithmetic operators + __neg__ = _func_op("negate") + __add__ = _bin_op("plus") + __sub__ = _bin_op("minus") + __mul__ = _bin_op("multiply") + __div__ = _bin_op("divide") + __truediv__ = _bin_op("divide") + __mod__ = _bin_op("mod") + __radd__ = _bin_op("plus") + __rsub__ = _reverse_op("minus") + __rmul__ = _bin_op("multiply") + __rdiv__ = _reverse_op("divide") + __rtruediv__ = _reverse_op("divide") + __rmod__ = _reverse_op("mod") + + # logistic operators + __eq__ = _bin_op("equalTo") + __ne__ = _bin_op("notEqual") + __lt__ = _bin_op("lt") + __le__ = _bin_op("leq") + __ge__ = _bin_op("geq") + __gt__ = _bin_op("gt") + + # `and`, `or`, `not` cannot be overloaded in Python, + # so use bitwise operators as boolean operators + __and__ = _bin_op('and') + __or__ = _bin_op('or') + __invert__ = _func_op('not') + __rand__ = _bin_op("and") + __ror__ = _bin_op("or") + + # container operators + __contains__ = _bin_op("contains") + __getitem__ = _bin_op("apply") + + # bitwise operators + bitwiseOR = _bin_op("bitwiseOR") + bitwiseAND = _bin_op("bitwiseAND") + bitwiseXOR = _bin_op("bitwiseXOR") + + @since(1.3) + def getItem(self, key): + """ + An expression that gets an item at position ``ordinal`` out of a list, + or gets an item by key out of a dict. + + >>> df = sc.parallelize([([1, 2], {"key": "value"})]).toDF(["l", "d"]) + >>> df.select(df.l.getItem(0), df.d.getItem("key")).show() + +----+------+ + |l[0]|d[key]| + +----+------+ + | 1| value| + +----+------+ + >>> df.select(df.l[0], df.d["key"]).show() + +----+------+ + |l[0]|d[key]| + +----+------+ + | 1| value| + +----+------+ + """ + return self[key] + + @since(1.3) + def getField(self, name): + """ + An expression that gets a field by name in a StructField. + + >>> from pyspark.sql import Row + >>> df = sc.parallelize([Row(r=Row(a=1, b="b"))]).toDF() + >>> df.select(df.r.getField("b")).show() + +----+ + |r[b]| + +----+ + | b| + +----+ + >>> df.select(df.r.a).show() + +----+ + |r[a]| + +----+ + | 1| + +----+ + """ + return self[name] + + def __getattr__(self, item): + if item.startswith("__"): + raise AttributeError(item) + return self.getField(item) + + # string methods + rlike = _bin_op("rlike") + like = _bin_op("like") + startswith = _bin_op("startsWith") + endswith = _bin_op("endsWith") + + @ignore_unicode_prefix + @since(1.3) + def substr(self, startPos, length): + """ + Return a :class:`Column` which is a substring of the column. + + :param startPos: start position (int or Column) + :param length: length of the substring (int or Column) + + >>> df.select(df.name.substr(1, 3).alias("col")).collect() + [Row(col=u'Ali'), Row(col=u'Bob')] + """ + if type(startPos) != type(length): + raise TypeError("Can not mix the type") + if isinstance(startPos, (int, long)): + jc = self._jc.substr(startPos, length) + elif isinstance(startPos, Column): + jc = self._jc.substr(startPos._jc, length._jc) + else: + raise TypeError("Unexpected type: %s" % type(startPos)) + return Column(jc) + + __getslice__ = substr + + @ignore_unicode_prefix + @since(1.3) + def inSet(self, *cols): + """ + A boolean expression that is evaluated to true if the value of this + expression is contained by the evaluated values of the arguments. + + >>> df[df.name.inSet("Bob", "Mike")].collect() + [Row(age=5, name=u'Bob')] + >>> df[df.age.inSet([1, 2, 3])].collect() + [Row(age=2, name=u'Alice')] + """ + if len(cols) == 1 and isinstance(cols[0], (list, set)): + cols = cols[0] + cols = [c._jc if isinstance(c, Column) else _create_column_from_literal(c) for c in cols] + sc = SparkContext._active_spark_context + jc = getattr(self._jc, "in")(_to_seq(sc, cols)) + return Column(jc) + + # order + asc = _unary_op("asc", "Returns a sort expression based on the" + " ascending order of the given column name.") + desc = _unary_op("desc", "Returns a sort expression based on the" + " descending order of the given column name.") + + isNull = _unary_op("isNull", "True if the current expression is null.") + isNotNull = _unary_op("isNotNull", "True if the current expression is not null.") + + @since(1.3) + def alias(self, *alias): + """ + Returns this column aliased with a new name or names (in the case of expressions that + return more than one column, such as explode). + + >>> df.select(df.age.alias("age2")).collect() + [Row(age2=2), Row(age2=5)] + """ + + if len(alias) == 1: + return Column(getattr(self._jc, "as")(alias[0])) + else: + sc = SparkContext._active_spark_context + return Column(getattr(self._jc, "as")(_to_seq(sc, list(alias)))) + + @ignore_unicode_prefix + @since(1.3) + def cast(self, dataType): + """ Convert the column into type ``dataType``. + + >>> df.select(df.age.cast("string").alias('ages')).collect() + [Row(ages=u'2'), Row(ages=u'5')] + >>> df.select(df.age.cast(StringType()).alias('ages')).collect() + [Row(ages=u'2'), Row(ages=u'5')] + """ + if isinstance(dataType, basestring): + jc = self._jc.cast(dataType) + elif isinstance(dataType, DataType): + sc = SparkContext._active_spark_context + ssql_ctx = sc._jvm.SQLContext(sc._jsc.sc()) + jdt = ssql_ctx.parseDataType(dataType.json()) + jc = self._jc.cast(jdt) + else: + raise TypeError("unexpected type: %s" % type(dataType)) + return Column(jc) + + astype = cast + + @since(1.3) + def between(self, lowerBound, upperBound): + """ + A boolean expression that is evaluated to true if the value of this + expression is between the given columns. + + >>> df.select(df.name, df.age.between(2, 4)).show() + +-----+--------------------------+ + | name|((age >= 2) && (age <= 4))| + +-----+--------------------------+ + |Alice| true| + | Bob| false| + +-----+--------------------------+ + """ + return (self >= lowerBound) & (self <= upperBound) + + @since(1.4) + def when(self, condition, value): + """ + Evaluates a list of conditions and returns one of multiple possible result expressions. + If :func:`Column.otherwise` is not invoked, None is returned for unmatched conditions. + + See :func:`pyspark.sql.functions.when` for example usage. + + :param condition: a boolean :class:`Column` expression. + :param value: a literal value, or a :class:`Column` expression. + + >>> from pyspark.sql import functions as F + >>> df.select(df.name, F.when(df.age > 4, 1).when(df.age < 3, -1).otherwise(0)).show() + +-----+--------------------------------------------------------+ + | name|CASE WHEN (age > 4) THEN 1 WHEN (age < 3) THEN -1 ELSE 0| + +-----+--------------------------------------------------------+ + |Alice| -1| + | Bob| 1| + +-----+--------------------------------------------------------+ + """ + if not isinstance(condition, Column): + raise TypeError("condition should be a Column") + v = value._jc if isinstance(value, Column) else value + jc = self._jc.when(condition._jc, v) + return Column(jc) + + @since(1.4) + def otherwise(self, value): + """ + Evaluates a list of conditions and returns one of multiple possible result expressions. + If :func:`Column.otherwise` is not invoked, None is returned for unmatched conditions. + + See :func:`pyspark.sql.functions.when` for example usage. + + :param value: a literal value, or a :class:`Column` expression. + + >>> from pyspark.sql import functions as F + >>> df.select(df.name, F.when(df.age > 3, 1).otherwise(0)).show() + +-----+---------------------------------+ + | name|CASE WHEN (age > 3) THEN 1 ELSE 0| + +-----+---------------------------------+ + |Alice| 0| + | Bob| 1| + +-----+---------------------------------+ + """ + v = value._jc if isinstance(value, Column) else value + jc = self._jc.otherwise(v) + return Column(jc) + + @since(1.4) + def over(self, window): + """ + Define a windowing column. + + :param window: a :class:`WindowSpec` + :return: a Column + + >>> from pyspark.sql import Window + >>> window = Window.partitionBy("name").orderBy("age").rowsBetween(-1, 1) + >>> from pyspark.sql.functions import rank, min + >>> # df.select(rank().over(window), min('age').over(window)) + + .. note:: Window functions is only supported with HiveContext in 1.4 + """ + from pyspark.sql.window import WindowSpec + if not isinstance(window, WindowSpec): + raise TypeError("window should be WindowSpec") + jc = self._jc.over(window._jspec) + return Column(jc) + + def __nonzero__(self): + raise ValueError("Cannot convert column into bool: please use '&' for 'and', '|' for 'or', " + "'~' for 'not' when building DataFrame boolean expressions.") + __bool__ = __nonzero__ + + def __repr__(self): + return 'Column<%s>' % self._jc.toString().encode('utf8') + + +def _test(): + import doctest + from pyspark.context import SparkContext + from pyspark.sql import SQLContext + import pyspark.sql.column + globs = pyspark.sql.column.__dict__.copy() + sc = SparkContext('local[4]', 'PythonTest') + globs['sc'] = sc + globs['sqlContext'] = SQLContext(sc) + globs['df'] = sc.parallelize([(2, 'Alice'), (5, 'Bob')]) \ + .toDF(StructType([StructField('age', IntegerType()), + StructField('name', StringType())])) + + (failure_count, test_count) = doctest.testmod( + pyspark.sql.column, globs=globs, + optionflags=doctest.ELLIPSIS | doctest.NORMALIZE_WHITESPACE | doctest.REPORT_NDIFF) + globs['sc'].stop() + if failure_count: + exit(-1) + + +if __name__ == "__main__": + _test() diff --git a/python/pyspark/sql/context.py b/python/pyspark/sql/context.py index f6f107ca32d2..309c11faf931 100644 --- a/python/pyspark/sql/context.py +++ b/python/pyspark/sql/context.py @@ -28,9 +28,12 @@ from pyspark.rdd import RDD, _prepare_for_python_RDD, ignore_unicode_prefix from pyspark.serializers import AutoBatchedSerializer, PickleSerializer +from pyspark.sql import since from pyspark.sql.types import Row, StringType, StructType, _verify_type, \ _infer_schema, _has_nulltype, _merge_type, _create_converter, _python_to_sql_converter from pyspark.sql.dataframe import DataFrame +from pyspark.sql.readwriter import DataFrameReader +from pyspark.sql.utils import install_exception_handler try: import pandas @@ -84,7 +87,8 @@ def __init__(self, sparkContext, sqlContext=None): >>> df.registerTempTable("allTypes") >>> sqlContext.sql('select i+1, d+1, not b, list[1], dict["s"], time, row.a ' ... 'from allTypes where b and i > 0').collect() - [Row(c0=2, c1=2.0, c2=False, c3=2, c4=0, time=datetime.datetime(2014, 8, 1, 14, 1, 5), a=1)] + [Row(_c0=2, _c1=2.0, _c2=False, _c3=2, _c4=0, \ + time=datetime.datetime(2014, 8, 1, 14, 1, 5), a=1)] >>> df.map(lambda x: (x.i, x.s, x.d, x.l, x.b, x.time, x.row.a, x.list)).collect() [(1, u'string', 1.0, 1, True, datetime.datetime(2014, 8, 1, 14, 1, 5), 1, [1, 2, 3])] """ @@ -93,6 +97,7 @@ def __init__(self, sparkContext, sqlContext=None): self._jvm = self._sc._jvm self._scala_SQLContext = sqlContext _monkey_patch_RDD(self) + install_exception_handler() @property def _ssql_ctx(self): @@ -105,11 +110,13 @@ def _ssql_ctx(self): self._scala_SQLContext = self._jvm.SQLContext(self._jsc.sc()) return self._scala_SQLContext + @since(1.3) def setConf(self, key, value): """Sets the given Spark SQL configuration property. """ self._ssql_ctx.setConf(key, value) + @since(1.3) def getConf(self, key, defaultValue): """Returns the value of Spark SQL configuration property for the given key. @@ -118,11 +125,47 @@ def getConf(self, key, defaultValue): return self._ssql_ctx.getConf(key, defaultValue) @property + @since("1.3.1") def udf(self): - """Returns a :class:`UDFRegistration` for UDF registration.""" + """Returns a :class:`UDFRegistration` for UDF registration. + + :return: :class:`UDFRegistration` + """ return UDFRegistration(self) + @since(1.4) + def range(self, start, end=None, step=1, numPartitions=None): + """ + Create a :class:`DataFrame` with single LongType column named `id`, + containing elements in a range from `start` to `end` (exclusive) with + step value `step`. + + :param start: the start value + :param end: the end value (exclusive) + :param step: the incremental step (default: 1) + :param numPartitions: the number of partitions of the DataFrame + :return: :class:`DataFrame` + + >>> sqlContext.range(1, 7, 2).collect() + [Row(id=1), Row(id=3), Row(id=5)] + + If only one argument is specified, it will be used as the end value. + + >>> sqlContext.range(3).collect() + [Row(id=0), Row(id=1), Row(id=2)] + """ + if numPartitions is None: + numPartitions = self._sc.defaultParallelism + + if end is None: + jdf = self._ssql_ctx.range(0, int(start), int(step), int(numPartitions)) + else: + jdf = self._ssql_ctx.range(int(start), int(end), int(step), int(numPartitions)) + + return DataFrame(jdf, self) + @ignore_unicode_prefix + @since(1.2) def registerFunction(self, name, f, returnType=StringType()): """Registers a lambda function as a UDF so it can be used in SQL statements. @@ -136,17 +179,17 @@ def registerFunction(self, name, f, returnType=StringType()): >>> sqlContext.registerFunction("stringLengthString", lambda x: len(x)) >>> sqlContext.sql("SELECT stringLengthString('test')").collect() - [Row(c0=u'4')] + [Row(_c0=u'4')] >>> from pyspark.sql.types import IntegerType >>> sqlContext.registerFunction("stringLengthInt", lambda x: len(x), IntegerType()) >>> sqlContext.sql("SELECT stringLengthInt('test')").collect() - [Row(c0=4)] + [Row(_c0=4)] >>> from pyspark.sql.types import IntegerType >>> sqlContext.udf.register("stringLengthInt", lambda x: len(x), IntegerType()) >>> sqlContext.sql("SELECT stringLengthInt('test')").collect() - [Row(c0=4)] + [Row(_c0=4)] """ func = lambda _, it: map(lambda x: f(*x), it) ser = AutoBatchedSerializer(PickleSerializer()) @@ -157,18 +200,49 @@ def registerFunction(self, name, f, returnType=StringType()): env, includes, self._sc.pythonExec, + self._sc.pythonVer, bvars, self._sc._javaAccumulator, returnType.json()) + def _inferSchemaFromList(self, data): + """ + Infer schema from list of Row or tuple. + + :param data: list of Row or tuple + :return: StructType + """ + if not data: + raise ValueError("can not infer schema from empty dataset") + first = data[0] + if type(first) is dict: + warnings.warn("inferring schema from dict is deprecated," + "please use pyspark.sql.Row instead") + schema = _infer_schema(first) + if _has_nulltype(schema): + for r in data: + schema = _merge_type(schema, _infer_schema(r)) + if not _has_nulltype(schema): + break + else: + raise ValueError("Some of types cannot be determined after inferring") + return schema + def _inferSchema(self, rdd, samplingRatio=None): + """ + Infer schema from an RDD of Row or tuple. + + :param rdd: an RDD of Row or tuple + :param samplingRatio: sampling ratio, or no sampling (default) + :return: StructType + """ first = rdd.first() if not first: raise ValueError("The first row in RDD is empty, " "can not infer schema") if type(first) is dict: - warnings.warn("Using RDD of dict to inferSchema is deprecated," - "please use pyspark.sql.Row instead") + warnings.warn("Using RDD of dict to inferSchema is deprecated. " + "Use pyspark.sql.Row instead") if samplingRatio is None: schema = _infer_schema(first) @@ -188,9 +262,10 @@ def _inferSchema(self, rdd, samplingRatio=None): @ignore_unicode_prefix def inferSchema(self, rdd, samplingRatio=None): - """::note: Deprecated in 1.3, use :func:`createDataFrame` instead. """ - warnings.warn("inferSchema is deprecated, please use createDataFrame instead") + .. note:: Deprecated in 1.3, use :func:`createDataFrame` instead. + """ + warnings.warn("inferSchema is deprecated, please use createDataFrame instead.") if isinstance(rdd, DataFrame): raise TypeError("Cannot apply schema to DataFrame") @@ -199,7 +274,8 @@ def inferSchema(self, rdd, samplingRatio=None): @ignore_unicode_prefix def applySchema(self, rdd, schema): - """::note: Deprecated in 1.3, use :func:`createDataFrame` instead. + """ + .. note:: Deprecated in 1.3, use :func:`createDataFrame` instead. """ warnings.warn("applySchema is deprecated, please use createDataFrame instead") @@ -211,6 +287,7 @@ def applySchema(self, rdd, schema): return self.createDataFrame(rdd, schema) + @since(1.3) @ignore_unicode_prefix def createDataFrame(self, data, schema=None, samplingRatio=None): """ @@ -231,6 +308,7 @@ def createDataFrame(self, data, schema=None, samplingRatio=None): :class:`list`, or :class:`pandas.DataFrame`. :param schema: a :class:`StructType` or list of column names. default None. :param samplingRatio: the sample ratio of rows used for inferring + :return: :class:`DataFrame` >>> l = [('Alice', 1)] >>> sqlContext.createDataFrame(l).collect() @@ -266,16 +344,20 @@ def createDataFrame(self, data, schema=None, samplingRatio=None): >>> sqlContext.createDataFrame(df.toPandas()).collect() # doctest: +SKIP [Row(name=u'Alice', age=1)] + >>> sqlContext.createDataFrame(pandas.DataFrame([[1, 2]]).collect()) # doctest: +SKIP + [Row(0=1, 1=2)] """ if isinstance(data, DataFrame): raise TypeError("data is already a DataFrame") if has_pandas and isinstance(data, pandas.DataFrame): if schema is None: - schema = list(data.columns) + schema = [str(x) for x in data.columns] data = [r.tolist() for r in data.to_records(index=False)] if not isinstance(data, RDD): + if not isinstance(data, list): + data = list(data) try: # data could be list, tuple, generator ... rdd = self._sc.parallelize(data) @@ -284,28 +366,26 @@ def createDataFrame(self, data, schema=None, samplingRatio=None): else: rdd = data - if schema is None: - schema = self._inferSchema(rdd, samplingRatio) + if schema is None or isinstance(schema, (list, tuple)): + if isinstance(data, RDD): + struct = self._inferSchema(rdd, samplingRatio) + else: + struct = self._inferSchemaFromList(data) + if isinstance(schema, (list, tuple)): + for i, name in enumerate(schema): + struct.fields[i].name = name + schema = struct converter = _create_converter(schema) rdd = rdd.map(converter) - if isinstance(schema, (list, tuple)): - first = rdd.first() - if not isinstance(first, (list, tuple)): - raise TypeError("each row in `rdd` should be list or tuple, " - "but got %r" % type(first)) - row_cls = Row(*schema) - schema = self._inferSchema(rdd.map(lambda r: row_cls(*r)), samplingRatio) - - # take the first few rows to verify schema - rows = rdd.take(10) - # Row() cannot been deserialized by Pyrolite - if rows and isinstance(rows[0], tuple) and rows[0].__class__.__name__ == 'Row': - rdd = rdd.map(tuple) + elif isinstance(schema, StructType): + # take the first few rows to verify schema rows = rdd.take(10) + for row in rows: + _verify_type(row, schema) - for row in rows: - _verify_type(row, schema) + else: + raise TypeError("schema should be StructType or list or None") # convert python objects to sql data converter = _python_to_sql_converter(schema) @@ -315,6 +395,7 @@ def createDataFrame(self, data, schema=None, samplingRatio=None): df = self._ssql_ctx.applySchemaToPythonRDD(jrdd.rdd(), schema.json()) return DataFrame(df, self) + @since(1.3) def registerDataFrameAsTable(self, df, tableName): """Registers the given :class:`DataFrame` as a temporary table in the catalog. @@ -330,14 +411,12 @@ def registerDataFrameAsTable(self, df, tableName): def parquetFile(self, *paths): """Loads a Parquet file, returning the result as a :class:`DataFrame`. - >>> import tempfile, shutil - >>> parquetFile = tempfile.mkdtemp() - >>> shutil.rmtree(parquetFile) - >>> df.saveAsParquetFile(parquetFile) - >>> df2 = sqlContext.parquetFile(parquetFile) - >>> sorted(df.collect()) == sorted(df2.collect()) - True + .. note:: Deprecated in 1.4, use :func:`DataFrameReader.parquet` instead. + + >>> sqlContext.parquetFile('python/test_support/sql/parquet_partitioned').dtypes + [('name', 'string'), ('year', 'int'), ('month', 'int'), ('day', 'int')] """ + warnings.warn("parquetFile is deprecated. Use read.parquet() instead.") gateway = self._sc._gateway jpaths = gateway.new_array(gateway.jvm.java.lang.String, len(paths)) for i in range(0, len(paths)): @@ -348,35 +427,12 @@ def parquetFile(self, *paths): def jsonFile(self, path, schema=None, samplingRatio=1.0): """Loads a text file storing one JSON object per line as a :class:`DataFrame`. - If the schema is provided, applies the given schema to this JSON dataset. - Otherwise, it samples the dataset with ratio ``samplingRatio`` to determine the schema. + .. note:: Deprecated in 1.4, use :func:`DataFrameReader.json` instead. - >>> import tempfile, shutil - >>> jsonFile = tempfile.mkdtemp() - >>> shutil.rmtree(jsonFile) - >>> with open(jsonFile, 'w') as f: - ... f.writelines(jsonStrings) - >>> df1 = sqlContext.jsonFile(jsonFile) - >>> df1.printSchema() - root - |-- field1: long (nullable = true) - |-- field2: string (nullable = true) - |-- field3: struct (nullable = true) - | |-- field4: long (nullable = true) - - >>> from pyspark.sql.types import * - >>> schema = StructType([ - ... StructField("field2", StringType()), - ... StructField("field3", - ... StructType([StructField("field5", ArrayType(IntegerType()))]))]) - >>> df2 = sqlContext.jsonFile(jsonFile, schema) - >>> df2.printSchema() - root - |-- field2: string (nullable = true) - |-- field3: struct (nullable = true) - | |-- field5: array (nullable = true) - | | |-- element: integer (containsNull = true) + >>> sqlContext.jsonFile('python/test_support/sql/people.json').dtypes + [('age', 'bigint'), ('name', 'string')] """ + warnings.warn("jsonFile is deprecated. Use read.json() instead.") if schema is None: df = self._ssql_ctx.jsonFile(path, samplingRatio) else: @@ -385,6 +441,7 @@ def jsonFile(self, path, schema=None, samplingRatio=1.0): return DataFrame(df, self) @ignore_unicode_prefix + @since(1.0) def jsonRDD(self, rdd, schema=None, samplingRatio=1.0): """Loads an RDD storing one JSON object per string as a :class:`DataFrame`. @@ -430,28 +487,13 @@ def func(iterator): def load(self, path=None, source=None, schema=None, **options): """Returns the dataset in a data source as a :class:`DataFrame`. - The data source is specified by the ``source`` and a set of ``options``. - If ``source`` is not specified, the default data source configured by - ``spark.sql.sources.default`` will be used. - - Optionally, a schema can be provided as the schema of the returned DataFrame. + .. note:: Deprecated in 1.4, use :func:`DataFrameReader.load` instead. """ - if path is not None: - options["path"] = path - if source is None: - source = self.getConf("spark.sql.sources.default", - "org.apache.spark.sql.parquet") - if schema is None: - df = self._ssql_ctx.load(source, options) - else: - if not isinstance(schema, StructType): - raise TypeError("schema should be StructType") - scala_datatype = self._ssql_ctx.parseDataType(schema.json()) - df = self._ssql_ctx.load(source, scala_datatype, options) - return DataFrame(df, self) + warnings.warn("load is deprecated. Use read.load() instead.") + return self.read.load(path, source, schema, **options) - def createExternalTable(self, tableName, path=None, source=None, - schema=None, **options): + @since(1.3) + def createExternalTable(self, tableName, path=None, source=None, schema=None, **options): """Creates an external table based on the dataset in a data source. It returns the DataFrame associated with the external table. @@ -462,6 +504,8 @@ def createExternalTable(self, tableName, path=None, source=None, Optionally, a schema can be provided as the schema of the returned :class:`DataFrame` and created external table. + + :return: :class:`DataFrame` """ if path is not None: options["path"] = path @@ -479,9 +523,12 @@ def createExternalTable(self, tableName, path=None, source=None, return DataFrame(df, self) @ignore_unicode_prefix + @since(1.0) def sql(self, sqlQuery): """Returns a :class:`DataFrame` representing the result of the given query. + :return: :class:`DataFrame` + >>> sqlContext.registerDataFrameAsTable(df, "table1") >>> df2 = sqlContext.sql("SELECT field1 AS f1, field2 as f2 from table1") >>> df2.collect() @@ -489,9 +536,12 @@ def sql(self, sqlQuery): """ return DataFrame(self._ssql_ctx.sql(sqlQuery), self) + @since(1.0) def table(self, tableName): """Returns the specified table as a :class:`DataFrame`. + :return: :class:`DataFrame` + >>> sqlContext.registerDataFrameAsTable(df, "table1") >>> df2 = sqlContext.table("table1") >>> sorted(df.collect()) == sorted(df2.collect()) @@ -500,6 +550,7 @@ def table(self, tableName): return DataFrame(self._ssql_ctx.table(tableName), self) @ignore_unicode_prefix + @since(1.3) def tables(self, dbName=None): """Returns a :class:`DataFrame` containing names of tables in the given database. @@ -508,6 +559,9 @@ def tables(self, dbName=None): The returned DataFrame has two columns: ``tableName`` and ``isTemporary`` (a column with :class:`BooleanType` indicating if a table is a temporary one or not). + :param dbName: string, name of the database to use. + :return: :class:`DataFrame` + >>> sqlContext.registerDataFrameAsTable(df, "table1") >>> df2 = sqlContext.tables() >>> df2.filter("tableName = 'table1'").first() @@ -518,10 +572,12 @@ def tables(self, dbName=None): else: return DataFrame(self._ssql_ctx.tables(dbName), self) + @since(1.3) def tableNames(self, dbName=None): """Returns a list of names of tables in the database ``dbName``. - If ``dbName`` is not specified, the current database will be used. + :param dbName: string, name of the database to use. Default to the current database. + :return: list of table names, in string >>> sqlContext.registerDataFrameAsTable(df, "table1") >>> "table1" in sqlContext.tableNames() @@ -534,18 +590,32 @@ def tableNames(self, dbName=None): else: return [name for name in self._ssql_ctx.tableNames(dbName)] + @since(1.0) def cacheTable(self, tableName): """Caches the specified table in-memory.""" self._ssql_ctx.cacheTable(tableName) + @since(1.0) def uncacheTable(self, tableName): """Removes the specified table from the in-memory cache.""" self._ssql_ctx.uncacheTable(tableName) + @since(1.3) def clearCache(self): """Removes all cached tables from the in-memory cache. """ self._ssql_ctx.clearCache() + @property + @since(1.4) + def read(self): + """ + Returns a :class:`DataFrameReader` that can be used to read data + in as a :class:`DataFrame`. + + :return: :class:`DataFrameReader` + """ + return DataFrameReader(self) + class HiveContext(SQLContext): """A variant of Spark SQL that integrates with data stored in Hive. @@ -600,10 +670,14 @@ def register(self, name, f, returnType=StringType()): def _test(): + import os import doctest from pyspark.context import SparkContext from pyspark.sql import Row, SQLContext import pyspark.sql.context + + os.chdir(os.environ["SPARK_HOME"]) + globs = pyspark.sql.context.__dict__.copy() sc = SparkContext('local[4]', 'PythonTest') globs['sc'] = sc diff --git a/python/pyspark/sql/dataframe.py b/python/pyspark/sql/dataframe.py index 2ed95ac8e250..1e9c657cf81b 100644 --- a/python/pyspark/sql/dataframe.py +++ b/python/pyspark/sql/dataframe.py @@ -22,20 +22,21 @@ if sys.version >= '3': basestring = unicode = str long = int + from functools import reduce else: from itertools import imap as map -from pyspark.context import SparkContext from pyspark.rdd import RDD, _load_from_socket, ignore_unicode_prefix from pyspark.serializers import BatchedSerializer, PickleSerializer, UTF8Deserializer from pyspark.storagelevel import StorageLevel from pyspark.traceback_utils import SCCallSiteSync -from pyspark.sql.types import * +from pyspark.sql import since from pyspark.sql.types import _create_cls, _parse_datatype_json_string +from pyspark.sql.column import Column, _to_seq, _to_java_column +from pyspark.sql.readwriter import DataFrameWriter +from pyspark.sql.types import * - -__all__ = ["DataFrame", "GroupedData", "Column", "SchemaRDD", "DataFrameNaFunctions", - "DataFrameStatFunctions"] +__all__ = ["DataFrame", "SchemaRDD", "DataFrameNaFunctions", "DataFrameStatFunctions"] class DataFrame(object): @@ -44,7 +45,7 @@ class DataFrame(object): A :class:`DataFrame` is equivalent to a relational table in Spark SQL, and can be created using various functions in :class:`SQLContext`:: - people = sqlContext.parquetFile("...") + people = sqlContext.read.parquet("...") Once created, it can be manipulated using the various domain-specific-language (DSL) functions defined in: :class:`DataFrame`, :class:`Column`. @@ -56,11 +57,15 @@ class DataFrame(object): A more concrete example:: # To create DataFrame using SQLContext - people = sqlContext.parquetFile("...") - department = sqlContext.parquetFile("...") + people = sqlContext.read.parquet("...") + department = sqlContext.read.parquet("...") people.filter(people.age > 30).join(department, people.deptId == department.id)) \ .groupBy(department.name, "gender").agg({"salary": "avg", "age": "max"}) + + .. note:: Experimental + + .. versionadded:: 1.3 """ def __init__(self, jdf, sql_ctx): @@ -72,6 +77,7 @@ def __init__(self, jdf, sql_ctx): self._lazy_rdd = None @property + @since(1.3) def rdd(self): """Returns the content as an :class:`pyspark.RDD` of :class:`Row`. """ @@ -89,18 +95,21 @@ def applySchema(it): return self._lazy_rdd @property + @since("1.3.1") def na(self): """Returns a :class:`DataFrameNaFunctions` for handling missing values. """ return DataFrameNaFunctions(self) @property + @since(1.4) def stat(self): """Returns a :class:`DataFrameStatFunctions` for statistic functions. """ return DataFrameStatFunctions(self) @ignore_unicode_prefix + @since(1.3) def toJSON(self, use_unicode=True): """Converts a :class:`DataFrame` into a :class:`RDD` of string. @@ -115,19 +124,12 @@ def toJSON(self, use_unicode=True): def saveAsParquetFile(self, path): """Saves the contents as a Parquet file, preserving the schema. - Files that are written out using this method can be read back in as - a :class:`DataFrame` using :func:`SQLContext.parquetFile`. - - >>> import tempfile, shutil - >>> parquetFile = tempfile.mkdtemp() - >>> shutil.rmtree(parquetFile) - >>> df.saveAsParquetFile(parquetFile) - >>> df2 = sqlContext.parquetFile(parquetFile) - >>> sorted(df2.collect()) == sorted(df.collect()) - True + .. note:: Deprecated in 1.4, use :func:`DataFrameWriter.parquet` instead. """ + warnings.warn("saveAsParquetFile is deprecated. Use write.parquet() instead.") self._jdf.saveAsParquetFile(path) + @since(1.3) def registerTempTable(self, name): """Registers this RDD as a temporary table using the given name. @@ -142,81 +144,49 @@ def registerTempTable(self, name): self._jdf.registerTempTable(name) def registerAsTable(self, name): - """DEPRECATED: use :func:`registerTempTable` instead""" - warnings.warn("Use registerTempTable instead of registerAsTable.", DeprecationWarning) + """ + .. note:: Deprecated in 1.4, use :func:`registerTempTable` instead. + """ + warnings.warn("Use registerTempTable instead of registerAsTable.") self.registerTempTable(name) def insertInto(self, tableName, overwrite=False): """Inserts the contents of this :class:`DataFrame` into the specified table. - Optionally overwriting any existing data. - """ - self._jdf.insertInto(tableName, overwrite) - - def _java_save_mode(self, mode): - """Returns the Java save mode based on the Python save mode represented by a string. + .. note:: Deprecated in 1.4, use :func:`DataFrameWriter.insertInto` instead. """ - jSaveMode = self._sc._jvm.org.apache.spark.sql.SaveMode - jmode = jSaveMode.ErrorIfExists - mode = mode.lower() - if mode == "append": - jmode = jSaveMode.Append - elif mode == "overwrite": - jmode = jSaveMode.Overwrite - elif mode == "ignore": - jmode = jSaveMode.Ignore - elif mode == "error": - pass - else: - raise ValueError( - "Only 'append', 'overwrite', 'ignore', and 'error' are acceptable save mode.") - return jmode + warnings.warn("insertInto is deprecated. Use write.insertInto() instead.") + self.write.insertInto(tableName, overwrite) def saveAsTable(self, tableName, source=None, mode="error", **options): """Saves the contents of this :class:`DataFrame` to a data source as a table. - The data source is specified by the ``source`` and a set of ``options``. - If ``source`` is not specified, the default data source configured by - ``spark.sql.sources.default`` will be used. - - Additionally, mode is used to specify the behavior of the saveAsTable operation when - table already exists in the data source. There are four modes: - - * `append`: Append contents of this :class:`DataFrame` to existing data. - * `overwrite`: Overwrite existing data. - * `error`: Throw an exception if data already exists. - * `ignore`: Silently ignore this operation if data already exists. + .. note:: Deprecated in 1.4, use :func:`DataFrameWriter.saveAsTable` instead. """ - if source is None: - source = self.sql_ctx.getConf("spark.sql.sources.default", - "org.apache.spark.sql.parquet") - jmode = self._java_save_mode(mode) - self._jdf.saveAsTable(tableName, source, jmode, options) + warnings.warn("insertInto is deprecated. Use write.saveAsTable() instead.") + self.write.saveAsTable(tableName, source, mode, **options) + @since(1.3) def save(self, path=None, source=None, mode="error", **options): """Saves the contents of the :class:`DataFrame` to a data source. - The data source is specified by the ``source`` and a set of ``options``. - If ``source`` is not specified, the default data source configured by - ``spark.sql.sources.default`` will be used. + .. note:: Deprecated in 1.4, use :func:`DataFrameWriter.save` instead. + """ + warnings.warn("insertInto is deprecated. Use write.save() instead.") + return self.write.save(path, source, mode, **options) - Additionally, mode is used to specify the behavior of the save operation when - data already exists in the data source. There are four modes: + @property + @since(1.4) + def write(self): + """ + Interface for saving the content of the :class:`DataFrame` out into external storage. - * `append`: Append contents of this :class:`DataFrame` to existing data. - * `overwrite`: Overwrite existing data. - * `error`: Throw an exception if data already exists. - * `ignore`: Silently ignore this operation if data already exists. + :return: :class:`DataFrameWriter` """ - if path is not None: - options["path"] = path - if source is None: - source = self.sql_ctx.getConf("spark.sql.sources.default", - "org.apache.spark.sql.parquet") - jmode = self._java_save_mode(mode) - self._jdf.save(source, jmode, options) + return DataFrameWriter(self) @property + @since(1.3) def schema(self): """Returns the schema of this :class:`DataFrame` as a :class:`types.StructType`. @@ -224,9 +194,14 @@ def schema(self): StructType(List(StructField(age,IntegerType,true),StructField(name,StringType,true))) """ if self._schema is None: - self._schema = _parse_datatype_json_string(self._jdf.schema().json()) + try: + self._schema = _parse_datatype_json_string(self._jdf.schema().json()) + except AttributeError as e: + raise Exception( + "Unable to parse datatype from schema. %s" % e) return self._schema + @since(1.3) def printSchema(self): """Prints out the schema in the tree format. @@ -238,6 +213,7 @@ def printSchema(self): """ print(self._jdf.schema().treeString()) + @since(1.3) def explain(self, extended=False): """Prints the (logical and physical) plans to the console for debugging purpose. @@ -263,15 +239,20 @@ def explain(self, extended=False): else: print(self._jdf.queryExecution().executedPlan().toString()) + @since(1.3) def isLocal(self): """Returns ``True`` if the :func:`collect` and :func:`take` methods can be run locally (without any Spark executors). """ return self._jdf.isLocal() - def show(self, n=20): + @since(1.3) + def show(self, n=20, truncate=True): """Prints the first ``n`` rows to the console. + :param n: Number of rows to show. + :param truncate: Whether truncate long strings and align cells right. + >>> df DataFrame[age: int, name: string] >>> df.show() @@ -282,11 +263,12 @@ def show(self, n=20): | 5| Bob| +---+-----+ """ - print(self._jdf.showString(n)) + print(self._jdf.showString(n, truncate)) def __repr__(self): return "DataFrame[%s]" % (", ".join("%s: %s" % c for c in self.dtypes)) + @since(1.3) def count(self): """Returns the number of rows in this :class:`DataFrame`. @@ -296,6 +278,7 @@ def count(self): return int(self._jdf.count()) @ignore_unicode_prefix + @since(1.3) def collect(self): """Returns all the records as a list of :class:`Row`. @@ -309,6 +292,7 @@ def collect(self): return [cls(r) for r in rs] @ignore_unicode_prefix + @since(1.3) def limit(self, num): """Limits the result count to the number specified. @@ -321,6 +305,7 @@ def limit(self, num): return DataFrame(jdf, self.sql_ctx) @ignore_unicode_prefix + @since(1.3) def take(self, num): """Returns the first ``num`` rows as a :class:`list` of :class:`Row`. @@ -330,6 +315,7 @@ def take(self, num): return self.limit(num).collect() @ignore_unicode_prefix + @since(1.3) def map(self, f): """ Returns a new :class:`RDD` by applying a the ``f`` function to each :class:`Row`. @@ -341,6 +327,7 @@ def map(self, f): return self.rdd.map(f) @ignore_unicode_prefix + @since(1.3) def flatMap(self, f): """ Returns a new :class:`RDD` by first applying the ``f`` function to each :class:`Row`, and then flattening the results. @@ -352,6 +339,7 @@ def flatMap(self, f): """ return self.rdd.flatMap(f) + @since(1.3) def mapPartitions(self, f, preservesPartitioning=False): """Returns a new :class:`RDD` by applying the ``f`` function to each partition. @@ -364,6 +352,7 @@ def mapPartitions(self, f, preservesPartitioning=False): """ return self.rdd.mapPartitions(f, preservesPartitioning) + @since(1.3) def foreach(self, f): """Applies the ``f`` function to all :class:`Row` of this :class:`DataFrame`. @@ -375,6 +364,7 @@ def foreach(self, f): """ return self.rdd.foreach(f) + @since(1.3) def foreachPartition(self, f): """Applies the ``f`` function to each partition of this :class:`DataFrame`. @@ -387,6 +377,7 @@ def foreachPartition(self, f): """ return self.rdd.foreachPartition(f) + @since(1.3) def cache(self): """ Persists with the default storage level (C{MEMORY_ONLY_SER}). """ @@ -394,6 +385,7 @@ def cache(self): self._jdf.cache() return self + @since(1.3) def persist(self, storageLevel=StorageLevel.MEMORY_ONLY_SER): """Sets the storage level to persist its values across operations after the first time it is computed. This can only be used to assign @@ -405,6 +397,7 @@ def persist(self, storageLevel=StorageLevel.MEMORY_ONLY_SER): self._jdf.persist(javaStorageLevel) return self + @since(1.3) def unpersist(self, blocking=True): """Marks the :class:`DataFrame` as non-persistent, and remove all blocks for it from memory and disk. @@ -413,10 +406,22 @@ def unpersist(self, blocking=True): self._jdf.unpersist(blocking) return self - # def coalesce(self, numPartitions, shuffle=False): - # rdd = self._jdf.coalesce(numPartitions, shuffle, None) - # return DataFrame(rdd, self.sql_ctx) + @since(1.4) + def coalesce(self, numPartitions): + """ + Returns a new :class:`DataFrame` that has exactly `numPartitions` partitions. + + Similar to coalesce defined on an :class:`RDD`, this operation results in a + narrow dependency, e.g. if you go from 1000 partitions to 100 partitions, + there will not be a shuffle, instead each of the 100 new partitions will + claim 10 of the current partitions. + + >>> df.coalesce(1).rdd.getNumPartitions() + 1 + """ + return DataFrame(self._jdf.coalesce(numPartitions), self.sql_ctx) + @since(1.3) def repartition(self, numPartitions): """Returns a new :class:`DataFrame` that has exactly ``numPartitions`` partitions. @@ -425,6 +430,7 @@ def repartition(self, numPartitions): """ return DataFrame(self._jdf.repartition(numPartitions), self.sql_ctx) + @since(1.3) def distinct(self): """Returns a new :class:`DataFrame` containing the distinct rows in this :class:`DataFrame`. @@ -433,6 +439,7 @@ def distinct(self): """ return DataFrame(self._jdf.distinct(), self.sql_ctx) + @since(1.3) def sample(self, withReplacement, fraction, seed=None): """Returns a sampled subset of this :class:`DataFrame`. @@ -444,6 +451,7 @@ def sample(self, withReplacement, fraction, seed=None): rdd = self._jdf.sample(withReplacement, fraction, long(seed)) return DataFrame(rdd, self.sql_ctx) + @since(1.4) def randomSplit(self, weights, seed=None): """Randomly splits this :class:`DataFrame` with the provided weights. @@ -466,6 +474,7 @@ def randomSplit(self, weights, seed=None): return [DataFrame(rdd, self.sql_ctx) for rdd in rdd_array] @property + @since(1.3) def dtypes(self): """Returns all column names and their data types as a list. @@ -475,16 +484,17 @@ def dtypes(self): return [(str(f.name), f.dataType.simpleString()) for f in self.schema.fields] @property - @ignore_unicode_prefix + @since(1.3) def columns(self): """Returns all column names as a list. >>> df.columns - [u'age', u'name'] + ['age', 'name'] """ return [f.name for f in self.schema.fields] @ignore_unicode_prefix + @since(1.3) def alias(self, alias): """Returns a new :class:`DataFrame` with an alias set. @@ -499,39 +509,57 @@ def alias(self, alias): return DataFrame(getattr(self._jdf, "as")(alias), self.sql_ctx) @ignore_unicode_prefix - def join(self, other, joinExprs=None, joinType=None): + @since(1.3) + def join(self, other, on=None, how=None): """Joins with another :class:`DataFrame`, using the given join expression. The following performs a full outer join between ``df1`` and ``df2``. :param other: Right side of the join - :param joinExprs: a string for join column name, or a join expression (Column). - If joinExprs is a string indicating the name of the join column, - the column must exist on both sides, and this performs an inner equi-join. - :param joinType: str, default 'inner'. + :param on: a string for join column name, a list of column names, + , a join expression (Column) or a list of Columns. + If `on` is a string or a list of string indicating the name of the join column(s), + the column(s) must exist on both sides, and this performs an inner equi-join. + :param how: str, default 'inner'. One of `inner`, `outer`, `left_outer`, `right_outer`, `semijoin`. >>> df.join(df2, df.name == df2.name, 'outer').select(df.name, df2.height).collect() [Row(name=None, height=80), Row(name=u'Alice', height=None), Row(name=u'Bob', height=85)] + >>> cond = [df.name == df3.name, df.age == df3.age] + >>> df.join(df3, cond, 'outer').select(df.name, df3.age).collect() + [Row(name=u'Bob', age=5), Row(name=u'Alice', age=2)] + >>> df.join(df2, 'name').select(df.name, df2.height).collect() [Row(name=u'Bob', height=85)] + + >>> df.join(df4, ['name', 'age']).select(df.name, df.age).collect() + [Row(name=u'Bob', age=5)] """ - if joinExprs is None: + if on is not None and not isinstance(on, list): + on = [on] + + if on is None or len(on) == 0: jdf = self._jdf.join(other._jdf) - elif isinstance(joinExprs, basestring): - jdf = self._jdf.join(other._jdf, joinExprs) + + if isinstance(on[0], basestring): + jdf = self._jdf.join(other._jdf, self._jseq(on)) else: - assert isinstance(joinExprs, Column), "joinExprs should be Column" - if joinType is None: - jdf = self._jdf.join(other._jdf, joinExprs._jc) + assert isinstance(on[0], Column), "on should be Column or list of Column" + if len(on) > 1: + on = reduce(lambda x, y: x.__and__(y), on) else: - assert isinstance(joinType, basestring), "joinType should be basestring" - jdf = self._jdf.join(other._jdf, joinExprs._jc, joinType) + on = on[0] + if how is None: + jdf = self._jdf.join(other._jdf, on._jc, "inner") + else: + assert isinstance(how, basestring), "how should be basestring" + jdf = self._jdf.join(other._jdf, on._jc, how) return DataFrame(jdf, self.sql_ctx) @ignore_unicode_prefix + @since(1.3) def sort(self, *cols, **kwargs): """Returns a new :class:`DataFrame` sorted by the specified column(s). @@ -591,12 +619,16 @@ def _jcols(self, *cols): cols = cols[0] return self._jseq(cols, _to_java_column) + @since("1.3.1") def describe(self, *cols): """Computes statistics for numeric columns. This include count, mean, stddev, min, and max. If no columns are given, this function computes statistics for all numerical columns. + .. note:: This function is meant for exploratory data analysis, as we make no \ + guarantee about the backward compatibility of the schema of the resulting DataFrame. + >>> df.describe().show() +-------+---+ |summary|age| @@ -607,15 +639,30 @@ def describe(self, *cols): | min| 2| | max| 5| +-------+---+ + >>> df.describe(['age', 'name']).show() + +-------+---+-----+ + |summary|age| name| + +-------+---+-----+ + | count| 2| 2| + | mean|3.5| null| + | stddev|1.5| null| + | min| 2|Alice| + | max| 5| Bob| + +-------+---+-----+ """ + if len(cols) == 1 and isinstance(cols[0], list): + cols = cols[0] jdf = self._jdf.describe(self._jseq(cols)) return DataFrame(jdf, self.sql_ctx) @ignore_unicode_prefix + @since(1.3) def head(self, n=None): - """ - Returns the first ``n`` rows as a list of :class:`Row`, - or the first :class:`Row` if ``n`` is ``None.`` + """Returns the first ``n`` rows. + + :param n: int, default 1. Number of rows to return. + :return: If n is greater than 1, return a list of :class:`Row`. + If n is 1, return a single Row. >>> df.head() Row(age=2, name=u'Alice') @@ -628,6 +675,7 @@ def head(self, n=None): return self.take(n) @ignore_unicode_prefix + @since(1.3) def first(self): """Returns the first row as a :class:`Row`. @@ -637,6 +685,7 @@ def first(self): return self.head() @ignore_unicode_prefix + @since(1.3) def __getitem__(self, item): """Returns the column as a :class:`Column`. @@ -664,6 +713,7 @@ def __getitem__(self, item): else: raise TypeError("unexpected item type: %s" % type(item)) + @since(1.3) def __getattr__(self, name): """Returns the :class:`Column` denoted by ``name``. @@ -677,6 +727,7 @@ def __getattr__(self, name): return Column(jc) @ignore_unicode_prefix + @since(1.3) def select(self, *cols): """Projects a set of expressions and returns a new :class:`DataFrame`. @@ -694,13 +745,14 @@ def select(self, *cols): jdf = self._jdf.select(self._jcols(*cols)) return DataFrame(jdf, self.sql_ctx) + @since(1.3) def selectExpr(self, *expr): """Projects a set of SQL expressions and returns a new :class:`DataFrame`. This is a variant of :func:`select` that accepts SQL expressions. >>> df.selectExpr("age * 2", "abs(age)").collect() - [Row((age * 2)=4, Abs(age)=2), Row((age * 2)=10, Abs(age)=5)] + [Row((age * 2)=4, 'abs(age)=2), Row((age * 2)=10, 'abs(age)=5)] """ if len(expr) == 1 and isinstance(expr[0], list): expr = expr[0] @@ -708,6 +760,7 @@ def selectExpr(self, *expr): return DataFrame(jdf, self.sql_ctx) @ignore_unicode_prefix + @since(1.3) def filter(self, condition): """Filters rows using the given condition. @@ -737,6 +790,7 @@ def filter(self, condition): where = filter @ignore_unicode_prefix + @since(1.3) def groupBy(self, *cols): """Groups the :class:`DataFrame` using the specified columns, so we can run aggregation on them. See :class:`GroupedData` @@ -748,29 +802,76 @@ def groupBy(self, *cols): Each element should be a column name (string) or an expression (:class:`Column`). >>> df.groupBy().avg().collect() - [Row(AVG(age)=3.5)] + [Row(avg(age)=3.5)] >>> df.groupBy('name').agg({'age': 'mean'}).collect() - [Row(name=u'Alice', AVG(age)=2.0), Row(name=u'Bob', AVG(age)=5.0)] + [Row(name=u'Alice', avg(age)=2.0), Row(name=u'Bob', avg(age)=5.0)] >>> df.groupBy(df.name).avg().collect() - [Row(name=u'Alice', AVG(age)=2.0), Row(name=u'Bob', AVG(age)=5.0)] + [Row(name=u'Alice', avg(age)=2.0), Row(name=u'Bob', avg(age)=5.0)] >>> df.groupBy(['name', df.age]).count().collect() [Row(name=u'Bob', age=5, count=1), Row(name=u'Alice', age=2, count=1)] """ - jdf = self._jdf.groupBy(self._jcols(*cols)) - return GroupedData(jdf, self.sql_ctx) - + jgd = self._jdf.groupBy(self._jcols(*cols)) + from pyspark.sql.group import GroupedData + return GroupedData(jgd, self.sql_ctx) + + @since(1.4) + def rollup(self, *cols): + """ + Create a multi-dimensional rollup for the current :class:`DataFrame` using + the specified columns, so we can run aggregation on them. + + >>> df.rollup('name', df.age).count().show() + +-----+----+-----+ + | name| age|count| + +-----+----+-----+ + |Alice|null| 1| + | Bob| 5| 1| + | Bob|null| 1| + | null|null| 2| + |Alice| 2| 1| + +-----+----+-----+ + """ + jgd = self._jdf.rollup(self._jcols(*cols)) + from pyspark.sql.group import GroupedData + return GroupedData(jgd, self.sql_ctx) + + @since(1.4) + def cube(self, *cols): + """ + Create a multi-dimensional cube for the current :class:`DataFrame` using + the specified columns, so we can run aggregation on them. + + >>> df.cube('name', df.age).count().show() + +-----+----+-----+ + | name| age|count| + +-----+----+-----+ + | null| 2| 1| + |Alice|null| 1| + | Bob| 5| 1| + | Bob|null| 1| + | null| 5| 1| + | null|null| 2| + |Alice| 2| 1| + +-----+----+-----+ + """ + jgd = self._jdf.cube(self._jcols(*cols)) + from pyspark.sql.group import GroupedData + return GroupedData(jgd, self.sql_ctx) + + @since(1.3) def agg(self, *exprs): """ Aggregate on the entire :class:`DataFrame` without groups (shorthand for ``df.groupBy.agg()``). >>> df.agg({"age": "max"}).collect() - [Row(MAX(age)=5)] + [Row(max(age)=5)] >>> from pyspark.sql import functions as F >>> df.agg(F.min(df.age)).collect() - [Row(MIN(age)=2)] + [Row(min(age)=2)] """ return self.groupBy().agg(*exprs) + @since(1.3) def unionAll(self, other): """ Return a new :class:`DataFrame` containing union of rows in this frame and another frame. @@ -779,6 +880,7 @@ def unionAll(self, other): """ return DataFrame(self._jdf.unionAll(other._jdf), self.sql_ctx) + @since(1.3) def intersect(self, other): """ Return a new :class:`DataFrame` containing rows only in both this frame and another frame. @@ -787,6 +889,7 @@ def intersect(self, other): """ return DataFrame(self._jdf.intersect(other._jdf), self.sql_ctx) + @since(1.3) def subtract(self, other): """ Return a new :class:`DataFrame` containing rows in this frame but not in another frame. @@ -795,6 +898,7 @@ def subtract(self, other): """ return DataFrame(getattr(self._jdf, "except")(other._jdf), self.sql_ctx) + @since(1.4) def dropDuplicates(self, subset=None): """Return a new :class:`DataFrame` with duplicate rows removed, optionally only considering certain columns. @@ -825,10 +929,10 @@ def dropDuplicates(self, subset=None): jdf = self._jdf.dropDuplicates(self._jseq(subset)) return DataFrame(jdf, self.sql_ctx) + @since("1.3.1") def dropna(self, how='any', thresh=None, subset=None): """Returns a new :class:`DataFrame` omitting rows with null values. - - This is an alias for ``na.drop()``. + :func:`DataFrame.dropna` and :func:`DataFrameNaFunctions.drop` are aliases of each other. :param how: 'any' or 'all'. If 'any', drop a row if it contains any nulls. @@ -838,13 +942,6 @@ def dropna(self, how='any', thresh=None, subset=None): This overwrites the `how` parameter. :param subset: optional list of column names to consider. - >>> df4.dropna().show() - +---+------+-----+ - |age|height| name| - +---+------+-----+ - | 10| 80|Alice| - +---+------+-----+ - >>> df4.na.drop().show() +---+------+-----+ |age|height| name| @@ -867,8 +964,10 @@ def dropna(self, how='any', thresh=None, subset=None): return DataFrame(self._jdf.na().drop(thresh, self._jseq(subset)), self.sql_ctx) + @since("1.3.1") def fillna(self, value, subset=None): """Replace null values, alias for ``na.fill()``. + :func:`DataFrame.fillna` and :func:`DataFrameNaFunctions.fill` are aliases of each other. :param value: int, long, float, string, or dict. Value to replace null values with. @@ -880,7 +979,7 @@ def fillna(self, value, subset=None): For example, if `value` is a string, and subset contains a non-string column, then the non-string column is simply ignored. - >>> df4.fillna(50).show() + >>> df4.na.fill(50).show() +---+------+-----+ |age|height| name| +---+------+-----+ @@ -890,16 +989,6 @@ def fillna(self, value, subset=None): | 50| 50| null| +---+------+-----+ - >>> df4.fillna({'age': 50, 'name': 'unknown'}).show() - +---+------+-------+ - |age|height| name| - +---+------+-------+ - | 10| 80| Alice| - | 5| null| Bob| - | 50| null| Tom| - | 50| null|unknown| - +---+------+-------+ - >>> df4.na.fill({'age': 50, 'name': 'unknown'}).show() +---+------+-------+ |age|height| name| @@ -928,8 +1017,11 @@ def fillna(self, value, subset=None): return DataFrame(self._jdf.na().fill(value, self._jseq(subset)), self.sql_ctx) + @since(1.4) def replace(self, to_replace, value, subset=None): """Returns a new :class:`DataFrame` replacing a value with another value. + :func:`DataFrame.replace` and :func:`DataFrameNaFunctions.replace` are + aliases of each other. :param to_replace: int, long, float, string, or list. Value to be replaced. @@ -944,7 +1036,8 @@ def replace(self, to_replace, value, subset=None): Columns specified in subset that do not have matching data type are ignored. For example, if `value` is a string, and subset contains a non-string column, then the non-string column is simply ignored. - >>> df4.replace(10, 20).show() + + >>> df4.na.replace(10, 20).show() +----+------+-----+ | age|height| name| +----+------+-----+ @@ -954,7 +1047,7 @@ def replace(self, to_replace, value, subset=None): |null| null| null| +----+------+-----+ - >>> df4.replace(['Alice', 'Bob'], ['A', 'B'], 'name').show() + >>> df4.na.replace(['Alice', 'Bob'], ['A', 'B'], 'name').show() +----+------+----+ | age|height|name| +----+------+----+ @@ -1002,11 +1095,12 @@ def replace(self, to_replace, value, subset=None): return DataFrame( self._jdf.na().replace(self._jseq(subset), self._jmap(rep_dict)), self.sql_ctx) + @since(1.4) def corr(self, col1, col2, method=None): """ - Calculates the correlation of two columns of a DataFrame as a double value. Currently only - supports the Pearson Correlation Coefficient. - :func:`DataFrame.corr` and :func:`DataFrameStatFunctions.corr` are aliases. + Calculates the correlation of two columns of a DataFrame as a double value. + Currently only supports the Pearson Correlation Coefficient. + :func:`DataFrame.corr` and :func:`DataFrameStatFunctions.corr` are aliases of each other. :param col1: The name of the first column :param col2: The name of the second column @@ -1023,6 +1117,7 @@ def corr(self, col1, col2, method=None): "coefficient is supported.") return self._jdf.stat().corr(col1, col2, method) + @since(1.4) def cov(self, col1, col2): """ Calculate the sample covariance for the given columns, specified by their names, as a @@ -1037,6 +1132,7 @@ def cov(self, col1, col2): raise ValueError("col2 should be a string.") return self._jdf.stat().cov(col1, col2) + @since(1.4) def crosstab(self, col1, col2): """ Computes a pair-wise frequency table of the given columns. Also known as a contingency @@ -1058,6 +1154,7 @@ def crosstab(self, col1, col2): raise ValueError("col2 should be a string.") return DataFrame(self._jdf.stat().crosstab(col1, col2), self.sql_ctx) + @since(1.4) def freqItems(self, cols, support=None): """ Finding frequent items for columns, possibly with false positives. Using the @@ -1065,6 +1162,9 @@ def freqItems(self, cols, support=None): "http://dx.doi.org/10.1145/762471.762473, proposed by Karp, Schenker, and Papadimitriou". :func:`DataFrame.freqItems` and :func:`DataFrameStatFunctions.freqItems` are aliases. + .. note:: This function is meant for exploratory data analysis, as we make no \ + guarantee about the backward compatibility of the schema of the resulting DataFrame. + :param cols: Names of the columns to calculate frequent items for as a list or tuple of strings. :param support: The frequency with which to consider an item 'frequent'. Default is 1%. @@ -1079,6 +1179,7 @@ def freqItems(self, cols, support=None): return DataFrame(self._jdf.stat().freqItems(_to_seq(self._sc, cols), support), self.sql_ctx) @ignore_unicode_prefix + @since(1.3) def withColumn(self, colName, col): """Returns a new :class:`DataFrame` by adding a column. @@ -1091,6 +1192,7 @@ def withColumn(self, colName, col): return self.select('*', col.alias(colName)) @ignore_unicode_prefix + @since(1.3) def withColumnRenamed(self, existing, new): """Returns a new :class:`DataFrame` by renaming an existing column. @@ -1105,18 +1207,35 @@ def withColumnRenamed(self, existing, new): for c in self.columns] return self.select(*cols) + @since(1.4) @ignore_unicode_prefix - def drop(self, colName): + def drop(self, col): """Returns a new :class:`DataFrame` that drops the specified column. - :param colName: string, name of the column to drop. + :param col: a string name of the column to drop, or a + :class:`Column` to drop. >>> df.drop('age').collect() [Row(name=u'Alice'), Row(name=u'Bob')] + + >>> df.drop(df.age).collect() + [Row(name=u'Alice'), Row(name=u'Bob')] + + >>> df.join(df2, df.name == df2.name, 'inner').drop(df.name).collect() + [Row(age=5, height=85, name=u'Bob')] + + >>> df.join(df2, df.name == df2.name, 'inner').drop(df2.name).collect() + [Row(age=5, name=u'Bob', height=85)] """ - jdf = self._jdf.drop(colName) + if isinstance(col, basestring): + jdf = self._jdf.drop(col) + elif isinstance(col, Column): + jdf = self._jdf.drop(col._jc) + else: + raise TypeError("col should be a string or a Column") return DataFrame(jdf, self.sql_ctx) + @since(1.3) def toPandas(self): """Returns the contents of this :class:`DataFrame` as Pandas ``pandas.DataFrame``. @@ -1130,7 +1249,10 @@ def toPandas(self): import pandas as pd return pd.DataFrame.from_records(self.collect(), columns=self.columns) + ########################################################################################## # Pandas compatibility + ########################################################################################## + groupby = groupBy drop_duplicates = dropDuplicates @@ -1141,169 +1263,6 @@ class SchemaRDD(DataFrame): """ -def dfapi(f): - def _api(self): - name = f.__name__ - jdf = getattr(self._jdf, name)() - return DataFrame(jdf, self.sql_ctx) - _api.__name__ = f.__name__ - _api.__doc__ = f.__doc__ - return _api - - -def df_varargs_api(f): - def _api(self, *args): - name = f.__name__ - jdf = getattr(self._jdf, name)(_to_seq(self.sql_ctx._sc, args)) - return DataFrame(jdf, self.sql_ctx) - _api.__name__ = f.__name__ - _api.__doc__ = f.__doc__ - return _api - - -class GroupedData(object): - """ - A set of methods for aggregations on a :class:`DataFrame`, - created by :func:`DataFrame.groupBy`. - """ - - def __init__(self, jdf, sql_ctx): - self._jdf = jdf - self.sql_ctx = sql_ctx - - @ignore_unicode_prefix - def agg(self, *exprs): - """Compute aggregates and returns the result as a :class:`DataFrame`. - - The available aggregate functions are `avg`, `max`, `min`, `sum`, `count`. - - If ``exprs`` is a single :class:`dict` mapping from string to string, then the key - is the column to perform aggregation on, and the value is the aggregate function. - - Alternatively, ``exprs`` can also be a list of aggregate :class:`Column` expressions. - - :param exprs: a dict mapping from column name (string) to aggregate functions (string), - or a list of :class:`Column`. - - >>> gdf = df.groupBy(df.name) - >>> gdf.agg({"*": "count"}).collect() - [Row(name=u'Alice', COUNT(1)=1), Row(name=u'Bob', COUNT(1)=1)] - - >>> from pyspark.sql import functions as F - >>> gdf.agg(F.min(df.age)).collect() - [Row(name=u'Alice', MIN(age)=2), Row(name=u'Bob', MIN(age)=5)] - """ - assert exprs, "exprs should not be empty" - if len(exprs) == 1 and isinstance(exprs[0], dict): - jdf = self._jdf.agg(exprs[0]) - else: - # Columns - assert all(isinstance(c, Column) for c in exprs), "all exprs should be Column" - jdf = self._jdf.agg(exprs[0]._jc, - _to_seq(self.sql_ctx._sc, [c._jc for c in exprs[1:]])) - return DataFrame(jdf, self.sql_ctx) - - @dfapi - def count(self): - """Counts the number of records for each group. - - >>> df.groupBy(df.age).count().collect() - [Row(age=2, count=1), Row(age=5, count=1)] - """ - - @df_varargs_api - def mean(self, *cols): - """Computes average values for each numeric columns for each group. - - :func:`mean` is an alias for :func:`avg`. - - :param cols: list of column names (string). Non-numeric columns are ignored. - - >>> df.groupBy().mean('age').collect() - [Row(AVG(age)=3.5)] - >>> df3.groupBy().mean('age', 'height').collect() - [Row(AVG(age)=3.5, AVG(height)=82.5)] - """ - - @df_varargs_api - def avg(self, *cols): - """Computes average values for each numeric columns for each group. - - :func:`mean` is an alias for :func:`avg`. - - :param cols: list of column names (string). Non-numeric columns are ignored. - - >>> df.groupBy().avg('age').collect() - [Row(AVG(age)=3.5)] - >>> df3.groupBy().avg('age', 'height').collect() - [Row(AVG(age)=3.5, AVG(height)=82.5)] - """ - - @df_varargs_api - def max(self, *cols): - """Computes the max value for each numeric columns for each group. - - >>> df.groupBy().max('age').collect() - [Row(MAX(age)=5)] - >>> df3.groupBy().max('age', 'height').collect() - [Row(MAX(age)=5, MAX(height)=85)] - """ - - @df_varargs_api - def min(self, *cols): - """Computes the min value for each numeric column for each group. - - :param cols: list of column names (string). Non-numeric columns are ignored. - - >>> df.groupBy().min('age').collect() - [Row(MIN(age)=2)] - >>> df3.groupBy().min('age', 'height').collect() - [Row(MIN(age)=2, MIN(height)=80)] - """ - - @df_varargs_api - def sum(self, *cols): - """Compute the sum for each numeric columns for each group. - - :param cols: list of column names (string). Non-numeric columns are ignored. - - >>> df.groupBy().sum('age').collect() - [Row(SUM(age)=7)] - >>> df3.groupBy().sum('age', 'height').collect() - [Row(SUM(age)=7, SUM(height)=165)] - """ - - -def _create_column_from_literal(literal): - sc = SparkContext._active_spark_context - return sc._jvm.functions.lit(literal) - - -def _create_column_from_name(name): - sc = SparkContext._active_spark_context - return sc._jvm.functions.col(name) - - -def _to_java_column(col): - if isinstance(col, Column): - jcol = col._jc - else: - jcol = _create_column_from_name(col) - return jcol - - -def _to_seq(sc, cols, converter=None): - """ - Convert a list of Column (or names) into a JVM Seq of Column. - - An optional `converter` could be used to convert items in `cols` - into JVM Column objects. - """ - if converter: - cols = [converter(c) for c in cols] - return sc._jvm.PythonUtils.toSeq(cols) - - def _to_scala_map(sc, jm): """ Convert a dict into a JVM Map. @@ -1311,284 +1270,10 @@ def _to_scala_map(sc, jm): return sc._jvm.PythonUtils.toScalaMap(jm) -def _unary_op(name, doc="unary operator"): - """ Create a method for given unary operator """ - def _(self): - jc = getattr(self._jc, name)() - return Column(jc) - _.__doc__ = doc - return _ - - -def _func_op(name, doc=''): - def _(self): - sc = SparkContext._active_spark_context - jc = getattr(sc._jvm.functions, name)(self._jc) - return Column(jc) - _.__doc__ = doc - return _ - - -def _bin_op(name, doc="binary operator"): - """ Create a method for given binary operator - """ - def _(self, other): - jc = other._jc if isinstance(other, Column) else other - njc = getattr(self._jc, name)(jc) - return Column(njc) - _.__doc__ = doc - return _ - - -def _reverse_op(name, doc="binary operator"): - """ Create a method for binary operator (this object is on right side) - """ - def _(self, other): - jother = _create_column_from_literal(other) - jc = getattr(jother, name)(self._jc) - return Column(jc) - _.__doc__ = doc - return _ - - -class Column(object): - - """ - A column in a DataFrame. - - :class:`Column` instances can be created by:: - - # 1. Select a column out of a DataFrame - - df.colName - df["colName"] - - # 2. Create from an expression - df.colName + 1 - 1 / df.colName - """ - - def __init__(self, jc): - self._jc = jc - - # arithmetic operators - __neg__ = _func_op("negate") - __add__ = _bin_op("plus") - __sub__ = _bin_op("minus") - __mul__ = _bin_op("multiply") - __div__ = _bin_op("divide") - __truediv__ = _bin_op("divide") - __mod__ = _bin_op("mod") - __radd__ = _bin_op("plus") - __rsub__ = _reverse_op("minus") - __rmul__ = _bin_op("multiply") - __rdiv__ = _reverse_op("divide") - __rtruediv__ = _reverse_op("divide") - __rmod__ = _reverse_op("mod") - - # logistic operators - __eq__ = _bin_op("equalTo") - __ne__ = _bin_op("notEqual") - __lt__ = _bin_op("lt") - __le__ = _bin_op("leq") - __ge__ = _bin_op("geq") - __gt__ = _bin_op("gt") - - # `and`, `or`, `not` cannot be overloaded in Python, - # so use bitwise operators as boolean operators - __and__ = _bin_op('and') - __or__ = _bin_op('or') - __invert__ = _func_op('not') - __rand__ = _bin_op("and") - __ror__ = _bin_op("or") - - # container operators - __contains__ = _bin_op("contains") - __getitem__ = _bin_op("apply") - - # bitwise operators - bitwiseOR = _bin_op("bitwiseOR") - bitwiseAND = _bin_op("bitwiseAND") - bitwiseXOR = _bin_op("bitwiseXOR") - - def getItem(self, key): - """An expression that gets an item at position `ordinal` out of a list, - or gets an item by key out of a dict. - - >>> df = sc.parallelize([([1, 2], {"key": "value"})]).toDF(["l", "d"]) - >>> df.select(df.l.getItem(0), df.d.getItem("key")).show() - +----+------+ - |l[0]|d[key]| - +----+------+ - | 1| value| - +----+------+ - >>> df.select(df.l[0], df.d["key"]).show() - +----+------+ - |l[0]|d[key]| - +----+------+ - | 1| value| - +----+------+ - """ - return self[key] - - def getField(self, name): - """An expression that gets a field by name in a StructField. - - >>> from pyspark.sql import Row - >>> df = sc.parallelize([Row(r=Row(a=1, b="b"))]).toDF() - >>> df.select(df.r.getField("b")).show() - +----+ - |r[b]| - +----+ - | b| - +----+ - >>> df.select(df.r.a).show() - +----+ - |r[a]| - +----+ - | 1| - +----+ - """ - return self[name] - - def __getattr__(self, item): - if item.startswith("__"): - raise AttributeError(item) - return self.getField(item) - - # string methods - rlike = _bin_op("rlike") - like = _bin_op("like") - startswith = _bin_op("startsWith") - endswith = _bin_op("endsWith") - - @ignore_unicode_prefix - def substr(self, startPos, length): - """ - Return a :class:`Column` which is a substring of the column - - :param startPos: start position (int or Column) - :param length: length of the substring (int or Column) - - >>> df.select(df.name.substr(1, 3).alias("col")).collect() - [Row(col=u'Ali'), Row(col=u'Bob')] - """ - if type(startPos) != type(length): - raise TypeError("Can not mix the type") - if isinstance(startPos, (int, long)): - jc = self._jc.substr(startPos, length) - elif isinstance(startPos, Column): - jc = self._jc.substr(startPos._jc, length._jc) - else: - raise TypeError("Unexpected type: %s" % type(startPos)) - return Column(jc) - - __getslice__ = substr - - @ignore_unicode_prefix - def inSet(self, *cols): - """ A boolean expression that is evaluated to true if the value of this - expression is contained by the evaluated values of the arguments. - - >>> df[df.name.inSet("Bob", "Mike")].collect() - [Row(age=5, name=u'Bob')] - >>> df[df.age.inSet([1, 2, 3])].collect() - [Row(age=2, name=u'Alice')] - """ - if len(cols) == 1 and isinstance(cols[0], (list, set)): - cols = cols[0] - cols = [c._jc if isinstance(c, Column) else _create_column_from_literal(c) for c in cols] - sc = SparkContext._active_spark_context - jc = getattr(self._jc, "in")(_to_seq(sc, cols)) - return Column(jc) - - # order - asc = _unary_op("asc", "Returns a sort expression based on the" - " ascending order of the given column name.") - desc = _unary_op("desc", "Returns a sort expression based on the" - " descending order of the given column name.") - - isNull = _unary_op("isNull", "True if the current expression is null.") - isNotNull = _unary_op("isNotNull", "True if the current expression is not null.") - - def alias(self, *alias): - """Returns this column aliased with a new name or names (in the case of expressions that - return more than one column, such as explode). - - >>> df.select(df.age.alias("age2")).collect() - [Row(age2=2), Row(age2=5)] - """ - - if len(alias) == 1: - return Column(getattr(self._jc, "as")(alias[0])) - else: - sc = SparkContext._active_spark_context - return Column(getattr(self._jc, "as")(_to_seq(sc, list(alias)))) - - @ignore_unicode_prefix - def cast(self, dataType): - """ Convert the column into type `dataType` - - >>> df.select(df.age.cast("string").alias('ages')).collect() - [Row(ages=u'2'), Row(ages=u'5')] - >>> df.select(df.age.cast(StringType()).alias('ages')).collect() - [Row(ages=u'2'), Row(ages=u'5')] - """ - if isinstance(dataType, basestring): - jc = self._jc.cast(dataType) - elif isinstance(dataType, DataType): - sc = SparkContext._active_spark_context - ssql_ctx = sc._jvm.SQLContext(sc._jsc.sc()) - jdt = ssql_ctx.parseDataType(dataType.json()) - jc = self._jc.cast(jdt) - else: - raise TypeError("unexpected type: %s" % type(dataType)) - return Column(jc) - - @ignore_unicode_prefix - def between(self, lowerBound, upperBound): - """ A boolean expression that is evaluated to true if the value of this - expression is between the given columns. - """ - return (self >= lowerBound) & (self <= upperBound) - - @ignore_unicode_prefix - def when(self, condition, value): - """Evaluates a list of conditions and returns one of multiple possible result expressions. - If :func:`Column.otherwise` is not invoked, None is returned for unmatched conditions. - - See :func:`pyspark.sql.functions.when` for example usage. - - :param condition: a boolean :class:`Column` expression. - :param value: a literal value, or a :class:`Column` expression. - - """ - sc = SparkContext._active_spark_context - if not isinstance(condition, Column): - raise TypeError("condition should be a Column") - v = value._jc if isinstance(value, Column) else value - jc = sc._jvm.functions.when(condition._jc, v) - return Column(jc) - - @ignore_unicode_prefix - def otherwise(self, value): - """Evaluates a list of conditions and returns one of multiple possible result expressions. - If :func:`Column.otherwise` is not invoked, None is returned for unmatched conditions. - - See :func:`pyspark.sql.functions.when` for example usage. - - :param value: a literal value, or a :class:`Column` expression. - """ - v = value._jc if isinstance(value, Column) else value - jc = self._jc.otherwise(value) - return Column(jc) - - def __repr__(self): - return 'Column<%s>' % self._jc.toString().encode('utf8') - - class DataFrameNaFunctions(object): """Functionality for working with missing data in :class:`DataFrame`. + + .. versionadded:: 1.4 """ def __init__(self, df): @@ -1604,9 +1289,16 @@ def fill(self, value, subset=None): fill.__doc__ = DataFrame.fillna.__doc__ + def replace(self, to_replace, value, subset=None): + return self.df.replace(to_replace, value, subset) + + replace.__doc__ = DataFrame.replace.__doc__ + class DataFrameStatFunctions(object): """Functionality for statistic functions with :class:`DataFrame`. + + .. versionadded:: 1.4 """ def __init__(self, df): @@ -1646,9 +1338,8 @@ def _test(): .toDF(StructType([StructField('age', IntegerType()), StructField('name', StringType())])) globs['df2'] = sc.parallelize([Row(name='Tom', height=80), Row(name='Bob', height=85)]).toDF() - globs['df3'] = sc.parallelize([Row(name='Alice', age=2, height=80), - Row(name='Bob', age=5, height=85)]).toDF() - + globs['df3'] = sc.parallelize([Row(name='Alice', age=2), + Row(name='Bob', age=5)]).toDF() globs['df4'] = sc.parallelize([Row(name='Alice', age=10, height=80), Row(name='Bob', age=5, height=None), Row(name='Tom', age=None, height=None), diff --git a/python/pyspark/sql/functions.py b/python/pyspark/sql/functions.py index 6cd6974b0e5b..dca39fa83343 100644 --- a/python/pyspark/sql/functions.py +++ b/python/pyspark/sql/functions.py @@ -18,6 +18,7 @@ """ A collections of builtin functions """ +import math import sys if sys.version < "3": @@ -26,21 +27,33 @@ from pyspark import SparkContext from pyspark.rdd import _prepare_for_python_RDD, ignore_unicode_prefix from pyspark.serializers import PickleSerializer, AutoBatchedSerializer +from pyspark.sql import since from pyspark.sql.types import StringType -from pyspark.sql.dataframe import Column, _to_java_column, _to_seq +from pyspark.sql.column import Column, _to_java_column, _to_seq __all__ = [ + 'array', 'approxCountDistinct', + 'bin', 'coalesce', 'countDistinct', + 'explode', + 'log2', + 'md5', 'monotonicallyIncreasingId', 'rand', 'randn', + 'sha1', + 'sha2', 'sparkPartitionId', + 'strlen', + 'struct', 'udf', 'when'] +__all__ += ['lag', 'lead', 'ntile'] + def _create_function(name, doc=""): """ Create a function for aggregator by name""" @@ -66,6 +79,17 @@ def _(col1, col2): return _ +def _create_window_function(name, doc=''): + """ Create a window function by name """ + def _(): + sc = SparkContext._active_spark_context + jc = getattr(sc._jvm.functions, name)() + return Column(jc) + _.__name__ = name + _.__doc__ = 'Window function: ' + doc + return _ + + _functions = { 'lit': 'Creates a :class:`Column` of literal value.', 'col': 'Returns a :class:`Column` based on the given column name.', @@ -78,6 +102,18 @@ def _(col1, col2): 'sqrt': 'Computes the square root of the specified float value.', 'abs': 'Computes the absolute value.', + 'max': 'Aggregate function: returns the maximum value of the expression in a group.', + 'min': 'Aggregate function: returns the minimum value of the expression in a group.', + 'first': 'Aggregate function: returns the first value in a group.', + 'last': 'Aggregate function: returns the last value in a group.', + 'count': 'Aggregate function: returns the number of items in a group.', + 'sum': 'Aggregate function: returns the sum of all values in the expression.', + 'avg': 'Aggregate function: returns the average of the values in a group.', + 'mean': 'Aggregate function: returns the average of the values in a group.', + 'sumDistinct': 'Aggregate function: returns the sum of distinct values in the expression.', +} + +_functions_1_4 = { # unary math functions 'acos': 'Computes the cosine inverse of the given value; the returned angle is in the range' + '0.0 through pi.', @@ -102,21 +138,11 @@ def _(col1, col2): 'tan': 'Computes the tangent of the given value.', 'tanh': 'Computes the hyperbolic tangent of the given value.', 'toDegrees': 'Converts an angle measured in radians to an approximately equivalent angle ' + - 'measured in degrees.', + 'measured in degrees.', 'toRadians': 'Converts an angle measured in degrees to an approximately equivalent angle ' + - 'measured in radians.', + 'measured in radians.', 'bitwiseNOT': 'Computes bitwise not.', - - 'max': 'Aggregate function: returns the maximum value of the expression in a group.', - 'min': 'Aggregate function: returns the minimum value of the expression in a group.', - 'first': 'Aggregate function: returns the first value in a group.', - 'last': 'Aggregate function: returns the last value in a group.', - 'count': 'Aggregate function: returns the number of items in a group.', - 'sum': 'Aggregate function: returns the sum of all values in the expression.', - 'avg': 'Aggregate function: returns the average of the values in a group.', - 'mean': 'Aggregate function: returns the average of the values in a group.', - 'sumDistinct': 'Aggregate function: returns the sum of distinct values in the expression.', } # math functions that take two arguments as input @@ -124,19 +150,60 @@ def _(col1, col2): 'atan2': 'Returns the angle theta from the conversion of rectangular coordinates (x, y) to' + 'polar coordinates (r, theta).', 'hypot': 'Computes `sqrt(a^2^ + b^2^)` without intermediate overflow or underflow.', - 'pow': 'Returns the value of the first argument raised to the power of the second argument.' + 'pow': 'Returns the value of the first argument raised to the power of the second argument.', +} + +_window_functions = { + 'rowNumber': + """returns a sequential number starting at 1 within a window partition. + + This is equivalent to the ROW_NUMBER function in SQL.""", + 'denseRank': + """returns the rank of rows within a window partition, without any gaps. + + The difference between rank and denseRank is that denseRank leaves no gaps in ranking + sequence when there are ties. That is, if you were ranking a competition using denseRank + and had three people tie for second place, you would say that all three were in second + place and that the next person came in third. + + This is equivalent to the DENSE_RANK function in SQL.""", + 'rank': + """returns the rank of rows within a window partition. + + The difference between rank and denseRank is that denseRank leaves no gaps in ranking + sequence when there are ties. That is, if you were ranking a competition using denseRank + and had three people tie for second place, you would say that all three were in second + place and that the next person came in third. + + This is equivalent to the RANK function in SQL.""", + 'cumeDist': + """returns the cumulative distribution of values within a window partition, + i.e. the fraction of rows that are below the current row. + + This is equivalent to the CUME_DIST function in SQL.""", + 'percentRank': + """returns the relative rank (i.e. percentile) of rows within a window partition. + + This is equivalent to the PERCENT_RANK function in SQL.""", } for _name, _doc in _functions.items(): - globals()[_name] = _create_function(_name, _doc) + globals()[_name] = since(1.3)(_create_function(_name, _doc)) +for _name, _doc in _functions_1_4.items(): + globals()[_name] = since(1.4)(_create_function(_name, _doc)) for _name, _doc in _binary_mathfunctions.items(): - globals()[_name] = _create_binary_mathfunction(_name, _doc) + globals()[_name] = since(1.4)(_create_binary_mathfunction(_name, _doc)) +for _name, _doc in _window_functions.items(): + globals()[_name] = since(1.4)(_create_window_function(_name, _doc)) del _name, _doc __all__ += _functions.keys() +__all__ += _functions_1_4.keys() __all__ += _binary_mathfunctions.keys() +__all__ += _window_functions.keys() __all__.sort() +@since(1.4) def array(*cols): """Creates a new array column. @@ -155,6 +222,7 @@ def array(*cols): return Column(jc) +@since(1.3) def approxCountDistinct(col, rsd=None): """Returns a new :class:`Column` for approximate distinct count of ``col``. @@ -169,26 +237,20 @@ def approxCountDistinct(col, rsd=None): return Column(jc) -def explode(col): - """Returns a new row for each element in the given array or map. - - >>> from pyspark.sql import Row - >>> eDF = sqlContext.createDataFrame([Row(a=1, intlist=[1,2,3], mapfield={"a": "b"})]) - >>> eDF.select(explode(eDF.intlist).alias("anInt")).collect() - [Row(anInt=1), Row(anInt=2), Row(anInt=3)] +@ignore_unicode_prefix +@since(1.5) +def bin(col): + """Returns the string representation of the binary value of the given column. - >>> eDF.select(explode(eDF.mapfield).alias("key", "value")).show() - +---+-----+ - |key|value| - +---+-----+ - | a| b| - +---+-----+ + >>> df.select(bin(df.age).alias('c')).collect() + [Row(c=u'10'), Row(c=u'101')] """ sc = SparkContext._active_spark_context - jc = sc._jvm.functions.explode(_to_java_column(col)) + jc = sc._jvm.functions.bin(_to_java_column(col)) return Column(jc) +@since(1.4) def coalesce(*cols): """Returns the first column that is not null. @@ -204,7 +266,7 @@ def coalesce(*cols): >>> cDf.select(coalesce(cDf["a"], cDf["b"])).show() +-------------+ - |Coalesce(a,b)| + |coalesce(a,b)| +-------------+ | null| | 1| @@ -213,7 +275,7 @@ def coalesce(*cols): >>> cDf.select('*', coalesce(cDf["a"], lit(0.0))).show() +----+----+---------------+ - | a| b|Coalesce(a,0.0)| + | a| b|coalesce(a,0.0)| +----+----+---------------+ |null|null| 0.0| | 1|null| 1.0| @@ -225,6 +287,7 @@ def coalesce(*cols): return Column(jc) +@since(1.3) def countDistinct(col, *cols): """Returns a new :class:`Column` for distinct count of ``col`` or ``cols``. @@ -239,6 +302,55 @@ def countDistinct(col, *cols): return Column(jc) +@since(1.4) +def explode(col): + """Returns a new row for each element in the given array or map. + + >>> from pyspark.sql import Row + >>> eDF = sqlContext.createDataFrame([Row(a=1, intlist=[1,2,3], mapfield={"a": "b"})]) + >>> eDF.select(explode(eDF.intlist).alias("anInt")).collect() + [Row(anInt=1), Row(anInt=2), Row(anInt=3)] + + >>> eDF.select(explode(eDF.mapfield).alias("key", "value")).show() + +---+-----+ + |key|value| + +---+-----+ + | a| b| + +---+-----+ + """ + sc = SparkContext._active_spark_context + jc = sc._jvm.functions.explode(_to_java_column(col)) + return Column(jc) + + +@ignore_unicode_prefix +@since(1.5) +def levenshtein(left, right): + """Computes the Levenshtein distance of the two given strings. + + >>> df0 = sqlContext.createDataFrame([('kitten', 'sitting',)], ['l', 'r']) + >>> df0.select(levenshtein('l', 'r').alias('d')).collect() + [Row(d=3)] + """ + sc = SparkContext._active_spark_context + jc = sc._jvm.functions.levenshtein(_to_java_column(left), _to_java_column(right)) + return Column(jc) + + +@ignore_unicode_prefix +@since(1.5) +def md5(col): + """Calculates the MD5 digest and returns the value as a 32 character hex string. + + >>> sqlContext.createDataFrame([('ABC',)], ['a']).select(md5('a').alias('hash')).collect() + [Row(hash=u'902fbdd2b1df0c4f70b4a5d23525e932')] + """ + sc = SparkContext._active_spark_context + jc = sc._jvm.functions.md5(_to_java_column(col)) + return Column(jc) + + +@since(1.4) def monotonicallyIncreasingId(): """A column that generates monotonically increasing 64-bit integers. @@ -247,7 +359,7 @@ def monotonicallyIncreasingId(): within each partition in the lower 33 bits. The assumption is that the data frame has less than 1 billion partitions, and each partition has less than 8 billion records. - As an example, consider a [[DataFrame]] with two partitions, each with 3 records. + As an example, consider a :class:`DataFrame` with two partitions, each with 3 records. This expression would return the following IDs: 0, 1, 2, 8589934592 (1L << 33), 8589934593, 8589934594. @@ -259,6 +371,7 @@ def monotonicallyIncreasingId(): return Column(sc._jvm.functions.monotonicallyIncreasingId()) +@since(1.4) def rand(seed=None): """Generates a random column with i.i.d. samples from U[0.0, 1.0]. """ @@ -270,6 +383,7 @@ def rand(seed=None): return Column(jc) +@since(1.4) def randn(seed=None): """Generates a column with i.i.d. samples from the standard normal distribution. """ @@ -281,6 +395,103 @@ def randn(seed=None): return Column(jc) +@ignore_unicode_prefix +@since(1.5) +def hex(col): + """Computes hex value of the given column, which could be StringType, + BinaryType, IntegerType or LongType. + + >>> sqlContext.createDataFrame([('ABC', 3)], ['a', 'b']).select(hex('a'), hex('b')).collect() + [Row(hex(a)=u'414243', hex(b)=u'3')] + """ + sc = SparkContext._active_spark_context + jc = sc._jvm.functions.hex(_to_java_column(col)) + return Column(jc) + + +@ignore_unicode_prefix +@since(1.5) +def unhex(col): + """Inverse of hex. Interprets each pair of characters as a hexadecimal number + and converts to the byte representation of number. + + >>> sqlContext.createDataFrame([('414243',)], ['a']).select(unhex('a')).collect() + [Row(unhex(a)=bytearray(b'ABC'))] + """ + sc = SparkContext._active_spark_context + jc = sc._jvm.functions.unhex(_to_java_column(col)) + return Column(jc) + + +@ignore_unicode_prefix +@since(1.5) +def sha1(col): + """Returns the hex string result of SHA-1. + + >>> sqlContext.createDataFrame([('ABC',)], ['a']).select(sha1('a').alias('hash')).collect() + [Row(hash=u'3c01bdbb26f358bab27f267924aa2c9a03fcfdb8')] + """ + sc = SparkContext._active_spark_context + jc = sc._jvm.functions.sha1(_to_java_column(col)) + return Column(jc) + + +@ignore_unicode_prefix +@since(1.5) +def sha2(col, numBits): + """Returns the hex string result of SHA-2 family of hash functions (SHA-224, SHA-256, SHA-384, + and SHA-512). The numBits indicates the desired bit length of the result, which must have a + value of 224, 256, 384, 512, or 0 (which is equivalent to 256). + + >>> digests = df.select(sha2(df.name, 256).alias('s')).collect() + >>> digests[0] + Row(s=u'3bc51062973c458d5a6f2d8d64a023246354ad7e064b1e4e009ec8a0699a3043') + >>> digests[1] + Row(s=u'cd9fb1e148ccd8442e5aa74904cc73bf6fb54d1d54d333bd596aa9bb4bb4e961') + """ + sc = SparkContext._active_spark_context + jc = sc._jvm.functions.sha2(_to_java_column(col), numBits) + return Column(jc) + + +@since(1.5) +def shiftLeft(col, numBits): + """Shift the the given value numBits left. + + >>> sqlContext.createDataFrame([(21,)], ['a']).select(shiftLeft('a', 1).alias('r')).collect() + [Row(r=42)] + """ + sc = SparkContext._active_spark_context + jc = sc._jvm.functions.shiftLeft(_to_java_column(col), numBits) + return Column(jc) + + +@since(1.5) +def shiftRight(col, numBits): + """Shift the the given value numBits right. + + >>> sqlContext.createDataFrame([(42,)], ['a']).select(shiftRight('a', 1).alias('r')).collect() + [Row(r=21)] + """ + sc = SparkContext._active_spark_context + jc = sc._jvm.functions.shiftRight(_to_java_column(col), numBits) + return Column(jc) + + +@since(1.5) +def shiftRightUnsigned(col, numBits): + """Unsigned shift the the given value numBits right. + + >>> sqlContext.createDataFrame([(-42,)], ['a']).select(shiftRightUnsigned('a', 1).alias('r'))\ + .collect() + [Row(r=9223372036854775787)] + """ + sc = SparkContext._active_spark_context + jc = sc._jvm.functions.shiftRightUnsigned(_to_java_column(col), numBits) + return Column(jc) + + +@since(1.4) def sparkPartitionId(): """A column for partition ID of the Spark task. @@ -294,11 +505,23 @@ def sparkPartitionId(): @ignore_unicode_prefix +@since(1.5) +def strlen(col): + """Calculates the length of a string expression. + + >>> sqlContext.createDataFrame([('ABC',)], ['a']).select(strlen('a').alias('length')).collect() + [Row(length=3)] + """ + sc = SparkContext._active_spark_context + return Column(sc._jvm.functions.strlen(_to_java_column(col))) + + +@ignore_unicode_prefix +@since(1.4) def struct(*cols): """Creates a new struct column. :param cols: list of column names (string) or list of :class:`Column` expressions - that are named or aliased. >>> df.select(struct('age', 'name').alias("struct")).collect() [Row(struct=Row(age=2, name=u'Alice')), Row(struct=Row(age=5, name=u'Bob'))] @@ -312,6 +535,7 @@ def struct(*cols): return Column(jc) +@since(1.4) def when(condition, value): """Evaluates a list of conditions and returns one of multiple possible result expressions. If :func:`Column.otherwise` is not invoked, None is returned for unmatched conditions. @@ -333,9 +557,91 @@ def when(condition, value): return Column(jc) +@since(1.5) +def log(arg1, arg2=None): + """Returns the first argument-based logarithm of the second argument. + + If there is only one argument, then this takes the natural logarithm of the argument. + + >>> df.select(log(10.0, df.age).alias('ten')).map(lambda l: str(l.ten)[:7]).collect() + ['0.30102', '0.69897'] + + >>> df.select(log(df.age).alias('e')).map(lambda l: str(l.e)[:7]).collect() + ['0.69314', '1.60943'] + """ + sc = SparkContext._active_spark_context + if arg2 is None: + jc = sc._jvm.functions.log(_to_java_column(arg1)) + else: + jc = sc._jvm.functions.log(arg1, _to_java_column(arg2)) + return Column(jc) + + +@since(1.5) +def log2(col): + """Returns the base-2 logarithm of the argument. + + >>> sqlContext.createDataFrame([(4,)], ['a']).select(log2('a').alias('log2')).collect() + [Row(log2=2.0)] + """ + sc = SparkContext._active_spark_context + return Column(sc._jvm.functions.log2(_to_java_column(col))) + + +@since(1.4) +def lag(col, count=1, default=None): + """ + Window function: returns the value that is `offset` rows before the current row, and + `defaultValue` if there is less than `offset` rows before the current row. For example, + an `offset` of one will return the previous row at any given point in the window partition. + + This is equivalent to the LAG function in SQL. + + :param col: name of column or expression + :param count: number of row to extend + :param default: default value + """ + sc = SparkContext._active_spark_context + return Column(sc._jvm.functions.lag(_to_java_column(col), count, default)) + + +@since(1.4) +def lead(col, count=1, default=None): + """ + Window function: returns the value that is `offset` rows after the current row, and + `defaultValue` if there is less than `offset` rows after the current row. For example, + an `offset` of one will return the next row at any given point in the window partition. + + This is equivalent to the LEAD function in SQL. + + :param col: name of column or expression + :param count: number of row to extend + :param default: default value + """ + sc = SparkContext._active_spark_context + return Column(sc._jvm.functions.lead(_to_java_column(col), count, default)) + + +@since(1.4) +def ntile(n): + """ + Window function: returns a group id from 1 to `n` (inclusive) in a round-robin fashion in + a window partition. Fow example, if `n` is 3, the first row will get 1, the second row will + get 2, the third row will get 3, and the fourth row will get 1... + + This is equivalent to the NTILE function in SQL. + + :param n: an integer + """ + sc = SparkContext._active_spark_context + return Column(sc._jvm.functions.ntile(int(n))) + + class UserDefinedFunction(object): """ User defined function in Python + + .. versionadded:: 1.3 """ def __init__(self, func, returnType): self.func = func @@ -353,8 +659,8 @@ def _create_judf(self): ssql_ctx = sc._jvm.SQLContext(sc._jsc.sc()) jdt = ssql_ctx.parseDataType(self.returnType.json()) fname = f.__name__ if hasattr(f, '__name__') else f.__class__.__name__ - judf = sc._jvm.UserDefinedPythonFunction(fname, bytearray(pickled_command), env, - includes, sc.pythonExec, broadcast_vars, + judf = sc._jvm.UserDefinedPythonFunction(fname, bytearray(pickled_command), env, includes, + sc.pythonExec, sc.pythonVer, broadcast_vars, sc._javaAccumulator, jdt) return judf @@ -369,6 +675,7 @@ def __call__(self, *cols): return Column(jc) +@since(1.3) def udf(f, returnType=StringType()): """Creates a :class:`Column` expression representing a user defined function (UDF). diff --git a/python/pyspark/sql/group.py b/python/pyspark/sql/group.py new file mode 100644 index 000000000000..04594d5a836c --- /dev/null +++ b/python/pyspark/sql/group.py @@ -0,0 +1,195 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +from pyspark.rdd import ignore_unicode_prefix +from pyspark.sql import since +from pyspark.sql.column import Column, _to_seq +from pyspark.sql.dataframe import DataFrame +from pyspark.sql.types import * + +__all__ = ["GroupedData"] + + +def dfapi(f): + def _api(self): + name = f.__name__ + jdf = getattr(self._jdf, name)() + return DataFrame(jdf, self.sql_ctx) + _api.__name__ = f.__name__ + _api.__doc__ = f.__doc__ + return _api + + +def df_varargs_api(f): + def _api(self, *args): + name = f.__name__ + jdf = getattr(self._jdf, name)(_to_seq(self.sql_ctx._sc, args)) + return DataFrame(jdf, self.sql_ctx) + _api.__name__ = f.__name__ + _api.__doc__ = f.__doc__ + return _api + + +class GroupedData(object): + """ + A set of methods for aggregations on a :class:`DataFrame`, + created by :func:`DataFrame.groupBy`. + + .. note:: Experimental + + .. versionadded:: 1.3 + """ + + def __init__(self, jdf, sql_ctx): + self._jdf = jdf + self.sql_ctx = sql_ctx + + @ignore_unicode_prefix + @since(1.3) + def agg(self, *exprs): + """Compute aggregates and returns the result as a :class:`DataFrame`. + + The available aggregate functions are `avg`, `max`, `min`, `sum`, `count`. + + If ``exprs`` is a single :class:`dict` mapping from string to string, then the key + is the column to perform aggregation on, and the value is the aggregate function. + + Alternatively, ``exprs`` can also be a list of aggregate :class:`Column` expressions. + + :param exprs: a dict mapping from column name (string) to aggregate functions (string), + or a list of :class:`Column`. + + >>> gdf = df.groupBy(df.name) + >>> gdf.agg({"*": "count"}).collect() + [Row(name=u'Alice', count(1)=1), Row(name=u'Bob', count(1)=1)] + + >>> from pyspark.sql import functions as F + >>> gdf.agg(F.min(df.age)).collect() + [Row(name=u'Alice', min(age)=2), Row(name=u'Bob', min(age)=5)] + """ + assert exprs, "exprs should not be empty" + if len(exprs) == 1 and isinstance(exprs[0], dict): + jdf = self._jdf.agg(exprs[0]) + else: + # Columns + assert all(isinstance(c, Column) for c in exprs), "all exprs should be Column" + jdf = self._jdf.agg(exprs[0]._jc, + _to_seq(self.sql_ctx._sc, [c._jc for c in exprs[1:]])) + return DataFrame(jdf, self.sql_ctx) + + @dfapi + @since(1.3) + def count(self): + """Counts the number of records for each group. + + >>> df.groupBy(df.age).count().collect() + [Row(age=2, count=1), Row(age=5, count=1)] + """ + + @df_varargs_api + @since(1.3) + def mean(self, *cols): + """Computes average values for each numeric columns for each group. + + :func:`mean` is an alias for :func:`avg`. + + :param cols: list of column names (string). Non-numeric columns are ignored. + + >>> df.groupBy().mean('age').collect() + [Row(avg(age)=3.5)] + >>> df3.groupBy().mean('age', 'height').collect() + [Row(avg(age)=3.5, avg(height)=82.5)] + """ + + @df_varargs_api + @since(1.3) + def avg(self, *cols): + """Computes average values for each numeric columns for each group. + + :func:`mean` is an alias for :func:`avg`. + + :param cols: list of column names (string). Non-numeric columns are ignored. + + >>> df.groupBy().avg('age').collect() + [Row(avg(age)=3.5)] + >>> df3.groupBy().avg('age', 'height').collect() + [Row(avg(age)=3.5, avg(height)=82.5)] + """ + + @df_varargs_api + @since(1.3) + def max(self, *cols): + """Computes the max value for each numeric columns for each group. + + >>> df.groupBy().max('age').collect() + [Row(max(age)=5)] + >>> df3.groupBy().max('age', 'height').collect() + [Row(max(age)=5, max(height)=85)] + """ + + @df_varargs_api + @since(1.3) + def min(self, *cols): + """Computes the min value for each numeric column for each group. + + :param cols: list of column names (string). Non-numeric columns are ignored. + + >>> df.groupBy().min('age').collect() + [Row(min(age)=2)] + >>> df3.groupBy().min('age', 'height').collect() + [Row(min(age)=2, min(height)=80)] + """ + + @df_varargs_api + @since(1.3) + def sum(self, *cols): + """Compute the sum for each numeric columns for each group. + + :param cols: list of column names (string). Non-numeric columns are ignored. + + >>> df.groupBy().sum('age').collect() + [Row(sum(age)=7)] + >>> df3.groupBy().sum('age', 'height').collect() + [Row(sum(age)=7, sum(height)=165)] + """ + + +def _test(): + import doctest + from pyspark.context import SparkContext + from pyspark.sql import Row, SQLContext + import pyspark.sql.group + globs = pyspark.sql.group.__dict__.copy() + sc = SparkContext('local[4]', 'PythonTest') + globs['sc'] = sc + globs['sqlContext'] = SQLContext(sc) + globs['df'] = sc.parallelize([(2, 'Alice'), (5, 'Bob')]) \ + .toDF(StructType([StructField('age', IntegerType()), + StructField('name', StringType())])) + globs['df3'] = sc.parallelize([Row(name='Alice', age=2, height=80), + Row(name='Bob', age=5, height=85)]).toDF() + + (failure_count, test_count) = doctest.testmod( + pyspark.sql.group, globs=globs, + optionflags=doctest.ELLIPSIS | doctest.NORMALIZE_WHITESPACE | doctest.REPORT_NDIFF) + globs['sc'].stop() + if failure_count: + exit(-1) + + +if __name__ == "__main__": + _test() diff --git a/python/pyspark/sql/readwriter.py b/python/pyspark/sql/readwriter.py new file mode 100644 index 000000000000..882a03090ec1 --- /dev/null +++ b/python/pyspark/sql/readwriter.py @@ -0,0 +1,434 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +from py4j.java_gateway import JavaClass + +from pyspark.sql import since +from pyspark.sql.column import _to_seq +from pyspark.sql.types import * + +__all__ = ["DataFrameReader", "DataFrameWriter"] + + +class DataFrameReader(object): + """ + Interface used to load a :class:`DataFrame` from external storage systems + (e.g. file systems, key-value stores, etc). Use :func:`SQLContext.read` + to access this. + + ::Note: Experimental + + .. versionadded:: 1.4 + """ + + def __init__(self, sqlContext): + self._jreader = sqlContext._ssql_ctx.read() + self._sqlContext = sqlContext + + def _df(self, jdf): + from pyspark.sql.dataframe import DataFrame + return DataFrame(jdf, self._sqlContext) + + @since(1.4) + def format(self, source): + """Specifies the input data source format. + + :param source: string, name of the data source, e.g. 'json', 'parquet'. + + >>> df = sqlContext.read.format('json').load('python/test_support/sql/people.json') + >>> df.dtypes + [('age', 'bigint'), ('name', 'string')] + + """ + self._jreader = self._jreader.format(source) + return self + + @since(1.4) + def schema(self, schema): + """Specifies the input schema. + + Some data sources (e.g. JSON) can infer the input schema automatically from data. + By specifying the schema here, the underlying data source can skip the schema + inference step, and thus speed up data loading. + + :param schema: a StructType object + """ + if not isinstance(schema, StructType): + raise TypeError("schema should be StructType") + jschema = self._sqlContext._ssql_ctx.parseDataType(schema.json()) + self._jreader = self._jreader.schema(jschema) + return self + + @since(1.5) + def option(self, key, value): + """Adds an input option for the underlying data source. + """ + self._jreader = self._jreader.option(key, value) + return self + + @since(1.4) + def options(self, **options): + """Adds input options for the underlying data source. + """ + for k in options: + self._jreader = self._jreader.option(k, options[k]) + return self + + @since(1.4) + def load(self, path=None, format=None, schema=None, **options): + """Loads data from a data source and returns it as a :class`DataFrame`. + + :param path: optional string for file-system backed data sources. + :param format: optional string for format of the data source. Default to 'parquet'. + :param schema: optional :class:`StructType` for the input schema. + :param options: all other string options + + >>> df = sqlContext.read.load('python/test_support/sql/parquet_partitioned') + >>> df.dtypes + [('name', 'string'), ('year', 'int'), ('month', 'int'), ('day', 'int')] + """ + if format is not None: + self.format(format) + if schema is not None: + self.schema(schema) + self.options(**options) + if path is not None: + return self._df(self._jreader.load(path)) + else: + return self._df(self._jreader.load()) + + @since(1.4) + def json(self, path, schema=None): + """ + Loads a JSON file (one object per line) and returns the result as + a :class`DataFrame`. + + If the ``schema`` parameter is not specified, this function goes + through the input once to determine the input schema. + + :param path: string, path to the JSON dataset. + :param schema: an optional :class:`StructType` for the input schema. + + >>> df = sqlContext.read.json('python/test_support/sql/people.json') + >>> df.dtypes + [('age', 'bigint'), ('name', 'string')] + + """ + if schema is not None: + self.schema(schema) + return self._df(self._jreader.json(path)) + + @since(1.4) + def table(self, tableName): + """Returns the specified table as a :class:`DataFrame`. + + :param tableName: string, name of the table. + + >>> df = sqlContext.read.parquet('python/test_support/sql/parquet_partitioned') + >>> df.registerTempTable('tmpTable') + >>> sqlContext.read.table('tmpTable').dtypes + [('name', 'string'), ('year', 'int'), ('month', 'int'), ('day', 'int')] + """ + return self._df(self._jreader.table(tableName)) + + @since(1.4) + def parquet(self, *path): + """Loads a Parquet file, returning the result as a :class:`DataFrame`. + + >>> df = sqlContext.read.parquet('python/test_support/sql/parquet_partitioned') + >>> df.dtypes + [('name', 'string'), ('year', 'int'), ('month', 'int'), ('day', 'int')] + """ + return self._df(self._jreader.parquet(_to_seq(self._sqlContext._sc, path))) + + @since(1.4) + def jdbc(self, url, table, column=None, lowerBound=None, upperBound=None, numPartitions=None, + predicates=None, properties={}): + """ + Construct a :class:`DataFrame` representing the database table accessible + via JDBC URL `url` named `table` and connection `properties`. + + The `column` parameter could be used to partition the table, then it will + be retrieved in parallel based on the parameters passed to this function. + + The `predicates` parameter gives a list expressions suitable for inclusion + in WHERE clauses; each one defines one partition of the :class:`DataFrame`. + + ::Note: Don't create too many partitions in parallel on a large cluster; + otherwise Spark might crash your external database systems. + + :param url: a JDBC URL + :param table: name of table + :param column: the column used to partition + :param lowerBound: the lower bound of partition column + :param upperBound: the upper bound of the partition column + :param numPartitions: the number of partitions + :param predicates: a list of expressions + :param properties: JDBC database connection arguments, a list of arbitrary string + tag/value. Normally at least a "user" and "password" property + should be included. + :return: a DataFrame + """ + jprop = JavaClass("java.util.Properties", self._sqlContext._sc._gateway._gateway_client)() + for k in properties: + jprop.setProperty(k, properties[k]) + if column is not None: + if numPartitions is None: + numPartitions = self._sqlContext._sc.defaultParallelism + return self._df(self._jreader.jdbc(url, table, column, int(lowerBound), int(upperBound), + int(numPartitions), jprop)) + if predicates is not None: + arr = self._sqlContext._sc._jvm.PythonUtils.toArray(predicates) + return self._df(self._jreader.jdbc(url, table, arr, jprop)) + return self._df(self._jreader.jdbc(url, table, jprop)) + + +class DataFrameWriter(object): + """ + Interface used to write a [[DataFrame]] to external storage systems + (e.g. file systems, key-value stores, etc). Use :func:`DataFrame.write` + to access this. + + ::Note: Experimental + + .. versionadded:: 1.4 + """ + def __init__(self, df): + self._df = df + self._sqlContext = df.sql_ctx + self._jwrite = df._jdf.write() + + @since(1.4) + def mode(self, saveMode): + """Specifies the behavior when data or table already exists. + + Options include: + + * `append`: Append contents of this :class:`DataFrame` to existing data. + * `overwrite`: Overwrite existing data. + * `error`: Throw an exception if data already exists. + * `ignore`: Silently ignore this operation if data already exists. + + >>> df.write.mode('append').parquet(os.path.join(tempfile.mkdtemp(), 'data')) + """ + # At the JVM side, the default value of mode is already set to "error". + # So, if the given saveMode is None, we will not call JVM-side's mode method. + if saveMode is not None: + self._jwrite = self._jwrite.mode(saveMode) + return self + + @since(1.4) + def format(self, source): + """Specifies the underlying output data source. + + :param source: string, name of the data source, e.g. 'json', 'parquet'. + + >>> df.write.format('json').save(os.path.join(tempfile.mkdtemp(), 'data')) + """ + self._jwrite = self._jwrite.format(source) + return self + + @since(1.5) + def option(self, key, value): + """Adds an output option for the underlying data source. + """ + self._jwrite = self._jwrite.option(key, value) + return self + + @since(1.4) + def options(self, **options): + """Adds output options for the underlying data source. + """ + for k in options: + self._jwrite = self._jwrite.option(k, options[k]) + return self + + @since(1.4) + def partitionBy(self, *cols): + """Partitions the output by the given columns on the file system. + + If specified, the output is laid out on the file system similar + to Hive's partitioning scheme. + + :param cols: name of columns + + >>> df.write.partitionBy('year', 'month').parquet(os.path.join(tempfile.mkdtemp(), 'data')) + """ + if len(cols) == 1 and isinstance(cols[0], (list, tuple)): + cols = cols[0] + self._jwrite = self._jwrite.partitionBy(_to_seq(self._sqlContext._sc, cols)) + return self + + @since(1.4) + def save(self, path=None, format=None, mode=None, partitionBy=None, **options): + """Saves the contents of the :class:`DataFrame` to a data source. + + The data source is specified by the ``format`` and a set of ``options``. + If ``format`` is not specified, the default data source configured by + ``spark.sql.sources.default`` will be used. + + :param path: the path in a Hadoop supported file system + :param format: the format used to save + :param mode: specifies the behavior of the save operation when data already exists. + + * ``append``: Append contents of this :class:`DataFrame` to existing data. + * ``overwrite``: Overwrite existing data. + * ``ignore``: Silently ignore this operation if data already exists. + * ``error`` (default case): Throw an exception if data already exists. + :param partitionBy: names of partitioning columns + :param options: all other string options + + >>> df.write.mode('append').parquet(os.path.join(tempfile.mkdtemp(), 'data')) + """ + self.mode(mode).options(**options) + if partitionBy is not None: + self.partitionBy(partitionBy) + if format is not None: + self.format(format) + if path is None: + self._jwrite.save() + else: + self._jwrite.save(path) + + @since(1.4) + def insertInto(self, tableName, overwrite=False): + """Inserts the content of the :class:`DataFrame` to the specified table. + + It requires that the schema of the class:`DataFrame` is the same as the + schema of the table. + + Optionally overwriting any existing data. + """ + self._jwrite.mode("overwrite" if overwrite else "append").insertInto(tableName) + + @since(1.4) + def saveAsTable(self, name, format=None, mode=None, partitionBy=None, **options): + """Saves the content of the :class:`DataFrame` as the specified table. + + In the case the table already exists, behavior of this function depends on the + save mode, specified by the `mode` function (default to throwing an exception). + When `mode` is `Overwrite`, the schema of the [[DataFrame]] does not need to be + the same as that of the existing table. + + * `append`: Append contents of this :class:`DataFrame` to existing data. + * `overwrite`: Overwrite existing data. + * `error`: Throw an exception if data already exists. + * `ignore`: Silently ignore this operation if data already exists. + + :param name: the table name + :param format: the format used to save + :param mode: one of `append`, `overwrite`, `error`, `ignore` (default: error) + :param partitionBy: names of partitioning columns + :param options: all other string options + """ + self.mode(mode).options(**options) + if partitionBy is not None: + self.partitionBy(partitionBy) + if format is not None: + self.format(format) + self._jwrite.saveAsTable(name) + + @since(1.4) + def json(self, path, mode=None): + """Saves the content of the :class:`DataFrame` in JSON format at the specified path. + + :param path: the path in any Hadoop supported file system + :param mode: specifies the behavior of the save operation when data already exists. + + * ``append``: Append contents of this :class:`DataFrame` to existing data. + * ``overwrite``: Overwrite existing data. + * ``ignore``: Silently ignore this operation if data already exists. + * ``error`` (default case): Throw an exception if data already exists. + + >>> df.write.json(os.path.join(tempfile.mkdtemp(), 'data')) + """ + self.mode(mode)._jwrite.json(path) + + @since(1.4) + def parquet(self, path, mode=None, partitionBy=None): + """Saves the content of the :class:`DataFrame` in Parquet format at the specified path. + + :param path: the path in any Hadoop supported file system + :param mode: specifies the behavior of the save operation when data already exists. + + * ``append``: Append contents of this :class:`DataFrame` to existing data. + * ``overwrite``: Overwrite existing data. + * ``ignore``: Silently ignore this operation if data already exists. + * ``error`` (default case): Throw an exception if data already exists. + :param partitionBy: names of partitioning columns + + >>> df.write.parquet(os.path.join(tempfile.mkdtemp(), 'data')) + """ + self.mode(mode) + if partitionBy is not None: + self.partitionBy(partitionBy) + self._jwrite.parquet(path) + + @since(1.4) + def jdbc(self, url, table, mode=None, properties={}): + """Saves the content of the :class:`DataFrame` to a external database table via JDBC. + + .. note:: Don't create too many partitions in parallel on a large cluster;\ + otherwise Spark might crash your external database systems. + + :param url: a JDBC URL of the form ``jdbc:subprotocol:subname`` + :param table: Name of the table in the external database. + :param mode: specifies the behavior of the save operation when data already exists. + + * ``append``: Append contents of this :class:`DataFrame` to existing data. + * ``overwrite``: Overwrite existing data. + * ``ignore``: Silently ignore this operation if data already exists. + * ``error`` (default case): Throw an exception if data already exists. + :param properties: JDBC database connection arguments, a list of + arbitrary string tag/value. Normally at least a + "user" and "password" property should be included. + """ + jprop = JavaClass("java.util.Properties", self._sqlContext._sc._gateway._gateway_client)() + for k in properties: + jprop.setProperty(k, properties[k]) + self._jwrite.mode(mode).jdbc(url, table, jprop) + + +def _test(): + import doctest + import os + import tempfile + from pyspark.context import SparkContext + from pyspark.sql import Row, SQLContext + import pyspark.sql.readwriter + + os.chdir(os.environ["SPARK_HOME"]) + + globs = pyspark.sql.readwriter.__dict__.copy() + sc = SparkContext('local[4]', 'PythonTest') + + globs['tempfile'] = tempfile + globs['os'] = os + globs['sc'] = sc + globs['sqlContext'] = SQLContext(sc) + globs['df'] = globs['sqlContext'].read.parquet('python/test_support/sql/parquet_partitioned') + + (failure_count, test_count) = doctest.testmod( + pyspark.sql.readwriter, globs=globs, + optionflags=doctest.ELLIPSIS | doctest.NORMALIZE_WHITESPACE | doctest.REPORT_NDIFF) + globs['sc'].stop() + if failure_count: + exit(-1) + + +if __name__ == "__main__": + _test() diff --git a/python/pyspark/sql/tests.py b/python/pyspark/sql/tests.py index d37c5dbed7f6..333378c7f185 100644 --- a/python/pyspark/sql/tests.py +++ b/python/pyspark/sql/tests.py @@ -1,3 +1,4 @@ +# -*- encoding: utf-8 -*- # # Licensed to the Apache Software Foundation (ASF) under one or more # contributor license agreements. See the NOTICE file distributed with @@ -26,6 +27,7 @@ import tempfile import pickle import functools +import time import datetime import py4j @@ -44,6 +46,22 @@ from pyspark.sql.types import UserDefinedType, _infer_type from pyspark.tests import ReusedPySparkTestCase from pyspark.sql.functions import UserDefinedFunction +from pyspark.sql.window import Window +from pyspark.sql.utils import AnalysisException + + +class UTC(datetime.tzinfo): + """UTC""" + ZERO = datetime.timedelta(0) + + def utcoffset(self, dt): + return self.ZERO + + def tzname(self, dt): + return "UTC" + + def dst(self, dt): + return self.ZERO class ExamplePointUDT(UserDefinedType): @@ -99,6 +117,15 @@ def test_data_type_eq(self): lt2 = pickle.loads(pickle.dumps(LongType())) self.assertEquals(lt, lt2) + # regression test for SPARK-7978 + def test_decimal_type(self): + t1 = DecimalType() + t2 = DecimalType(10, 2) + self.assertTrue(t2 is not t1) + self.assertNotEqual(t1, t2) + t3 = DecimalType(8) + self.assertNotEqual(t2, t3) + class SQLTests(ReusedPySparkTestCase): @@ -117,6 +144,13 @@ def tearDownClass(cls): ReusedPySparkTestCase.tearDownClass() shutil.rmtree(cls.tempdir.name, ignore_errors=True) + def test_range(self): + self.assertEqual(self.sqlCtx.range(1, 1).count(), 0) + self.assertEqual(self.sqlCtx.range(1, 0, -1).count(), 1) + self.assertEqual(self.sqlCtx.range(0, 1 << 40, 1 << 39).count(), 2) + self.assertEqual(self.sqlCtx.range(-2).count(), 0) + self.assertEqual(self.sqlCtx.range(3).count(), 3) + def test_explode(self): from pyspark.sql.functions import explode d = [Row(a=1, intlist=[1, 2, 3], mapfield={"a": "b"})] @@ -132,6 +166,14 @@ def test_explode(self): self.assertEqual(result[0][0], "a") self.assertEqual(result[0][1], "b") + def test_and_in_expression(self): + self.assertEqual(4, self.df.filter((self.df.key <= 10) & (self.df.value <= "2")).count()) + self.assertRaises(ValueError, lambda: (self.df.key <= 10) and (self.df.value <= "2")) + self.assertEqual(14, self.df.filter((self.df.key <= 3) | (self.df.value < "2")).count()) + self.assertRaises(ValueError, lambda: self.df.key <= 3 or self.df.value < "2") + self.assertEqual(99, self.df.filter(~(self.df.key == 1)).count()) + self.assertRaises(ValueError, lambda: not self.df.key == 1) + def test_udf_with_callable(self): d = [Row(number=i, squared=i**2) for i in range(10)] rdd = self.sc.parallelize(d) @@ -376,7 +418,7 @@ def test_column_operators(self): self.assertTrue(isinstance((- ci - 1 - 2) % 3 * 2.5 / 3.5, Column)) rcc = (1 + ci), (1 - ci), (1 * ci), (1 / ci), (1 % ci) self.assertTrue(all(isinstance(c, Column) for c in rcc)) - cb = [ci == 5, ci != 0, ci > 3, ci < 4, ci >= 0, ci <= 7, ci and cs, ci or cs] + cb = [ci == 5, ci != 0, ci > 3, ci < 4, ci >= 0, ci <= 7] self.assertTrue(all(isinstance(c, Column) for c in cb)) cbool = (ci & ci), (ci | ci), (~ci) self.assertTrue(all(isinstance(c, Column) for c in cbool)) @@ -476,33 +518,95 @@ def test_between_function(self): self.assertEqual([Row(a=2, b=1, c=3), Row(a=4, b=1, c=4)], df.filter(df.a.between(df.b, df.c)).collect()) + def test_struct_type(self): + from pyspark.sql.types import StructType, StringType, StructField + struct1 = StructType().add("f1", StringType(), True).add("f2", StringType(), True, None) + struct2 = StructType([StructField("f1", StringType(), True), + StructField("f2", StringType(), True, None)]) + self.assertEqual(struct1, struct2) + + struct1 = StructType().add("f1", StringType(), True).add("f2", StringType(), True, None) + struct2 = StructType([StructField("f1", StringType(), True)]) + self.assertNotEqual(struct1, struct2) + + struct1 = (StructType().add(StructField("f1", StringType(), True)) + .add(StructField("f2", StringType(), True, None))) + struct2 = StructType([StructField("f1", StringType(), True), + StructField("f2", StringType(), True, None)]) + self.assertEqual(struct1, struct2) + + struct1 = (StructType().add(StructField("f1", StringType(), True)) + .add(StructField("f2", StringType(), True, None))) + struct2 = StructType([StructField("f1", StringType(), True)]) + self.assertNotEqual(struct1, struct2) + + # Catch exception raised during improper construction + try: + struct1 = StructType().add("name") + self.assertEqual(1, 0) + except ValueError: + self.assertEqual(1, 1) + def test_save_and_load(self): df = self.df tmpPath = tempfile.mkdtemp() shutil.rmtree(tmpPath) - df.save(tmpPath, "org.apache.spark.sql.json", "error") - actual = self.sqlCtx.load(tmpPath, "org.apache.spark.sql.json") - self.assertTrue(sorted(df.collect()) == sorted(actual.collect())) + df.write.json(tmpPath) + actual = self.sqlCtx.read.json(tmpPath) + self.assertEqual(sorted(df.collect()), sorted(actual.collect())) schema = StructType([StructField("value", StringType(), True)]) - actual = self.sqlCtx.load(tmpPath, "org.apache.spark.sql.json", schema) - self.assertTrue(sorted(df.select("value").collect()) == sorted(actual.collect())) + actual = self.sqlCtx.read.json(tmpPath, schema) + self.assertEqual(sorted(df.select("value").collect()), sorted(actual.collect())) - df.save(tmpPath, "org.apache.spark.sql.json", "overwrite") - actual = self.sqlCtx.load(tmpPath, "org.apache.spark.sql.json") - self.assertTrue(sorted(df.collect()) == sorted(actual.collect())) + df.write.json(tmpPath, "overwrite") + actual = self.sqlCtx.read.json(tmpPath) + self.assertEqual(sorted(df.collect()), sorted(actual.collect())) - df.save(source="org.apache.spark.sql.json", mode="overwrite", path=tmpPath, - noUse="this options will not be used in save.") - actual = self.sqlCtx.load(source="org.apache.spark.sql.json", path=tmpPath, - noUse="this options will not be used in load.") - self.assertTrue(sorted(df.collect()) == sorted(actual.collect())) + df.write.save(format="json", mode="overwrite", path=tmpPath, + noUse="this options will not be used in save.") + actual = self.sqlCtx.read.load(format="json", path=tmpPath, + noUse="this options will not be used in load.") + self.assertEqual(sorted(df.collect()), sorted(actual.collect())) defaultDataSourceName = self.sqlCtx.getConf("spark.sql.sources.default", "org.apache.spark.sql.parquet") self.sqlCtx.sql("SET spark.sql.sources.default=org.apache.spark.sql.json") actual = self.sqlCtx.load(path=tmpPath) - self.assertTrue(sorted(df.collect()) == sorted(actual.collect())) + self.assertEqual(sorted(df.collect()), sorted(actual.collect())) + self.sqlCtx.sql("SET spark.sql.sources.default=" + defaultDataSourceName) + + shutil.rmtree(tmpPath) + + def test_save_and_load_builder(self): + df = self.df + tmpPath = tempfile.mkdtemp() + shutil.rmtree(tmpPath) + df.write.json(tmpPath) + actual = self.sqlCtx.read.json(tmpPath) + self.assertEqual(sorted(df.collect()), sorted(actual.collect())) + + schema = StructType([StructField("value", StringType(), True)]) + actual = self.sqlCtx.read.json(tmpPath, schema) + self.assertEqual(sorted(df.select("value").collect()), sorted(actual.collect())) + + df.write.mode("overwrite").json(tmpPath) + actual = self.sqlCtx.read.json(tmpPath) + self.assertEqual(sorted(df.collect()), sorted(actual.collect())) + + df.write.mode("overwrite").options(noUse="this options will not be used in save.")\ + .option("noUse", "this option will not be used in save.")\ + .format("json").save(path=tmpPath) + actual =\ + self.sqlCtx.read.format("json")\ + .load(path=tmpPath, noUse="this options will not be used in load.") + self.assertEqual(sorted(df.collect()), sorted(actual.collect())) + + defaultDataSourceName = self.sqlCtx.getConf("spark.sql.sources.default", + "org.apache.spark.sql.parquet") + self.sqlCtx.sql("SET spark.sql.sources.default=org.apache.spark.sql.json") + actual = self.sqlCtx.load(path=tmpPath) + self.assertEqual(sorted(df.collect()), sorted(actual.collect())) self.sqlCtx.sql("SET spark.sql.sources.default=" + defaultDataSourceName) shutil.rmtree(tmpPath) @@ -525,6 +629,14 @@ def test_access_column(self): self.assertRaises(IndexError, lambda: df["bad_key"]) self.assertRaises(TypeError, lambda: df[{}]) + def test_column_name_with_non_ascii(self): + df = self.sqlCtx.createDataFrame([(1,)], ["数量"]) + self.assertEqual(StructType([StructField("数量", LongType(), True)]), df.schema) + self.assertEqual("DataFrame[数量: bigint]", str(df)) + self.assertEqual([("数量", 'bigint')], df.dtypes) + self.assertEqual(1, df.select("数量").first()[0]) + self.assertEqual(1, df.select(df["数量"]).first()[0]) + def test_access_nested_types(self): df = self.sc.parallelize([Row(l=[1], r=Row(a=1, b="b"), d={"k": "v"})]).toDF() self.assertEqual(1, df.select(df.l[0]).first()[0]) @@ -571,6 +683,23 @@ def test_filter_with_datetime(self): self.assertEqual(0, df.filter(df.date > date).count()) self.assertEqual(0, df.filter(df.time > time).count()) + def test_time_with_timezone(self): + day = datetime.date.today() + now = datetime.datetime.now() + ts = time.mktime(now.timetuple()) + now.microsecond / 1e6 + # class in __main__ is not serializable + from pyspark.sql.tests import UTC + utc = UTC() + utcnow = datetime.datetime.fromtimestamp(ts, utc) + df = self.sqlCtx.createDataFrame([(day, now, utcnow)]) + day1, now1, utcnow1 = df.first() + # Pyrolite serialize java.sql.Date as datetime, will be fixed in new version + self.assertEqual(day1.date(), day) + # Pyrolite does not support microsecond, the error should be + # less than 1 millisecond + self.assertTrue(now - now1 < datetime.timedelta(0.001)) + self.assertTrue(now - utcnow1 < datetime.timedelta(0.001)) + def test_dropna(self): schema = StructType([ StructField("name", StringType(), True), @@ -728,6 +857,12 @@ def test_replace(self): self.assertEqual(row.age, 10) self.assertEqual(row.height, None) + def test_capture_analysis_exception(self): + self.assertRaises(AnalysisException, lambda: self.sqlCtx.sql("select abc")) + self.assertRaises(AnalysisException, lambda: self.df.selectExpr("a + b")) + # RuntimeException should not be captured + self.assertRaises(py4j.protocol.Py4JJavaError, lambda: self.sqlCtx.sql("abc")) + class HiveContextSQLTests(ReusedPySparkTestCase): @@ -738,11 +873,11 @@ def setUpClass(cls): try: cls.sc._jvm.org.apache.hadoop.hive.conf.HiveConf() except py4j.protocol.Py4JError: - cls.sqlCtx = None - return + cls.tearDownClass() + raise unittest.SkipTest("Hive is not available") except TypeError: - cls.sqlCtx = None - return + cls.tearDownClass() + raise unittest.SkipTest("Hive is not available") os.unlink(cls.tempdir.name) _scala_HiveContext =\ cls.sc._jvm.org.apache.spark.sql.hive.test.TestHiveContext(cls.sc._jsc.sc()) @@ -756,57 +891,68 @@ def tearDownClass(cls): shutil.rmtree(cls.tempdir.name, ignore_errors=True) def test_save_and_load_table(self): - if self.sqlCtx is None: - return # no hive available, skipped - df = self.df tmpPath = tempfile.mkdtemp() shutil.rmtree(tmpPath) - df.saveAsTable("savedJsonTable", "org.apache.spark.sql.json", "append", path=tmpPath) - actual = self.sqlCtx.createExternalTable("externalJsonTable", tmpPath, - "org.apache.spark.sql.json") - self.assertTrue( - sorted(df.collect()) == - sorted(self.sqlCtx.sql("SELECT * FROM savedJsonTable").collect())) - self.assertTrue( - sorted(df.collect()) == - sorted(self.sqlCtx.sql("SELECT * FROM externalJsonTable").collect())) - self.assertTrue(sorted(df.collect()) == sorted(actual.collect())) + df.write.saveAsTable("savedJsonTable", "json", "append", path=tmpPath) + actual = self.sqlCtx.createExternalTable("externalJsonTable", tmpPath, "json") + self.assertEqual(sorted(df.collect()), + sorted(self.sqlCtx.sql("SELECT * FROM savedJsonTable").collect())) + self.assertEqual(sorted(df.collect()), + sorted(self.sqlCtx.sql("SELECT * FROM externalJsonTable").collect())) + self.assertEqual(sorted(df.collect()), sorted(actual.collect())) self.sqlCtx.sql("DROP TABLE externalJsonTable") - df.saveAsTable("savedJsonTable", "org.apache.spark.sql.json", "overwrite", path=tmpPath) + df.write.saveAsTable("savedJsonTable", "json", "overwrite", path=tmpPath) schema = StructType([StructField("value", StringType(), True)]) - actual = self.sqlCtx.createExternalTable("externalJsonTable", - source="org.apache.spark.sql.json", + actual = self.sqlCtx.createExternalTable("externalJsonTable", source="json", schema=schema, path=tmpPath, noUse="this options will not be used") - self.assertTrue( - sorted(df.collect()) == - sorted(self.sqlCtx.sql("SELECT * FROM savedJsonTable").collect())) - self.assertTrue( - sorted(df.select("value").collect()) == - sorted(self.sqlCtx.sql("SELECT * FROM externalJsonTable").collect())) - self.assertTrue(sorted(df.select("value").collect()) == sorted(actual.collect())) + self.assertEqual(sorted(df.collect()), + sorted(self.sqlCtx.sql("SELECT * FROM savedJsonTable").collect())) + self.assertEqual(sorted(df.select("value").collect()), + sorted(self.sqlCtx.sql("SELECT * FROM externalJsonTable").collect())) + self.assertEqual(sorted(df.select("value").collect()), sorted(actual.collect())) self.sqlCtx.sql("DROP TABLE savedJsonTable") self.sqlCtx.sql("DROP TABLE externalJsonTable") defaultDataSourceName = self.sqlCtx.getConf("spark.sql.sources.default", "org.apache.spark.sql.parquet") self.sqlCtx.sql("SET spark.sql.sources.default=org.apache.spark.sql.json") - df.saveAsTable("savedJsonTable", path=tmpPath, mode="overwrite") + df.write.saveAsTable("savedJsonTable", path=tmpPath, mode="overwrite") actual = self.sqlCtx.createExternalTable("externalJsonTable", path=tmpPath) - self.assertTrue( - sorted(df.collect()) == - sorted(self.sqlCtx.sql("SELECT * FROM savedJsonTable").collect())) - self.assertTrue( - sorted(df.collect()) == - sorted(self.sqlCtx.sql("SELECT * FROM externalJsonTable").collect())) - self.assertTrue(sorted(df.collect()) == sorted(actual.collect())) + self.assertEqual(sorted(df.collect()), + sorted(self.sqlCtx.sql("SELECT * FROM savedJsonTable").collect())) + self.assertEqual(sorted(df.collect()), + sorted(self.sqlCtx.sql("SELECT * FROM externalJsonTable").collect())) + self.assertEqual(sorted(df.collect()), sorted(actual.collect())) self.sqlCtx.sql("DROP TABLE savedJsonTable") self.sqlCtx.sql("DROP TABLE externalJsonTable") self.sqlCtx.sql("SET spark.sql.sources.default=" + defaultDataSourceName) shutil.rmtree(tmpPath) + def test_window_functions(self): + df = self.sqlCtx.createDataFrame([(1, "1"), (2, "2"), (1, "2"), (1, "2")], ["key", "value"]) + w = Window.partitionBy("value").orderBy("key") + from pyspark.sql import functions as F + sel = df.select(df.value, df.key, + F.max("key").over(w.rowsBetween(0, 1)), + F.min("key").over(w.rowsBetween(0, 1)), + F.count("key").over(w.rowsBetween(float('-inf'), float('inf'))), + F.rowNumber().over(w), + F.rank().over(w), + F.denseRank().over(w), + F.ntile(2).over(w)) + rs = sorted(sel.collect()) + expected = [ + ("1", 1, 1, 1, 1, 1, 1, 1, 1), + ("2", 1, 1, 1, 3, 1, 1, 1, 1), + ("2", 1, 2, 1, 3, 2, 1, 1, 1), + ("2", 2, 2, 2, 3, 3, 3, 2, 2) + ] + for r, ex in zip(rs, expected): + self.assertEqual(tuple(r), ex[:len(r)]) + if __name__ == "__main__": unittest.main() diff --git a/python/pyspark/sql/_types.py b/python/pyspark/sql/types.py similarity index 86% rename from python/pyspark/sql/_types.py rename to python/pyspark/sql/types.py index 629c3a94513b..160df40d65cc 100644 --- a/python/pyspark/sql/_types.py +++ b/python/pyspark/sql/types.py @@ -19,6 +19,7 @@ import decimal import time import datetime +import calendar import keyword import warnings import json @@ -73,56 +74,84 @@ def json(self): # This singleton pattern does not work with pickle, you will get # another object after pickle and unpickle -class PrimitiveTypeSingleton(type): - """Metaclass for PrimitiveType""" +class DataTypeSingleton(type): + """Metaclass for DataType""" _instances = {} def __call__(cls): if cls not in cls._instances: - cls._instances[cls] = super(PrimitiveTypeSingleton, cls).__call__() + cls._instances[cls] = super(DataTypeSingleton, cls).__call__() return cls._instances[cls] -class PrimitiveType(DataType): - """Spark SQL PrimitiveType""" +class NullType(DataType): + """Null type. - __metaclass__ = PrimitiveTypeSingleton + The data type representing None, used for the types that cannot be inferred. + """ + __metaclass__ = DataTypeSingleton -class NullType(PrimitiveType): - """Null type. - The data type representing None, used for the types that cannot be inferred. +class AtomicType(DataType): + """An internal type used to represent everything that is not + null, UDTs, arrays, structs, and maps.""" + + +class NumericType(AtomicType): + """Numeric data types. + """ + + +class IntegralType(NumericType): + """Integral data types. + """ + + __metaclass__ = DataTypeSingleton + + +class FractionalType(NumericType): + """Fractional data types. """ -class StringType(PrimitiveType): +class StringType(AtomicType): """String data type. """ + __metaclass__ = DataTypeSingleton -class BinaryType(PrimitiveType): + +class BinaryType(AtomicType): """Binary (byte array) data type. """ + __metaclass__ = DataTypeSingleton + -class BooleanType(PrimitiveType): +class BooleanType(AtomicType): """Boolean data type. """ + __metaclass__ = DataTypeSingleton -class DateType(PrimitiveType): + +class DateType(AtomicType): """Date (datetime.date) data type. """ + __metaclass__ = DataTypeSingleton + -class TimestampType(PrimitiveType): +class TimestampType(AtomicType): """Timestamp (datetime.datetime) data type. """ + __metaclass__ = DataTypeSingleton -class DecimalType(DataType): + +class DecimalType(FractionalType): """Decimal (decimal.Decimal) data type. """ @@ -150,31 +179,35 @@ def __repr__(self): return "DecimalType()" -class DoubleType(PrimitiveType): +class DoubleType(FractionalType): """Double data type, representing double precision floats. """ + __metaclass__ = DataTypeSingleton + -class FloatType(PrimitiveType): +class FloatType(FractionalType): """Float data type, representing single precision floats. """ + __metaclass__ = DataTypeSingleton + -class ByteType(PrimitiveType): +class ByteType(IntegralType): """Byte data type, i.e. a signed integer in a single byte. """ def simpleString(self): return 'tinyint' -class IntegerType(PrimitiveType): +class IntegerType(IntegralType): """Int data type, i.e. a signed 32-bit integer. """ def simpleString(self): return 'int' -class LongType(PrimitiveType): +class LongType(IntegralType): """Long data type, i.e. a signed 64-bit integer. If the values are beyond the range of [-9223372036854775808, 9223372036854775807], @@ -184,7 +217,7 @@ def simpleString(self): return 'bigint' -class ShortType(PrimitiveType): +class ShortType(IntegralType): """Short data type, i.e. a signed 16-bit integer. """ def simpleString(self): @@ -291,6 +324,8 @@ def __init__(self, name, dataType, nullable=True, metadata=None): False """ assert isinstance(dataType, DataType), "dataType should be DataType" + if not isinstance(name, str): + name = name.encode('utf-8') self.name = name self.dataType = dataType self.nullable = nullable @@ -322,8 +357,7 @@ class StructType(DataType): This is the data type representing a :class:`Row`. """ - - def __init__(self, fields): + def __init__(self, fields=None): """ >>> struct1 = StructType([StructField("f1", StringType(), True)]) >>> struct2 = StructType([StructField("f1", StringType(), True)]) @@ -335,8 +369,53 @@ def __init__(self, fields): >>> struct1 == struct2 False """ - assert all(isinstance(f, DataType) for f in fields), "fields should be a list of DataType" - self.fields = fields + if not fields: + self.fields = [] + else: + self.fields = fields + assert all(isinstance(f, StructField) for f in fields),\ + "fields should be a list of StructField" + + def add(self, field, data_type=None, nullable=True, metadata=None): + """ + Construct a StructType by adding new elements to it to define the schema. The method accepts + either: + a) A single parameter which is a StructField object. + b) Between 2 and 4 parameters as (name, data_type, nullable (optional), + metadata(optional). The data_type parameter may be either a String or a DataType object + + >>> struct1 = StructType().add("f1", StringType(), True).add("f2", StringType(), True, None) + >>> struct2 = StructType([StructField("f1", StringType(), True),\ + StructField("f2", StringType(), True, None)]) + >>> struct1 == struct2 + True + >>> struct1 = StructType().add(StructField("f1", StringType(), True)) + >>> struct2 = StructType([StructField("f1", StringType(), True)]) + >>> struct1 == struct2 + True + >>> struct1 = StructType().add("f1", "string", True) + >>> struct2 = StructType([StructField("f1", StringType(), True)]) + >>> struct1 == struct2 + True + + :param field: Either the name of the field or a StructField object + :param data_type: If present, the DataType of the StructField to create + :param nullable: Whether the field to add should be nullable (default True) + :param metadata: Any additional metadata (default None) + :return: a new updated StructType + """ + if isinstance(field, StructField): + self.fields.append(field) + else: + if isinstance(field, str) and data_type is None: + raise ValueError("Must specify DataType if passing name of struct_field to create.") + + if isinstance(data_type, str): + data_type_f = _parse_datatype_json_value(data_type) + else: + data_type_f = data_type + self.fields.append(StructField(field, data_type_f, nullable, metadata)) + return self def simpleString(self): return 'struct<%s>' % (','.join(f.simpleString() for f in self.fields)) @@ -426,11 +505,9 @@ def __eq__(self, other): return type(self) == type(other) -_all_primitive_types = dict((v.typeName(), v) - for v in list(globals().values()) - if (type(v) is type or type(v) is PrimitiveTypeSingleton) - and v.__base__ == PrimitiveType) - +_atomic_types = [StringType, BinaryType, BooleanType, DecimalType, FloatType, DoubleType, + ByteType, ShortType, IntegerType, LongType, DateType, TimestampType] +_all_atomic_types = dict((t.typeName(), t) for t in _atomic_types) _all_complex_types = dict((v.typeName(), v) for v in [ArrayType, MapType, StructType]) @@ -444,7 +521,7 @@ def _parse_datatype_json_string(json_string): ... scala_datatype = sqlContext._ssql_ctx.parseDataType(datatype.json()) ... python_datatype = _parse_datatype_json_string(scala_datatype.json()) ... assert datatype == python_datatype - >>> for cls in _all_primitive_types.values(): + >>> for cls in _all_atomic_types.values(): ... check_datatype(cls()) >>> # Simple ArrayType. @@ -494,8 +571,8 @@ def _parse_datatype_json_string(json_string): def _parse_datatype_json_value(json_value): if not isinstance(json_value, dict): - if json_value in _all_primitive_types.keys(): - return _all_primitive_types[json_value]() + if json_value in _all_atomic_types.keys(): + return _all_atomic_types[json_value]() elif json_value == 'decimal': return DecimalType() elif _FIXED_DECIMAL.match(json_value): @@ -604,7 +681,7 @@ def _need_python_to_sql_conversion(dataType): >>> schema0 = StructType([StructField("indices", ArrayType(IntegerType(), False), False), ... StructField("values", ArrayType(DoubleType(), False), False)]) >>> _need_python_to_sql_conversion(schema0) - False + True >>> _need_python_to_sql_conversion(ExamplePointUDT()) True >>> schema1 = ArrayType(ExamplePointUDT(), False) @@ -616,7 +693,8 @@ def _need_python_to_sql_conversion(dataType): True """ if isinstance(dataType, StructType): - return any([_need_python_to_sql_conversion(f.dataType) for f in dataType.fields]) + # convert namedtuple or Row into tuple + return True elif isinstance(dataType, ArrayType): return _need_python_to_sql_conversion(dataType.elementType) elif isinstance(dataType, MapType): @@ -624,10 +702,15 @@ def _need_python_to_sql_conversion(dataType): _need_python_to_sql_conversion(dataType.valueType) elif isinstance(dataType, UserDefinedType): return True + elif isinstance(dataType, (DateType, TimestampType)): + return True else: return False +EPOCH_ORDINAL = datetime.datetime(1970, 1, 1).toordinal() + + def _python_to_sql_converter(dataType): """ Returns a converter that converts a Python object into a SQL datum for the given type. @@ -652,31 +735,49 @@ def _python_to_sql_converter(dataType): if isinstance(dataType, StructType): names, types = zip(*[(f.name, f.dataType) for f in dataType.fields]) - converters = [_python_to_sql_converter(t) for t in types] - - def converter(obj): - if isinstance(obj, dict): - return tuple(c(obj.get(n)) for n, c in zip(names, converters)) - elif isinstance(obj, tuple): - if hasattr(obj, "__fields__") or hasattr(obj, "_fields"): - return tuple(c(v) for c, v in zip(converters, obj)) - elif all(isinstance(x, tuple) and len(x) == 2 for x in obj): # k-v pairs - d = dict(obj) - return tuple(c(d.get(n)) for n, c in zip(names, converters)) + if any(_need_python_to_sql_conversion(t) for t in types): + converters = [_python_to_sql_converter(t) for t in types] + + def converter(obj): + if isinstance(obj, dict): + return tuple(c(obj.get(n)) for n, c in zip(names, converters)) + elif isinstance(obj, tuple): + if hasattr(obj, "__fields__") or hasattr(obj, "_fields"): + return tuple(c(v) for c, v in zip(converters, obj)) + else: + return tuple(c(v) for c, v in zip(converters, obj)) + elif obj is not None: + raise ValueError("Unexpected tuple %r with type %r" % (obj, dataType)) + else: + def converter(obj): + if isinstance(obj, dict): + return tuple(obj.get(n) for n in names) else: - return tuple(c(v) for c, v in zip(converters, obj)) - else: - raise ValueError("Unexpected tuple %r with type %r" % (obj, dataType)) + return tuple(obj) return converter elif isinstance(dataType, ArrayType): element_converter = _python_to_sql_converter(dataType.elementType) - return lambda a: [element_converter(v) for v in a] + return lambda a: a and [element_converter(v) for v in a] elif isinstance(dataType, MapType): key_converter = _python_to_sql_converter(dataType.keyType) value_converter = _python_to_sql_converter(dataType.valueType) - return lambda m: dict([(key_converter(k), value_converter(v)) for k, v in m.items()]) + return lambda m: m and dict([(key_converter(k), value_converter(v)) for k, v in m.items()]) + elif isinstance(dataType, UserDefinedType): - return lambda obj: dataType.serialize(obj) + return lambda obj: obj and dataType.serialize(obj) + + elif isinstance(dataType, DateType): + return lambda d: d and d.toordinal() - EPOCH_ORDINAL + + elif isinstance(dataType, TimestampType): + + def to_posix_timstamp(dt): + if dt: + seconds = (calendar.timegm(dt.utctimetuple()) if dt.tzinfo + else time.mktime(dt.timetuple())) + return int(seconds * 1e7 + dt.microsecond * 10) + return to_posix_timstamp + else: raise ValueError("Unexpected type %r" % dataType) @@ -977,10 +1078,13 @@ def _verify_type(obj, dataType): _type = type(dataType) assert _type in _acceptable_types, "unknown datatype: %s" % dataType - # subclass of them can not be deserialized in JVM - if type(obj) not in _acceptable_types[_type]: - raise TypeError("%s can not accept object in type %s" - % (dataType, type(obj))) + if _type is StructType: + if not isinstance(obj, (tuple, list)): + raise TypeError("StructType can not accept object in type %s" % type(obj)) + else: + # subclass of them can not be deserialized in JVM + if type(obj) not in _acceptable_types[_type]: + raise TypeError("%s can not accept object in type %s" % (dataType, type(obj))) if isinstance(dataType, ArrayType): for i in obj: @@ -1125,7 +1229,7 @@ def Dict(d): return lambda datum: dataType.deserialize(datum) elif not isinstance(dataType, StructType): - # no wrapper for primitive types + # no wrapper for atomic types return lambda x: x class Row(tuple): diff --git a/python/pyspark/sql/utils.py b/python/pyspark/sql/utils.py new file mode 100644 index 000000000000..cc5b2c088b7c --- /dev/null +++ b/python/pyspark/sql/utils.py @@ -0,0 +1,54 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +import py4j + + +class AnalysisException(Exception): + """ + Failed to analyze a SQL query plan. + """ + + +def capture_sql_exception(f): + def deco(*a, **kw): + try: + return f(*a, **kw) + except py4j.protocol.Py4JJavaError as e: + s = e.java_exception.toString() + if s.startswith('org.apache.spark.sql.AnalysisException: '): + raise AnalysisException(s.split(': ', 1)[1]) + raise + return deco + + +def install_exception_handler(): + """ + Hook an exception handler into Py4j, which could capture some SQL exceptions in Java. + + When calling Java API, it will call `get_return_value` to parse the returned object. + If any exception happened in JVM, the result will be Java exception object, it raise + py4j.protocol.Py4JJavaError. We replace the original `get_return_value` with one that + could capture the Java exception and throw a Python one (with the same error message). + + It's idempotent, could be called multiple times. + """ + original = py4j.protocol.get_return_value + # The original `get_return_value` is not patched, it's idempotent. + patched = capture_sql_exception(original) + # only patch the one used in in py4j.java_gateway (call Java API) + py4j.java_gateway.get_return_value = patched diff --git a/python/pyspark/sql/window.py b/python/pyspark/sql/window.py new file mode 100644 index 000000000000..c74745c726a0 --- /dev/null +++ b/python/pyspark/sql/window.py @@ -0,0 +1,157 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +import sys + +from pyspark import SparkContext +from pyspark.sql import since +from pyspark.sql.column import _to_seq, _to_java_column + +__all__ = ["Window", "WindowSpec"] + + +def _to_java_cols(cols): + sc = SparkContext._active_spark_context + if len(cols) == 1 and isinstance(cols[0], list): + cols = cols[0] + return _to_seq(sc, cols, _to_java_column) + + +class Window(object): + """ + Utility functions for defining window in DataFrames. + + For example: + + >>> # PARTITION BY country ORDER BY date ROWS BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW + >>> window = Window.partitionBy("country").orderBy("date").rowsBetween(-sys.maxsize, 0) + + >>> # PARTITION BY country ORDER BY date RANGE BETWEEN 3 PRECEDING AND 3 FOLLOWING + >>> window = Window.orderBy("date").partitionBy("country").rangeBetween(-3, 3) + + .. note:: Experimental + + .. versionadded:: 1.4 + """ + @staticmethod + @since(1.4) + def partitionBy(*cols): + """ + Creates a :class:`WindowSpec` with the partitioning defined. + """ + sc = SparkContext._active_spark_context + jspec = sc._jvm.org.apache.spark.sql.expressions.Window.partitionBy(_to_java_cols(cols)) + return WindowSpec(jspec) + + @staticmethod + @since(1.4) + def orderBy(*cols): + """ + Creates a :class:`WindowSpec` with the partitioning defined. + """ + sc = SparkContext._active_spark_context + jspec = sc._jvm.org.apache.spark.sql.expressions.Window.partitionBy(_to_java_cols(cols)) + return WindowSpec(jspec) + + +class WindowSpec(object): + """ + A window specification that defines the partitioning, ordering, + and frame boundaries. + + Use the static methods in :class:`Window` to create a :class:`WindowSpec`. + + .. note:: Experimental + + .. versionadded:: 1.4 + """ + + _JAVA_MAX_LONG = (1 << 63) - 1 + _JAVA_MIN_LONG = - (1 << 63) + + def __init__(self, jspec): + self._jspec = jspec + + @since(1.4) + def partitionBy(self, *cols): + """ + Defines the partitioning columns in a :class:`WindowSpec`. + + :param cols: names of columns or expressions + """ + return WindowSpec(self._jspec.partitionBy(_to_java_cols(cols))) + + @since(1.4) + def orderBy(self, *cols): + """ + Defines the ordering columns in a :class:`WindowSpec`. + + :param cols: names of columns or expressions + """ + return WindowSpec(self._jspec.orderBy(_to_java_cols(cols))) + + @since(1.4) + def rowsBetween(self, start, end): + """ + Defines the frame boundaries, from `start` (inclusive) to `end` (inclusive). + + Both `start` and `end` are relative positions from the current row. + For example, "0" means "current row", while "-1" means the row before + the current row, and "5" means the fifth row after the current row. + + :param start: boundary start, inclusive. + The frame is unbounded if this is ``-sys.maxsize`` (or lower). + :param end: boundary end, inclusive. + The frame is unbounded if this is ``sys.maxsize`` (or higher). + """ + if start <= -sys.maxsize: + start = self._JAVA_MIN_LONG + if end >= sys.maxsize: + end = self._JAVA_MAX_LONG + return WindowSpec(self._jspec.rowsBetween(start, end)) + + @since(1.4) + def rangeBetween(self, start, end): + """ + Defines the frame boundaries, from `start` (inclusive) to `end` (inclusive). + + Both `start` and `end` are relative from the current row. For example, + "0" means "current row", while "-1" means one off before the current row, + and "5" means the five off after the current row. + + :param start: boundary start, inclusive. + The frame is unbounded if this is ``-sys.maxsize`` (or lower). + :param end: boundary end, inclusive. + The frame is unbounded if this is ``sys.maxsize`` (or higher). + """ + if start <= -sys.maxsize: + start = self._JAVA_MIN_LONG + if end >= sys.maxsize: + end = self._JAVA_MAX_LONG + return WindowSpec(self._jspec.rangeBetween(start, end)) + + +def _test(): + import doctest + SparkContext('local[4]', 'PythonTest') + (failure_count, test_count) = doctest.testmod() + if failure_count: + exit(-1) + + +if __name__ == "__main__": + _test() diff --git a/python/pyspark/streaming/dstream.py b/python/pyspark/streaming/dstream.py index ff097985fae3..8dcb9645cdc6 100644 --- a/python/pyspark/streaming/dstream.py +++ b/python/pyspark/streaming/dstream.py @@ -176,7 +176,7 @@ def takeAndPrint(time, rdd): print(record) if len(taken) > num: print("...") - print() + print("") self.foreachRDD(takeAndPrint) diff --git a/python/pyspark/streaming/flume.py b/python/pyspark/streaming/flume.py new file mode 100644 index 000000000000..cbb573f226bb --- /dev/null +++ b/python/pyspark/streaming/flume.py @@ -0,0 +1,147 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +import sys +if sys.version >= "3": + from io import BytesIO +else: + from StringIO import StringIO +from py4j.java_gateway import Py4JJavaError + +from pyspark.storagelevel import StorageLevel +from pyspark.serializers import PairDeserializer, NoOpSerializer, UTF8Deserializer, read_int +from pyspark.streaming import DStream + +__all__ = ['FlumeUtils', 'utf8_decoder'] + + +def utf8_decoder(s): + """ Decode the unicode as UTF-8 """ + return s and s.decode('utf-8') + + +class FlumeUtils(object): + + @staticmethod + def createStream(ssc, hostname, port, + storageLevel=StorageLevel.MEMORY_AND_DISK_SER_2, + enableDecompression=False, + bodyDecoder=utf8_decoder): + """ + Create an input stream that pulls events from Flume. + + :param ssc: StreamingContext object + :param hostname: Hostname of the slave machine to which the flume data will be sent + :param port: Port of the slave machine to which the flume data will be sent + :param storageLevel: Storage level to use for storing the received objects + :param enableDecompression: Should netty server decompress input stream + :param bodyDecoder: A function used to decode body (default is utf8_decoder) + :return: A DStream object + """ + jlevel = ssc._sc._getJavaStorageLevel(storageLevel) + + try: + helperClass = ssc._jvm.java.lang.Thread.currentThread().getContextClassLoader()\ + .loadClass("org.apache.spark.streaming.flume.FlumeUtilsPythonHelper") + helper = helperClass.newInstance() + jstream = helper.createStream(ssc._jssc, hostname, port, jlevel, enableDecompression) + except Py4JJavaError as e: + if 'ClassNotFoundException' in str(e.java_exception): + FlumeUtils._printErrorMsg(ssc.sparkContext) + raise e + + return FlumeUtils._toPythonDStream(ssc, jstream, bodyDecoder) + + @staticmethod + def createPollingStream(ssc, addresses, + storageLevel=StorageLevel.MEMORY_AND_DISK_SER_2, + maxBatchSize=1000, + parallelism=5, + bodyDecoder=utf8_decoder): + """ + Creates an input stream that is to be used with the Spark Sink deployed on a Flume agent. + This stream will poll the sink for data and will pull events as they are available. + + :param ssc: StreamingContext object + :param addresses: List of (host, port)s on which the Spark Sink is running. + :param storageLevel: Storage level to use for storing the received objects + :param maxBatchSize: The maximum number of events to be pulled from the Spark sink + in a single RPC call + :param parallelism: Number of concurrent requests this stream should send to the sink. + Note that having a higher number of requests concurrently being pulled + will result in this stream using more threads + :param bodyDecoder: A function used to decode body (default is utf8_decoder) + :return: A DStream object + """ + jlevel = ssc._sc._getJavaStorageLevel(storageLevel) + hosts = [] + ports = [] + for (host, port) in addresses: + hosts.append(host) + ports.append(port) + + try: + helperClass = ssc._jvm.java.lang.Thread.currentThread().getContextClassLoader() \ + .loadClass("org.apache.spark.streaming.flume.FlumeUtilsPythonHelper") + helper = helperClass.newInstance() + jstream = helper.createPollingStream( + ssc._jssc, hosts, ports, jlevel, maxBatchSize, parallelism) + except Py4JJavaError as e: + if 'ClassNotFoundException' in str(e.java_exception): + FlumeUtils._printErrorMsg(ssc.sparkContext) + raise e + + return FlumeUtils._toPythonDStream(ssc, jstream, bodyDecoder) + + @staticmethod + def _toPythonDStream(ssc, jstream, bodyDecoder): + ser = PairDeserializer(NoOpSerializer(), NoOpSerializer()) + stream = DStream(jstream, ssc, ser) + + def func(event): + headersBytes = BytesIO(event[0]) if sys.version >= "3" else StringIO(event[0]) + headers = {} + strSer = UTF8Deserializer() + for i in range(0, read_int(headersBytes)): + key = strSer.loads(headersBytes) + value = strSer.loads(headersBytes) + headers[key] = value + body = bodyDecoder(event[1]) + return (headers, body) + return stream.map(func) + + @staticmethod + def _printErrorMsg(sc): + print(""" +________________________________________________________________________________________________ + + Spark Streaming's Flume libraries not found in class path. Try one of the following. + + 1. Include the Flume library and its dependencies with in the + spark-submit command as + + $ bin/spark-submit --packages org.apache.spark:spark-streaming-flume:%s ... + + 2. Download the JAR of the artifact from Maven Central http://search.maven.org/, + Group Id = org.apache.spark, Artifact Id = spark-streaming-flume-assembly, Version = %s. + Then, include the jar in the spark-submit command as + + $ bin/spark-submit --jars ... + +________________________________________________________________________________________________ + +""" % (sc.version, sc.version)) diff --git a/python/pyspark/streaming/kafka.py b/python/pyspark/streaming/kafka.py index e278b29003f6..10a859a532e2 100644 --- a/python/pyspark/streaming/kafka.py +++ b/python/pyspark/streaming/kafka.py @@ -132,11 +132,12 @@ def createRDD(sc, kafkaParams, offsetRanges, leaders={}, .. note:: Experimental Create a RDD from Kafka using offset ranges for each topic and partition. + :param sc: SparkContext object :param kafkaParams: Additional params for Kafka :param offsetRanges: list of offsetRange to specify topic:partition:[start, end) to consume :param leaders: Kafka brokers for each TopicAndPartition in offsetRanges. May be an empty - map, in which case leaders will be looked up on the driver. + map, in which case leaders will be looked up on the driver. :param keyDecoder: A function used to decode key (default is utf8_decoder) :param valueDecoder: A function used to decode value (default is utf8_decoder) :return: A RDD object diff --git a/python/pyspark/streaming/tests.py b/python/pyspark/streaming/tests.py index 33ea8c9293d7..188c8ff12067 100644 --- a/python/pyspark/streaming/tests.py +++ b/python/pyspark/streaming/tests.py @@ -15,6 +15,7 @@ # limitations under the License. # +import glob import os import sys from itertools import chain @@ -37,12 +38,13 @@ from pyspark.context import SparkConf, SparkContext, RDD from pyspark.streaming.context import StreamingContext from pyspark.streaming.kafka import Broker, KafkaUtils, OffsetRange, TopicAndPartition +from pyspark.streaming.flume import FlumeUtils class PySparkStreamingTestCase(unittest.TestCase): - timeout = 4 # seconds - duration = .2 + timeout = 10 # seconds + duration = .5 @classmethod def setUpClass(cls): @@ -379,13 +381,13 @@ def func(dstream): class WindowFunctionTests(PySparkStreamingTestCase): - timeout = 5 + timeout = 15 def test_window(self): input = [range(1), range(2), range(3), range(4), range(5)] def func(dstream): - return dstream.window(.6, .2).count() + return dstream.window(1.5, .5).count() expected = [[1], [3], [6], [9], [12], [9], [5]] self._test_func(input, func, expected) @@ -394,7 +396,7 @@ def test_count_by_window(self): input = [range(1), range(2), range(3), range(4), range(5)] def func(dstream): - return dstream.countByWindow(.6, .2) + return dstream.countByWindow(1.5, .5) expected = [[1], [3], [6], [9], [12], [9], [5]] self._test_func(input, func, expected) @@ -403,7 +405,7 @@ def test_count_by_window_large(self): input = [range(1), range(2), range(3), range(4), range(5), range(6)] def func(dstream): - return dstream.countByWindow(1, .2) + return dstream.countByWindow(2.5, .5) expected = [[1], [3], [6], [10], [15], [20], [18], [15], [11], [6]] self._test_func(input, func, expected) @@ -412,7 +414,7 @@ def test_count_by_value_and_window(self): input = [range(1), range(2), range(3), range(4), range(5), range(6)] def func(dstream): - return dstream.countByValueAndWindow(1, .2) + return dstream.countByValueAndWindow(2.5, .5) expected = [[1], [2], [3], [4], [5], [6], [6], [6], [6], [6]] self._test_func(input, func, expected) @@ -421,7 +423,7 @@ def test_group_by_key_and_window(self): input = [[('a', i)] for i in range(5)] def func(dstream): - return dstream.groupByKeyAndWindow(.6, .2).mapValues(list) + return dstream.groupByKeyAndWindow(1.5, .5).mapValues(list) expected = [[('a', [0])], [('a', [0, 1])], [('a', [0, 1, 2])], [('a', [1, 2, 3])], [('a', [2, 3, 4])], [('a', [3, 4])], [('a', [4])]] @@ -615,7 +617,6 @@ def test_kafka_stream(self): self._kafkaTestUtils.createTopic(topic) self._kafkaTestUtils.sendMessages(topic, sendData) - self._kafkaTestUtils.waitUntilLeaderOffset(topic, 0, sum(sendData.values())) stream = KafkaUtils.createStream(self.ssc, self._kafkaTestUtils.zkAddress(), "test-streaming-consumer", {topic: 1}, @@ -631,7 +632,6 @@ def test_kafka_direct_stream(self): self._kafkaTestUtils.createTopic(topic) self._kafkaTestUtils.sendMessages(topic, sendData) - self._kafkaTestUtils.waitUntilLeaderOffset(topic, 0, sum(sendData.values())) stream = KafkaUtils.createDirectStream(self.ssc, [topic], kafkaParams) self._validateStreamResult(sendData, stream) @@ -646,7 +646,6 @@ def test_kafka_direct_stream_from_offset(self): self._kafkaTestUtils.createTopic(topic) self._kafkaTestUtils.sendMessages(topic, sendData) - self._kafkaTestUtils.waitUntilLeaderOffset(topic, 0, sum(sendData.values())) stream = KafkaUtils.createDirectStream(self.ssc, [topic], kafkaParams, fromOffsets) self._validateStreamResult(sendData, stream) @@ -661,7 +660,6 @@ def test_kafka_rdd(self): self._kafkaTestUtils.createTopic(topic) self._kafkaTestUtils.sendMessages(topic, sendData) - self._kafkaTestUtils.waitUntilLeaderOffset(topic, 0, sum(sendData.values())) rdd = KafkaUtils.createRDD(self.sc, kafkaParams, offsetRanges) self._validateRddResult(sendData, rdd) @@ -677,9 +675,197 @@ def test_kafka_rdd_with_leaders(self): self._kafkaTestUtils.createTopic(topic) self._kafkaTestUtils.sendMessages(topic, sendData) - self._kafkaTestUtils.waitUntilLeaderOffset(topic, 0, sum(sendData.values())) rdd = KafkaUtils.createRDD(self.sc, kafkaParams, offsetRanges, leaders) self._validateRddResult(sendData, rdd) + +class FlumeStreamTests(PySparkStreamingTestCase): + timeout = 20 # seconds + duration = 1 + + def setUp(self): + super(FlumeStreamTests, self).setUp() + + utilsClz = self.ssc._jvm.java.lang.Thread.currentThread().getContextClassLoader() \ + .loadClass("org.apache.spark.streaming.flume.FlumeTestUtils") + self._utils = utilsClz.newInstance() + + def tearDown(self): + if self._utils is not None: + self._utils.close() + self._utils = None + + super(FlumeStreamTests, self).tearDown() + + def _startContext(self, n, compressed): + # Start the StreamingContext and also collect the result + dstream = FlumeUtils.createStream(self.ssc, "localhost", self._utils.getTestPort(), + enableDecompression=compressed) + result = [] + + def get_output(_, rdd): + for event in rdd.collect(): + if len(result) < n: + result.append(event) + dstream.foreachRDD(get_output) + self.ssc.start() + return result + + def _validateResult(self, input, result): + # Validate both the header and the body + header = {"test": "header"} + self.assertEqual(len(input), len(result)) + for i in range(0, len(input)): + self.assertEqual(header, result[i][0]) + self.assertEqual(input[i], result[i][1]) + + def _writeInput(self, input, compressed): + # Try to write input to the receiver until success or timeout + start_time = time.time() + while True: + try: + self._utils.writeInput(input, compressed) + break + except: + if time.time() - start_time < self.timeout: + time.sleep(0.01) + else: + raise + + def test_flume_stream(self): + input = [str(i) for i in range(1, 101)] + result = self._startContext(len(input), False) + self._writeInput(input, False) + self.wait_for(result, len(input)) + self._validateResult(input, result) + + def test_compressed_flume_stream(self): + input = [str(i) for i in range(1, 101)] + result = self._startContext(len(input), True) + self._writeInput(input, True) + self.wait_for(result, len(input)) + self._validateResult(input, result) + + +class FlumePollingStreamTests(PySparkStreamingTestCase): + timeout = 20 # seconds + duration = 1 + maxAttempts = 5 + + def setUp(self): + utilsClz = \ + self.sc._jvm.java.lang.Thread.currentThread().getContextClassLoader() \ + .loadClass("org.apache.spark.streaming.flume.PollingFlumeTestUtils") + self._utils = utilsClz.newInstance() + + def tearDown(self): + if self._utils is not None: + self._utils.close() + self._utils = None + + def _writeAndVerify(self, ports): + # Set up the streaming context and input streams + ssc = StreamingContext(self.sc, self.duration) + try: + addresses = [("localhost", port) for port in ports] + dstream = FlumeUtils.createPollingStream( + ssc, + addresses, + maxBatchSize=self._utils.eventsPerBatch(), + parallelism=5) + outputBuffer = [] + + def get_output(_, rdd): + for e in rdd.collect(): + outputBuffer.append(e) + + dstream.foreachRDD(get_output) + ssc.start() + self._utils.sendDatAndEnsureAllDataHasBeenReceived() + + self.wait_for(outputBuffer, self._utils.getTotalEvents()) + outputHeaders = [event[0] for event in outputBuffer] + outputBodies = [event[1] for event in outputBuffer] + self._utils.assertOutput(outputHeaders, outputBodies) + finally: + ssc.stop(False) + + def _testMultipleTimes(self, f): + attempt = 0 + while True: + try: + f() + break + except: + attempt += 1 + if attempt >= self.maxAttempts: + raise + else: + import traceback + traceback.print_exc() + + def _testFlumePolling(self): + try: + port = self._utils.startSingleSink() + self._writeAndVerify([port]) + self._utils.assertChannelsAreEmpty() + finally: + self._utils.close() + + def _testFlumePollingMultipleHosts(self): + try: + port = self._utils.startSingleSink() + self._writeAndVerify([port]) + self._utils.assertChannelsAreEmpty() + finally: + self._utils.close() + + def test_flume_polling(self): + self._testMultipleTimes(self._testFlumePolling) + + def test_flume_polling_multiple_hosts(self): + self._testMultipleTimes(self._testFlumePollingMultipleHosts) + + +def search_kafka_assembly_jar(): + SPARK_HOME = os.environ["SPARK_HOME"] + kafka_assembly_dir = os.path.join(SPARK_HOME, "external/kafka-assembly") + jars = glob.glob( + os.path.join(kafka_assembly_dir, "target/scala-*/spark-streaming-kafka-assembly-*.jar")) + if not jars: + raise Exception( + ("Failed to find Spark Streaming kafka assembly jar in %s. " % kafka_assembly_dir) + + "You need to build Spark with " + "'build/sbt assembly/assembly streaming-kafka-assembly/assembly' or " + "'build/mvn package' before running this test") + elif len(jars) > 1: + raise Exception(("Found multiple Spark Streaming Kafka assembly JARs in %s; please " + "remove all but one") % kafka_assembly_dir) + else: + return jars[0] + + +def search_flume_assembly_jar(): + SPARK_HOME = os.environ["SPARK_HOME"] + flume_assembly_dir = os.path.join(SPARK_HOME, "external/flume-assembly") + jars = glob.glob( + os.path.join(flume_assembly_dir, "target/scala-*/spark-streaming-flume-assembly-*.jar")) + if not jars: + raise Exception( + ("Failed to find Spark Streaming Flume assembly jar in %s. " % flume_assembly_dir) + + "You need to build Spark with " + "'build/sbt assembly/assembly streaming-flume-assembly/assembly' or " + "'build/mvn package' before running this test") + elif len(jars) > 1: + raise Exception(("Found multiple Spark Streaming Flume assembly JARs in %s; please " + "remove all but one") % flume_assembly_dir) + else: + return jars[0] + if __name__ == "__main__": + kafka_assembly_jar = search_kafka_assembly_jar() + flume_assembly_jar = search_flume_assembly_jar() + jars = "%s,%s" % (kafka_assembly_jar, flume_assembly_jar) + + os.environ["PYSPARK_SUBMIT_ARGS"] = "--jars %s pyspark-shell" % jars unittest.main() diff --git a/python/pyspark/streaming/util.py b/python/pyspark/streaming/util.py index 34291f30a565..a9bfec2aab8f 100644 --- a/python/pyspark/streaming/util.py +++ b/python/pyspark/streaming/util.py @@ -125,4 +125,6 @@ def rddToFileName(prefix, suffix, timestamp): if __name__ == "__main__": import doctest - doctest.testmod() + (failure_count, test_count) = doctest.testmod() + if failure_count: + exit(-1) diff --git a/python/pyspark/tests.py b/python/pyspark/tests.py index 09de4d159fdc..17256dfc9574 100644 --- a/python/pyspark/tests.py +++ b/python/pyspark/tests.py @@ -179,9 +179,12 @@ def test_in_memory_sort(self): list(sorter.sorted(l, key=lambda x: -x, reverse=True))) def test_external_sort(self): + class CustomizedSorter(ExternalSorter): + def _next_limit(self): + return self.memory_limit l = list(range(1024)) random.shuffle(l) - sorter = ExternalSorter(1) + sorter = CustomizedSorter(1) self.assertEqual(sorted(l), list(sorter.sorted(l))) self.assertGreater(shuffle.DiskBytesSpilled, 0) last = shuffle.DiskBytesSpilled @@ -444,6 +447,11 @@ def func(x): class RDDTests(ReusedPySparkTestCase): + def test_range(self): + self.assertEqual(self.sc.range(1, 1).count(), 0) + self.assertEqual(self.sc.range(1, 0, -1).count(), 1) + self.assertEqual(self.sc.range(0, 1 << 40, 1 << 39).count(), 2) + def test_id(self): rdd = self.sc.parallelize(range(10)) id = rdd.id() @@ -453,6 +461,14 @@ def test_id(self): self.assertEqual(id + 1, id2) self.assertEqual(id2, rdd2.id()) + def test_empty_rdd(self): + rdd = self.sc.emptyRDD() + self.assertTrue(rdd.isEmpty()) + + def test_sum(self): + self.assertEqual(0, self.sc.emptyRDD().sum()) + self.assertEqual(6, self.sc.parallelize([1, 2, 3]).sum()) + def test_save_as_textfile_with_unicode(self): # Regression test for SPARK-970 x = u"\u00A1Hola, mundo!" @@ -1405,7 +1421,8 @@ def do_termination_test(self, terminator): # start daemon daemon_path = os.path.join(os.path.dirname(__file__), "daemon.py") - daemon = Popen([sys.executable, daemon_path], stdin=PIPE, stdout=PIPE) + python_exec = sys.executable or os.environ.get("PYSPARK_PYTHON") + daemon = Popen([python_exec, daemon_path], stdin=PIPE, stdout=PIPE) # read the port number port = read_int(daemon.stdout) @@ -1543,13 +1560,13 @@ def count(): def test_with_different_versions_of_python(self): rdd = self.sc.parallelize(range(10)) rdd.count() - version = sys.version_info - sys.version_info = (2, 0, 0) + version = self.sc.pythonVer + self.sc.pythonVer = "2.0" try: with QuietTest(self.sc): self.assertRaises(Py4JJavaError, lambda: rdd.count()) finally: - sys.version_info = version + self.sc.pythonVer = version class SparkSubmitTests(unittest.TestCase): @@ -1804,6 +1821,10 @@ def run(): sc.stop() + def test_startTime(self): + with SparkContext() as sc: + self.assertGreater(sc.startTime, 0) + @unittest.skipIf(not _have_scipy, "SciPy not installed") class SciPyTests(PySparkTestCase): diff --git a/python/pyspark/worker.py b/python/pyspark/worker.py index fbdaf3a5814c..93df9002be37 100644 --- a/python/pyspark/worker.py +++ b/python/pyspark/worker.py @@ -57,6 +57,12 @@ def main(infile, outfile): if split_index == -1: # for unit tests exit(-1) + version = utf8_deserializer.loads(infile) + if version != "%d.%d" % sys.version_info[:2]: + raise Exception(("Python in worker has different version %s than that in " + + "driver %s, PySpark cannot run with different minor versions") % + ("%d.%d" % sys.version_info[:2], version)) + # initialize global state shuffle.MemoryBytesSpilled = 0 shuffle.DiskBytesSpilled = 0 @@ -92,11 +98,7 @@ def main(infile, outfile): command = pickleSer._read_with_length(infile) if isinstance(command, Broadcast): command = pickleSer.loads(command.value) - (func, profiler, deserializer, serializer), version = command - if version != sys.version_info[:2]: - raise Exception(("Python in worker has different version %s than that in " + - "driver %s, PySpark cannot run with different minor versions") % - (sys.version_info[:2], version)) + func, profiler, deserializer, serializer = command init_time = time.time() def process(): diff --git a/python/run-tests b/python/run-tests index f2757a3967e8..24949657ed7a 100755 --- a/python/run-tests +++ b/python/run-tests @@ -18,160 +18,7 @@ # -# Figure out where the Spark framework is installed -FWDIR="$(cd "`dirname "$0"`"; cd ../; pwd)" +FWDIR="$(cd "`dirname $0`"/..; pwd)" +cd "$FWDIR" -. "$FWDIR"/bin/load-spark-env.sh - -# CD into the python directory to find things on the right path -cd "$FWDIR/python" - -FAILED=0 -LOG_FILE=unit-tests.log -START=$(date +"%s") - -rm -f $LOG_FILE - -# Remove the metastore and warehouse directory created by the HiveContext tests in Spark SQL -rm -rf metastore warehouse - -function run_test() { - echo -en "Running test: $1 ... " | tee -a $LOG_FILE - start=$(date +"%s") - SPARK_TESTING=1 time "$FWDIR"/bin/pyspark $1 > $LOG_FILE 2>&1 - - FAILED=$((PIPESTATUS[0]||$FAILED)) - - # Fail and exit on the first test failure. - if [[ $FAILED != 0 ]]; then - cat $LOG_FILE | grep -v "^[0-9][0-9]*" # filter all lines starting with a number. - echo -en "\033[31m" # Red - echo "Had test failures; see logs." - echo -en "\033[0m" # No color - exit -1 - else - now=$(date +"%s") - echo "ok ($(($now - $start))s)" - fi -} - -function run_core_tests() { - echo "Run core tests ..." - run_test "pyspark/rdd.py" - run_test "pyspark/context.py" - run_test "pyspark/conf.py" - PYSPARK_DOC_TEST=1 run_test "pyspark/broadcast.py" - PYSPARK_DOC_TEST=1 run_test "pyspark/accumulators.py" - run_test "pyspark/serializers.py" - run_test "pyspark/profiler.py" - run_test "pyspark/shuffle.py" - run_test "pyspark/tests.py" -} - -function run_sql_tests() { - echo "Run sql tests ..." - run_test "pyspark/sql/_types.py" - run_test "pyspark/sql/context.py" - run_test "pyspark/sql/dataframe.py" - run_test "pyspark/sql/functions.py" - run_test "pyspark/sql/tests.py" -} - -function run_mllib_tests() { - echo "Run mllib tests ..." - run_test "pyspark/mllib/classification.py" - run_test "pyspark/mllib/clustering.py" - run_test "pyspark/mllib/evaluation.py" - run_test "pyspark/mllib/feature.py" - run_test "pyspark/mllib/fpm.py" - run_test "pyspark/mllib/linalg.py" - run_test "pyspark/mllib/rand.py" - run_test "pyspark/mllib/recommendation.py" - run_test "pyspark/mllib/regression.py" - run_test "pyspark/mllib/stat/_statistics.py" - run_test "pyspark/mllib/tree.py" - run_test "pyspark/mllib/util.py" - run_test "pyspark/mllib/tests.py" -} - -function run_ml_tests() { - echo "Run ml tests ..." - run_test "pyspark/ml/feature.py" - run_test "pyspark/ml/classification.py" - run_test "pyspark/ml/recommendation.py" - run_test "pyspark/ml/regression.py" - run_test "pyspark/ml/tuning.py" - run_test "pyspark/ml/tests.py" - run_test "pyspark/ml/evaluation.py" -} - -function run_streaming_tests() { - echo "Run streaming tests ..." - - KAFKA_ASSEMBLY_DIR="$FWDIR"/external/kafka-assembly - JAR_PATH="${KAFKA_ASSEMBLY_DIR}/target/scala-${SPARK_SCALA_VERSION}" - for f in "${JAR_PATH}"/spark-streaming-kafka-assembly-*.jar; do - if [[ ! -e "$f" ]]; then - echo "Failed to find Spark Streaming Kafka assembly jar in $KAFKA_ASSEMBLY_DIR" 1>&2 - echo "You need to build Spark with " \ - "'build/sbt assembly/assembly streaming-kafka-assembly/assembly' or" \ - "'build/mvn package' before running this program" 1>&2 - exit 1 - fi - KAFKA_ASSEMBLY_JAR="$f" - done - - export PYSPARK_SUBMIT_ARGS="--jars ${KAFKA_ASSEMBLY_JAR} pyspark-shell" - run_test "pyspark/streaming/util.py" - run_test "pyspark/streaming/tests.py" -} - -echo "Running PySpark tests. Output is in python/$LOG_FILE." - -export PYSPARK_PYTHON="python" - -# Try to test with Python 2.6, since that's the minimum version that we support: -if [ $(which python2.6) ]; then - export PYSPARK_PYTHON="python2.6" -fi - -echo "Testing with Python version:" -$PYSPARK_PYTHON --version - -run_core_tests -run_sql_tests -run_mllib_tests -run_ml_tests -run_streaming_tests - -# Try to test with Python 3 -if [ $(which python3.4) ]; then - export PYSPARK_PYTHON="python3.4" - echo "Testing with Python3.4 version:" - $PYSPARK_PYTHON --version - - run_core_tests - run_sql_tests - run_mllib_tests - run_ml_tests - run_streaming_tests -fi - -# Try to test with PyPy -if [ $(which pypy) ]; then - export PYSPARK_PYTHON="pypy" - echo "Testing with PyPy version:" - $PYSPARK_PYTHON --version - - run_core_tests - run_sql_tests - run_streaming_tests -fi - -if [[ $FAILED == 0 ]]; then - now=$(date +"%s") - echo -e "\033[32mTests passed \033[0min $(($now - $START)) seconds" -fi - -# TODO: in the long-run, it would be nice to use a test runner like `nose`. -# The doctest fixtures are the current barrier to doing this. +exec python -u ./python/run-tests.py "$@" diff --git a/python/run-tests.py b/python/run-tests.py new file mode 100755 index 000000000000..7638854def2e --- /dev/null +++ b/python/run-tests.py @@ -0,0 +1,213 @@ +#!/usr/bin/env python + +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +from __future__ import print_function +import logging +from optparse import OptionParser +import os +import re +import subprocess +import sys +import tempfile +from threading import Thread, Lock +import time +if sys.version < '3': + import Queue +else: + import queue as Queue +if sys.version_info >= (2, 7): + subprocess_check_output = subprocess.check_output +else: + # SPARK-8763 + # backported from subprocess module in Python 2.7 + def subprocess_check_output(*popenargs, **kwargs): + if 'stdout' in kwargs: + raise ValueError('stdout argument not allowed, it will be overridden.') + process = subprocess.Popen(stdout=subprocess.PIPE, *popenargs, **kwargs) + output, unused_err = process.communicate() + retcode = process.poll() + if retcode: + cmd = kwargs.get("args") + if cmd is None: + cmd = popenargs[0] + raise subprocess.CalledProcessError(retcode, cmd, output=output) + return output + + +# Append `SPARK_HOME/dev` to the Python path so that we can import the sparktestsupport module +sys.path.append(os.path.join(os.path.dirname(os.path.realpath(__file__)), "../dev/")) + + +from sparktestsupport import SPARK_HOME # noqa (suppress pep8 warnings) +from sparktestsupport.shellutils import which # noqa +from sparktestsupport.modules import all_modules # noqa + + +python_modules = dict((m.name, m) for m in all_modules if m.python_test_goals if m.name != 'root') + + +def print_red(text): + print('\033[31m' + text + '\033[0m') + + +LOG_FILE = os.path.join(SPARK_HOME, "python/unit-tests.log") +FAILURE_REPORTING_LOCK = Lock() +LOGGER = logging.getLogger() + + +def run_individual_python_test(test_name, pyspark_python): + env = {'SPARK_TESTING': '1', 'PYSPARK_PYTHON': which(pyspark_python)} + LOGGER.debug("Starting test(%s): %s", pyspark_python, test_name) + start_time = time.time() + try: + per_test_output = tempfile.TemporaryFile() + retcode = subprocess.Popen( + [os.path.join(SPARK_HOME, "bin/pyspark"), test_name], + stderr=per_test_output, stdout=per_test_output, env=env).wait() + except: + LOGGER.exception("Got exception while running %s with %s", test_name, pyspark_python) + # Here, we use os._exit() instead of sys.exit() in order to force Python to exit even if + # this code is invoked from a thread other than the main thread. + os._exit(1) + duration = time.time() - start_time + # Exit on the first failure. + if retcode != 0: + try: + with FAILURE_REPORTING_LOCK: + with open(LOG_FILE, 'ab') as log_file: + per_test_output.seek(0) + log_file.writelines(per_test_output) + per_test_output.seek(0) + for line in per_test_output: + decoded_line = line.decode() + if not re.match('[0-9]+', decoded_line): + print(decoded_line, end='') + per_test_output.close() + except: + LOGGER.exception("Got an exception while trying to print failed test output") + finally: + print_red("\nHad test failures in %s with %s; see logs." % (test_name, pyspark_python)) + # Here, we use os._exit() instead of sys.exit() in order to force Python to exit even if + # this code is invoked from a thread other than the main thread. + os._exit(-1) + else: + per_test_output.close() + LOGGER.info("Finished test(%s): %s (%is)", pyspark_python, test_name, duration) + + +def get_default_python_executables(): + python_execs = [x for x in ["python2.6", "python3.4", "pypy"] if which(x)] + if "python2.6" not in python_execs: + LOGGER.warning("Not testing against `python2.6` because it could not be found; falling" + " back to `python` instead") + python_execs.insert(0, "python") + return python_execs + + +def parse_opts(): + parser = OptionParser( + prog="run-tests" + ) + parser.add_option( + "--python-executables", type="string", default=','.join(get_default_python_executables()), + help="A comma-separated list of Python executables to test against (default: %default)" + ) + parser.add_option( + "--modules", type="string", + default=",".join(sorted(python_modules.keys())), + help="A comma-separated list of Python modules to test (default: %default)" + ) + parser.add_option( + "-p", "--parallelism", type="int", default=4, + help="The number of suites to test in parallel (default %default)" + ) + parser.add_option( + "--verbose", action="store_true", + help="Enable additional debug logging" + ) + + (opts, args) = parser.parse_args() + if args: + parser.error("Unsupported arguments: %s" % ' '.join(args)) + if opts.parallelism < 1: + parser.error("Parallelism cannot be less than 1") + return opts + + +def main(): + opts = parse_opts() + if (opts.verbose): + log_level = logging.DEBUG + else: + log_level = logging.INFO + logging.basicConfig(stream=sys.stdout, level=log_level, format="%(message)s") + LOGGER.info("Running PySpark tests. Output is in python/%s", LOG_FILE) + if os.path.exists(LOG_FILE): + os.remove(LOG_FILE) + python_execs = opts.python_executables.split(',') + modules_to_test = [] + for module_name in opts.modules.split(','): + if module_name in python_modules: + modules_to_test.append(python_modules[module_name]) + else: + print("Error: unrecognized module %s" % module_name) + sys.exit(-1) + LOGGER.info("Will test against the following Python executables: %s", python_execs) + LOGGER.info("Will test the following Python modules: %s", [x.name for x in modules_to_test]) + + task_queue = Queue.Queue() + for python_exec in python_execs: + python_implementation = subprocess_check_output( + [python_exec, "-c", "import platform; print(platform.python_implementation())"], + universal_newlines=True).strip() + LOGGER.debug("%s python_implementation is %s", python_exec, python_implementation) + LOGGER.debug("%s version is: %s", python_exec, subprocess_check_output( + [python_exec, "--version"], stderr=subprocess.STDOUT, universal_newlines=True).strip()) + for module in modules_to_test: + if python_implementation not in module.blacklisted_python_implementations: + for test_goal in module.python_test_goals: + task_queue.put((python_exec, test_goal)) + + def process_queue(task_queue): + while True: + try: + (python_exec, test_goal) = task_queue.get_nowait() + except Queue.Empty: + break + try: + run_individual_python_test(test_goal, python_exec) + finally: + task_queue.task_done() + + start_time = time.time() + for _ in range(opts.parallelism): + worker = Thread(target=process_queue, args=(task_queue,)) + worker.daemon = True + worker.start() + try: + task_queue.join() + except (KeyboardInterrupt, SystemExit): + print_red("Exiting due to interrupt") + sys.exit(-1) + total_duration = time.time() - start_time + LOGGER.info("Tests passed in %i seconds", total_duration) + + +if __name__ == "__main__": + main() diff --git a/python/test_support/sql/parquet_partitioned/_SUCCESS b/python/test_support/sql/parquet_partitioned/_SUCCESS new file mode 100644 index 000000000000..e69de29bb2d1 diff --git a/python/test_support/sql/parquet_partitioned/_common_metadata b/python/test_support/sql/parquet_partitioned/_common_metadata new file mode 100644 index 000000000000..7ef2320651de Binary files /dev/null and b/python/test_support/sql/parquet_partitioned/_common_metadata differ diff --git a/python/test_support/sql/parquet_partitioned/_metadata b/python/test_support/sql/parquet_partitioned/_metadata new file mode 100644 index 000000000000..78a1ca7d3827 Binary files /dev/null and b/python/test_support/sql/parquet_partitioned/_metadata differ diff --git a/python/test_support/sql/parquet_partitioned/year=2014/month=9/day=1/.part-r-00008.gz.parquet.crc b/python/test_support/sql/parquet_partitioned/year=2014/month=9/day=1/.part-r-00008.gz.parquet.crc new file mode 100644 index 000000000000..e93f42ed6f35 Binary files /dev/null and b/python/test_support/sql/parquet_partitioned/year=2014/month=9/day=1/.part-r-00008.gz.parquet.crc differ diff --git a/python/test_support/sql/parquet_partitioned/year=2014/month=9/day=1/part-r-00008.gz.parquet b/python/test_support/sql/parquet_partitioned/year=2014/month=9/day=1/part-r-00008.gz.parquet new file mode 100644 index 000000000000..461c382937ec Binary files /dev/null and b/python/test_support/sql/parquet_partitioned/year=2014/month=9/day=1/part-r-00008.gz.parquet differ diff --git a/python/test_support/sql/parquet_partitioned/year=2015/month=10/day=25/.part-r-00002.gz.parquet.crc b/python/test_support/sql/parquet_partitioned/year=2015/month=10/day=25/.part-r-00002.gz.parquet.crc new file mode 100644 index 000000000000..b63c4d6d1e1d Binary files /dev/null and b/python/test_support/sql/parquet_partitioned/year=2015/month=10/day=25/.part-r-00002.gz.parquet.crc differ diff --git a/python/test_support/sql/parquet_partitioned/year=2015/month=10/day=25/.part-r-00004.gz.parquet.crc b/python/test_support/sql/parquet_partitioned/year=2015/month=10/day=25/.part-r-00004.gz.parquet.crc new file mode 100644 index 000000000000..5bc0ebd71356 Binary files /dev/null and b/python/test_support/sql/parquet_partitioned/year=2015/month=10/day=25/.part-r-00004.gz.parquet.crc differ diff --git a/python/test_support/sql/parquet_partitioned/year=2015/month=10/day=25/part-r-00002.gz.parquet b/python/test_support/sql/parquet_partitioned/year=2015/month=10/day=25/part-r-00002.gz.parquet new file mode 100644 index 000000000000..62a63915beac Binary files /dev/null and b/python/test_support/sql/parquet_partitioned/year=2015/month=10/day=25/part-r-00002.gz.parquet differ diff --git a/python/test_support/sql/parquet_partitioned/year=2015/month=10/day=25/part-r-00004.gz.parquet b/python/test_support/sql/parquet_partitioned/year=2015/month=10/day=25/part-r-00004.gz.parquet new file mode 100644 index 000000000000..67665a7b55da Binary files /dev/null and b/python/test_support/sql/parquet_partitioned/year=2015/month=10/day=25/part-r-00004.gz.parquet differ diff --git a/python/test_support/sql/parquet_partitioned/year=2015/month=10/day=26/.part-r-00005.gz.parquet.crc b/python/test_support/sql/parquet_partitioned/year=2015/month=10/day=26/.part-r-00005.gz.parquet.crc new file mode 100644 index 000000000000..ae94a15d08c8 Binary files /dev/null and b/python/test_support/sql/parquet_partitioned/year=2015/month=10/day=26/.part-r-00005.gz.parquet.crc differ diff --git a/python/test_support/sql/parquet_partitioned/year=2015/month=10/day=26/part-r-00005.gz.parquet b/python/test_support/sql/parquet_partitioned/year=2015/month=10/day=26/part-r-00005.gz.parquet new file mode 100644 index 000000000000..6cb8538aa890 Binary files /dev/null and b/python/test_support/sql/parquet_partitioned/year=2015/month=10/day=26/part-r-00005.gz.parquet differ diff --git a/python/test_support/sql/parquet_partitioned/year=2015/month=9/day=1/.part-r-00007.gz.parquet.crc b/python/test_support/sql/parquet_partitioned/year=2015/month=9/day=1/.part-r-00007.gz.parquet.crc new file mode 100644 index 000000000000..58d9bb5fc588 Binary files /dev/null and b/python/test_support/sql/parquet_partitioned/year=2015/month=9/day=1/.part-r-00007.gz.parquet.crc differ diff --git a/python/test_support/sql/parquet_partitioned/year=2015/month=9/day=1/part-r-00007.gz.parquet b/python/test_support/sql/parquet_partitioned/year=2015/month=9/day=1/part-r-00007.gz.parquet new file mode 100644 index 000000000000..9b00805481e7 Binary files /dev/null and b/python/test_support/sql/parquet_partitioned/year=2015/month=9/day=1/part-r-00007.gz.parquet differ diff --git a/python/test_support/sql/people.json b/python/test_support/sql/people.json new file mode 100644 index 000000000000..50a859cbd7ee --- /dev/null +++ b/python/test_support/sql/people.json @@ -0,0 +1,3 @@ +{"name":"Michael"} +{"name":"Andy", "age":30} +{"name":"Justin", "age":19} diff --git a/repl/pom.xml b/repl/pom.xml index 03053b4c3b28..370b2bc2fa8e 100644 --- a/repl/pom.xml +++ b/repl/pom.xml @@ -21,7 +21,7 @@ org.apache.spark spark-parent_2.10 - 1.4.0-SNAPSHOT + 1.5.0-SNAPSHOT ../pom.xml @@ -48,6 +48,13 @@ spark-core_${scala.binary.version} ${project.version} + + org.apache.spark + spark-core_${scala.binary.version} + ${project.version} + test-jar + test + org.apache.spark spark-bagel_${scala.binary.version} @@ -86,7 +93,7 @@ org.mockito - mockito-all + mockito-core test diff --git a/repl/scala-2.10/src/test/scala/org/apache/spark/repl/ReplSuite.scala b/repl/scala-2.10/src/test/scala/org/apache/spark/repl/ReplSuite.scala index 934daaeaafca..f150fec7db94 100644 --- a/repl/scala-2.10/src/test/scala/org/apache/spark/repl/ReplSuite.scala +++ b/repl/scala-2.10/src/test/scala/org/apache/spark/repl/ReplSuite.scala @@ -22,13 +22,12 @@ import java.net.URLClassLoader import scala.collection.mutable.ArrayBuffer -import org.scalatest.FunSuite -import org.apache.spark.SparkContext +import org.apache.spark.{SparkContext, SparkFunSuite} import org.apache.commons.lang3.StringEscapeUtils import org.apache.spark.util.Utils -class ReplSuite extends FunSuite { +class ReplSuite extends SparkFunSuite { def runInterpreter(master: String, input: String): String = { val CONF_EXECUTOR_CLASSPATH = "spark.executor.extraClassPath" @@ -268,6 +267,17 @@ class ReplSuite extends FunSuite { assertDoesNotContain("Exception", output) } + test("SPARK-8461 SQL with codegen") { + val output = runInterpreter("local", + """ + |val sqlContext = new org.apache.spark.sql.SQLContext(sc) + |sqlContext.setConf("spark.sql.codegen", "true") + |sqlContext.range(0, 100).filter('id > 50).count() + """.stripMargin) + assertContains("Long = 49", output) + assertDoesNotContain("java.lang.ClassNotFoundException", output) + } + test("SPARK-2632 importing a method from non serializable class and not using it.") { val output = runInterpreter("local", """ diff --git a/repl/scala-2.11/src/test/scala/org/apache/spark/repl/ReplSuite.scala b/repl/scala-2.11/src/test/scala/org/apache/spark/repl/ReplSuite.scala index 14f5e9ed4f25..9ecc7c229e38 100644 --- a/repl/scala-2.11/src/test/scala/org/apache/spark/repl/ReplSuite.scala +++ b/repl/scala-2.11/src/test/scala/org/apache/spark/repl/ReplSuite.scala @@ -24,14 +24,13 @@ import scala.collection.mutable.ArrayBuffer import scala.concurrent.duration._ import scala.tools.nsc.interpreter.SparkILoop -import org.scalatest.FunSuite import org.apache.commons.lang3.StringEscapeUtils -import org.apache.spark.SparkContext +import org.apache.spark.{SparkContext, SparkFunSuite} import org.apache.spark.util.Utils -class ReplSuite extends FunSuite { +class ReplSuite extends SparkFunSuite { def runInterpreter(master: String, input: String): String = { val CONF_EXECUTOR_CLASSPATH = "spark.executor.extraClassPath" diff --git a/repl/src/test/scala/org/apache/spark/repl/ExecutorClassLoaderSuite.scala b/repl/src/test/scala/org/apache/spark/repl/ExecutorClassLoaderSuite.scala index c709cde74074..a58eda12b112 100644 --- a/repl/src/test/scala/org/apache/spark/repl/ExecutorClassLoaderSuite.scala +++ b/repl/src/test/scala/org/apache/spark/repl/ExecutorClassLoaderSuite.scala @@ -25,7 +25,6 @@ import scala.language.implicitConversions import scala.language.postfixOps import org.scalatest.BeforeAndAfterAll -import org.scalatest.FunSuite import org.scalatest.concurrent.Interruptor import org.scalatest.concurrent.Timeouts._ import org.scalatest.mock.MockitoSugar @@ -35,7 +34,7 @@ import org.apache.spark._ import org.apache.spark.util.Utils class ExecutorClassLoaderSuite - extends FunSuite + extends SparkFunSuite with BeforeAndAfterAll with MockitoSugar with Logging { diff --git a/scalastyle-config.xml b/scalastyle-config.xml index 7168d5b2a8e2..d6f927b6fa80 100644 --- a/scalastyle-config.xml +++ b/scalastyle-config.xml @@ -14,25 +14,41 @@ ~ See the License for the specific language governing permissions and ~ limitations under the License. --> - - - - - - + - Scalastyle standard configuration - - - - - - - - - Scalastyle standard configuration + + + + + + + + + + - - - - - - - - - - true - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - + + + + + + + + + + + + + + true + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + ARROW, EQUALS, ELSE, TRY, CATCH, FINALLY, LARROW, RARROW + + + + + + ARROW, EQUALS, COMMA, COLON, IF, ELSE, DO, WHILE, FOR, MATCH, TRY, CATCH, FINALLY, LARROW, RARROW + + + + + + + + + ^FunSuite[A-Za-z]*$ + Tests must extend org.apache.spark.SparkFunSuite instead. + + + + + + + + + ^println$ + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + 800> + + + + + 30 + + + + + 10 + + + + + 50 + + + + + + + + + + + -1,0,1,2,3 + + diff --git a/sql/README.md b/sql/README.md index 46aec7cef798..63d4dac9829e 100644 --- a/sql/README.md +++ b/sql/README.md @@ -25,7 +25,7 @@ export HADOOP_HOME="/hadoop-1.0.4" If you are working with Hive 0.13.1, the following steps are needed: -1. Download Hive's [0.13.1](https://hive.apache.org/downloads.html) and set `HIVE_HOME` with `export HIVE_HOME=""`. Please do not set `HIVE_DEV_HOME` (See [SPARK-4119](https://issues.apache.org/jira/browse/SPARK-4119)). +1. Download Hive's [0.13.1](https://archive.apache.org/dist/hive/hive-0.13.1) and set `HIVE_HOME` with `export HIVE_HOME=""`. Please do not set `HIVE_DEV_HOME` (See [SPARK-4119](https://issues.apache.org/jira/browse/SPARK-4119)). 2. Set `HADOOP_HOME` with `export HADOOP_HOME=""` 3. Download all Hive 0.13.1a jars (Hive jars actually used by Spark) from [here](http://mvnrepository.com/artifact/org.spark-project.hive) and replace corresponding original 0.13.1 jars in `$HIVE_HOME/lib`. 4. Download [Kryo 2.21 jar](http://mvnrepository.com/artifact/com.esotericsoftware.kryo/kryo/2.21) (Note: 2.22 jar does not work) and [Javolution 5.5.1 jar](http://mvnrepository.com/artifact/javolution/javolution/5.5.1) to `$HIVE_HOME/lib`. diff --git a/sql/catalyst/pom.xml b/sql/catalyst/pom.xml index 5c322d032d47..f4b1cc3a4ffe 100644 --- a/sql/catalyst/pom.xml +++ b/sql/catalyst/pom.xml @@ -22,7 +22,7 @@ org.apache.spark spark-parent_2.10 - 1.4.0-SNAPSHOT + 1.5.0-SNAPSHOT ../../pom.xml @@ -36,10 +36,6 @@ - - org.scala-lang - scala-compiler - org.scala-lang scala-reflect @@ -50,6 +46,13 @@ spark-core_${scala.binary.version} ${project.version} + + org.apache.spark + spark-core_${scala.binary.version} + ${project.version} + test-jar + test + org.apache.spark spark-unsafe_${scala.binary.version} @@ -60,6 +63,11 @@ scalacheck_${scala.binary.version} test + + org.codehaus.janino + janino + 2.7.8 + target/scala-${scala.binary.version}/classes @@ -101,13 +109,6 @@ !scala-2.11 - - - org.scalamacros - quasiquotes_${scala.binary.version} - ${scala.macros.version} - - diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeFixedWidthAggregationMap.java b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeFixedWidthAggregationMap.java index 299ff3728a6d..1e79f4b2e88e 100644 --- a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeFixedWidthAggregationMap.java +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeFixedWidthAggregationMap.java @@ -17,12 +17,13 @@ package org.apache.spark.sql.catalyst.expressions; -import java.util.Arrays; import java.util.Iterator; -import org.apache.spark.sql.Row; -import org.apache.spark.sql.types.StructField; -import org.apache.spark.sql.types.StructType; +import scala.Function1; + +import org.apache.spark.sql.catalyst.InternalRow; +import org.apache.spark.sql.catalyst.util.ObjectPool; +import org.apache.spark.sql.catalyst.util.UniqueObjectPool; import org.apache.spark.unsafe.PlatformDependent; import org.apache.spark.unsafe.map.BytesToBytesMap; import org.apache.spark.unsafe.memory.MemoryLocation; @@ -39,16 +40,28 @@ public final class UnsafeFixedWidthAggregationMap { * An empty aggregation buffer, encoded in UnsafeRow format. When inserting a new key into the * map, we copy this buffer and use it as the value. */ - private final long[] emptyAggregationBuffer; + private final byte[] emptyBuffer; + + /** + * An empty row used by `initProjection` + */ + private static final InternalRow emptyRow = new GenericInternalRow(); - private final StructType aggregationBufferSchema; + /** + * Whether can the empty aggregation buffer be reuse without calling `initProjection` or not. + */ + private final boolean reuseEmptyBuffer; - private final StructType groupingKeySchema; + /** + * The projection used to initialize the emptyBuffer + */ + private final Function1 initProjection; /** - * Encodes grouping keys as UnsafeRows. + * Encodes grouping keys or buffers as UnsafeRows. */ - private final UnsafeRowConverter groupingKeyToUnsafeRowConverter; + private final UnsafeRowConverter keyConverter; + private final UnsafeRowConverter bufferConverter; /** * A hashmap which maps from opaque bytearray keys to bytearray values. @@ -56,134 +69,115 @@ public final class UnsafeFixedWidthAggregationMap { private final BytesToBytesMap map; /** - * Re-used pointer to the current aggregation buffer + * An object pool for objects that are used in grouping keys. */ - private final UnsafeRow currentAggregationBuffer = new UnsafeRow(); + private final UniqueObjectPool keyPool; /** - * Scratch space that is used when encoding grouping keys into UnsafeRow format. - * - * By default, this is a 1MB array, but it will grow as necessary in case larger keys are - * encountered. + * An object pool for objects that are used in aggregation buffers. */ - private long[] groupingKeyConversionScratchSpace = new long[1024 / 8]; - - private final boolean enablePerfMetrics; + private final ObjectPool bufferPool; /** - * @return true if UnsafeFixedWidthAggregationMap supports grouping keys with the given schema, - * false otherwise. + * Re-used pointer to the current aggregation buffer */ - public static boolean supportsGroupKeySchema(StructType schema) { - for (StructField field: schema.fields()) { - if (!UnsafeRow.readableFieldTypes.contains(field.dataType())) { - return false; - } - } - return true; - } + private final UnsafeRow currentBuffer = new UnsafeRow(); /** - * @return true if UnsafeFixedWidthAggregationMap supports aggregation buffers with the given - * schema, false otherwise. + * Scratch space that is used when encoding grouping keys into UnsafeRow format. + * + * By default, this is a 8 kb array, but it will grow as necessary in case larger keys are + * encountered. */ - public static boolean supportsAggregationBufferSchema(StructType schema) { - for (StructField field: schema.fields()) { - if (!UnsafeRow.settableFieldTypes.contains(field.dataType())) { - return false; - } - } - return true; - } + private byte[] groupingKeyConversionScratchSpace = new byte[1024 * 8]; + + private final boolean enablePerfMetrics; /** * Create a new UnsafeFixedWidthAggregationMap. * - * @param emptyAggregationBuffer the default value for new keys (a "zero" of the agg. function) - * @param aggregationBufferSchema the schema of the aggregation buffer, used for row conversion. - * @param groupingKeySchema the schema of the grouping key, used for row conversion. + * @param initProjection the default value for new keys (a "zero" of the agg. function) + * @param keyConverter the converter of the grouping key, used for row conversion. + * @param bufferConverter the converter of the aggregation buffer, used for row conversion. * @param memoryManager the memory manager used to allocate our Unsafe memory structures. * @param initialCapacity the initial capacity of the map (a sizing hint to avoid re-hashing). * @param enablePerfMetrics if true, performance metrics will be recorded (has minor perf impact) */ public UnsafeFixedWidthAggregationMap( - Row emptyAggregationBuffer, - StructType aggregationBufferSchema, - StructType groupingKeySchema, + Function1 initProjection, + UnsafeRowConverter keyConverter, + UnsafeRowConverter bufferConverter, TaskMemoryManager memoryManager, int initialCapacity, boolean enablePerfMetrics) { - this.emptyAggregationBuffer = - convertToUnsafeRow(emptyAggregationBuffer, aggregationBufferSchema); - this.aggregationBufferSchema = aggregationBufferSchema; - this.groupingKeyToUnsafeRowConverter = new UnsafeRowConverter(groupingKeySchema); - this.groupingKeySchema = groupingKeySchema; - this.map = new BytesToBytesMap(memoryManager, initialCapacity, enablePerfMetrics); + this.initProjection = initProjection; + this.keyConverter = keyConverter; + this.bufferConverter = bufferConverter; this.enablePerfMetrics = enablePerfMetrics; - } - /** - * Convert a Java object row into an UnsafeRow, allocating it into a new long array. - */ - private static long[] convertToUnsafeRow(Row javaRow, StructType schema) { - final UnsafeRowConverter converter = new UnsafeRowConverter(schema); - final long[] unsafeRow = new long[converter.getSizeRequirement(javaRow)]; - final long writtenLength = - converter.writeRow(javaRow, unsafeRow, PlatformDependent.LONG_ARRAY_OFFSET); - assert (writtenLength == unsafeRow.length): "Size requirement calculation was wrong!"; - return unsafeRow; + this.map = new BytesToBytesMap(memoryManager, initialCapacity, enablePerfMetrics); + this.keyPool = new UniqueObjectPool(100); + this.bufferPool = new ObjectPool(initialCapacity); + + InternalRow initRow = initProjection.apply(emptyRow); + this.emptyBuffer = new byte[bufferConverter.getSizeRequirement(initRow)]; + int writtenLength = bufferConverter.writeRow( + initRow, emptyBuffer, PlatformDependent.BYTE_ARRAY_OFFSET, bufferPool); + assert (writtenLength == emptyBuffer.length): "Size requirement calculation was wrong!"; + // re-use the empty buffer only when there is no object saved in pool. + reuseEmptyBuffer = bufferPool.size() == 0; } /** * Return the aggregation buffer for the current group. For efficiency, all calls to this method * return the same object. */ - public UnsafeRow getAggregationBuffer(Row groupingKey) { - final int groupingKeySize = groupingKeyToUnsafeRowConverter.getSizeRequirement(groupingKey); + public UnsafeRow getAggregationBuffer(InternalRow groupingKey) { + final int groupingKeySize = keyConverter.getSizeRequirement(groupingKey); // Make sure that the buffer is large enough to hold the key. If it's not, grow it: if (groupingKeySize > groupingKeyConversionScratchSpace.length) { - // This new array will be initially zero, so there's no need to zero it out here - groupingKeyConversionScratchSpace = new long[groupingKeySize]; - } else { - // Zero out the buffer that's used to hold the current row. This is necessary in order - // to ensure that rows hash properly, since garbage data from the previous row could - // otherwise end up as padding in this row. As a performance optimization, we only zero out - // the portion of the buffer that we'll actually write to. - Arrays.fill(groupingKeyConversionScratchSpace, 0, groupingKeySize, 0); + groupingKeyConversionScratchSpace = new byte[groupingKeySize]; } - final long actualGroupingKeySize = groupingKeyToUnsafeRowConverter.writeRow( + final int actualGroupingKeySize = keyConverter.writeRow( groupingKey, groupingKeyConversionScratchSpace, - PlatformDependent.LONG_ARRAY_OFFSET); + PlatformDependent.BYTE_ARRAY_OFFSET, + keyPool); assert (groupingKeySize == actualGroupingKeySize) : "Size requirement calculation was wrong!"; // Probe our map using the serialized key final BytesToBytesMap.Location loc = map.lookup( groupingKeyConversionScratchSpace, - PlatformDependent.LONG_ARRAY_OFFSET, + PlatformDependent.BYTE_ARRAY_OFFSET, groupingKeySize); if (!loc.isDefined()) { // This is the first time that we've seen this grouping key, so we'll insert a copy of the // empty aggregation buffer into the map: + if (!reuseEmptyBuffer) { + // There is some objects referenced by emptyBuffer, so generate a new one + InternalRow initRow = initProjection.apply(emptyRow); + bufferConverter.writeRow(initRow, emptyBuffer, PlatformDependent.BYTE_ARRAY_OFFSET, + bufferPool); + } loc.putNewKey( groupingKeyConversionScratchSpace, - PlatformDependent.LONG_ARRAY_OFFSET, + PlatformDependent.BYTE_ARRAY_OFFSET, groupingKeySize, - emptyAggregationBuffer, - PlatformDependent.LONG_ARRAY_OFFSET, - emptyAggregationBuffer.length + emptyBuffer, + PlatformDependent.BYTE_ARRAY_OFFSET, + emptyBuffer.length ); } // Reset the pointer to point to the value that we just stored or looked up: final MemoryLocation address = loc.getValueAddress(); - currentAggregationBuffer.pointTo( + currentBuffer.pointTo( address.getBaseObject(), address.getBaseOffset(), - aggregationBufferSchema.length(), - aggregationBufferSchema + bufferConverter.numFields(), + bufferPool ); - return currentAggregationBuffer; + return currentBuffer; } /** @@ -219,14 +213,14 @@ public MapEntry next() { entry.key.pointTo( keyAddress.getBaseObject(), keyAddress.getBaseOffset(), - groupingKeySchema.length(), - groupingKeySchema + keyConverter.numFields(), + keyPool ); entry.value.pointTo( valueAddress.getBaseObject(), valueAddress.getBaseOffset(), - aggregationBufferSchema.length(), - aggregationBufferSchema + bufferConverter.numFields(), + bufferPool ); return entry; } @@ -254,6 +248,8 @@ public void printPerfMetrics() { System.out.println("Number of hash collisions: " + map.getNumHashCollisions()); System.out.println("Time spent resizing (ns): " + map.getTimeSpentResizingNs()); System.out.println("Total memory consumption (bytes): " + map.getTotalMemoryConsumption()); + System.out.println("Number of unique objects in keys: " + keyPool.size()); + System.out.println("Number of objects in buffers: " + bufferPool.size()); } } diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeRow.java b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeRow.java index bb546b3086b3..f077064a02ec 100644 --- a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeRow.java +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeRow.java @@ -17,22 +17,12 @@ package org.apache.spark.sql.catalyst.expressions; -import scala.collection.Map; -import scala.collection.Seq; -import scala.collection.mutable.ArraySeq; - -import javax.annotation.Nullable; -import java.math.BigDecimal; -import java.sql.Date; -import java.util.*; - -import org.apache.spark.sql.Row; -import org.apache.spark.sql.types.DataType; -import static org.apache.spark.sql.types.DataTypes.*; -import org.apache.spark.sql.types.StructType; -import org.apache.spark.sql.types.UTF8String; +import org.apache.spark.sql.catalyst.InternalRow; +import org.apache.spark.sql.catalyst.util.ObjectPool; import org.apache.spark.unsafe.PlatformDependent; import org.apache.spark.unsafe.bitset.BitSetMethods; +import org.apache.spark.unsafe.types.UTF8String; + /** * An Unsafe implementation of Row which is backed by raw memory instead of Java objects. @@ -45,32 +35,43 @@ * In the `values` region, we store one 8-byte word per field. For fields that hold fixed-length * primitive types, such as long, double, or int, we store the value directly in the word. For * fields with non-primitive or variable-length values, we store a relative offset (w.r.t. the - * base address of the row) that points to the beginning of the variable-length field. + * base address of the row) that points to the beginning of the variable-length field, and length + * (they are combined into a long). For other objects, they are stored in a pool, the indexes of + * them are hold in the the word. + * + * In order to support fast hashing and equality checks for UnsafeRows that contain objects + * when used as grouping key in BytesToBytesMap, we put the objects in an UniqueObjectPool to make + * sure all the key have the same index for same object, then we can hash/compare the objects by + * hash/compare the index. + * + * For non-primitive types, the word of a field could be: + * UNION { + * [1] [offset: 31bits] [length: 31bits] // StringType + * [0] [offset: 31bits] [length: 31bits] // BinaryType + * - [index: 63bits] // StringType, Binary, index to object in pool + * } * * Instances of `UnsafeRow` act as pointers to row data stored in this format. */ -public final class UnsafeRow implements MutableRow { +public final class UnsafeRow extends MutableRow { private Object baseObject; private long baseOffset; + /** A pool to hold non-primitive objects */ + private ObjectPool pool; + Object getBaseObject() { return baseObject; } long getBaseOffset() { return baseOffset; } + ObjectPool getPool() { return pool; } /** The number of fields in this row, used for calculating the bitset width (and in assertions) */ private int numFields; + public int length() { return numFields; } + /** The width of the null tracking bit set, in bytes */ private int bitSetWidthInBytes; - /** - * This optional schema is required if you want to call generic get() and set() methods on - * this UnsafeRow, but is optional if callers will only use type-specific getTYPE() and setTYPE() - * methods. This should be removed after the planned InternalRow / Row split; right now, it's only - * needed by the generic get() method, which is only called internally by code that accesses - * UTF8String-typed columns. - */ - @Nullable - private StructType schema; private long getFieldOffset(int ordinal) { return baseOffset + bitSetWidthInBytes + ordinal * 8L; @@ -80,38 +81,7 @@ public static int calculateBitSetWidthInBytes(int numFields) { return ((numFields / 64) + (numFields % 64 == 0 ? 0 : 1)) * 8; } - /** - * Field types that can be updated in place in UnsafeRows (e.g. we support set() for these types) - */ - public static final Set settableFieldTypes; - - /** - * Fields types can be read(but not set (e.g. set() will throw UnsupportedOperationException). - */ - public static final Set readableFieldTypes; - - static { - settableFieldTypes = Collections.unmodifiableSet( - new HashSet( - Arrays.asList(new DataType[] { - NullType, - BooleanType, - ByteType, - ShortType, - IntegerType, - LongType, - FloatType, - DoubleType - }))); - - // We support get() on a superset of the types for which we support set(): - final Set _readableFieldTypes = new HashSet( - Arrays.asList(new DataType[]{ - StringType - })); - _readableFieldTypes.addAll(settableFieldTypes); - readableFieldTypes = Collections.unmodifiableSet(_readableFieldTypes); - } + public static final long OFFSET_BITS = 31L; /** * Construct a new UnsafeRow. The resulting row won't be usable until `pointTo()` has been called, @@ -125,22 +95,15 @@ public UnsafeRow() { } * @param baseObject the base object * @param baseOffset the offset within the base object * @param numFields the number of fields in this row - * @param schema an optional schema; this is necessary if you want to call generic get() or set() - * methods on this row, but is optional if the caller will only use type-specific - * getTYPE() and setTYPE() methods. + * @param pool the object pool to hold arbitrary objects */ - public void pointTo( - Object baseObject, - long baseOffset, - int numFields, - @Nullable StructType schema) { + public void pointTo(Object baseObject, long baseOffset, int numFields, ObjectPool pool) { assert numFields >= 0 : "numFields should >= 0"; - assert schema == null || schema.fields().length == numFields; this.bitSetWidthInBytes = calculateBitSetWidthInBytes(numFields); this.baseObject = baseObject; this.baseOffset = baseOffset; this.numFields = numFields; - this.schema = schema; + this.pool = pool; } private void assertIndexIsValid(int index) { @@ -163,9 +126,68 @@ private void setNotNullAt(int i) { BitSetMethods.unset(baseObject, baseOffset, i); } + /** + * Updates the column `i` as Object `value`, which cannot be primitive types. + */ @Override - public void update(int ordinal, Object value) { - throw new UnsupportedOperationException(); + public void update(int i, Object value) { + if (value == null) { + if (!isNullAt(i)) { + // remove the old value from pool + long idx = getLong(i); + if (idx <= 0) { + // this is the index of old value in pool, remove it + pool.replace((int)-idx, null); + } else { + // there will be some garbage left (UTF8String or byte[]) + } + setNullAt(i); + } + return; + } + + if (isNullAt(i)) { + // there is not an old value, put the new value into pool + int idx = pool.put(value); + setLong(i, (long)-idx); + } else { + // there is an old value, check the type, then replace it or update it + long v = getLong(i); + if (v <= 0) { + // it's the index in the pool, replace old value with new one + int idx = (int)-v; + pool.replace(idx, value); + } else { + // old value is UTF8String or byte[], try to reuse the space + boolean isString; + byte[] newBytes; + if (value instanceof UTF8String) { + newBytes = ((UTF8String) value).getBytes(); + isString = true; + } else { + newBytes = (byte[]) value; + isString = false; + } + int offset = (int) ((v >> OFFSET_BITS) & Integer.MAX_VALUE); + int oldLength = (int) (v & Integer.MAX_VALUE); + if (newBytes.length <= oldLength) { + // the new value can fit in the old buffer, re-use it + PlatformDependent.copyMemory( + newBytes, + PlatformDependent.BYTE_ARRAY_OFFSET, + baseObject, + baseOffset + offset, + newBytes.length); + long flag = isString ? 1L << (OFFSET_BITS * 2) : 0L; + setLong(i, flag | (((long) offset) << OFFSET_BITS) | (long) newBytes.length); + } else { + // Cannot fit in the buffer + int idx = pool.put(value); + setLong(i, (long) -idx); + } + } + } + setNotNullAt(i); } @Override @@ -217,46 +239,43 @@ public void setFloat(int ordinal, float value) { PlatformDependent.UNSAFE.putFloat(baseObject, getFieldOffset(ordinal), value); } - @Override - public void setString(int ordinal, String value) { - throw new UnsupportedOperationException(); - } - @Override public int size() { return numFields; } - @Override - public int length() { - return size(); - } - - @Override - public StructType schema() { - return schema; - } - - @Override - public Object apply(int i) { - return get(i); - } - + /** + * Returns the object for column `i`, which should not be primitive type. + */ @Override public Object get(int i) { assertIndexIsValid(i); - assert (schema != null) : "Schema must be defined when calling generic get() method"; - final DataType dataType = schema.fields()[i].dataType(); - // UnsafeRow is only designed to be invoked by internal code, which only invokes this generic - // get() method when trying to access UTF8String-typed columns. If we refactor the codebase to - // separate the internal and external row interfaces, then internal code can fetch strings via - // a new getUTF8String() method and we'll be able to remove this method. if (isNullAt(i)) { return null; - } else if (dataType == StringType) { - return getUTF8String(i); + } + long v = PlatformDependent.UNSAFE.getLong(baseObject, getFieldOffset(i)); + if (v <= 0) { + // It's an index to object in the pool. + int idx = (int)-v; + return pool.get(idx); } else { - throw new UnsupportedOperationException(); + // The column could be StingType or BinaryType + boolean isString = (v >> (OFFSET_BITS * 2)) > 0; + int offset = (int) ((v >> OFFSET_BITS) & Integer.MAX_VALUE); + int size = (int) (v & Integer.MAX_VALUE); + final byte[] bytes = new byte[size]; + PlatformDependent.copyMemory( + baseObject, + baseOffset + offset, + bytes, + PlatformDependent.BYTE_ARRAY_OFFSET, + size + ); + if (isString) { + return UTF8String.fromBytes(bytes); + } else { + return bytes; + } } } @@ -316,86 +335,8 @@ public double getDouble(int i) { } } - public UTF8String getUTF8String(int i) { - assertIndexIsValid(i); - final UTF8String str = new UTF8String(); - final long offsetToStringSize = getLong(i); - final int stringSizeInBytes = - (int) PlatformDependent.UNSAFE.getLong(baseObject, baseOffset + offsetToStringSize); - final byte[] strBytes = new byte[stringSizeInBytes]; - PlatformDependent.copyMemory( - baseObject, - baseOffset + offsetToStringSize + 8, // The `+ 8` is to skip past the size to get the data - strBytes, - PlatformDependent.BYTE_ARRAY_OFFSET, - stringSizeInBytes - ); - str.set(strBytes); - return str; - } - - @Override - public String getString(int i) { - return getUTF8String(i).toString(); - } - @Override - public BigDecimal getDecimal(int i) { - throw new UnsupportedOperationException(); - } - - @Override - public Date getDate(int i) { - throw new UnsupportedOperationException(); - } - - @Override - public Seq getSeq(int i) { - throw new UnsupportedOperationException(); - } - - @Override - public List getList(int i) { - throw new UnsupportedOperationException(); - } - - @Override - public Map getMap(int i) { - throw new UnsupportedOperationException(); - } - - @Override - public scala.collection.immutable.Map getValuesMap(Seq fieldNames) { - throw new UnsupportedOperationException(); - } - - @Override - public java.util.Map getJavaMap(int i) { - throw new UnsupportedOperationException(); - } - - @Override - public Row getStruct(int i) { - throw new UnsupportedOperationException(); - } - - @Override - public T getAs(int i) { - throw new UnsupportedOperationException(); - } - - @Override - public T getAs(String fieldName) { - throw new UnsupportedOperationException(); - } - - @Override - public int fieldIndex(String name) { - throw new UnsupportedOperationException(); - } - - @Override - public Row copy() { + public InternalRow copy() { throw new UnsupportedOperationException(); } @@ -403,33 +344,4 @@ public Row copy() { public boolean anyNull() { return BitSetMethods.anySet(baseObject, baseOffset, bitSetWidthInBytes); } - - @Override - public Seq toSeq() { - final ArraySeq values = new ArraySeq(numFields); - for (int fieldNumber = 0; fieldNumber < numFields; fieldNumber++) { - values.update(fieldNumber, get(fieldNumber)); - } - return values; - } - - @Override - public String toString() { - return mkString("[", ",", "]"); - } - - @Override - public String mkString() { - return toSeq().mkString(); - } - - @Override - public String mkString(String sep) { - return toSeq().mkString(sep); - } - - @Override - public String mkString(String start, String sep, String end) { - return toSeq().mkString(start, sep, end); - } } diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/util/ObjectPool.java b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/util/ObjectPool.java new file mode 100644 index 000000000000..97f89a7d0b75 --- /dev/null +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/util/ObjectPool.java @@ -0,0 +1,78 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.util; + +/** + * A object pool stores a collection of objects in array, then they can be referenced by the + * pool plus an index. + */ +public class ObjectPool { + + /** + * An array to hold objects, which will grow as needed. + */ + private Object[] objects; + + /** + * How many objects in the pool. + */ + private int numObj; + + public ObjectPool(int capacity) { + objects = new Object[capacity]; + numObj = 0; + } + + /** + * Returns how many objects in the pool. + */ + public int size() { + return numObj; + } + + /** + * Returns the object at position `idx` in the array. + */ + public Object get(int idx) { + assert (idx < numObj); + return objects[idx]; + } + + /** + * Puts an object `obj` at the end of array, returns the index of it. + *

    + * The array will grow as needed. + */ + public int put(Object obj) { + if (numObj >= objects.length) { + Object[] tmp = new Object[objects.length * 2]; + System.arraycopy(objects, 0, tmp, 0, objects.length); + objects = tmp; + } + objects[numObj++] = obj; + return numObj - 1; + } + + /** + * Replaces the object at `idx` with new one `obj`. + */ + public void replace(int idx, Object obj) { + assert (idx < numObj); + objects[idx] = obj; + } +} diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/util/UniqueObjectPool.java b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/util/UniqueObjectPool.java new file mode 100644 index 000000000000..d512392dcaac --- /dev/null +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/util/UniqueObjectPool.java @@ -0,0 +1,59 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.util; + +import java.util.HashMap; + +/** + * An unique object pool stores a collection of unique objects in it. + */ +public class UniqueObjectPool extends ObjectPool { + + /** + * A hash map from objects to their indexes in the array. + */ + private HashMap objIndex; + + public UniqueObjectPool(int capacity) { + super(capacity); + objIndex = new HashMap(); + } + + /** + * Put an object `obj` into the pool. If there is an existing object equals to `obj`, it will + * return the index of the existing one. + */ + @Override + public int put(Object obj) { + if (objIndex.containsKey(obj)) { + return objIndex.get(obj); + } else { + int idx = super.put(obj); + objIndex.put(obj, idx); + return idx; + } + } + + /** + * The objects can not be replaced. + */ + @Override + public void replace(int idx, Object obj) { + throw new UnsupportedOperationException(); + } +} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/Row.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/Row.scala index 4190b7ffe1c8..0f2fd6a86d17 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/Row.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/Row.scala @@ -17,8 +17,6 @@ package org.apache.spark.sql -import scala.util.hashing.MurmurHash3 - import org.apache.spark.sql.catalyst.expressions.GenericRow import org.apache.spark.sql.types.StructType @@ -55,6 +53,9 @@ object Row { // TODO: Improve the performance of this if used in performance critical part. new GenericRow(rows.flatMap(_.toSeq).toArray) } + + /** Returns an empty row. */ + val empty = apply() } @@ -178,7 +179,7 @@ trait Row extends Serializable { def get(i: Int): Any = apply(i) /** Checks whether the value at position i is null. */ - def isNullAt(i: Int): Boolean + def isNullAt(i: Int): Boolean = apply(i) == null /** * Returns the value at position i as a primitive boolean. @@ -186,7 +187,7 @@ trait Row extends Serializable { * @throws ClassCastException when data type does not match. * @throws NullPointerException when value is null. */ - def getBoolean(i: Int): Boolean + def getBoolean(i: Int): Boolean = getAs[Boolean](i) /** * Returns the value at position i as a primitive byte. @@ -194,7 +195,7 @@ trait Row extends Serializable { * @throws ClassCastException when data type does not match. * @throws NullPointerException when value is null. */ - def getByte(i: Int): Byte + def getByte(i: Int): Byte = getAs[Byte](i) /** * Returns the value at position i as a primitive short. @@ -202,7 +203,7 @@ trait Row extends Serializable { * @throws ClassCastException when data type does not match. * @throws NullPointerException when value is null. */ - def getShort(i: Int): Short + def getShort(i: Int): Short = getAs[Short](i) /** * Returns the value at position i as a primitive int. @@ -210,7 +211,7 @@ trait Row extends Serializable { * @throws ClassCastException when data type does not match. * @throws NullPointerException when value is null. */ - def getInt(i: Int): Int + def getInt(i: Int): Int = getAs[Int](i) /** * Returns the value at position i as a primitive long. @@ -218,7 +219,7 @@ trait Row extends Serializable { * @throws ClassCastException when data type does not match. * @throws NullPointerException when value is null. */ - def getLong(i: Int): Long + def getLong(i: Int): Long = getAs[Long](i) /** * Returns the value at position i as a primitive float. @@ -227,7 +228,7 @@ trait Row extends Serializable { * @throws ClassCastException when data type does not match. * @throws NullPointerException when value is null. */ - def getFloat(i: Int): Float + def getFloat(i: Int): Float = getAs[Float](i) /** * Returns the value at position i as a primitive double. @@ -235,7 +236,7 @@ trait Row extends Serializable { * @throws ClassCastException when data type does not match. * @throws NullPointerException when value is null. */ - def getDouble(i: Int): Double + def getDouble(i: Int): Double = getAs[Double](i) /** * Returns the value at position i as a String object. @@ -243,29 +244,35 @@ trait Row extends Serializable { * @throws ClassCastException when data type does not match. * @throws NullPointerException when value is null. */ - def getString(i: Int): String + def getString(i: Int): String = getAs[String](i) /** * Returns the value at position i of decimal type as java.math.BigDecimal. * * @throws ClassCastException when data type does not match. */ - def getDecimal(i: Int): java.math.BigDecimal = apply(i).asInstanceOf[java.math.BigDecimal] + def getDecimal(i: Int): java.math.BigDecimal = getAs[java.math.BigDecimal](i) /** * Returns the value at position i of date type as java.sql.Date. * * @throws ClassCastException when data type does not match. */ - // TODO(davies): This is not the right default implementation, we use Int as Date internally - def getDate(i: Int): java.sql.Date = apply(i).asInstanceOf[java.sql.Date] + def getDate(i: Int): java.sql.Date = getAs[java.sql.Date](i) + + /** + * Returns the value at position i of date type as java.sql.Timestamp. + * + * @throws ClassCastException when data type does not match. + */ + def getTimestamp(i: Int): java.sql.Timestamp = getAs[java.sql.Timestamp](i) /** * Returns the value at position i of array type as a Scala Seq. * * @throws ClassCastException when data type does not match. */ - def getSeq[T](i: Int): Seq[T] = apply(i).asInstanceOf[Seq[T]] + def getSeq[T](i: Int): Seq[T] = getAs[Seq[T]](i) /** * Returns the value at position i of array type as [[java.util.List]]. @@ -281,7 +288,7 @@ trait Row extends Serializable { * * @throws ClassCastException when data type does not match. */ - def getMap[K, V](i: Int): scala.collection.Map[K, V] = apply(i).asInstanceOf[Map[K, V]] + def getMap[K, V](i: Int): scala.collection.Map[K, V] = getAs[Map[K, V]](i) /** * Returns the value at position i of array type as a [[java.util.Map]]. @@ -356,42 +363,21 @@ trait Row extends Serializable { false } - override def equals(that: Any): Boolean = that match { - case null => false - case that: Row => - if (this.length != that.length) { - return false - } - var i = 0 - val len = this.length - while (i < len) { - if (apply(i) != that.apply(i)) { - return false - } - i += 1 - } - true - case _ => false - } - - override def hashCode: Int = { - // Using Scala's Seq hash code implementation. - var n = 0 - var h = MurmurHash3.seqSeed - val len = length - while (n < len) { - h = MurmurHash3.mix(h, apply(n).##) - n += 1 - } - MurmurHash3.finalizeHash(h, n) - } - /* ---------------------- utility methods for Scala ---------------------- */ /** - * Return a Scala Seq representing the row. ELements are placed in the same order in the Seq. + * Return a Scala Seq representing the row. Elements are placed in the same order in the Seq. */ - def toSeq: Seq[Any] + def toSeq: Seq[Any] = { + val n = length + val values = new Array[Any](n) + var i = 0 + while (i < n) { + values.update(i, get(i)) + i += 1 + } + values.toSeq + } /** Displays all elements of this sequence in a string (without a separator). */ def mkString: String = toSeq.mkString diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/AbstractSparkSQLParser.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/AbstractSparkSQLParser.scala index 2eb3e167baad..d494ae7b71d1 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/AbstractSparkSQLParser.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/AbstractSparkSQLParser.scala @@ -20,7 +20,7 @@ package org.apache.spark.sql.catalyst import scala.language.implicitConversions import scala.util.parsing.combinator.lexical.StdLexical import scala.util.parsing.combinator.syntactical.StandardTokenParsers -import scala.util.parsing.combinator.{PackratParsers, RegexParsers} +import scala.util.parsing.combinator.PackratParsers import scala.util.parsing.input.CharArrayReader.EofCh import org.apache.spark.sql.catalyst.plans.logical._ @@ -30,12 +30,14 @@ private[sql] abstract class AbstractSparkSQLParser def parse(input: String): LogicalPlan = { // Initialize the Keywords. - lexical.initialize(reservedWords) + initLexical phrase(start)(new lexical.Scanner(input)) match { case Success(plan, _) => plan case failureOrError => sys.error(failureOrError.toString) } } + /* One time initialization of lexical.This avoid reinitialization of lexical in parse method */ + protected lazy val initLexical: Unit = lexical.initialize(reservedWords) protected case class Keyword(str: String) { def normalize: String = lexical.normalizeKeyword(str) @@ -103,7 +105,7 @@ class SqlLexical extends StdLexical { ( identChar ~ (identChar | digit).* ^^ { case first ~ rest => processIdent((first :: rest).mkString) } | rep1(digit) ~ ('.' ~> digit.*).? ^^ { - case i ~ None => NumericLit(i.mkString) + case i ~ None => NumericLit(i.mkString) case i ~ Some(d) => FloatLit(i.mkString + "." + d.mkString) } | '\'' ~> chrExcept('\'', '\n', EofCh).* <~ '\'' ^^ diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/CatalystTypeConverters.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/CatalystTypeConverters.scala index 75a493b248f6..8f63d2120ad0 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/CatalystTypeConverters.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/CatalystTypeConverters.scala @@ -18,13 +18,18 @@ package org.apache.spark.sql.catalyst import java.lang.{Iterable => JavaIterable} +import java.math.{BigDecimal => JavaBigDecimal} +import java.sql.{Date, Timestamp} import java.util.{Map => JavaMap} +import javax.annotation.Nullable import scala.collection.mutable.HashMap +import org.apache.spark.sql.Row import org.apache.spark.sql.catalyst.expressions._ -import org.apache.spark.sql.catalyst.util.DateUtils +import org.apache.spark.sql.catalyst.util.DateTimeUtils import org.apache.spark.sql.types._ +import org.apache.spark.unsafe.types.UTF8String /** * Functions to convert Scala types to Catalyst types and vice versa. @@ -34,320 +39,379 @@ object CatalystTypeConverters { // Since the map values can be mutable, we explicitly import scala.collection.Map at here. import scala.collection.Map + private def isPrimitive(dataType: DataType): Boolean = { + dataType match { + case BooleanType => true + case ByteType => true + case ShortType => true + case IntegerType => true + case LongType => true + case FloatType => true + case DoubleType => true + case _ => false + } + } + + private def isWholePrimitive(dt: DataType): Boolean = dt match { + case dt if isPrimitive(dt) => true + case ArrayType(elementType, _) => isWholePrimitive(elementType) + case MapType(keyType, valueType, _) => isWholePrimitive(keyType) && isWholePrimitive(valueType) + case _ => false + } + + private def getConverterForType(dataType: DataType): CatalystTypeConverter[Any, Any, Any] = { + val converter = dataType match { + case udt: UserDefinedType[_] => UDTConverter(udt) + case arrayType: ArrayType => ArrayConverter(arrayType.elementType) + case mapType: MapType => MapConverter(mapType.keyType, mapType.valueType) + case structType: StructType => StructConverter(structType) + case StringType => StringConverter + case DateType => DateConverter + case TimestampType => TimestampConverter + case dt: DecimalType => BigDecimalConverter + case BooleanType => BooleanConverter + case ByteType => ByteConverter + case ShortType => ShortConverter + case IntegerType => IntConverter + case LongType => LongConverter + case FloatType => FloatConverter + case DoubleType => DoubleConverter + case _ => IdentityConverter + } + converter.asInstanceOf[CatalystTypeConverter[Any, Any, Any]] + } + /** - * Converts Scala objects to catalyst rows / types. This method is slow, and for batch - * conversion you should be using converter produced by createToCatalystConverter. - * Note: This is always called after schemaFor has been called. - * This ordering is important for UDT registration. + * Converts a Scala type to its Catalyst equivalent (and vice versa). + * + * @tparam ScalaInputType The type of Scala values that can be converted to Catalyst. + * @tparam ScalaOutputType The type of Scala values returned when converting Catalyst to Scala. + * @tparam CatalystType The internal Catalyst type used to represent values of this Scala type. */ - def convertToCatalyst(a: Any, dataType: DataType): Any = (a, dataType) match { - // Check UDT first since UDTs can override other types - case (obj, udt: UserDefinedType[_]) => - udt.serialize(obj) - - case (o: Option[_], _) => - o.map(convertToCatalyst(_, dataType)).orNull - - case (s: Seq[_], arrayType: ArrayType) => - s.map(convertToCatalyst(_, arrayType.elementType)) - - case (jit: JavaIterable[_], arrayType: ArrayType) => { - val iter = jit.iterator - var listOfItems: List[Any] = List() - while (iter.hasNext) { - val item = iter.next() - listOfItems :+= convertToCatalyst(item, arrayType.elementType) + private abstract class CatalystTypeConverter[ScalaInputType, ScalaOutputType, CatalystType] + extends Serializable { + + /** + * Converts a Scala type to its Catalyst equivalent while automatically handling nulls + * and Options. + */ + final def toCatalyst(@Nullable maybeScalaValue: Any): CatalystType = { + if (maybeScalaValue == null) { + null.asInstanceOf[CatalystType] + } else if (maybeScalaValue.isInstanceOf[Option[ScalaInputType]]) { + val opt = maybeScalaValue.asInstanceOf[Option[ScalaInputType]] + if (opt.isDefined) { + toCatalystImpl(opt.get) + } else { + null.asInstanceOf[CatalystType] + } + } else { + toCatalystImpl(maybeScalaValue.asInstanceOf[ScalaInputType]) } - listOfItems } - case (s: Array[_], arrayType: ArrayType) => - s.toSeq.map(convertToCatalyst(_, arrayType.elementType)) + /** + * Given a Catalyst row, convert the value at column `column` to its Scala equivalent. + */ + final def toScala(row: InternalRow, column: Int): ScalaOutputType = { + if (row.isNullAt(column)) null.asInstanceOf[ScalaOutputType] else toScalaImpl(row, column) + } - case (m: Map[_, _], mapType: MapType) => - m.map { case (k, v) => - convertToCatalyst(k, mapType.keyType) -> convertToCatalyst(v, mapType.valueType) - } + /** + * Convert a Catalyst value to its Scala equivalent. + */ + def toScala(@Nullable catalystValue: CatalystType): ScalaOutputType + + /** + * Converts a Scala value to its Catalyst equivalent. + * @param scalaValue the Scala value, guaranteed not to be null. + * @return the Catalyst value. + */ + protected def toCatalystImpl(scalaValue: ScalaInputType): CatalystType + + /** + * Given a Catalyst row, convert the value at column `column` to its Scala equivalent. + * This method will only be called on non-null columns. + */ + protected def toScalaImpl(row: InternalRow, column: Int): ScalaOutputType + } - case (jmap: JavaMap[_, _], mapType: MapType) => - val iter = jmap.entrySet.iterator - var listOfEntries: List[(Any, Any)] = List() - while (iter.hasNext) { - val entry = iter.next() - listOfEntries :+= (convertToCatalyst(entry.getKey, mapType.keyType), - convertToCatalyst(entry.getValue, mapType.valueType)) - } - listOfEntries.toMap - - case (p: Product, structType: StructType) => - val ar = new Array[Any](structType.size) - val iter = p.productIterator - var idx = 0 - while (idx < structType.size) { - ar(idx) = convertToCatalyst(iter.next(), structType.fields(idx).dataType) - idx += 1 - } - new GenericRowWithSchema(ar, structType) + private object IdentityConverter extends CatalystTypeConverter[Any, Any, Any] { + override def toCatalystImpl(scalaValue: Any): Any = scalaValue + override def toScala(catalystValue: Any): Any = catalystValue + override def toScalaImpl(row: InternalRow, column: Int): Any = row(column) + } - case (d: String, _) => - UTF8String(d) + private case class UDTConverter( + udt: UserDefinedType[_]) extends CatalystTypeConverter[Any, Any, Any] { + override def toCatalystImpl(scalaValue: Any): Any = udt.serialize(scalaValue) + override def toScala(catalystValue: Any): Any = udt.deserialize(catalystValue) + override def toScalaImpl(row: InternalRow, column: Int): Any = toScala(row(column)) + } - case (d: BigDecimal, _) => - Decimal(d) + /** Converter for arrays, sequences, and Java iterables. */ + private case class ArrayConverter( + elementType: DataType) extends CatalystTypeConverter[Any, Seq[Any], Seq[Any]] { - case (d: java.math.BigDecimal, _) => - Decimal(d) + private[this] val elementConverter = getConverterForType(elementType) - case (d: java.sql.Date, _) => - DateUtils.fromJavaDate(d) + private[this] val isNoChange = isWholePrimitive(elementType) - case (r: Row, structType: StructType) => - val converters = structType.fields.map { - f => (item: Any) => convertToCatalyst(item, f.dataType) + override def toCatalystImpl(scalaValue: Any): Seq[Any] = { + scalaValue match { + case a: Array[_] => a.toSeq.map(elementConverter.toCatalyst) + case s: Seq[_] => s.map(elementConverter.toCatalyst) + case i: JavaIterable[_] => + val iter = i.iterator + var convertedIterable: List[Any] = List() + while (iter.hasNext) { + val item = iter.next() + convertedIterable :+= elementConverter.toCatalyst(item) + } + convertedIterable } - convertRowWithConverters(r, structType, converters) + } - case (other, _) => - other + override def toScala(catalystValue: Seq[Any]): Seq[Any] = { + if (catalystValue == null) { + null + } else if (isNoChange) { + catalystValue + } else { + catalystValue.map(elementConverter.toScala) + } + } + + override def toScalaImpl(row: InternalRow, column: Int): Seq[Any] = + toScala(row(column).asInstanceOf[Seq[Any]]) } - /** - * Creates a converter function that will convert Scala objects to the specified catalyst type. - * Typical use case would be converting a collection of rows that have the same schema. You will - * call this function once to get a converter, and apply it to every row. - */ - private[sql] def createToCatalystConverter(dataType: DataType): Any => Any = { - def extractOption(item: Any): Any = item match { - case opt: Option[_] => opt.orNull - case other => other - } + private case class MapConverter( + keyType: DataType, + valueType: DataType) + extends CatalystTypeConverter[Any, Map[Any, Any], Map[Any, Any]] { - dataType match { - // Check UDT first since UDTs can override other types - case udt: UserDefinedType[_] => - (item) => extractOption(item) match { - case null => null - case other => udt.serialize(other) + private[this] val keyConverter = getConverterForType(keyType) + private[this] val valueConverter = getConverterForType(valueType) + + private[this] val isNoChange = isWholePrimitive(keyType) && isWholePrimitive(valueType) + + override def toCatalystImpl(scalaValue: Any): Map[Any, Any] = scalaValue match { + case m: Map[_, _] => + m.map { case (k, v) => + keyConverter.toCatalyst(k) -> valueConverter.toCatalyst(v) } - case arrayType: ArrayType => - val elementConverter = createToCatalystConverter(arrayType.elementType) - (item: Any) => { - extractOption(item) match { - case a: Array[_] => a.toSeq.map(elementConverter) - case s: Seq[_] => s.map(elementConverter) - case i: JavaIterable[_] => { - val iter = i.iterator - var convertedIterable: List[Any] = List() - while (iter.hasNext) { - val item = iter.next() - convertedIterable :+= elementConverter(item) - } - convertedIterable - } - case null => null - } + case jmap: JavaMap[_, _] => + val iter = jmap.entrySet.iterator + val convertedMap: HashMap[Any, Any] = HashMap() + while (iter.hasNext) { + val entry = iter.next() + val key = keyConverter.toCatalyst(entry.getKey) + convertedMap(key) = valueConverter.toCatalyst(entry.getValue) } + convertedMap + } - case mapType: MapType => - val keyConverter = createToCatalystConverter(mapType.keyType) - val valueConverter = createToCatalystConverter(mapType.valueType) - (item: Any) => { - extractOption(item) match { - case m: Map[_, _] => - m.map { case (k, v) => - keyConverter(k) -> valueConverter(v) - } - - case jmap: JavaMap[_, _] => - val iter = jmap.entrySet.iterator - val convertedMap: HashMap[Any, Any] = HashMap() - while (iter.hasNext) { - val entry = iter.next() - convertedMap(keyConverter(entry.getKey)) = valueConverter(entry.getValue) - } - convertedMap - - case null => null - } + override def toScala(catalystValue: Map[Any, Any]): Map[Any, Any] = { + if (catalystValue == null) { + null + } else if (isNoChange) { + catalystValue + } else { + catalystValue.map { case (k, v) => + keyConverter.toScala(k) -> valueConverter.toScala(v) } + } + } - case structType: StructType => - val converters = structType.fields.map(f => createToCatalystConverter(f.dataType)) - (item: Any) => { - extractOption(item) match { - case r: Row => - convertRowWithConverters(r, structType, converters) - - case p: Product => - val ar = new Array[Any](structType.size) - val iter = p.productIterator - var idx = 0 - while (idx < structType.size) { - ar(idx) = converters(idx)(iter.next()) - idx += 1 - } - new GenericRowWithSchema(ar, structType) - - case null => - null - } + override def toScalaImpl(row: InternalRow, column: Int): Map[Any, Any] = + toScala(row(column).asInstanceOf[Map[Any, Any]]) + } + + private case class StructConverter( + structType: StructType) extends CatalystTypeConverter[Any, Row, InternalRow] { + + private[this] val converters = structType.fields.map { f => getConverterForType(f.dataType) } + + override def toCatalystImpl(scalaValue: Any): InternalRow = scalaValue match { + case row: Row => + val ar = new Array[Any](row.size) + var idx = 0 + while (idx < row.size) { + ar(idx) = converters(idx).toCatalyst(row(idx)) + idx += 1 + } + new GenericInternalRow(ar) + + case p: Product => + val ar = new Array[Any](structType.size) + val iter = p.productIterator + var idx = 0 + while (idx < structType.size) { + ar(idx) = converters(idx).toCatalyst(iter.next()) + idx += 1 } + new GenericInternalRow(ar) + } - case dateType: DateType => (item: Any) => extractOption(item) match { - case d: java.sql.Date => DateUtils.fromJavaDate(d) - case other => other + override def toScala(row: InternalRow): Row = { + if (row == null) { + null + } else { + val ar = new Array[Any](row.size) + var idx = 0 + while (idx < row.size) { + ar(idx) = converters(idx).toScala(row, idx) + idx += 1 + } + new GenericRowWithSchema(ar, structType) } + } - case dataType: StringType => (item: Any) => extractOption(item) match { - case s: String => UTF8String(s) - case other => other - } + override def toScalaImpl(row: InternalRow, column: Int): Row = + toScala(row(column).asInstanceOf[InternalRow]) + } - case _ => - (item: Any) => extractOption(item) match { - case d: BigDecimal => Decimal(d) - case d: java.math.BigDecimal => Decimal(d) - case other => other - } + private object StringConverter extends CatalystTypeConverter[Any, String, UTF8String] { + override def toCatalystImpl(scalaValue: Any): UTF8String = scalaValue match { + case str: String => UTF8String.fromString(str) + case utf8: UTF8String => utf8 } + override def toScala(catalystValue: UTF8String): String = + if (catalystValue == null) null else catalystValue.toString + override def toScalaImpl(row: InternalRow, column: Int): String = row(column).toString } - /** - * Converts Scala objects to catalyst rows / types. - * - * Note: This should be called before do evaluation on Row - * (It does not support UDT) - * This is used to create an RDD or test results with correct types for Catalyst. - */ - def convertToCatalyst(a: Any): Any = a match { - case s: String => UTF8String(s) - case d: java.sql.Date => DateUtils.fromJavaDate(d) - case d: BigDecimal => Decimal(d) - case d: java.math.BigDecimal => Decimal(d) - case seq: Seq[Any] => seq.map(convertToCatalyst) - case r: Row => Row(r.toSeq.map(convertToCatalyst): _*) - case arr: Array[Any] => arr.toSeq.map(convertToCatalyst).toArray - case m: Map[Any, Any] => - m.map { case (k, v) => (convertToCatalyst(k), convertToCatalyst(v)) }.toMap - case other => other + private object DateConverter extends CatalystTypeConverter[Date, Date, Any] { + override def toCatalystImpl(scalaValue: Date): Int = DateTimeUtils.fromJavaDate(scalaValue) + override def toScala(catalystValue: Any): Date = + if (catalystValue == null) null else DateTimeUtils.toJavaDate(catalystValue.asInstanceOf[Int]) + override def toScalaImpl(row: InternalRow, column: Int): Date = + DateTimeUtils.toJavaDate(row.getInt(column)) } - /** - * Converts Catalyst types used internally in rows to standard Scala types - * This method is slow, and for batch conversion you should be using converter - * produced by createToScalaConverter. - */ - def convertToScala(a: Any, dataType: DataType): Any = (a, dataType) match { - // Check UDT first since UDTs can override other types - case (d, udt: UserDefinedType[_]) => - udt.deserialize(d) + private object TimestampConverter extends CatalystTypeConverter[Timestamp, Timestamp, Any] { + override def toCatalystImpl(scalaValue: Timestamp): Long = + DateTimeUtils.fromJavaTimestamp(scalaValue) + override def toScala(catalystValue: Any): Timestamp = + if (catalystValue == null) null + else DateTimeUtils.toJavaTimestamp(catalystValue.asInstanceOf[Long]) + override def toScalaImpl(row: InternalRow, column: Int): Timestamp = + DateTimeUtils.toJavaTimestamp(row.getLong(column)) + } - case (s: Seq[_], arrayType: ArrayType) => - s.map(convertToScala(_, arrayType.elementType)) + private object BigDecimalConverter extends CatalystTypeConverter[Any, JavaBigDecimal, Decimal] { + override def toCatalystImpl(scalaValue: Any): Decimal = scalaValue match { + case d: BigDecimal => Decimal(d) + case d: JavaBigDecimal => Decimal(d) + case d: Decimal => d + } + override def toScala(catalystValue: Decimal): JavaBigDecimal = catalystValue.toJavaBigDecimal + override def toScalaImpl(row: InternalRow, column: Int): JavaBigDecimal = + row.get(column).asInstanceOf[Decimal].toJavaBigDecimal + } - case (m: Map[_, _], mapType: MapType) => - m.map { case (k, v) => - convertToScala(k, mapType.keyType) -> convertToScala(v, mapType.valueType) - } + private abstract class PrimitiveConverter[T] extends CatalystTypeConverter[T, Any, Any] { + final override def toScala(catalystValue: Any): Any = catalystValue + final override def toCatalystImpl(scalaValue: T): Any = scalaValue + } - case (r: Row, s: StructType) => - convertRowToScala(r, s) + private object BooleanConverter extends PrimitiveConverter[Boolean] { + override def toScalaImpl(row: InternalRow, column: Int): Boolean = row.getBoolean(column) + } + + private object ByteConverter extends PrimitiveConverter[Byte] { + override def toScalaImpl(row: InternalRow, column: Int): Byte = row.getByte(column) + } + + private object ShortConverter extends PrimitiveConverter[Short] { + override def toScalaImpl(row: InternalRow, column: Int): Short = row.getShort(column) + } - case (d: Decimal, _: DecimalType) => - d.toJavaBigDecimal + private object IntConverter extends PrimitiveConverter[Int] { + override def toScalaImpl(row: InternalRow, column: Int): Int = row.getInt(column) + } - case (i: Int, DateType) => - DateUtils.toJavaDate(i) + private object LongConverter extends PrimitiveConverter[Long] { + override def toScalaImpl(row: InternalRow, column: Int): Long = row.getLong(column) + } - case (s: UTF8String, StringType) => - s.toString() + private object FloatConverter extends PrimitiveConverter[Float] { + override def toScalaImpl(row: InternalRow, column: Int): Float = row.getFloat(column) + } - case (other, _) => - other + private object DoubleConverter extends PrimitiveConverter[Double] { + override def toScalaImpl(row: InternalRow, column: Int): Double = row.getDouble(column) } /** - * Creates a converter function that will convert Catalyst types to Scala type. + * Creates a converter function that will convert Scala objects to the specified Catalyst type. * Typical use case would be converting a collection of rows that have the same schema. You will * call this function once to get a converter, and apply it to every row. */ - private[sql] def createToScalaConverter(dataType: DataType): Any => Any = dataType match { - // Check UDT first since UDTs can override other types - case udt: UserDefinedType[_] => - (item: Any) => if (item == null) null else udt.deserialize(item) - - case arrayType: ArrayType => - val elementConverter = createToScalaConverter(arrayType.elementType) - (item: Any) => if (item == null) null else item.asInstanceOf[Seq[_]].map(elementConverter) - - case mapType: MapType => - val keyConverter = createToScalaConverter(mapType.keyType) - val valueConverter = createToScalaConverter(mapType.valueType) - (item: Any) => if (item == null) { - null - } else { - item.asInstanceOf[Map[_, _]].map { case (k, v) => - keyConverter(k) -> valueConverter(v) - } - } - - case s: StructType => - val converters = s.fields.map(f => createToScalaConverter(f.dataType)) - (item: Any) => { - if (item == null) { - null + private[sql] def createToCatalystConverter(dataType: DataType): Any => Any = { + if (isPrimitive(dataType)) { + // Although the `else` branch here is capable of handling inbound conversion of primitives, + // we add some special-case handling for those types here. The motivation for this relates to + // Java method invocation costs: if we have rows that consist entirely of primitive columns, + // then returning the same conversion function for all of the columns means that the call site + // will be monomorphic instead of polymorphic. In microbenchmarks, this actually resulted in + // a measurable performance impact. Note that this optimization will be unnecessary if we + // use code generation to construct Scala Row -> Catalyst Row converters. + def convert(maybeScalaValue: Any): Any = { + if (maybeScalaValue.isInstanceOf[Option[Any]]) { + maybeScalaValue.asInstanceOf[Option[Any]].orNull } else { - convertRowWithConverters(item.asInstanceOf[Row], s, converters) + maybeScalaValue } } - - case _: DecimalType => - (item: Any) => item match { - case d: Decimal => d.toJavaBigDecimal - case other => other - } - - case DateType => - (item: Any) => item match { - case i: Int => DateUtils.toJavaDate(i) - case other => other - } - - case StringType => - (item: Any) => item match { - case s: UTF8String => s.toString() - case other => other - } - - case other => - (item: Any) => item + convert + } else { + getConverterForType(dataType).toCatalyst + } } - def convertRowToScala(r: Row, schema: StructType): Row = { - val ar = new Array[Any](r.size) - var idx = 0 - while (idx < r.size) { - ar(idx) = convertToScala(r(idx), schema.fields(idx).dataType) - idx += 1 + /** + * Creates a converter function that will convert Catalyst types to Scala type. + * Typical use case would be converting a collection of rows that have the same schema. You will + * call this function once to get a converter, and apply it to every row. + */ + private[sql] def createToScalaConverter(dataType: DataType): Any => Any = { + if (isPrimitive(dataType)) { + identity + } else { + getConverterForType(dataType).toScala } - new GenericRowWithSchema(ar, schema) } /** - * Converts a row by applying the provided set of converter functions. It is used for both - * toScala and toCatalyst conversions. + * Converts Scala objects to Catalyst rows / types. + * + * Note: This should be called before do evaluation on Row + * (It does not support UDT) + * This is used to create an RDD or test results with correct types for Catalyst. */ - private[sql] def convertRowWithConverters( - row: Row, - schema: StructType, - converters: Array[Any => Any]): Row = { - val ar = new Array[Any](row.size) - var idx = 0 - while (idx < row.size) { - ar(idx) = converters(idx)(row(idx)) - idx += 1 - } - new GenericRowWithSchema(ar, schema) + def convertToCatalyst(a: Any): Any = a match { + case s: String => StringConverter.toCatalyst(s) + case d: Date => DateConverter.toCatalyst(d) + case t: Timestamp => TimestampConverter.toCatalyst(t) + case d: BigDecimal => BigDecimalConverter.toCatalyst(d) + case d: JavaBigDecimal => BigDecimalConverter.toCatalyst(d) + case seq: Seq[Any] => seq.map(convertToCatalyst) + case r: Row => InternalRow(r.toSeq.map(convertToCatalyst): _*) + case arr: Array[Any] => arr.toSeq.map(convertToCatalyst).toArray + case m: Map[Any, Any] => + m.map { case (k, v) => (convertToCatalyst(k), convertToCatalyst(v)) }.toMap + case other => other + } + + /** + * Converts Catalyst types used internally in rows to standard Scala types + * This method is slow, and for batch conversion you should be using converter + * produced by createToScalaConverter. + */ + def convertToScala(catalystValue: Any, dataType: DataType): Any = { + createToScalaConverter(dataType)(catalystValue) } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/InternalRow.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/InternalRow.scala new file mode 100644 index 000000000000..57de0f26a972 --- /dev/null +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/InternalRow.scala @@ -0,0 +1,135 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst + +import org.apache.spark.sql.Row +import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.unsafe.types.UTF8String + +/** + * An abstract class for row used internal in Spark SQL, which only contain the columns as + * internal types. + */ +abstract class InternalRow extends Row { + + // This is only use for test + override def getString(i: Int): String = { + val str = getAs[UTF8String](i) + if (str != null) str.toString else null + } + + // These expensive API should not be used internally. + final override def getDecimal(i: Int): java.math.BigDecimal = + throw new UnsupportedOperationException + final override def getDate(i: Int): java.sql.Date = + throw new UnsupportedOperationException + final override def getTimestamp(i: Int): java.sql.Timestamp = + throw new UnsupportedOperationException + final override def getSeq[T](i: Int): Seq[T] = throw new UnsupportedOperationException + final override def getList[T](i: Int): java.util.List[T] = throw new UnsupportedOperationException + final override def getMap[K, V](i: Int): scala.collection.Map[K, V] = + throw new UnsupportedOperationException + final override def getJavaMap[K, V](i: Int): java.util.Map[K, V] = + throw new UnsupportedOperationException + final override def getStruct(i: Int): Row = throw new UnsupportedOperationException + final override def getAs[T](fieldName: String): T = throw new UnsupportedOperationException + final override def getValuesMap[T](fieldNames: Seq[String]): Map[String, T] = + throw new UnsupportedOperationException + + // A default implementation to change the return type + override def copy(): InternalRow = this + override def apply(i: Int): Any = get(i) + + override def equals(o: Any): Boolean = { + if (!o.isInstanceOf[Row]) { + return false + } + + val other = o.asInstanceOf[Row] + if (length != other.length) { + return false + } + + var i = 0 + while (i < length) { + if (isNullAt(i) != other.isNullAt(i)) { + return false + } + if (!isNullAt(i)) { + val o1 = apply(i) + val o2 = other.apply(i) + if (o1.isInstanceOf[Array[Byte]]) { + // handle equality of Array[Byte] + val b1 = o1.asInstanceOf[Array[Byte]] + if (!o2.isInstanceOf[Array[Byte]] || + !java.util.Arrays.equals(b1, o2.asInstanceOf[Array[Byte]])) { + return false + } + } else if (o1 != o2) { + return false + } + } + i += 1 + } + true + } + + // Custom hashCode function that matches the efficient code generated version. + override def hashCode: Int = { + var result: Int = 37 + var i = 0 + while (i < length) { + val update: Int = + if (isNullAt(i)) { + 0 + } else { + apply(i) match { + case b: Boolean => if (b) 0 else 1 + case b: Byte => b.toInt + case s: Short => s.toInt + case i: Int => i + case l: Long => (l ^ (l >>> 32)).toInt + case f: Float => java.lang.Float.floatToIntBits(f) + case d: Double => + val b = java.lang.Double.doubleToLongBits(d) + (b ^ (b >>> 32)).toInt + case a: Array[Byte] => java.util.Arrays.hashCode(a) + case other => other.hashCode() + } + } + result = 37 * result + update + i += 1 + } + result + } +} + +object InternalRow { + /** + * This method can be used to construct a [[Row]] with the given values. + */ + def apply(values: Any*): InternalRow = new GenericInternalRow(values.toArray) + + /** + * This method can be used to construct a [[Row]] from a [[Seq]] of values. + */ + def fromSeq(values: Seq[Any]): InternalRow = new GenericInternalRow(values.toArray) + + /** Returns an empty row. */ + val empty = apply() +} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/JavaTypeInference.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/JavaTypeInference.scala index 625c8d3a6212..9a3f9694e4c4 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/JavaTypeInference.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/JavaTypeInference.scala @@ -38,12 +38,21 @@ private [sql] object JavaTypeInference { private val keySetReturnType = classOf[JMap[_, _]].getMethod("keySet").getGenericReturnType private val valuesReturnType = classOf[JMap[_, _]].getMethod("values").getGenericReturnType + /** + * Infers the corresponding SQL data type of a JavaClean class. + * @param beanClass Java type + * @return (SQL data type, nullable) + */ + def inferDataType(beanClass: Class[_]): (DataType, Boolean) = { + inferDataType(TypeToken.of(beanClass)) + } + /** * Infers the corresponding SQL data type of a Java type. * @param typeToken Java type * @return (SQL data type, nullable) */ - private [sql] def inferDataType(typeToken: TypeToken[_]): (DataType, Boolean) = { + private def inferDataType(typeToken: TypeToken[_]): (DataType, Boolean) = { // TODO: All of this could probably be moved to Catalyst as it is mostly not Spark specific. typeToken.getRawType match { case c: Class[_] if c.isAnnotationPresent(classOf[SQLUserDefinedType]) => diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala index 6998cc8d9666..21b1de1ab9cb 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala @@ -17,6 +17,7 @@ package org.apache.spark.sql.catalyst +import org.apache.spark.unsafe.types.UTF8String import org.apache.spark.util.Utils import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.plans.logical.LocalRelation @@ -27,7 +28,11 @@ import org.apache.spark.sql.types._ */ object ScalaReflection extends ScalaReflection { val universe: scala.reflect.runtime.universe.type = scala.reflect.runtime.universe - val mirror: universe.Mirror = universe.runtimeMirror(Thread.currentThread().getContextClassLoader) + // Since we are creating a runtime mirror usign the class loader of current thread, + // we need to use def at here. So, every time we call mirror, it is using the + // class loader of the current thread. + override def mirror: universe.Mirror = + universe.runtimeMirror(Thread.currentThread().getContextClassLoader) } /** @@ -38,7 +43,7 @@ trait ScalaReflection { val universe: scala.reflect.api.Universe /** The mirror used to access types in the universe */ - val mirror: universe.Mirror + def mirror: universe.Mirror import universe._ diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/SqlParser.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/SqlParser.scala index fc36b9f1f20d..e8e9b9802e94 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/SqlParser.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/SqlParser.scala @@ -19,6 +19,7 @@ package org.apache.spark.sql.catalyst import scala.language.implicitConversions +import org.apache.spark.sql.AnalysisException import org.apache.spark.sql.catalyst.analysis._ import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.plans._ @@ -39,7 +40,7 @@ class SqlParser extends AbstractSparkSQLParser with DataTypeParser { def parseExpression(input: String): Expression = { // Initialize the Keywords. - lexical.initialize(reservedWords) + initLexical phrase(projection)(new lexical.Scanner(input)) match { case Success(plan, _) => plan case failureOrError => sys.error(failureOrError.toString) @@ -48,31 +49,25 @@ class SqlParser extends AbstractSparkSQLParser with DataTypeParser { // Keyword is a convention with AbstractSparkSQLParser, which will scan all of the `Keyword` // properties via reflection the class in runtime for constructing the SqlLexical object - protected val ABS = Keyword("ABS") protected val ALL = Keyword("ALL") protected val AND = Keyword("AND") protected val APPROXIMATE = Keyword("APPROXIMATE") protected val AS = Keyword("AS") protected val ASC = Keyword("ASC") - protected val AVG = Keyword("AVG") protected val BETWEEN = Keyword("BETWEEN") protected val BY = Keyword("BY") protected val CASE = Keyword("CASE") protected val CAST = Keyword("CAST") - protected val COALESCE = Keyword("COALESCE") - protected val COUNT = Keyword("COUNT") protected val DESC = Keyword("DESC") protected val DISTINCT = Keyword("DISTINCT") protected val ELSE = Keyword("ELSE") protected val END = Keyword("END") protected val EXCEPT = Keyword("EXCEPT") protected val FALSE = Keyword("FALSE") - protected val FIRST = Keyword("FIRST") protected val FROM = Keyword("FROM") protected val FULL = Keyword("FULL") protected val GROUP = Keyword("GROUP") protected val HAVING = Keyword("HAVING") - protected val IF = Keyword("IF") protected val IN = Keyword("IN") protected val INNER = Keyword("INNER") protected val INSERT = Keyword("INSERT") @@ -80,13 +75,9 @@ class SqlParser extends AbstractSparkSQLParser with DataTypeParser { protected val INTO = Keyword("INTO") protected val IS = Keyword("IS") protected val JOIN = Keyword("JOIN") - protected val LAST = Keyword("LAST") protected val LEFT = Keyword("LEFT") protected val LIKE = Keyword("LIKE") protected val LIMIT = Keyword("LIMIT") - protected val LOWER = Keyword("LOWER") - protected val MAX = Keyword("MAX") - protected val MIN = Keyword("MIN") protected val NOT = Keyword("NOT") protected val NULL = Keyword("NULL") protected val ON = Keyword("ON") @@ -100,26 +91,14 @@ class SqlParser extends AbstractSparkSQLParser with DataTypeParser { protected val RLIKE = Keyword("RLIKE") protected val SELECT = Keyword("SELECT") protected val SEMI = Keyword("SEMI") - protected val SQRT = Keyword("SQRT") - protected val SUBSTR = Keyword("SUBSTR") - protected val SUBSTRING = Keyword("SUBSTRING") - protected val SUM = Keyword("SUM") protected val TABLE = Keyword("TABLE") protected val THEN = Keyword("THEN") protected val TRUE = Keyword("TRUE") protected val UNION = Keyword("UNION") - protected val UPPER = Keyword("UPPER") protected val WHEN = Keyword("WHEN") protected val WHERE = Keyword("WHERE") protected val WITH = Keyword("WITH") - protected def assignAliases(exprs: Seq[Expression]): Seq[NamedExpression] = { - exprs.zipWithIndex.map { - case (ne: NamedExpression, _) => ne - case (e, i) => Alias(e, s"c$i")() - } - } - protected lazy val start: Parser[LogicalPlan] = start1 | insert | cte @@ -140,12 +119,12 @@ class SqlParser extends AbstractSparkSQLParser with DataTypeParser { (HAVING ~> expression).? ~ sortType.? ~ (LIMIT ~> expression).? ^^ { - case d ~ p ~ r ~ f ~ g ~ h ~ o ~ l => + case d ~ p ~ r ~ f ~ g ~ h ~ o ~ l => val base = r.getOrElse(OneRowRelation) val withFilter = f.map(Filter(_, base)).getOrElse(base) val withProjection = g - .map(Aggregate(_, assignAliases(p), withFilter)) - .getOrElse(Project(assignAliases(p), withFilter)) + .map(Aggregate(_, p.map(UnresolvedAlias(_)), withFilter)) + .getOrElse(Project(p.map(UnresolvedAlias(_)), withFilter)) val withDistinct = d.map(_ => Distinct(withProjection)).getOrElse(withProjection) val withHaving = h.map(Filter(_, withDistinct)).getOrElse(withDistinct) val withOrder = o.map(_(withHaving)).getOrElse(withHaving) @@ -212,7 +191,7 @@ class SqlParser extends AbstractSparkSQLParser with DataTypeParser { protected lazy val ordering: Parser[Seq[SortOrder]] = ( rep1sep(expression ~ direction.? , ",") ^^ { - case exps => exps.map(pair => SortOrder(pair._1, pair._2.getOrElse(Ascending))) + case exps => exps.map(pair => SortOrder(pair._1, pair._2.getOrElse(Ascending))) } ) @@ -242,7 +221,7 @@ class SqlParser extends AbstractSparkSQLParser with DataTypeParser { | termExpression ~ NOT.? ~ (BETWEEN ~> termExpression) ~ (AND ~> termExpression) ^^ { case e ~ not ~ el ~ eu => val betweenExpr: Expression = And(GreaterThanOrEqual(e, el), LessThanOrEqual(e, eu)) - not.fold(betweenExpr)(f=> Not(betweenExpr)) + not.fold(betweenExpr)(f => Not(betweenExpr)) } | termExpression ~ (RLIKE ~> termExpression) ^^ { case e1 ~ e2 => RLike(e1, e2) } | termExpression ~ (REGEXP ~> termExpression) ^^ { case e1 ~ e2 => RLike(e1, e2) } @@ -277,44 +256,50 @@ class SqlParser extends AbstractSparkSQLParser with DataTypeParser { ) protected lazy val function: Parser[Expression] = - ( SUM ~> "(" ~> expression <~ ")" ^^ { case exp => Sum(exp) } - | SUM ~> "(" ~> DISTINCT ~> expression <~ ")" ^^ { case exp => SumDistinct(exp) } - | COUNT ~ "(" ~> "*" <~ ")" ^^ { case _ => Count(Literal(1)) } - | COUNT ~ "(" ~> expression <~ ")" ^^ { case exp => Count(exp) } - | COUNT ~> "(" ~> DISTINCT ~> repsep(expression, ",") <~ ")" ^^ - { case exps => CountDistinct(exps) } - | APPROXIMATE ~ COUNT ~ "(" ~ DISTINCT ~> expression <~ ")" ^^ - { case exp => ApproxCountDistinct(exp) } - | APPROXIMATE ~> "(" ~> floatLit ~ ")" ~ COUNT ~ "(" ~ DISTINCT ~ expression <~ ")" ^^ - { case s ~ _ ~ _ ~ _ ~ _ ~ e => ApproxCountDistinct(e, s.toDouble) } - | FIRST ~ "(" ~> expression <~ ")" ^^ { case exp => First(exp) } - | LAST ~ "(" ~> expression <~ ")" ^^ { case exp => Last(exp) } - | AVG ~ "(" ~> expression <~ ")" ^^ { case exp => Average(exp) } - | MIN ~ "(" ~> expression <~ ")" ^^ { case exp => Min(exp) } - | MAX ~ "(" ~> expression <~ ")" ^^ { case exp => Max(exp) } - | UPPER ~ "(" ~> expression <~ ")" ^^ { case exp => Upper(exp) } - | LOWER ~ "(" ~> expression <~ ")" ^^ { case exp => Lower(exp) } - | IF ~ "(" ~> expression ~ ("," ~> expression) ~ ("," ~> expression) <~ ")" ^^ - { case c ~ t ~ f => If(c, t, f) } - | CASE ~> expression.? ~ rep1(WHEN ~> expression ~ (THEN ~> expression)) ~ - (ELSE ~> expression).? <~ END ^^ { - case casePart ~ altPart ~ elsePart => - val branches = altPart.flatMap { case whenExpr ~ thenExpr => - Seq(whenExpr, thenExpr) - } ++ elsePart - casePart.map(CaseKeyWhen(_, branches)).getOrElse(CaseWhen(branches)) - } - | (SUBSTR | SUBSTRING) ~ "(" ~> expression ~ ("," ~> expression) <~ ")" ^^ - { case s ~ p => Substring(s, p, Literal(Integer.MAX_VALUE)) } - | (SUBSTR | SUBSTRING) ~ "(" ~> expression ~ ("," ~> expression) ~ ("," ~> expression) <~ ")" ^^ - { case s ~ p ~ l => Substring(s, p, l) } - | COALESCE ~ "(" ~> repsep(expression, ",") <~ ")" ^^ { case exprs => Coalesce(exprs) } - | SQRT ~ "(" ~> expression <~ ")" ^^ { case exp => Sqrt(exp) } - | ABS ~ "(" ~> expression <~ ")" ^^ { case exp => Abs(exp) } + ( ident <~ ("(" ~ "*" ~ ")") ^^ { case udfName => + if (lexical.normalizeKeyword(udfName) == "count") { + Count(Literal(1)) + } else { + throw new AnalysisException(s"invalid expression $udfName(*)") + } + } | ident ~ ("(" ~> repsep(expression, ",")) <~ ")" ^^ { case udfName ~ exprs => UnresolvedFunction(udfName, exprs) } + | ident ~ ("(" ~ DISTINCT ~> repsep(expression, ",")) <~ ")" ^^ { case udfName ~ exprs => + lexical.normalizeKeyword(udfName) match { + case "sum" => SumDistinct(exprs.head) + case "count" => CountDistinct(exprs) + case _ => throw new AnalysisException(s"function $udfName does not support DISTINCT") + } + } + | APPROXIMATE ~> ident ~ ("(" ~ DISTINCT ~> expression <~ ")") ^^ { case udfName ~ exp => + if (lexical.normalizeKeyword(udfName) == "count") { + ApproxCountDistinct(exp) + } else { + throw new AnalysisException(s"invalid function approximate $udfName") + } + } + | APPROXIMATE ~> "(" ~> floatLit ~ ")" ~ ident ~ "(" ~ DISTINCT ~ expression <~ ")" ^^ + { case s ~ _ ~ udfName ~ _ ~ _ ~ exp => + if (lexical.normalizeKeyword(udfName) == "count") { + ApproxCountDistinct(exp, s.toDouble) + } else { + throw new AnalysisException(s"invalid function approximate($floatLit) $udfName") + } + } + | CASE ~> whenThenElse ^^ CaseWhen + | CASE ~> expression ~ whenThenElse ^^ + { case keyPart ~ branches => CaseKeyWhen(keyPart, branches) } ) + protected lazy val whenThenElse: Parser[List[Expression]] = + rep1(WHEN ~> expression ~ (THEN ~> expression)) ~ (ELSE ~> expression).? <~ END ^^ { + case altPart ~ elsePart => + altPart.flatMap { case whenExpr ~ thenExpr => + Seq(whenExpr, thenExpr) + } ++ elsePart + } + protected lazy val cast: Parser[Expression] = CAST ~ "(" ~> expression ~ (AS ~> dataType) <~ ")" ^^ { case exp ~ t => Cast(exp, t) @@ -365,13 +350,18 @@ class SqlParser extends AbstractSparkSQLParser with DataTypeParser { protected lazy val baseExpression: Parser[Expression] = ( "*" ^^^ UnresolvedStar(None) - | ident <~ "." ~ "*" ^^ { case tableName => UnresolvedStar(Option(tableName)) } + | ident <~ "." ~ "*" ^^ { case tableName => UnresolvedStar(Option(tableName)) } | primary ) protected lazy val signedPrimary: Parser[Expression] = sign ~ primary ^^ { case s ~ e => if (s == "-") UnaryMinus(e) else e} + protected lazy val attributeName: Parser[String] = acceptMatch("attribute name", { + case lexical.Identifier(str) => str + case lexical.Keyword(str) if !lexical.delimiters.contains(str) => str + }) + protected lazy val primary: PackratParser[Expression] = ( literal | expression ~ ("[" ~> expression <~ "]") ^^ @@ -382,9 +372,9 @@ class SqlParser extends AbstractSparkSQLParser with DataTypeParser { | "(" ~> expression <~ ")" | function | dotExpressionHeader - | ident ^^ {case i => UnresolvedAttribute.quoted(i)} | signedPrimary | "~" ~> expression ^^ BitwiseNot + | attributeName ^^ UnresolvedAttribute.quoted ) protected lazy val dotExpressionHeader: Parser[Expression] = diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala index 0b6e1d44b9c4..15e84e68b988 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala @@ -17,15 +17,13 @@ package org.apache.spark.sql.catalyst.analysis -import scala.collection.mutable.ArrayBuffer - import org.apache.spark.sql.AnalysisException import org.apache.spark.sql.catalyst.{SimpleCatalystConf, CatalystConf} import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.catalyst.rules._ import org.apache.spark.sql.types._ -import org.apache.spark.util.collection.OpenHashSet +import scala.collection.mutable.ArrayBuffer /** * A trivial [[Analyzer]] with an [[EmptyCatalog]] and [[EmptyFunctionRegistry]]. Used for testing @@ -45,7 +43,7 @@ class Analyzer( registry: FunctionRegistry, conf: CatalystConf, maxIterations: Int = 100) - extends RuleExecutor[LogicalPlan] with HiveTypeCoercion with CheckAnalysis { + extends RuleExecutor[LogicalPlan] with CheckAnalysis { def resolver: Resolver = { if (conf.caseSensitiveAnalysis) { @@ -74,11 +72,11 @@ class Analyzer( ResolveSortReferences :: ResolveGenerate :: ResolveFunctions :: + ResolveAliases :: ExtractWindowExpressions :: GlobalAggregates :: UnresolvedHavingClauseAttributes :: - TrimGroupingAliases :: - typeCoercionRules ++ + HiveTypeCoercion.typeCoercionRules ++ extendedResolutionRules : _*) ) @@ -132,35 +130,42 @@ class Analyzer( } /** - * Removes no-op Alias expressions from the plan. + * Replaces [[UnresolvedAlias]]s with concrete aliases. */ - object TrimGroupingAliases extends Rule[LogicalPlan] { - def apply(plan: LogicalPlan): LogicalPlan = plan transform { - case Aggregate(groups, aggs, child) => - Aggregate(groups.map(_.transform { case Alias(c, _) => c }), aggs, child) + object ResolveAliases extends Rule[LogicalPlan] { + private def assignAliases(exprs: Seq[NamedExpression]) = { + // The `UnresolvedAlias`s will appear only at root of a expression tree, we don't need + // to transform down the whole tree. + exprs.zipWithIndex.map { + case (u @ UnresolvedAlias(child), i) => + child match { + case _: UnresolvedAttribute => u + case ne: NamedExpression => ne + case ev: ExtractValueWithStruct => Alias(ev, ev.field.name)() + case g: Generator if g.resolved && g.elementTypes.size > 1 => MultiAlias(g, Nil) + case e if !e.resolved => u + case other => Alias(other, s"_c$i")() + } + case (other, _) => other + } } - } - object ResolveGroupingAnalytics extends Rule[LogicalPlan] { - /** - * Extract attribute set according to the grouping id - * @param bitmask bitmask to represent the selected of the attribute sequence - * @param exprs the attributes in sequence - * @return the attributes of non selected specified via bitmask (with the bit set to 1) - */ - private def buildNonSelectExprSet(bitmask: Int, exprs: Seq[Expression]) - : OpenHashSet[Expression] = { - val set = new OpenHashSet[Expression](2) - - var bit = exprs.length - 1 - while (bit >= 0) { - if (((bitmask >> bit) & 1) == 0) set.add(exprs(bit)) - bit -= 1 - } + def apply(plan: LogicalPlan): LogicalPlan = plan transformUp { + case Aggregate(groups, aggs, child) + if child.resolved && aggs.exists(_.isInstanceOf[UnresolvedAlias]) => + Aggregate(groups, assignAliases(aggs), child) - set + case g: GroupingAnalytics + if g.child.resolved && g.aggregations.exists(_.isInstanceOf[UnresolvedAlias]) => + g.withNewAggs(assignAliases(g.aggregations)) + + case Project(projectList, child) + if child.resolved && projectList.exists(_.isInstanceOf[UnresolvedAlias]) => + Project(assignAliases(projectList), child) } + } + object ResolveGroupingAnalytics extends Rule[LogicalPlan] { /* * GROUP BY a, b, c WITH ROLLUP * is equivalent to @@ -187,44 +192,17 @@ class Analyzer( Seq.tabulate(1 << c.groupByExprs.length)(i => i) } - /** - * Create an array of Projections for the child projection, and replace the projections' - * expressions which equal GroupBy expressions with Literal(null), if those expressions - * are not set for this grouping set (according to the bit mask). - */ - private[this] def expand(g: GroupingSets): Seq[GroupExpression] = { - val result = new scala.collection.mutable.ArrayBuffer[GroupExpression] - - g.bitmasks.foreach { bitmask => - // get the non selected grouping attributes according to the bit mask - val nonSelectedGroupExprSet = buildNonSelectExprSet(bitmask, g.groupByExprs) - - val substitution = (g.child.output :+ g.gid).map(expr => expr transformDown { - case x: Expression if nonSelectedGroupExprSet.contains(x) => - // if the input attribute in the Invalid Grouping Expression set of for this group - // replace it with constant null - Literal.create(null, expr.dataType) - case x if x == g.gid => - // replace the groupingId with concrete value (the bit mask) - Literal.create(bitmask, IntegerType) - }) - - result += GroupExpression(substitution) - } - - result.toSeq - } - def apply(plan: LogicalPlan): LogicalPlan = plan transform { - case a: Cube if a.resolved => - GroupingSets(bitmasks(a), a.groupByExprs, a.child, a.aggregations, a.gid) - case a: Rollup if a.resolved => - GroupingSets(bitmasks(a), a.groupByExprs, a.child, a.aggregations, a.gid) - case x: GroupingSets if x.resolved => + case a: Cube => + GroupingSets(bitmasks(a), a.groupByExprs, a.child, a.aggregations) + case a: Rollup => + GroupingSets(bitmasks(a), a.groupByExprs, a.child, a.aggregations) + case x: GroupingSets => + val gid = AttributeReference(VirtualColumn.groupingIdName, IntegerType, false)() Aggregate( - x.groupByExprs :+ x.gid, + x.groupByExprs :+ VirtualColumn.groupingIdAttribute, x.aggregations, - Expand(expand(x), x.child.output :+ x.gid, x.child)) + Expand(x.bitmasks, x.groupByExprs, gid, x.child)) } } @@ -242,7 +220,7 @@ class Analyzer( } def apply(plan: LogicalPlan): LogicalPlan = plan transform { - case i@InsertIntoTable(u: UnresolvedRelation, _, _, _, _) => + case i @ InsertIntoTable(u: UnresolvedRelation, _, _, _, _) => i.copy(table = EliminateSubQueries(getTable(u))) case u: UnresolvedRelation => getTable(u) @@ -250,9 +228,8 @@ class Analyzer( } /** - * Replaces [[UnresolvedAttribute]]s with concrete - * [[catalyst.expressions.AttributeReference AttributeReferences]] from a logical plan node's - * children. + * Replaces [[UnresolvedAttribute]]s with concrete [[AttributeReference]]s from + * a logical plan node's children. */ object ResolveReferences extends Rule[LogicalPlan] { def apply(plan: LogicalPlan): LogicalPlan = plan transformUp { @@ -263,24 +240,24 @@ class Analyzer( Project( projectList.flatMap { case s: Star => s.expand(child.output, resolver) - case Alias(f @ UnresolvedFunction(_, args), name) if containsStar(args) => + case UnresolvedAlias(f @ UnresolvedFunction(_, args)) if containsStar(args) => val expandedArgs = args.flatMap { case s: Star => s.expand(child.output, resolver) case o => o :: Nil } - Alias(child = f.copy(children = expandedArgs), name)() :: Nil - case Alias(c @ CreateArray(args), name) if containsStar(args) => + UnresolvedAlias(child = f.copy(children = expandedArgs)) :: Nil + case UnresolvedAlias(c @ CreateArray(args)) if containsStar(args) => val expandedArgs = args.flatMap { case s: Star => s.expand(child.output, resolver) case o => o :: Nil } - Alias(c.copy(children = expandedArgs), name)() :: Nil - case Alias(c @ CreateStruct(args), name) if containsStar(args) => + UnresolvedAlias(c.copy(children = expandedArgs)) :: Nil + case UnresolvedAlias(c @ CreateStruct(args)) if containsStar(args) => val expandedArgs = args.flatMap { case s: Star => s.expand(child.output, resolver) case o => o :: Nil } - Alias(c.copy(children = expandedArgs), name)() :: Nil + UnresolvedAlias(c.copy(children = expandedArgs)) :: Nil case o => o :: Nil }, child) @@ -306,7 +283,7 @@ class Analyzer( val conflictingAttributes = left.outputSet.intersect(right.outputSet) logDebug(s"Conflicting attributes ${conflictingAttributes.mkString(",")} in $j") - val (oldRelation, newRelation) = right.collect { + right.collect { // Handle base relations that might appear more than once. case oldVersion: MultiInstanceRelation if oldVersion.outputSet.intersect(conflictingAttributes).nonEmpty => @@ -331,38 +308,43 @@ class Analyzer( if AttributeSet(windowExpressions.map(_.toAttribute)).intersect(conflictingAttributes) .nonEmpty => (oldVersion, oldVersion.copy(windowExpressions = newAliases(windowExpressions))) - }.headOption.getOrElse { // Only handle first case, others will be fixed on the next pass. - sys.error( - s""" - |Failure when resolving conflicting references in Join: - |$plan - | - |Conflicting attributes: ${conflictingAttributes.mkString(",")} - """.stripMargin) } - - val attributeRewrites = AttributeMap(oldRelation.output.zip(newRelation.output)) - val newRight = right transformUp { - case r if r == oldRelation => newRelation - } transformUp { - case other => other transformExpressions { - case a: Attribute => attributeRewrites.get(a).getOrElse(a) - } + // Only handle first case, others will be fixed on the next pass. + .headOption match { + case None => + /* + * No result implies that there is a logical plan node that produces new references + * that this rule cannot handle. When that is the case, there must be another rule + * that resolves these conflicts. Otherwise, the analysis will fail. + */ + j + case Some((oldRelation, newRelation)) => + val attributeRewrites = AttributeMap(oldRelation.output.zip(newRelation.output)) + val newRight = right transformUp { + case r if r == oldRelation => newRelation + } transformUp { + case other => other transformExpressions { + case a: Attribute => attributeRewrites.get(a).getOrElse(a) + } + } + j.copy(right = newRight) } - j.copy(right = newRight) + + // When resolve `SortOrder`s in Sort based on child, don't report errors as + // we still have chance to resolve it based on grandchild + case s @ Sort(ordering, global, child) if child.resolved && !s.resolved => + val newOrdering = resolveSortOrders(ordering, child, throws = false) + Sort(newOrdering, global, child) case q: LogicalPlan => logTrace(s"Attempting to resolve ${q.simpleString}") q transformExpressionsUp { - case u @ UnresolvedAttribute(nameParts) if nameParts.length == 1 && - resolver(nameParts(0), VirtualColumn.groupingIdName) && - q.isInstanceOf[GroupingAnalytics] => - // Resolve the virtual column GROUPING__ID for the operator GroupingAnalytics - q.asInstanceOf[GroupingAnalytics].gid case u @ UnresolvedAttribute(nameParts) => // Leave unchanged if resolution fails. Hopefully will be resolved next round. val result = - withPosition(u) { q.resolveChildren(nameParts, resolver).getOrElse(u) } + withPosition(u) { + q.resolveChildren(nameParts, resolver).map(trimUnresolvedAlias).getOrElse(u) + } logDebug(s"Resolving $u to $result") result case UnresolvedExtractValue(child, fieldExpr) if child.resolved => @@ -388,6 +370,31 @@ class Analyzer( exprs.exists(_.collect { case _: Star => true }.nonEmpty) } + private def trimUnresolvedAlias(ne: NamedExpression) = ne match { + case UnresolvedAlias(child) => child + case other => other + } + + private def resolveSortOrders(ordering: Seq[SortOrder], plan: LogicalPlan, throws: Boolean) = { + ordering.map { order => + // Resolve SortOrder in one round. + // If throws == false or the desired attribute doesn't exist + // (like try to resolve `a.b` but `a` doesn't exist), fail and return the origin one. + // Else, throw exception. + try { + val newOrder = order transformUp { + case u @ UnresolvedAttribute(nameParts) => + plan.resolve(nameParts, resolver).map(trimUnresolvedAlias).getOrElse(u) + case UnresolvedExtractValue(child, fieldName) if child.resolved => + ExtractValue(child, fieldName, resolver) + } + newOrder.asInstanceOf[SortOrder] + } catch { + case a: AnalysisException if !throws => order + } + } + } + /** * In many dialects of SQL it is valid to sort by attributes that are not present in the SELECT * clause. This rule detects such queries and adds the required attributes to the original @@ -398,13 +405,13 @@ class Analyzer( def apply(plan: LogicalPlan): LogicalPlan = plan transformUp { case s @ Sort(ordering, global, p @ Project(projectList, child)) if !s.resolved && p.resolved => - val (resolvedOrdering, missing) = resolveAndFindMissing(ordering, p, child) + val (newOrdering, missing) = resolveAndFindMissing(ordering, p, child) // If this rule was not a no-op, return the transformed plan, otherwise return the original. if (missing.nonEmpty) { // Add missing attributes and then project them away after the sort. Project(p.output, - Sort(resolvedOrdering, global, + Sort(newOrdering, global, Project(projectList ++ missing, child))) } else { logDebug(s"Failed to find $missing in ${p.output.mkString(", ")}") @@ -412,19 +419,31 @@ class Analyzer( } case s @ Sort(ordering, global, a @ Aggregate(grouping, aggs, child)) if !s.resolved && a.resolved => - val unresolved = ordering.flatMap(_.collect { case UnresolvedAttribute(name) => name }) // A small hack to create an object that will allow us to resolve any references that // refer to named expressions that are present in the grouping expressions. val groupingRelation = LocalRelation( grouping.collect { case ne: NamedExpression => ne.toAttribute } ) - val (resolvedOrdering, missing) = resolveAndFindMissing(ordering, a, groupingRelation) + // Find sort attributes that are projected away so we can temporarily add them back in. + val (newOrdering, missingAttr) = resolveAndFindMissing(ordering, a, groupingRelation) + + // Find aggregate expressions and evaluate them early, since they can't be evaluated in a + // Sort. + val (withAggsRemoved, aliasedAggregateList) = newOrdering.map { + case aggOrdering if aggOrdering.collect { case a: AggregateExpression => a }.nonEmpty => + val aliased = Alias(aggOrdering.child, "_aggOrdering")() + (aggOrdering.copy(child = aliased.toAttribute), Some(aliased)) + + case other => (other, None) + }.unzip + + val missing = missingAttr ++ aliasedAggregateList.flatten if (missing.nonEmpty) { // Add missing grouping exprs and then project them away after the sort. Project(a.output, - Sort(resolvedOrdering, global, + Sort(withAggsRemoved, global, Aggregate(grouping, aggs ++ missing, child))) } else { s // Nothing we can do here. Return original plan. @@ -432,52 +451,39 @@ class Analyzer( } /** - * Given a child and a grandchild that are present beneath a sort operator, returns - * a resolved sort ordering and a list of attributes that are missing from the child - * but are present in the grandchild. + * Given a child and a grandchild that are present beneath a sort operator, try to resolve + * the sort ordering and returns it with a list of attributes that are missing from the + * child but are present in the grandchild. */ def resolveAndFindMissing( ordering: Seq[SortOrder], child: LogicalPlan, grandchild: LogicalPlan): (Seq[SortOrder], Seq[Attribute]) = { - // Find any attributes that remain unresolved in the sort. - val unresolved: Seq[Seq[String]] = - ordering.flatMap(_.collect { case UnresolvedAttribute(nameParts) => nameParts }) - - // Create a map from name, to resolved attributes, when the desired name can be found - // prior to the projection. - val resolved: Map[Seq[String], NamedExpression] = - unresolved.flatMap(u => grandchild.resolve(u, resolver).map(a => u -> a)).toMap - + val newOrdering = resolveSortOrders(ordering, grandchild, throws = true) // Construct a set that contains all of the attributes that we need to evaluate the // ordering. - val requiredAttributes = AttributeSet(resolved.values) - + val requiredAttributes = AttributeSet(newOrdering.filter(_.resolved)) // Figure out which ones are missing from the projection, so that we can add them and // remove them after the sort. val missingInProject = requiredAttributes -- child.output - - // Now that we have all the attributes we need, reconstruct a resolved ordering. - // It is important to do it here, instead of waiting for the standard resolved as adding - // attributes to the project below can actually introduce ambiquity that was not present - // before. - val resolvedOrdering = ordering.map(_ transform { - case u @ UnresolvedAttribute(name) => resolved.getOrElse(name, u) - }).asInstanceOf[Seq[SortOrder]] - - (resolvedOrdering, missingInProject.toSeq) + // It is important to return the new SortOrders here, instead of waiting for the standard + // resolving process as adding attributes to the project below can actually introduce + // ambiguity that was not present before. + (newOrdering, missingInProject.toSeq) } } /** - * Replaces [[UnresolvedFunction]]s with concrete [[catalyst.expressions.Expression Expressions]]. + * Replaces [[UnresolvedFunction]]s with concrete [[Expression]]s. */ object ResolveFunctions extends Rule[LogicalPlan] { def apply(plan: LogicalPlan): LogicalPlan = plan transform { case q: LogicalPlan => q transformExpressions { - case u @ UnresolvedFunction(name, children) if u.childrenResolved => - registry.lookupFunction(name, children) + case u @ UnresolvedFunction(name, children) => + withPosition(u) { + registry.lookupFunction(name, children) + } } } } @@ -508,20 +514,21 @@ class Analyzer( object UnresolvedHavingClauseAttributes extends Rule[LogicalPlan] { def apply(plan: LogicalPlan): LogicalPlan = plan transformUp { case filter @ Filter(havingCondition, aggregate @ Aggregate(_, originalAggExprs, _)) - if aggregate.resolved && containsAggregate(havingCondition) => { - val evaluatedCondition = Alias(havingCondition, "havingCondition")() + if aggregate.resolved && containsAggregate(havingCondition) => + + val evaluatedCondition = Alias(havingCondition, "havingCondition")() val aggExprsWithHaving = evaluatedCondition +: originalAggExprs Project(aggregate.output, Filter(evaluatedCondition.toAttribute, aggregate.copy(aggregateExpressions = aggExprsWithHaving))) - } } - protected def containsAggregate(condition: Expression): Boolean = + protected def containsAggregate(condition: Expression): Boolean = { condition .collect { case ae: AggregateExpression => ae } .nonEmpty + } } /** @@ -530,16 +537,15 @@ class Analyzer( * - concrete attribute references for their output. * - to be relocated from a SELECT clause (i.e. from a [[Project]]) into a [[Generate]]). * - * Names for the output [[Attributes]] are extracted from [[Alias]] or [[MultiAlias]] expressions + * Names for the output [[Attribute]]s are extracted from [[Alias]] or [[MultiAlias]] expressions * that wrap the [[Generator]]. If more than one [[Generator]] is found in a Project, an * [[AnalysisException]] is throw. */ object ResolveGenerate extends Rule[LogicalPlan] { def apply(plan: LogicalPlan): LogicalPlan = plan transform { case p: Generate if !p.child.resolved || !p.generator.resolved => p - case g: Generate if g.resolved == false => - g.copy( - generatorOutput = makeGeneratorOutput(g.generator, g.generatorOutput.map(_.name))) + case g: Generate if !g.resolved => + g.copy(generatorOutput = makeGeneratorOutput(g.generator, g.generatorOutput.map(_.name))) case p @ Project(projectList, child) => // Holds the resolved generator, if one exists in the project list. @@ -576,8 +582,13 @@ class Analyzer( /** Extracts a [[Generator]] expression and any names assigned by aliases to their output. */ private object AliasedGenerator { def unapply(e: Expression): Option[(Generator, Seq[String])] = e match { - case Alias(g: Generator, name) => Some((g, name :: Nil)) - case MultiAlias(g: Generator, names) => Some(g, names) + case Alias(g: Generator, name) if g.resolved && g.elementTypes.size > 1 => + // If not given the default names, and the TGF with multiple output columns + failAnalysis( + s"""Expect multiple names given for ${g.getClass.getName}, + |but only single name '${name}' specified""".stripMargin) + case Alias(g: Generator, name) if g.resolved => Some((g, name :: Nil)) + case MultiAlias(g: Generator, names) if g.resolved => Some(g, names) case _ => None } } @@ -634,10 +645,10 @@ class Analyzer( * it into the plan tree. */ object ExtractWindowExpressions extends Rule[LogicalPlan] { - def hasWindowFunction(projectList: Seq[NamedExpression]): Boolean = + private def hasWindowFunction(projectList: Seq[NamedExpression]): Boolean = projectList.exists(hasWindowFunction) - def hasWindowFunction(expr: NamedExpression): Boolean = { + private def hasWindowFunction(expr: NamedExpression): Boolean = { expr.find { case window: WindowExpression => true case _ => false @@ -645,14 +656,24 @@ class Analyzer( } /** - * From a Seq of [[NamedExpression]]s, extract window expressions and - * other regular expressions. + * From a Seq of [[NamedExpression]]s, extract expressions containing window expressions and + * other regular expressions that do not contain any window expression. For example, for + * `col1, Sum(col2 + col3) OVER (PARTITION BY col4 ORDER BY col5)`, we will extract + * `col1`, `col2 + col3`, `col4`, and `col5` out and replace their appearances in + * the window expression as attribute references. So, the first returned value will be + * `[Sum(_w0) OVER (PARTITION BY _w1 ORDER BY _w2)]` and the second returned value will be + * [col1, col2 + col3 as _w0, col4 as _w1, col5 as _w2]. + * + * @return (seq of expressions containing at lease one window expressions, + * seq of non-window expressions) */ - def extract( + private def extract( expressions: Seq[NamedExpression]): (Seq[NamedExpression], Seq[NamedExpression]) = { - // First, we simple partition the input expressions to two part, one having - // WindowExpressions and another one without WindowExpressions. - val (windowExpressions, regularExpressions) = expressions.partition(hasWindowFunction) + // First, we partition the input expressions to two part. For the first part, + // every expression in it contain at least one WindowExpression. + // Expressions in the second part do not have any WindowExpression. + val (expressionsWithWindowFunctions, regularExpressions) = + expressions.partition(hasWindowFunction) // Then, we need to extract those regular expressions used in the WindowExpression. // For example, when we have col1 - Sum(col2 + col3) OVER (PARTITION BY col4 ORDER BY col5), @@ -661,8 +682,8 @@ class Analyzer( val extractedExprBuffer = new ArrayBuffer[NamedExpression]() def extractExpr(expr: Expression): Expression = expr match { case ne: NamedExpression => - // If a named expression is not in regularExpressions, add extract it and replace it - // with an AttributeReference. + // If a named expression is not in regularExpressions, add it to + // extractedExprBuffer and replace it with an AttributeReference. val missingExpr = AttributeSet(Seq(expr)) -- (regularExpressions ++ extractedExprBuffer) if (missingExpr.nonEmpty) { @@ -679,8 +700,9 @@ class Analyzer( withName.toAttribute } - // Now, we extract expressions from windowExpressions by using extractExpr. - val newWindowExpressions = windowExpressions.map { + // Now, we extract regular expressions from expressionsWithWindowFunctions + // by using extractExpr. + val newExpressionsWithWindowFunctions = expressionsWithWindowFunctions.map { _.transform { // Extracts children expressions of a WindowFunction (input parameters of // a WindowFunction). @@ -706,37 +728,80 @@ class Analyzer( }.asInstanceOf[NamedExpression] } - (newWindowExpressions, regularExpressions ++ extractedExprBuffer) - } + (newExpressionsWithWindowFunctions, regularExpressions ++ extractedExprBuffer) + } // end of extract /** * Adds operators for Window Expressions. Every Window operator handles a single Window Spec. */ - def addWindow(windowExpressions: Seq[NamedExpression], child: LogicalPlan): LogicalPlan = { - // First, we group window expressions based on their Window Spec. - val groupedWindowExpression = windowExpressions.groupBy { expr => - val windowSpec = expr.collectFirst { + private def addWindow( + expressionsWithWindowFunctions: Seq[NamedExpression], + child: LogicalPlan): LogicalPlan = { + // First, we need to extract all WindowExpressions from expressionsWithWindowFunctions + // and put those extracted WindowExpressions to extractedWindowExprBuffer. + // This step is needed because it is possible that an expression contains multiple + // WindowExpressions with different Window Specs. + // After extracting WindowExpressions, we need to construct a project list to generate + // expressionsWithWindowFunctions based on extractedWindowExprBuffer. + // For example, for "sum(a) over (...) / sum(b) over (...)", we will first extract + // "sum(a) over (...)" and "sum(b) over (...)" out, and assign "_we0" as the alias to + // "sum(a) over (...)" and "_we1" as the alias to "sum(b) over (...)". + // Then, the projectList will be [_we0/_we1]. + val extractedWindowExprBuffer = new ArrayBuffer[NamedExpression]() + val newExpressionsWithWindowFunctions = expressionsWithWindowFunctions.map { + // We need to use transformDown because we want to trigger + // "case alias @ Alias(window: WindowExpression, _)" first. + _.transformDown { + case alias @ Alias(window: WindowExpression, _) => + // If a WindowExpression has an assigned alias, just use it. + extractedWindowExprBuffer += alias + alias.toAttribute + case window: WindowExpression => + // If there is no alias assigned to the WindowExpressions. We create an + // internal column. + val withName = Alias(window, s"_we${extractedWindowExprBuffer.length}")() + extractedWindowExprBuffer += withName + withName.toAttribute + }.asInstanceOf[NamedExpression] + } + + // Second, we group extractedWindowExprBuffer based on their Window Spec. + val groupedWindowExpressions = extractedWindowExprBuffer.groupBy { expr => + val distinctWindowSpec = expr.collect { case window: WindowExpression => window.windowSpec + }.distinct + + // We do a final check and see if we only have a single Window Spec defined in an + // expressions. + if (distinctWindowSpec.length == 0 ) { + failAnalysis(s"$expr does not have any WindowExpression.") + } else if (distinctWindowSpec.length > 1) { + // newExpressionsWithWindowFunctions only have expressions with a single + // WindowExpression. If we reach here, we have a bug. + failAnalysis(s"$expr has multiple Window Specifications ($distinctWindowSpec)." + + s"Please file a bug report with this error message, stack trace, and the query.") + } else { + distinctWindowSpec.head } - windowSpec.getOrElse( - failAnalysis(s"$windowExpressions does not have any WindowExpression.")) }.toSeq - // For every Window Spec, we add a Window operator and set currentChild as the child of it. + // Third, for every Window Spec, we add a Window operator and set currentChild as the + // child of it. var currentChild = child var i = 0 - while (i < groupedWindowExpression.size) { - val (windowSpec, windowExpressions) = groupedWindowExpression(i) + while (i < groupedWindowExpressions.size) { + val (windowSpec, windowExpressions) = groupedWindowExpressions(i) // Set currentChild to the newly created Window operator. currentChild = Window(currentChild.output, windowExpressions, windowSpec, currentChild) - // Move to next WindowExpression. + // Move to next Window Spec. i += 1 } - // We return the top operator. - currentChild - } + // Finally, we create a Project to output currentChild's output + // newExpressionsWithWindowFunctions. + Project(currentChild.output ++ newExpressionsWithWindowFunctions, currentChild) + } // end of addWindow // We have to use transformDown at here to make sure the rule of // "Aggregate with Having clause" will be triggered. @@ -793,9 +858,8 @@ class Analyzer( } /** - * Removes [[catalyst.plans.logical.Subquery Subquery]] operators from the plan. Subqueries are - * only required to provide scoping information for attributes and can be removed once analysis is - * complete. + * Removes [[Subquery]] operators from the plan. Subqueries are only required to provide + * scoping information for attributes and can be removed once analysis is complete. */ object EliminateSubQueries extends Rule[LogicalPlan] { def apply(plan: LogicalPlan): LogicalPlan = plan transform { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Catalog.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Catalog.scala index 208021c42132..1541491608b2 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Catalog.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Catalog.scala @@ -17,7 +17,11 @@ package org.apache.spark.sql.catalyst.analysis +import java.util.concurrent.ConcurrentHashMap + +import scala.collection.JavaConversions._ import scala.collection.mutable +import scala.collection.mutable.ArrayBuffer import org.apache.spark.sql.catalyst.CatalystConf import org.apache.spark.sql.catalyst.EmptyConf @@ -81,18 +85,18 @@ trait Catalog { } class SimpleCatalog(val conf: CatalystConf) extends Catalog { - val tables = new mutable.HashMap[String, LogicalPlan]() + val tables = new ConcurrentHashMap[String, LogicalPlan] override def registerTable( tableIdentifier: Seq[String], plan: LogicalPlan): Unit = { val tableIdent = processTableIdentifier(tableIdentifier) - tables += ((getDbTableName(tableIdent), plan)) + tables.put(getDbTableName(tableIdent), plan) } override def unregisterTable(tableIdentifier: Seq[String]): Unit = { val tableIdent = processTableIdentifier(tableIdentifier) - tables -= getDbTableName(tableIdent) + tables.remove(getDbTableName(tableIdent)) } override def unregisterAllTables(): Unit = { @@ -101,10 +105,7 @@ class SimpleCatalog(val conf: CatalystConf) extends Catalog { override def tableExists(tableIdentifier: Seq[String]): Boolean = { val tableIdent = processTableIdentifier(tableIdentifier) - tables.get(getDbTableName(tableIdent)) match { - case Some(_) => true - case None => false - } + tables.containsKey(getDbTableName(tableIdent)) } override def lookupRelation( @@ -112,7 +113,10 @@ class SimpleCatalog(val conf: CatalystConf) extends Catalog { alias: Option[String] = None): LogicalPlan = { val tableIdent = processTableIdentifier(tableIdentifier) val tableFullName = getDbTableName(tableIdent) - val table = tables.getOrElse(tableFullName, sys.error(s"Table Not Found: $tableFullName")) + val table = tables.get(tableFullName) + if (table == null) { + sys.error(s"Table Not Found: $tableFullName") + } val tableWithQualifiers = Subquery(tableIdent.last, table) // If an alias was specified by the lookup, wrap the plan in a subquery so that attributes are @@ -121,9 +125,11 @@ class SimpleCatalog(val conf: CatalystConf) extends Catalog { } override def getTables(databaseName: Option[String]): Seq[(String, Boolean)] = { - tables.map { - case (name, _) => (name, true) - }.toSeq + val result = ArrayBuffer.empty[(String, Boolean)] + for (name <- tables.keySet()) { + result += ((name, true)) + } + result } override def refreshTable(databaseName: String, tableName: String): Unit = { @@ -140,7 +146,7 @@ class SimpleCatalog(val conf: CatalystConf) extends Catalog { trait OverrideCatalog extends Catalog { // TODO: This doesn't work when the database changes... - val overrides = new mutable.HashMap[(Option[String],String), LogicalPlan]() + val overrides = new mutable.HashMap[(Option[String], String), LogicalPlan]() abstract override def tableExists(tableIdentifier: Seq[String]): Boolean = { val tableIdent = processTableIdentifier(tableIdentifier) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala index f104e742c90f..476ac2b7cb47 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala @@ -26,7 +26,6 @@ import org.apache.spark.sql.types._ * Throws user facing errors when passed invalid queries that fail to analyze. */ trait CheckAnalysis { - self: Analyzer => /** * Override to provide additional checks for correct analysis. @@ -41,35 +40,35 @@ trait CheckAnalysis { def containsMultipleGenerators(exprs: Seq[Expression]): Boolean = { exprs.flatMap(_.collect { case e: Generator => true - }).length >= 1 + }).nonEmpty } def checkAnalysis(plan: LogicalPlan): Unit = { // We transform up and order the rules so as to catch the first possible failure instead // of the result of cascading resolution failures. plan.foreachUp { + case operator: LogicalPlan => operator transformExpressionsUp { case a: Attribute if !a.resolved => - if (operator.childrenResolved) { - a match { - case UnresolvedAttribute(nameParts) => - // Throw errors for specific problems with get field. - operator.resolveChildren(nameParts, resolver, throwErrors = true) - } - } - val from = operator.inputSet.map(_.name).mkString(", ") a.failAnalysis(s"cannot resolve '${a.prettyString}' given input columns $from") + case e: Expression if e.checkInputDataTypes().isFailure => + e.checkInputDataTypes() match { + case TypeCheckResult.TypeCheckFailure(message) => + e.failAnalysis( + s"cannot resolve '${e.prettyString}' due to data type mismatch: $message") + } + case c: Cast if !c.resolved => failAnalysis( s"invalid cast from ${c.child.dataType.simpleString} to ${c.dataType.simpleString}") - case b: BinaryExpression if !b.resolved => + case WindowExpression(UnresolvedWindowFunction(name, _), _) => failAnalysis( - s"invalid expression ${b.prettyString} " + - s"between ${b.left.dataType.simpleString} and ${b.right.dataType.simpleString}") + s"Could not resolve window function '$name'. " + + "Note that, using window functions currently requires a HiveContext") case w @ WindowExpression(windowFunction, windowSpec) if windowSpec.validate.nonEmpty => // The window spec is not valid. @@ -86,24 +85,17 @@ trait CheckAnalysis { case Aggregate(groupingExprs, aggregateExprs, child) => def checkValidAggregateExpression(expr: Expression): Unit = expr match { case _: AggregateExpression => // OK - case e: Attribute if !groupingExprs.contains(e) => + case e: Attribute if !groupingExprs.exists(_.semanticEquals(e)) => failAnalysis( s"expression '${e.prettyString}' is neither present in the group by, " + s"nor is it an aggregate function. " + "Add to group by or wrap in first() if you don't care which value you get.") - case e if groupingExprs.contains(e) => // OK + case e if groupingExprs.exists(_.semanticEquals(e)) => // OK case e if e.references.isEmpty => // OK case e => e.children.foreach(checkValidAggregateExpression) } - val cleaned = aggregateExprs.map(_.transform { - // Should trim aliases around `GetField`s. These aliases are introduced while - // resolving struct field accesses, because `GetField` is not a `NamedExpression`. - // (Should we just turn `GetField` into a `NamedExpression`?) - case Alias(g, _) => g - }) - - cleaned.foreach(checkValidAggregateExpression) + aggregateExprs.foreach(checkValidAggregateExpression) case _ => // Fallbacks to the following checks } @@ -129,6 +121,17 @@ trait CheckAnalysis { case _ => // Analysis successful! } + + // Special handling for cases when self-join introduce duplicate expression ids. + case j @ Join(left, right, _, _) if left.outputSet.intersect(right.outputSet).nonEmpty => + val conflictingAttributes = left.outputSet.intersect(right.outputSet) + failAnalysis( + s""" + |Failure when resolving conflicting references in Join: + |$plan + |Conflicting attributes: ${conflictingAttributes.mkString(",")} + |""".stripMargin) + } extendedCheckRules.foreach(_(plan)) } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala index 16ca5bcd57a7..fef276353022 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala @@ -17,42 +17,50 @@ package org.apache.spark.sql.catalyst.analysis -import org.apache.spark.sql.catalyst.expressions.Expression -import scala.collection.mutable +import scala.reflect.ClassTag +import scala.util.{Failure, Success, Try} + +import org.apache.spark.sql.AnalysisException +import org.apache.spark.sql.catalyst.analysis.FunctionRegistry.FunctionBuilder +import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.catalyst.util.StringKeyHashMap + /** A catalog for looking up user defined functions, used by an [[Analyzer]]. */ trait FunctionRegistry { - type FunctionBuilder = Seq[Expression] => Expression def registerFunction(name: String, builder: FunctionBuilder): Unit + @throws[AnalysisException]("If function does not exist") def lookupFunction(name: String, children: Seq[Expression]): Expression - - def caseSensitive: Boolean } -trait OverrideFunctionRegistry extends FunctionRegistry { +class OverrideFunctionRegistry(underlying: FunctionRegistry) extends FunctionRegistry { - val functionBuilders = StringKeyHashMap[FunctionBuilder](caseSensitive) + private val functionBuilders = StringKeyHashMap[FunctionBuilder](caseSensitive = false) override def registerFunction(name: String, builder: FunctionBuilder): Unit = { functionBuilders.put(name, builder) } - abstract override def lookupFunction(name: String, children: Seq[Expression]): Expression = { - functionBuilders.get(name).map(_(children)).getOrElse(super.lookupFunction(name, children)) + override def lookupFunction(name: String, children: Seq[Expression]): Expression = { + functionBuilders.get(name).map(_(children)).getOrElse(underlying.lookupFunction(name, children)) } } -class SimpleFunctionRegistry(val caseSensitive: Boolean) extends FunctionRegistry { - val functionBuilders = StringKeyHashMap[FunctionBuilder](caseSensitive) +class SimpleFunctionRegistry extends FunctionRegistry { + + private val functionBuilders = StringKeyHashMap[FunctionBuilder](caseSensitive = false) override def registerFunction(name: String, builder: FunctionBuilder): Unit = { functionBuilders.put(name, builder) } override def lookupFunction(name: String, children: Seq[Expression]): Expression = { - functionBuilders(name)(children) + val func = functionBuilders.get(name).getOrElse { + throw new AnalysisException(s"undefined function $name") + } + func(children) } } @@ -68,30 +76,134 @@ object EmptyFunctionRegistry extends FunctionRegistry { override def lookupFunction(name: String, children: Seq[Expression]): Expression = { throw new UnsupportedOperationException } - - override def caseSensitive: Boolean = throw new UnsupportedOperationException } -/** - * Build a map with String type of key, and it also supports either key case - * sensitive or insensitive. - * TODO move this into util folder? - */ -object StringKeyHashMap { - def apply[T](caseSensitive: Boolean): StringKeyHashMap[T] = caseSensitive match { - case false => new StringKeyHashMap[T](_.toLowerCase) - case true => new StringKeyHashMap[T](identity) - } -} -class StringKeyHashMap[T](normalizer: (String) => String) { - private val base = new collection.mutable.HashMap[String, T]() +object FunctionRegistry { + + type FunctionBuilder = Seq[Expression] => Expression - def apply(key: String): T = base(normalizer(key)) + val expressions: Map[String, FunctionBuilder] = Map( + // misc non-aggregate functions + expression[Abs]("abs"), + expression[CreateArray]("array"), + expression[Coalesce]("coalesce"), + expression[Explode]("explode"), + expression[If]("if"), + expression[IsNull]("isnull"), + expression[IsNotNull]("isnotnull"), + expression[Coalesce]("nvl"), + expression[Rand]("rand"), + expression[Randn]("randn"), + expression[CreateStruct]("struct"), + expression[CreateNamedStruct]("named_struct"), + expression[Sqrt]("sqrt"), + + // math functions + expression[Acos]("acos"), + expression[Asin]("asin"), + expression[Atan]("atan"), + expression[Atan2]("atan2"), + expression[Bin]("bin"), + expression[Cbrt]("cbrt"), + expression[Ceil]("ceil"), + expression[Ceil]("ceiling"), + expression[Cos]("cos"), + expression[EulerNumber]("e"), + expression[Exp]("exp"), + expression[Expm1]("expm1"), + expression[Floor]("floor"), + expression[Factorial]("factorial"), + expression[Hypot]("hypot"), + expression[Hex]("hex"), + expression[Logarithm]("log"), + expression[Log]("ln"), + expression[Log10]("log10"), + expression[Log1p]("log1p"), + expression[UnaryMinus]("negative"), + expression[Pi]("pi"), + expression[Log2]("log2"), + expression[Pow]("pow"), + expression[Pow]("power"), + expression[UnaryPositive]("positive"), + expression[Rint]("rint"), + expression[ShiftLeft]("shiftleft"), + expression[ShiftRight]("shiftright"), + expression[ShiftRightUnsigned]("shiftrightunsigned"), + expression[Signum]("sign"), + expression[Signum]("signum"), + expression[Sin]("sin"), + expression[Sinh]("sinh"), + expression[Tan]("tan"), + expression[Tanh]("tanh"), + expression[ToDegrees]("degrees"), + expression[ToRadians]("radians"), + + // misc functions + expression[Md5]("md5"), + expression[Sha2]("sha2"), + expression[Sha1]("sha1"), + expression[Sha1]("sha"), + expression[Crc32]("crc32"), + + // aggregate functions + expression[Average]("avg"), + expression[Count]("count"), + expression[First]("first"), + expression[Last]("last"), + expression[Max]("max"), + expression[Min]("min"), + expression[Sum]("sum"), + + // string functions + expression[Ascii]("ascii"), + expression[Base64]("base64"), + expression[Encode]("encode"), + expression[Decode]("decode"), + expression[Lower]("lcase"), + expression[Lower]("lower"), + expression[StringLength]("length"), + expression[Levenshtein]("levenshtein"), + expression[Substring]("substr"), + expression[Substring]("substring"), + expression[UnBase64]("unbase64"), + expression[Upper]("ucase"), + expression[Unhex]("unhex"), + expression[Upper]("upper"), + + // datetime functions + expression[CurrentDate]("current_date"), + expression[CurrentTimestamp]("current_timestamp") + ) + + val builtin: FunctionRegistry = { + val fr = new SimpleFunctionRegistry + expressions.foreach { case (name, builder) => fr.registerFunction(name, builder) } + fr + } - def get(key: String): Option[T] = base.get(normalizer(key)) - def put(key: String, value: T): Option[T] = base.put(normalizer(key), value) - def remove(key: String): Option[T] = base.remove(normalizer(key)) - def iterator: Iterator[(String, T)] = base.toIterator + /** See usage above. */ + private def expression[T <: Expression](name: String) + (implicit tag: ClassTag[T]): (String, FunctionBuilder) = { + + // See if we can find a constructor that accepts Seq[Expression] + val varargCtor = Try(tag.runtimeClass.getDeclaredConstructor(classOf[Seq[_]])).toOption + val builder = (expressions: Seq[Expression]) => { + if (varargCtor.isDefined) { + // If there is an apply method that accepts Seq[Expression], use that one. + varargCtor.get.newInstance(expressions).asInstanceOf[Expression] + } else { + // Otherwise, find an ctor method that matches the number of arguments, and use that. + val params = Seq.fill(expressions.size)(classOf[Expression]) + val f = Try(tag.runtimeClass.getDeclaredConstructor(params : _*)) match { + case Success(e) => + e + case Failure(e) => + throw new AnalysisException(s"Invalid number of arguments for function $name") + } + f.newInstance(expressions : _*).asInstanceOf[Expression] + } + } + (name, builder) + } } - diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala index fe0d3f29977c..5367b7f3308e 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala @@ -17,12 +17,38 @@ package org.apache.spark.sql.catalyst.analysis +import javax.annotation.Nullable + import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, Project, Union} import org.apache.spark.sql.catalyst.rules.Rule import org.apache.spark.sql.types._ + +/** + * A collection of [[Rule Rules]] that can be used to coerce differing types that + * participate in operations into compatible ones. Most of these rules are based on Hive semantics, + * but they do not introduce any dependencies on the hive codebase. For this reason they remain in + * Catalyst until we have a more standard set of coercions. + */ object HiveTypeCoercion { + + val typeCoercionRules = + PropagateTypes :: + InConversion :: + WidenTypes :: + PromoteStrings :: + DecimalPrecision :: + BooleanEquality :: + StringToIntegralCasts :: + FunctionArgumentConversion :: + CaseWhenCoercion :: + IfCoercion :: + Division :: + PropagateTypes :: + ImplicitTypeCasts :: + Nil + // See https://cwiki.apache.org/confluence/display/Hive/LanguageManual+Types. // The conversion for integral and floating point types have a linear widening hierarchy: private val numericPrecedence = @@ -41,7 +67,7 @@ object HiveTypeCoercion { * with primitive types, because in that case the precision and scale of the result depends on * the operation. Those rules are implemented in [[HiveTypeCoercion.DecimalPrecision]]. */ - val findTightestCommonType: (DataType, DataType) => Option[DataType] = { + val findTightestCommonTypeOfTwo: (DataType, DataType) => Option[DataType] = { case (t1, t2) if t1 == t2 => Some(t1) case (NullType, t1) => Some(t1) case (t1, NullType) => Some(t1) @@ -57,34 +83,38 @@ object HiveTypeCoercion { case _ => None } -} -/** - * A collection of [[Rule Rules]] that can be used to coerce differing types that - * participate in operations into compatible ones. Most of these rules are based on Hive semantics, - * but they do not introduce any dependencies on the hive codebase. For this reason they remain in - * Catalyst until we have a more standard set of coercions. - */ -trait HiveTypeCoercion { + /** Similar to [[findTightestCommonType]], but can promote all the way to StringType. */ + private def findTightestCommonTypeToString(left: DataType, right: DataType): Option[DataType] = { + findTightestCommonTypeOfTwo(left, right).orElse((left, right) match { + case (StringType, t2: AtomicType) if t2 != BinaryType && t2 != BooleanType => Some(StringType) + case (t1: AtomicType, StringType) if t1 != BinaryType && t1 != BooleanType => Some(StringType) + case _ => None + }) + } - import HiveTypeCoercion._ + /** + * Similar to [[findTightestCommonType]], if can not find the TightestCommonType, try to use + * [[findTightestCommonTypeToString]] to find the TightestCommonType. + */ + private def findTightestCommonTypeAndPromoteToString(types: Seq[DataType]): Option[DataType] = { + types.foldLeft[Option[DataType]](Some(NullType))((r, c) => r match { + case None => None + case Some(d) => + findTightestCommonTypeOfTwo(d, c).orElse(findTightestCommonTypeToString(d, c)) + }) + } - val typeCoercionRules = - PropagateTypes :: - ConvertNaNs :: - InConversion :: - WidenTypes :: - PromoteStrings :: - DecimalPrecision :: - BooleanComparisons :: - BooleanCasts :: - StringToIntegralCasts :: - FunctionArgumentConversion :: - CaseWhenCoercion :: - Division :: - PropagateTypes :: - ExpectedInputConversion :: - Nil + /** + * Find the tightest common type of a set of types by continuously applying + * `findTightestCommonTypeOfTwo` on these types. + */ + private def findTightestCommonType(types: Seq[DataType]) = { + types.foldLeft[Option[DataType]](Some(NullType))((r, c) => r match { + case None => None + case Some(d) => findTightestCommonTypeOfTwo(d, c) + }) + } /** * Applies any changes to [[AttributeReference]] data types that are made by other rules to @@ -98,51 +128,22 @@ trait HiveTypeCoercion { // Don't propagate types from unresolved children. case q: LogicalPlan if !q.childrenResolved => q - case q: LogicalPlan => q transformExpressions { - case a: AttributeReference => - q.inputSet.find(_.exprId == a.exprId) match { - // This can happen when a Attribute reference is born in a non-leaf node, for example - // due to a call to an external script like in the Transform operator. - // TODO: Perhaps those should actually be aliases? - case None => a - // Leave the same if the dataTypes match. - case Some(newType) if a.dataType == newType.dataType => a - case Some(newType) => - logDebug(s"Promoting $a to $newType in ${q.simpleString}}") - newType - } - } - } - } - - /** - * Converts string "NaN"s that are in binary operators with a NaN-able types (Float / Double) to - * the appropriate numeric equivalent. - */ - object ConvertNaNs extends Rule[LogicalPlan] { - val stringNaN = Literal("NaN") - - def apply(plan: LogicalPlan): LogicalPlan = plan transform { - case q: LogicalPlan => q transformExpressions { - // Skip nodes who's children have not been resolved yet. - case e if !e.childrenResolved => e - - /* Double Conversions */ - case b: BinaryExpression if b.left == stringNaN && b.right.dataType == DoubleType => - b.makeCopy(Array(b.right, Literal(Double.NaN))) - case b: BinaryExpression if b.left.dataType == DoubleType && b.right == stringNaN => - b.makeCopy(Array(Literal(Double.NaN), b.left)) - case b: BinaryExpression if b.left == stringNaN && b.right == stringNaN => - b.makeCopy(Array(Literal(Double.NaN), b.left)) - - /* Float Conversions */ - case b: BinaryExpression if b.left == stringNaN && b.right.dataType == FloatType => - b.makeCopy(Array(b.right, Literal(Float.NaN))) - case b: BinaryExpression if b.left.dataType == FloatType && b.right == stringNaN => - b.makeCopy(Array(Literal(Float.NaN), b.left)) - case b: BinaryExpression if b.left == stringNaN && b.right == stringNaN => - b.makeCopy(Array(Literal(Float.NaN), b.left)) - } + case q: LogicalPlan => + val inputMap = q.inputSet.toSeq.map(a => (a.exprId, a)).toMap + q transformExpressions { + case a: AttributeReference => + inputMap.get(a.exprId) match { + // This can happen when a Attribute reference is born in a non-leaf node, for example + // due to a call to an external script like in the Transform operator. + // TODO: Perhaps those should actually be aliases? + case None => a + // Leave the same if the dataTypes match. + case Some(newType) if a.dataType == newType.dataType => a + case Some(newType) => + logDebug(s"Promoting $a to $newType in ${q.simpleString}}") + newType + } + } } } @@ -167,28 +168,30 @@ trait HiveTypeCoercion { * - LongType to DoubleType */ object WidenTypes extends Rule[LogicalPlan] { - import HiveTypeCoercion._ - def apply(plan: LogicalPlan): LogicalPlan = plan transform { // TODO: unions with fixed-precision decimals case u @ Union(left, right) if u.childrenResolved && !u.resolved => val castedInput = left.output.zip(right.output).map { // When a string is found on one side, make the other side a string too. - case (l, r) if l.dataType == StringType && r.dataType != StringType => - (l, Alias(Cast(r, StringType), r.name)()) - case (l, r) if l.dataType != StringType && r.dataType == StringType => - (Alias(Cast(l, StringType), l.name)(), r) - - case (l, r) if l.dataType != r.dataType => - logDebug(s"Resolving mismatched union input ${l.dataType}, ${r.dataType}") - findTightestCommonType(l.dataType, r.dataType).map { widestType => + case (lhs, rhs) if lhs.dataType == StringType && rhs.dataType != StringType => + (lhs, Alias(Cast(rhs, StringType), rhs.name)()) + case (lhs, rhs) if lhs.dataType != StringType && rhs.dataType == StringType => + (Alias(Cast(lhs, StringType), lhs.name)(), rhs) + + case (lhs, rhs) if lhs.dataType != rhs.dataType => + logDebug(s"Resolving mismatched union input ${lhs.dataType}, ${rhs.dataType}") + findTightestCommonTypeOfTwo(lhs.dataType, rhs.dataType).map { widestType => val newLeft = - if (l.dataType == widestType) l else Alias(Cast(l, widestType), l.name)() + if (lhs.dataType == widestType) lhs else Alias(Cast(lhs, widestType), lhs.name)() val newRight = - if (r.dataType == widestType) r else Alias(Cast(r, widestType), r.name)() + if (rhs.dataType == widestType) rhs else Alias(Cast(rhs, widestType), rhs.name)() (newLeft, newRight) - }.getOrElse((l, r)) // If there is no applicable conversion, leave expression unchanged. + }.getOrElse { + // If there is no applicable conversion, leave expression unchanged. + (lhs, rhs) + } + case other => other } @@ -212,17 +215,15 @@ trait HiveTypeCoercion { Union(newLeft, newRight) - // Also widen types for BinaryExpressions. + // Also widen types for BinaryOperator. case q: LogicalPlan => q transformExpressions { // Skip nodes who's children have not been resolved yet. case e if !e.childrenResolved => e - case b: BinaryExpression if b.left.dataType != b.right.dataType => - findTightestCommonType(b.left.dataType, b.right.dataType).map { widestType => - val newLeft = - if (b.left.dataType == widestType) b.left else Cast(b.left, widestType) - val newRight = - if (b.right.dataType == widestType) b.right else Cast(b.right, widestType) + case b @ BinaryOperator(left, right) if left.dataType != right.dataType => + findTightestCommonTypeOfTwo(left.dataType, right.dataType).map { widestType => + val newLeft = if (left.dataType == widestType) left else Cast(left, widestType) + val newRight = if (right.dataType == widestType) right else Cast(right, widestType) b.makeCopy(Array(newLeft, newRight)) }.getOrElse(b) // If there is no applicable conversion, leave expression unchanged. } @@ -237,57 +238,53 @@ trait HiveTypeCoercion { // Skip nodes who's children have not been resolved yet. case e if !e.childrenResolved => e - case a: BinaryArithmetic if a.left.dataType == StringType => - a.makeCopy(Array(Cast(a.left, DoubleType), a.right)) - case a: BinaryArithmetic if a.right.dataType == StringType => - a.makeCopy(Array(a.left, Cast(a.right, DoubleType))) - - // we should cast all timestamp/date/string compare into string compare - case p: BinaryComparison if p.left.dataType == StringType && - p.right.dataType == DateType => - p.makeCopy(Array(p.left, Cast(p.right, StringType))) - case p: BinaryComparison if p.left.dataType == DateType && - p.right.dataType == StringType => - p.makeCopy(Array(Cast(p.left, StringType), p.right)) - case p: BinaryComparison if p.left.dataType == StringType && - p.right.dataType == TimestampType => - p.makeCopy(Array(Cast(p.left, TimestampType), p.right)) - case p: BinaryComparison if p.left.dataType == TimestampType && - p.right.dataType == StringType => - p.makeCopy(Array(p.left, Cast(p.right, TimestampType))) - case p: BinaryComparison if p.left.dataType == TimestampType && - p.right.dataType == DateType => - p.makeCopy(Array(Cast(p.left, StringType), Cast(p.right, StringType))) - case p: BinaryComparison if p.left.dataType == DateType && - p.right.dataType == TimestampType => - p.makeCopy(Array(Cast(p.left, StringType), Cast(p.right, StringType))) - - case p: BinaryComparison if p.left.dataType == StringType && - p.right.dataType != StringType => - p.makeCopy(Array(Cast(p.left, DoubleType), p.right)) - case p: BinaryComparison if p.left.dataType != StringType && - p.right.dataType == StringType => - p.makeCopy(Array(p.left, Cast(p.right, DoubleType))) - - case i @ In(a, b) if a.dataType == DateType && - b.forall(_.dataType == StringType) => + case a @ BinaryArithmetic(left @ StringType(), r) => + a.makeCopy(Array(Cast(left, DoubleType), r)) + case a @ BinaryArithmetic(left, right @ StringType()) => + a.makeCopy(Array(left, Cast(right, DoubleType))) + + // For equality between string and timestamp we cast the string to a timestamp + // so that things like rounding of subsecond precision does not affect the comparison. + case p @ Equality(left @ StringType(), right @ TimestampType()) => + p.makeCopy(Array(Cast(left, TimestampType), right)) + case p @ Equality(left @ TimestampType(), right @ StringType()) => + p.makeCopy(Array(left, Cast(right, TimestampType))) + + // We should cast all relative timestamp/date/string comparison into string comparisions + // This behaves as a user would expect because timestamp strings sort lexicographically. + // i.e. TimeStamp(2013-01-01 00:00 ...) < "2014" = true + case p @ BinaryComparison(left @ StringType(), right @ DateType()) => + p.makeCopy(Array(left, Cast(right, StringType))) + case p @ BinaryComparison(left @ DateType(), right @ StringType()) => + p.makeCopy(Array(Cast(left, StringType), right)) + case p @ BinaryComparison(left @ StringType(), right @ TimestampType()) => + p.makeCopy(Array(left, Cast(right, StringType))) + case p @ BinaryComparison(left @ TimestampType(), right @ StringType()) => + p.makeCopy(Array(Cast(left, StringType), right)) + + // Comparisons between dates and timestamps. + case p @ BinaryComparison(left @ TimestampType(), right @ DateType()) => + p.makeCopy(Array(Cast(left, StringType), Cast(right, StringType))) + case p @ BinaryComparison(left @ DateType(), right @ TimestampType()) => + p.makeCopy(Array(Cast(left, StringType), Cast(right, StringType))) + + case p @ BinaryComparison(left @ StringType(), right) if right.dataType != StringType => + p.makeCopy(Array(Cast(left, DoubleType), right)) + case p @ BinaryComparison(left, right @ StringType()) if left.dataType != StringType => + p.makeCopy(Array(left, Cast(right, DoubleType))) + + case i @ In(a @ DateType(), b) if b.forall(_.dataType == StringType) => i.makeCopy(Array(Cast(a, StringType), b)) - case i @ In(a, b) if a.dataType == TimestampType && - b.forall(_.dataType == StringType) => + case i @ In(a @ TimestampType(), b) if b.forall(_.dataType == StringType) => i.makeCopy(Array(a, b.map(Cast(_, TimestampType)))) - case i @ In(a, b) if a.dataType == DateType && - b.forall(_.dataType == TimestampType) => + case i @ In(a @ DateType(), b) if b.forall(_.dataType == TimestampType) => i.makeCopy(Array(Cast(a, StringType), b.map(Cast(_, StringType)))) - case i @ In(a, b) if a.dataType == TimestampType && - b.forall(_.dataType == DateType) => + case i @ In(a @ TimestampType(), b) if b.forall(_.dataType == DateType) => i.makeCopy(Array(Cast(a, StringType), b.map(Cast(_, StringType)))) - case Sum(e) if e.dataType == StringType => - Sum(Cast(e, DoubleType)) - case Average(e) if e.dataType == StringType => - Average(Cast(e, DoubleType)) - case Sqrt(e) if e.dataType == StringType => - Sqrt(Cast(e, DoubleType)) + case Sum(e @ StringType()) => Sum(Cast(e, DoubleType)) + case SumDistinct(e @ StringType()) => Sum(Cast(e, DoubleType)) + case Average(e @ StringType()) => Average(Cast(e, DoubleType)) } } @@ -296,6 +293,9 @@ trait HiveTypeCoercion { */ object InConversion extends Rule[LogicalPlan] { def apply(plan: LogicalPlan): LogicalPlan = plan transformAllExpressions { + // Skip nodes who's children have not been resolved yet. + case e if !e.childrenResolved => e + case i @ In(a, b) if b.exists(_.dataType != a.dataType) => i.makeCopy(Array(a, b.map(Cast(_, a.dataType)))) } @@ -347,17 +347,17 @@ trait HiveTypeCoercion { import scala.math.{max, min} // Conversion rules for integer types into fixed-precision decimals - val intTypeToFixed: Map[DataType, DecimalType] = Map( + private val intTypeToFixed: Map[DataType, DecimalType] = Map( ByteType -> DecimalType(3, 0), ShortType -> DecimalType(5, 0), IntegerType -> DecimalType(10, 0), LongType -> DecimalType(20, 0) ) - def isFloat(t: DataType): Boolean = t == FloatType || t == DoubleType + private def isFloat(t: DataType): Boolean = t == FloatType || t == DoubleType // Conversion rules for float and double into fixed-precision decimals - val floatTypeToFixed: Map[DataType, DecimalType] = Map( + private val floatTypeToFixed: Map[DataType, DecimalType] = Map( FloatType -> DecimalType(7, 7), DoubleType -> DecimalType(15, 15) ) @@ -366,22 +366,22 @@ trait HiveTypeCoercion { // fix decimal precision for union case u @ Union(left, right) if u.childrenResolved && !u.resolved => val castedInput = left.output.zip(right.output).map { - case (l, r) if l.dataType != r.dataType => - (l.dataType, r.dataType) match { + case (lhs, rhs) if lhs.dataType != rhs.dataType => + (lhs.dataType, rhs.dataType) match { case (DecimalType.Fixed(p1, s1), DecimalType.Fixed(p2, s2)) => // Union decimals with precision/scale p1/s2 and p2/s2 will be promoted to // DecimalType(max(s1, s2) + max(p1-s1, p2-s2), max(s1, s2)) val fixedType = DecimalType(max(s1, s2) + max(p1 - s1, p2 - s2), max(s1, s2)) - (Alias(Cast(l, fixedType), l.name)(), Alias(Cast(r, fixedType), r.name)()) + (Alias(Cast(lhs, fixedType), lhs.name)(), Alias(Cast(rhs, fixedType), rhs.name)()) case (t, DecimalType.Fixed(p, s)) if intTypeToFixed.contains(t) => - (Alias(Cast(l, intTypeToFixed(t)), l.name)(), r) + (Alias(Cast(lhs, intTypeToFixed(t)), lhs.name)(), rhs) case (DecimalType.Fixed(p, s), t) if intTypeToFixed.contains(t) => - (l, Alias(Cast(r, intTypeToFixed(t)), r.name)()) + (lhs, Alias(Cast(rhs, intTypeToFixed(t)), rhs.name)()) case (t, DecimalType.Fixed(p, s)) if floatTypeToFixed.contains(t) => - (Alias(Cast(l, floatTypeToFixed(t)), l.name)(), r) + (Alias(Cast(lhs, floatTypeToFixed(t)), lhs.name)(), rhs) case (DecimalType.Fixed(p, s), t) if floatTypeToFixed.contains(t) => - (l, Alias(Cast(r, floatTypeToFixed(t)), r.name)()) - case _ => (l, r) + (lhs, Alias(Cast(rhs, floatTypeToFixed(t)), rhs.name)()) + case _ => (lhs, rhs) } case other => other } @@ -439,34 +439,25 @@ trait HiveTypeCoercion { DecimalType(min(p1 - s1, p2 - s2) + max(s1, s2), max(s1, s2)) ) - case LessThan(e1 @ DecimalType.Expression(p1, s1), - e2 @ DecimalType.Expression(p2, s2)) if p1 != p2 || s1 != s2 => - LessThan(Cast(e1, DecimalType.Unlimited), Cast(e2, DecimalType.Unlimited)) - - case LessThanOrEqual(e1 @ DecimalType.Expression(p1, s1), - e2 @ DecimalType.Expression(p2, s2)) if p1 != p2 || s1 != s2 => - LessThanOrEqual(Cast(e1, DecimalType.Unlimited), Cast(e2, DecimalType.Unlimited)) - - case GreaterThan(e1 @ DecimalType.Expression(p1, s1), - e2 @ DecimalType.Expression(p2, s2)) if p1 != p2 || s1 != s2 => - GreaterThan(Cast(e1, DecimalType.Unlimited), Cast(e2, DecimalType.Unlimited)) - - case GreaterThanOrEqual(e1 @ DecimalType.Expression(p1, s1), - e2 @ DecimalType.Expression(p2, s2)) if p1 != p2 || s1 != s2 => - GreaterThanOrEqual(Cast(e1, DecimalType.Unlimited), Cast(e2, DecimalType.Unlimited)) + // When we compare 2 decimal types with different precisions, cast them to the smallest + // common precision. + case b @ BinaryComparison(e1 @ DecimalType.Expression(p1, s1), + e2 @ DecimalType.Expression(p2, s2)) if p1 != p2 || s1 != s2 => + val resultType = DecimalType(max(p1, p2), max(s1, s2)) + b.makeCopy(Array(Cast(e1, resultType), Cast(e2, resultType))) // Promote integers inside a binary expression with fixed-precision decimals to decimals, // and fixed-precision decimals in an expression with floats / doubles to doubles - case b: BinaryExpression if b.left.dataType != b.right.dataType => - (b.left.dataType, b.right.dataType) match { + case b @ BinaryOperator(left, right) if left.dataType != right.dataType => + (left.dataType, right.dataType) match { case (t, DecimalType.Fixed(p, s)) if intTypeToFixed.contains(t) => - b.makeCopy(Array(Cast(b.left, intTypeToFixed(t)), b.right)) + b.makeCopy(Array(Cast(left, intTypeToFixed(t)), right)) case (DecimalType.Fixed(p, s), t) if intTypeToFixed.contains(t) => - b.makeCopy(Array(b.left, Cast(b.right, intTypeToFixed(t)))) + b.makeCopy(Array(left, Cast(right, intTypeToFixed(t)))) case (t, DecimalType.Fixed(p, s)) if isFloat(t) => - b.makeCopy(Array(b.left, Cast(b.right, DoubleType))) + b.makeCopy(Array(left, Cast(right, DoubleType))) case (DecimalType.Fixed(p, s), t) if isFloat(t) => - b.makeCopy(Array(Cast(b.left, DoubleType), b.right)) + b.makeCopy(Array(Cast(left, DoubleType), right)) case _ => b } @@ -480,56 +471,66 @@ trait HiveTypeCoercion { } /** - * Changes Boolean values to Bytes so that expressions like true < false can be Evaluated. + * Changes numeric values to booleans so that expressions like true = 1 can be evaluated. */ - object BooleanComparisons extends Rule[LogicalPlan] { - val trueValues = Seq(1, 1L, 1.toByte, 1.toShort, new java.math.BigDecimal(1)).map(Literal(_)) - val falseValues = Seq(0, 0L, 0.toByte, 0.toShort, new java.math.BigDecimal(0)).map(Literal(_)) + object BooleanEquality extends Rule[LogicalPlan] { + private val trueValues = Seq(1.toByte, 1.toShort, 1, 1L, Decimal(1)) + private val falseValues = Seq(0.toByte, 0.toShort, 0, 0L, Decimal(0)) + + private def buildCaseKeyWhen(booleanExpr: Expression, numericExpr: Expression) = { + CaseKeyWhen(numericExpr, Seq( + Literal(trueValues.head), booleanExpr, + Literal(falseValues.head), Not(booleanExpr), + Literal(false))) + } - def apply(plan: LogicalPlan): LogicalPlan = plan transformAllExpressions { - // Skip nodes who's children have not been resolved yet. - case e if !e.childrenResolved => e + private def transform(booleanExpr: Expression, numericExpr: Expression) = { + If(Or(IsNull(booleanExpr), IsNull(numericExpr)), + Literal.create(null, BooleanType), + buildCaseKeyWhen(booleanExpr, numericExpr)) + } - // Hive treats (true = 1) as true and (false = 0) as true. - case EqualTo(l @ BooleanType(), r) if trueValues.contains(r) => l - case EqualTo(l, r @ BooleanType()) if trueValues.contains(l) => r - case EqualTo(l @ BooleanType(), r) if falseValues.contains(r) => Not(l) - case EqualTo(l, r @ BooleanType()) if falseValues.contains(l) => Not(r) - - // No need to change other EqualTo operators as that actually makes sense for boolean types. - case e: EqualTo => e - // No need to change the EqualNullSafe operators, too - case e: EqualNullSafe => e - // Otherwise turn them to Byte types so that there exists and ordering. - case p: BinaryComparison if p.left.dataType == BooleanType && - p.right.dataType == BooleanType => - p.makeCopy(Array(Cast(p.left, ByteType), Cast(p.right, ByteType))) + private def transformNullSafe(booleanExpr: Expression, numericExpr: Expression) = { + CaseWhen(Seq( + And(IsNull(booleanExpr), IsNull(numericExpr)), Literal(true), + Or(IsNull(booleanExpr), IsNull(numericExpr)), Literal(false), + buildCaseKeyWhen(booleanExpr, numericExpr) + )) } - } - /** - * Casts to/from [[BooleanType]] are transformed into comparisons since - * the JVM does not consider Booleans to be numeric types. - */ - object BooleanCasts extends Rule[LogicalPlan] { def apply(plan: LogicalPlan): LogicalPlan = plan transformAllExpressions { // Skip nodes who's children have not been resolved yet. case e if !e.childrenResolved => e - // Skip if the type is boolean type already. Note that this extra cast should be removed - // by optimizer.SimplifyCasts. - case Cast(e, BooleanType) if e.dataType == BooleanType => e - // DateType should be null if be cast to boolean. - case Cast(e, BooleanType) if e.dataType == DateType => Cast(e, BooleanType) - // If the data type is not boolean and is being cast boolean, turn it into a comparison - // with the numeric value, i.e. x != 0. This will coerce the type into numeric type. - case Cast(e, BooleanType) if e.dataType != BooleanType => Not(EqualTo(e, Literal(0))) - // Stringify boolean if casting to StringType. - // TODO Ensure true/false string letter casing is consistent with Hive in all cases. - case Cast(e, StringType) if e.dataType == BooleanType => - If(e, Literal("true"), Literal("false")) - // Turn true into 1, and false into 0 if casting boolean into other types. - case Cast(e, dataType) if e.dataType == BooleanType => - Cast(If(e, Literal(1), Literal(0)), dataType) + + // Hive treats (true = 1) as true and (false = 0) as true, + // all other cases are considered as false. + + // We may simplify the expression if one side is literal numeric values + case EqualTo(bool @ BooleanType(), Literal(value, _: NumericType)) + if trueValues.contains(value) => bool + case EqualTo(bool @ BooleanType(), Literal(value, _: NumericType)) + if falseValues.contains(value) => Not(bool) + case EqualTo(Literal(value, _: NumericType), bool @ BooleanType()) + if trueValues.contains(value) => bool + case EqualTo(Literal(value, _: NumericType), bool @ BooleanType()) + if falseValues.contains(value) => Not(bool) + case EqualNullSafe(bool @ BooleanType(), Literal(value, _: NumericType)) + if trueValues.contains(value) => And(IsNotNull(bool), bool) + case EqualNullSafe(bool @ BooleanType(), Literal(value, _: NumericType)) + if falseValues.contains(value) => And(IsNotNull(bool), Not(bool)) + case EqualNullSafe(Literal(value, _: NumericType), bool @ BooleanType()) + if trueValues.contains(value) => And(IsNotNull(bool), bool) + case EqualNullSafe(Literal(value, _: NumericType), bool @ BooleanType()) + if falseValues.contains(value) => And(IsNotNull(bool), Not(bool)) + + case EqualTo(left @ BooleanType(), right @ NumericType()) => + transform(left , right) + case EqualTo(left @ NumericType(), right @ BooleanType()) => + transform(right, left) + case EqualNullSafe(left @ BooleanType(), right @ NumericType()) => + transformNullSafe(left, right) + case EqualNullSafe(left @ NumericType(), right @ BooleanType()) => + transformNullSafe(right, left) } } @@ -556,12 +557,12 @@ trait HiveTypeCoercion { // Skip nodes who's children have not been resolved yet. case e if !e.childrenResolved => e - case a @ CreateArray(children) if !a.resolved => - val commonType = a.childTypes.reduce( - (a,b) => - findTightestCommonType(a,b).getOrElse(StringType)) - CreateArray( - children.map(c => if (c.dataType == commonType) c else Cast(c, commonType))) + case a @ CreateArray(children) if children.map(_.dataType).distinct.size > 1 => + val types = children.map(_.dataType) + findTightestCommonTypeAndPromoteToString(types) match { + case Some(finalDataType) => CreateArray(children.map(Cast(_, finalDataType))) + case None => a + } // Promote SUM, SUM DISTINCT and AVERAGE to largest types to prevent overflows. case s @ Sum(e @ DecimalType()) => s // Decimal is already the biggest. @@ -587,17 +588,11 @@ trait HiveTypeCoercion { // Coalesce should return the first non-null value, which could be any column // from the list. So we need to make sure the return type is deterministic and // compatible with every child column. - case Coalesce(es) if es.map(_.dataType).distinct.size > 1 => - val dt: Option[DataType] = Some(NullType) + case c @ Coalesce(es) if es.map(_.dataType).distinct.size > 1 => val types = es.map(_.dataType) - val rt = types.foldLeft(dt)((r, c) => r match { - case None => None - case Some(d) => findTightestCommonType(d, c) - }) - rt match { - case Some(finaldt) => Coalesce(es.map(Cast(_, finaldt))) - case None => - sys.error(s"Could not determine return type of Coalesce for ${types.mkString(",")}") + findTightestCommonTypeAndPromoteToString(types) match { + case Some(finalDataType) => Coalesce(es.map(Cast(_, finalDataType))) + case None => c } } } @@ -608,19 +603,15 @@ trait HiveTypeCoercion { */ object Division extends Rule[LogicalPlan] { def apply(plan: LogicalPlan): LogicalPlan = plan transformAllExpressions { - // Skip nodes who's children have not been resolved yet. - case e if !e.childrenResolved => e + // Skip nodes who has not been resolved yet, + // as this is an extra rule which should be applied at last. + case e if !e.resolved => e // Decimal and Double remain the same - case d: Divide if d.resolved && d.dataType == DoubleType => d - case d: Divide if d.resolved && d.dataType.isInstanceOf[DecimalType] => d - - case Divide(l, r) if l.dataType.isInstanceOf[DecimalType] => - Divide(l, Cast(r, DecimalType.Unlimited)) - case Divide(l, r) if r.dataType.isInstanceOf[DecimalType] => - Divide(Cast(l, DecimalType.Unlimited), r) + case d: Divide if d.dataType == DoubleType => d + case d: Divide if d.dataType.isInstanceOf[DecimalType] => d - case Divide(l, r) => Divide(Cast(l, DoubleType), Cast(r, DoubleType)) + case Divide(left, right) => Divide(Cast(left, DoubleType), Cast(right, DoubleType)) } } @@ -628,47 +619,129 @@ trait HiveTypeCoercion { * Coerces the type of different branches of a CASE WHEN statement to a common type. */ object CaseWhenCoercion extends Rule[LogicalPlan] { - import HiveTypeCoercion._ + def apply(plan: LogicalPlan): LogicalPlan = plan transformAllExpressions { + case c: CaseWhenLike if c.childrenResolved && !c.valueTypesEqual => + logDebug(s"Input values for null casting ${c.valueTypes.mkString(",")}") + val maybeCommonType = findTightestCommonTypeAndPromoteToString(c.valueTypes) + maybeCommonType.map { commonType => + val castedBranches = c.branches.grouped(2).map { + case Seq(when, value) if value.dataType != commonType => + Seq(when, Cast(value, commonType)) + case Seq(elseVal) if elseVal.dataType != commonType => + Seq(Cast(elseVal, commonType)) + case other => other + }.reduce(_ ++ _) + c match { + case _: CaseWhen => CaseWhen(castedBranches) + case CaseKeyWhen(key, _) => CaseKeyWhen(key, castedBranches) + } + }.getOrElse(c) + + case c: CaseKeyWhen if c.childrenResolved && !c.resolved => + val maybeCommonType = + findTightestCommonTypeAndPromoteToString((c.key +: c.whenList).map(_.dataType)) + maybeCommonType.map { commonType => + val castedBranches = c.branches.grouped(2).map { + case Seq(whenExpr, thenExpr) if whenExpr.dataType != commonType => + Seq(Cast(whenExpr, commonType), thenExpr) + case other => other + }.reduce(_ ++ _) + CaseKeyWhen(Cast(c.key, commonType), castedBranches) + }.getOrElse(c) + } + } + /** + * Coerces the type of different branches of If statement to a common type. + */ + object IfCoercion extends Rule[LogicalPlan] { def apply(plan: LogicalPlan): LogicalPlan = plan transformAllExpressions { - case cw: CaseWhenLike if !cw.resolved && cw.childrenResolved && !cw.valueTypesEqual => - logDebug(s"Input values for null casting ${cw.valueTypes.mkString(",")}") - val commonType = cw.valueTypes.reduce { (v1, v2) => - findTightestCommonType(v1, v2).getOrElse(sys.error( - s"Types in CASE WHEN must be the same or coercible to a common type: $v1 != $v2")) - } - val transformedBranches = cw.branches.sliding(2, 2).map { - case Seq(when, value) if value.dataType != commonType => - Seq(when, Cast(value, commonType)) - case Seq(elseVal) if elseVal.dataType != commonType => - Seq(Cast(elseVal, commonType)) - case s => s - }.reduce(_ ++ _) - cw match { - case _: CaseWhen => - CaseWhen(transformedBranches) - case CaseKeyWhen(key, _) => - CaseKeyWhen(key, transformedBranches) - } + // Find tightest common type for If, if the true value and false value have different types. + case i @ If(pred, left, right) if left.dataType != right.dataType => + findTightestCommonTypeToString(left.dataType, right.dataType).map { widestType => + val newLeft = if (left.dataType == widestType) left else Cast(left, widestType) + val newRight = if (right.dataType == widestType) right else Cast(right, widestType) + If(pred, newLeft, newRight) + }.getOrElse(i) // If there is no applicable conversion, leave expression unchanged. + + // Convert If(null literal, _, _) into boolean type. + // In the optimizer, we should short-circuit this directly into false value. + case If(pred, left, right) if pred.dataType == NullType => + If(Literal.create(null, BooleanType), left, right) } } /** * Casts types according to the expected input types for Expressions that have the trait - * `ExpectsInputTypes`. + * [[ExpectsInputTypes]]. */ - object ExpectedInputConversion extends Rule[LogicalPlan] { - + object ImplicitTypeCasts extends Rule[LogicalPlan] { def apply(plan: LogicalPlan): LogicalPlan = plan transformAllExpressions { // Skip nodes who's children have not been resolved yet. case e if !e.childrenResolved => e - case e: ExpectsInputTypes if e.children.map(_.dataType) != e.expectedChildTypes => - val newC = (e.children, e.children.map(_.dataType), e.expectedChildTypes).zipped.map { - case (child, actual, expected) => - if (actual == expected) child else Cast(child, expected) + case e: ExpectsInputTypes if (e.inputTypes.nonEmpty) => + val children: Seq[Expression] = e.children.zip(e.inputTypes).map { case (in, expected) => + // If we cannot do the implicit cast, just use the original input. + implicitCast(in, expected).getOrElse(in) } - e.withNewChildren(newC) + e.withNewChildren(children) + } + + /** + * Given an expected data type, try to cast the expression and return the cast expression. + * + * If the expression already fits the input type, we simply return the expression itself. + * If the expression has an incompatible type that cannot be implicitly cast, return None. + */ + def implicitCast(e: Expression, expectedType: AbstractDataType): Option[Expression] = { + val inType = e.dataType + + // Note that ret is nullable to avoid typing a lot of Some(...) in this local scope. + // We wrap immediately an Option after this. + @Nullable val ret: Expression = (inType, expectedType) match { + + // If the expected type is already a parent of the input type, no need to cast. + case _ if expectedType.isParentOf(inType) => e + + // Cast null type (usually from null literals) into target types + case (NullType, target) => Cast(e, target.defaultConcreteType) + + // Implicit cast among numeric types + // If input is a numeric type but not decimal, and we expect a decimal type, + // cast the input to unlimited precision decimal. + case (_: NumericType, DecimalType) if !inType.isInstanceOf[DecimalType] => + Cast(e, DecimalType.Unlimited) + // For any other numeric types, implicitly cast to each other, e.g. long -> int, int -> long + case (_: NumericType, target: NumericType) if e.dataType != target => Cast(e, target) + case (_: NumericType, target: NumericType) => e + + // Implicit cast between date time types + case (DateType, TimestampType) => Cast(e, TimestampType) + case (TimestampType, DateType) => Cast(e, DateType) + + // Implicit cast from/to string + case (StringType, DecimalType) => Cast(e, DecimalType.Unlimited) + case (StringType, target: NumericType) => Cast(e, target) + case (StringType, DateType) => Cast(e, DateType) + case (StringType, TimestampType) => Cast(e, TimestampType) + case (StringType, BinaryType) => Cast(e, BinaryType) + case (any, StringType) if any != StringType => Cast(e, StringType) + + // Type collection. + // First see if we can find our input type in the type collection. If we can, then just + // use the current expression; otherwise, find the first one we can implicitly cast. + case (_, TypeCollection(types)) => + if (types.exists(_.isParentOf(inType))) { + e + } else { + types.flatMap(implicitCast(e, _)).headOption.orNull + } + + // Else, just return the same input expression + case _ => null + } + Option(ret) } } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCheckResult.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCheckResult.scala new file mode 100644 index 000000000000..79c3528a522d --- /dev/null +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCheckResult.scala @@ -0,0 +1,45 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.analysis + +/** + * Represents the result of `Expression.checkInputDataTypes`. + * We will throw `AnalysisException` in `CheckAnalysis` if `isFailure` is true. + */ +trait TypeCheckResult { + def isFailure: Boolean = !isSuccess + def isSuccess: Boolean +} + +object TypeCheckResult { + + /** + * Represents the successful result of `Expression.checkInputDataTypes`. + */ + object TypeCheckSuccess extends TypeCheckResult { + def isSuccess: Boolean = true + } + + /** + * Represents the failing result of `Expression.checkInputDataTypes`, + * with a error message to show the reason of failure. + */ + case class TypeCheckFailure(message: String) extends TypeCheckResult { + def isSuccess: Boolean = false + } +} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/unresolved.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/unresolved.scala index 2999c2ef3efe..ae3adbab0510 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/unresolved.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/unresolved.scala @@ -67,7 +67,7 @@ case class UnresolvedAttribute(nameParts: Seq[String]) override def withName(newName: String): UnresolvedAttribute = UnresolvedAttribute.quoted(newName) // Unresolved attributes are transient at compile time and don't get evaluated during execution. - override def eval(input: Row = null): EvaluatedType = + override def eval(input: InternalRow = null): Any = throw new TreeNodeException(this, s"No function to evaluate expression. type: ${this.nodeName}") override def toString: String = s"'$name" @@ -85,7 +85,7 @@ case class UnresolvedFunction(name: String, children: Seq[Expression]) extends E override lazy val resolved = false // Unresolved functions are transient at compile time and don't get evaluated during execution. - override def eval(input: Row = null): EvaluatedType = + override def eval(input: InternalRow = null): Any = throw new TreeNodeException(this, s"No function to evaluate expression. type: ${this.nodeName}") override def toString: String = s"'$name(${children.mkString(",")})" @@ -107,7 +107,7 @@ trait Star extends NamedExpression with trees.LeafNode[Expression] { override lazy val resolved = false // Star gets expanded at runtime so we never evaluate a Star. - override def eval(input: Row = null): EvaluatedType = + override def eval(input: InternalRow = null): Any = throw new TreeNodeException(this, s"No function to evaluate expression. type: ${this.nodeName}") def expand(input: Seq[Attribute], resolver: Resolver): Seq[NamedExpression] @@ -166,7 +166,7 @@ case class MultiAlias(child: Expression, names: Seq[String]) override lazy val resolved = false - override def eval(input: Row = null): EvaluatedType = + override def eval(input: InternalRow = null): Any = throw new TreeNodeException(this, s"No function to evaluate expression. type: ${this.nodeName}") override def toString: String = s"$child AS $names" @@ -200,8 +200,27 @@ case class UnresolvedExtractValue(child: Expression, extraction: Expression) override def nullable: Boolean = throw new UnresolvedException(this, "nullable") override lazy val resolved = false - override def eval(input: Row = null): EvaluatedType = + override def eval(input: InternalRow = null): Any = throw new TreeNodeException(this, s"No function to evaluate expression. type: ${this.nodeName}") override def toString: String = s"$child[$extraction]" } + +/** + * Holds the expression that has yet to be aliased. + */ +case class UnresolvedAlias(child: Expression) extends NamedExpression + with trees.UnaryNode[Expression] { + + override def toAttribute: Attribute = throw new UnresolvedException(this, "toAttribute") + override def qualifiers: Seq[String] = throw new UnresolvedException(this, "qualifiers") + override def exprId: ExprId = throw new UnresolvedException(this, "exprId") + override def nullable: Boolean = throw new UnresolvedException(this, "nullable") + override def dataType: DataType = throw new UnresolvedException(this, "dataType") + override def name: String = throw new UnresolvedException(this, "name") + + override lazy val resolved = false + + override def eval(input: InternalRow = null): Any = + throw new TreeNodeException(this, s"No function to evaluate expression. type: ${this.nodeName}") +} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala index 4c0d70203c6f..51821757967d 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala @@ -20,7 +20,6 @@ package org.apache.spark.sql.catalyst import java.sql.{Date, Timestamp} import scala.language.implicitConversions -import scala.reflect.runtime.universe.{TypeTag, typeTag} import org.apache.spark.sql.catalyst.analysis.{EliminateSubQueries, UnresolvedExtractValue, UnresolvedAttribute} import org.apache.spark.sql.catalyst.expressions._ @@ -61,7 +60,7 @@ package object dsl { trait ImplicitOperators { def expr: Expression - def unary_- : Expression= UnaryMinus(expr) + def unary_- : Expression = UnaryMinus(expr) def unary_! : Predicate = Not(expr) def unary_~ : Expression = BitwiseNot(expr) @@ -141,7 +140,7 @@ package object dsl { // Note that if we make ExpressionConversions an object rather than a trait, we can // then make this a value class to avoid the small penalty of runtime instantiation. def $(args: Any*): analysis.UnresolvedAttribute = { - analysis.UnresolvedAttribute(sc.s(args :_*)) + analysis.UnresolvedAttribute(sc.s(args : _*)) } } @@ -234,129 +233,59 @@ package object dsl { implicit class DslAttribute(a: AttributeReference) { def notNull: AttributeReference = a.withNullability(false) def nullable: AttributeReference = a.withNullability(true) - - // Protobuf terminology - def required: AttributeReference = a.withNullability(false) - def at(ordinal: Int): BoundReference = BoundReference(ordinal, a.dataType, a.nullable) } } - object expressions extends ExpressionConversions // scalastyle:ignore - abstract class LogicalPlanFunctions { - def logicalPlan: LogicalPlan - - def select(exprs: NamedExpression*): LogicalPlan = Project(exprs, logicalPlan) + object plans { // scalastyle:ignore + implicit class DslLogicalPlan(val logicalPlan: LogicalPlan) { + def select(exprs: NamedExpression*): LogicalPlan = Project(exprs, logicalPlan) - def where(condition: Expression): LogicalPlan = Filter(condition, logicalPlan) + def where(condition: Expression): LogicalPlan = Filter(condition, logicalPlan) - def limit(limitExpr: Expression): LogicalPlan = Limit(limitExpr, logicalPlan) + def limit(limitExpr: Expression): LogicalPlan = Limit(limitExpr, logicalPlan) - def join( + def join( otherPlan: LogicalPlan, joinType: JoinType = Inner, condition: Option[Expression] = None): LogicalPlan = - Join(logicalPlan, otherPlan, joinType, condition) + Join(logicalPlan, otherPlan, joinType, condition) - def orderBy(sortExprs: SortOrder*): LogicalPlan = Sort(sortExprs, true, logicalPlan) + def orderBy(sortExprs: SortOrder*): LogicalPlan = Sort(sortExprs, true, logicalPlan) - def sortBy(sortExprs: SortOrder*): LogicalPlan = Sort(sortExprs, false, logicalPlan) + def sortBy(sortExprs: SortOrder*): LogicalPlan = Sort(sortExprs, false, logicalPlan) - def groupBy(groupingExprs: Expression*)(aggregateExprs: Expression*): LogicalPlan = { - val aliasedExprs = aggregateExprs.map { - case ne: NamedExpression => ne - case e => Alias(e, e.toString)() + def groupBy(groupingExprs: Expression*)(aggregateExprs: Expression*): LogicalPlan = { + val aliasedExprs = aggregateExprs.map { + case ne: NamedExpression => ne + case e => Alias(e, e.toString)() + } + Aggregate(groupingExprs, aliasedExprs, logicalPlan) } - Aggregate(groupingExprs, aliasedExprs, logicalPlan) - } - def subquery(alias: Symbol): LogicalPlan = Subquery(alias.name, logicalPlan) + def subquery(alias: Symbol): LogicalPlan = Subquery(alias.name, logicalPlan) + + def except(otherPlan: LogicalPlan): LogicalPlan = Except(logicalPlan, otherPlan) - def unionAll(otherPlan: LogicalPlan): LogicalPlan = Union(logicalPlan, otherPlan) + def intersect(otherPlan: LogicalPlan): LogicalPlan = Intersect(logicalPlan, otherPlan) - def sfilter[T1](arg1: Symbol)(udf: (T1) => Boolean): LogicalPlan = - Filter(ScalaUdf(udf, BooleanType, Seq(UnresolvedAttribute(arg1.name))), logicalPlan) + def unionAll(otherPlan: LogicalPlan): LogicalPlan = Union(logicalPlan, otherPlan) - // TODO specify the output column names - def generate( + // TODO specify the output column names + def generate( generator: Generator, join: Boolean = false, outer: Boolean = false, alias: Option[String] = None): LogicalPlan = - Generate(generator, join = join, outer = outer, alias, Nil, logicalPlan) + Generate(generator, join = join, outer = outer, alias, Nil, logicalPlan) - def insertInto(tableName: String, overwrite: Boolean = false): LogicalPlan = - InsertIntoTable( - analysis.UnresolvedRelation(Seq(tableName)), Map.empty, logicalPlan, overwrite, false) + def insertInto(tableName: String, overwrite: Boolean = false): LogicalPlan = + InsertIntoTable( + analysis.UnresolvedRelation(Seq(tableName)), Map.empty, logicalPlan, overwrite, false) - def analyze: LogicalPlan = EliminateSubQueries(analysis.SimpleAnalyzer.execute(logicalPlan)) - } - - object plans { // scalastyle:ignore - implicit class DslLogicalPlan(val logicalPlan: LogicalPlan) extends LogicalPlanFunctions { - def writeToFile(path: String): LogicalPlan = WriteToFile(path, logicalPlan) + def analyze: LogicalPlan = EliminateSubQueries(analysis.SimpleAnalyzer.execute(logicalPlan)) } } - - case class ScalaUdfBuilder[T: TypeTag](f: AnyRef) { - def call(args: Expression*): ScalaUdf = { - ScalaUdf(f, ScalaReflection.schemaFor(typeTag[T]).dataType, args) - } - } - - // scalastyle:off - /** functionToUdfBuilder 1-22 were generated by this script - - (1 to 22).map { x => - val argTypes = Seq.fill(x)("_").mkString(", ") - s"implicit def functionToUdfBuilder[T: TypeTag](func: Function$x[$argTypes, T]): ScalaUdfBuilder[T] = ScalaUdfBuilder(func)" - } - */ - - implicit def functionToUdfBuilder[T: TypeTag](func: Function1[_, T]): ScalaUdfBuilder[T] = ScalaUdfBuilder(func) - - implicit def functionToUdfBuilder[T: TypeTag](func: Function2[_, _, T]): ScalaUdfBuilder[T] = ScalaUdfBuilder(func) - - implicit def functionToUdfBuilder[T: TypeTag](func: Function3[_, _, _, T]): ScalaUdfBuilder[T] = ScalaUdfBuilder(func) - - implicit def functionToUdfBuilder[T: TypeTag](func: Function4[_, _, _, _, T]): ScalaUdfBuilder[T] = ScalaUdfBuilder(func) - - implicit def functionToUdfBuilder[T: TypeTag](func: Function5[_, _, _, _, _, T]): ScalaUdfBuilder[T] = ScalaUdfBuilder(func) - - implicit def functionToUdfBuilder[T: TypeTag](func: Function6[_, _, _, _, _, _, T]): ScalaUdfBuilder[T] = ScalaUdfBuilder(func) - - implicit def functionToUdfBuilder[T: TypeTag](func: Function7[_, _, _, _, _, _, _, T]): ScalaUdfBuilder[T] = ScalaUdfBuilder(func) - - implicit def functionToUdfBuilder[T: TypeTag](func: Function8[_, _, _, _, _, _, _, _, T]): ScalaUdfBuilder[T] = ScalaUdfBuilder(func) - - implicit def functionToUdfBuilder[T: TypeTag](func: Function9[_, _, _, _, _, _, _, _, _, T]): ScalaUdfBuilder[T] = ScalaUdfBuilder(func) - - implicit def functionToUdfBuilder[T: TypeTag](func: Function10[_, _, _, _, _, _, _, _, _, _, T]): ScalaUdfBuilder[T] = ScalaUdfBuilder(func) - - implicit def functionToUdfBuilder[T: TypeTag](func: Function11[_, _, _, _, _, _, _, _, _, _, _, T]): ScalaUdfBuilder[T] = ScalaUdfBuilder(func) - - implicit def functionToUdfBuilder[T: TypeTag](func: Function12[_, _, _, _, _, _, _, _, _, _, _, _, T]): ScalaUdfBuilder[T] = ScalaUdfBuilder(func) - - implicit def functionToUdfBuilder[T: TypeTag](func: Function13[_, _, _, _, _, _, _, _, _, _, _, _, _, T]): ScalaUdfBuilder[T] = ScalaUdfBuilder(func) - - implicit def functionToUdfBuilder[T: TypeTag](func: Function14[_, _, _, _, _, _, _, _, _, _, _, _, _, _, T]): ScalaUdfBuilder[T] = ScalaUdfBuilder(func) - - implicit def functionToUdfBuilder[T: TypeTag](func: Function15[_, _, _, _, _, _, _, _, _, _, _, _, _, _, _, T]): ScalaUdfBuilder[T] = ScalaUdfBuilder(func) - - implicit def functionToUdfBuilder[T: TypeTag](func: Function16[_, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, T]): ScalaUdfBuilder[T] = ScalaUdfBuilder(func) - - implicit def functionToUdfBuilder[T: TypeTag](func: Function17[_, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, T]): ScalaUdfBuilder[T] = ScalaUdfBuilder(func) - - implicit def functionToUdfBuilder[T: TypeTag](func: Function18[_, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, T]): ScalaUdfBuilder[T] = ScalaUdfBuilder(func) - - implicit def functionToUdfBuilder[T: TypeTag](func: Function19[_, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, T]): ScalaUdfBuilder[T] = ScalaUdfBuilder(func) - - implicit def functionToUdfBuilder[T: TypeTag](func: Function20[_, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, T]): ScalaUdfBuilder[T] = ScalaUdfBuilder(func) - - implicit def functionToUdfBuilder[T: TypeTag](func: Function21[_, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, T]): ScalaUdfBuilder[T] = ScalaUdfBuilder(func) - - implicit def functionToUdfBuilder[T: TypeTag](func: Function22[_, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, T]): ScalaUdfBuilder[T] = ScalaUdfBuilder(func) - // scalastyle:on } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/errors/package.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/errors/package.scala index 0fd4f9b374ee..d2a90a50c89f 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/errors/package.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/errors/package.scala @@ -49,11 +49,4 @@ package object errors { case e: Exception => throw new TreeNodeException(tree, msg, e) } } - - /** - * Executes `f` which is expected to throw a - * [[catalyst.errors.TreeNodeException TreeNodeException]]. The first tree encountered in - * the stack of exceptions of type `TreeType` is returned. - */ - def getTree[TreeType <: TreeNode[_]](f: => Unit): TreeType = ??? // TODO: Implement } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/BoundAttribute.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/BoundAttribute.scala index c6217f07c452..dc0b4ac5cd9b 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/BoundAttribute.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/BoundAttribute.scala @@ -19,8 +19,9 @@ package org.apache.spark.sql.catalyst.expressions import org.apache.spark.Logging import org.apache.spark.sql.catalyst.errors.attachTree -import org.apache.spark.sql.types._ +import org.apache.spark.sql.catalyst.expressions.codegen.{CodeGenContext, GeneratedExpressionCode} import org.apache.spark.sql.catalyst.trees +import org.apache.spark.sql.types._ /** * A bound reference points to a specific slot in the input tuple, allowing the actual value @@ -30,11 +31,9 @@ import org.apache.spark.sql.catalyst.trees case class BoundReference(ordinal: Int, dataType: DataType, nullable: Boolean) extends NamedExpression with trees.LeafNode[Expression] { - type EvaluatedType = Any - override def toString: String = s"input[$ordinal]" - override def eval(input: Row): Any = input(ordinal) + override def eval(input: InternalRow): Any = input(ordinal) override def name: String = s"i[$ordinal]" @@ -43,6 +42,14 @@ case class BoundReference(ordinal: Int, dataType: DataType, nullable: Boolean) override def qualifiers: Seq[String] = throw new UnsupportedOperationException override def exprId: ExprId = throw new UnsupportedOperationException + + override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { + s""" + boolean ${ev.isNull} = i.isNullAt($ordinal); + ${ctx.javaType(dataType)} ${ev.primitive} = ${ev.isNull} ? + ${ctx.defaultValue(dataType)} : (${ctx.getColumn("i", dataType, ordinal)}); + """ + } } object BindReferences extends Logging { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala index d8cf2b2e3243..2d99d1a3fe8d 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala @@ -17,105 +17,119 @@ package org.apache.spark.sql.catalyst.expressions +import java.math.{BigDecimal => JavaBigDecimal} import java.sql.{Date, Timestamp} -import java.text.{DateFormat, SimpleDateFormat} import org.apache.spark.Logging -import org.apache.spark.sql.catalyst.util.DateUtils +import org.apache.spark.sql.catalyst.analysis.TypeCheckResult +import org.apache.spark.sql.catalyst.expressions.codegen.{CodeGenContext, GeneratedExpressionCode} +import org.apache.spark.sql.catalyst.util.DateTimeUtils import org.apache.spark.sql.types._ +import org.apache.spark.unsafe.types.UTF8String -/** Cast the child expression to the target data type. */ -case class Cast(child: Expression, dataType: DataType) extends UnaryExpression with Logging { - override lazy val resolved = childrenResolved && resolve(child.dataType, dataType) - - override def foldable: Boolean = child.foldable - - override def nullable: Boolean = forceNullable(child.dataType, dataType) || child.nullable - - private[this] def forceNullable(from: DataType, to: DataType) = (from, to) match { - case (StringType, _: NumericType) => true - case (StringType, TimestampType) => true - case (DoubleType, TimestampType) => true - case (FloatType, TimestampType) => true - case (StringType, DateType) => true - case (_: NumericType, DateType) => true - case (BooleanType, DateType) => true - case (DateType, _: NumericType) => true - case (DateType, BooleanType) => true - case (DoubleType, _: DecimalType) => true - case (FloatType, _: DecimalType) => true - case (_, DecimalType.Fixed(_, _)) => true // TODO: not all upcasts here can really give null - case _ => false - } +object Cast { - private[this] def resolvableNullability(from: Boolean, to: Boolean) = !from || to + /** + * Returns true iff we can cast `from` type to `to` type. + */ + def canCast(from: DataType, to: DataType): Boolean = (from, to) match { + case (fromType, toType) if fromType == toType => true - private[this] def resolve(from: DataType, to: DataType): Boolean = { - (from, to) match { - case (from, to) if from == to => true + case (NullType, _) => true - case (NullType, _) => true + case (_, StringType) => true - case (_, StringType) => true + case (StringType, BinaryType) => true - case (StringType, BinaryType) => true + case (StringType, BooleanType) => true + case (DateType, BooleanType) => true + case (TimestampType, BooleanType) => true + case (_: NumericType, BooleanType) => true - case (StringType, BooleanType) => true - case (DateType, BooleanType) => true - case (TimestampType, BooleanType) => true - case (_: NumericType, BooleanType) => true + case (StringType, TimestampType) => true + case (BooleanType, TimestampType) => true + case (DateType, TimestampType) => true + case (_: NumericType, TimestampType) => true - case (StringType, TimestampType) => true - case (BooleanType, TimestampType) => true - case (DateType, TimestampType) => true - case (_: NumericType, TimestampType) => true + case (_, DateType) => true - case (_, DateType) => true + case (StringType, _: NumericType) => true + case (BooleanType, _: NumericType) => true + case (DateType, _: NumericType) => true + case (TimestampType, _: NumericType) => true + case (_: NumericType, _: NumericType) => true + + case (ArrayType(fromType, fn), ArrayType(toType, tn)) => + canCast(fromType, toType) && + resolvableNullability(fn || forceNullable(fromType, toType), tn) + + case (MapType(fromKey, fromValue, fn), MapType(toKey, toValue, tn)) => + canCast(fromKey, toKey) && + (!forceNullable(fromKey, toKey)) && + canCast(fromValue, toValue) && + resolvableNullability(fn || forceNullable(fromValue, toValue), tn) + + case (StructType(fromFields), StructType(toFields)) => + fromFields.length == toFields.length && + fromFields.zip(toFields).forall { + case (fromField, toField) => + canCast(fromField.dataType, toField.dataType) && + resolvableNullability( + fromField.nullable || forceNullable(fromField.dataType, toField.dataType), + toField.nullable) + } - case (StringType, _: NumericType) => true - case (BooleanType, _: NumericType) => true - case (DateType, _: NumericType) => true - case (TimestampType, _: NumericType) => true - case (_: NumericType, _: NumericType) => true + case _ => false + } - case (ArrayType(from, fn), ArrayType(to, tn)) => - resolve(from, to) && - resolvableNullability(fn || forceNullable(from, to), tn) + private def resolvableNullability(from: Boolean, to: Boolean) = !from || to - case (MapType(fromKey, fromValue, fn), MapType(toKey, toValue, tn)) => - resolve(fromKey, toKey) && - (!forceNullable(fromKey, toKey)) && - resolve(fromValue, toValue) && - resolvableNullability(fn || forceNullable(fromValue, toValue), tn) + private def forceNullable(from: DataType, to: DataType) = (from, to) match { + case (StringType, _: NumericType) => true + case (StringType, TimestampType) => true + case (DoubleType, TimestampType) => true + case (FloatType, TimestampType) => true + case (StringType, DateType) => true + case (_: NumericType, DateType) => true + case (BooleanType, DateType) => true + case (DateType, _: NumericType) => true + case (DateType, BooleanType) => true + case (DoubleType, _: DecimalType) => true + case (FloatType, _: DecimalType) => true + case (_, DecimalType.Fixed(_, _)) => true // TODO: not all upcasts here can really give null + case _ => false + } +} - case (StructType(fromFields), StructType(toFields)) => - fromFields.size == toFields.size && - fromFields.zip(toFields).forall { - case (fromField, toField) => - resolve(fromField.dataType, toField.dataType) && - resolvableNullability( - fromField.nullable || forceNullable(fromField.dataType, toField.dataType), - toField.nullable) - } +/** Cast the child expression to the target data type. */ +case class Cast(child: Expression, dataType: DataType) extends UnaryExpression with Logging { - case _ => false + override def checkInputDataTypes(): TypeCheckResult = { + if (Cast.canCast(child.dataType, dataType)) { + TypeCheckResult.TypeCheckSuccess + } else { + TypeCheckResult.TypeCheckFailure( + s"cannot cast ${child.dataType} to $dataType") } } - override def toString: String = s"CAST($child, $dataType)" + override def foldable: Boolean = child.foldable - type EvaluatedType = Any + override def nullable: Boolean = Cast.forceNullable(child.dataType, dataType) || child.nullable + + override def toString: String = s"CAST($child, $dataType)" // [[func]] assumes the input is no longer null because eval already does the null check. @inline private[this] def buildCast[T](a: Any, func: T => Any): Any = func(a.asInstanceOf[T]) // UDFToString private[this] def castToString(from: DataType): Any => Any = from match { - case BinaryType => buildCast[Array[Byte]](_, UTF8String(_)) - case DateType => buildCast[Int](_, d => UTF8String(DateUtils.toString(d))) - case TimestampType => buildCast[Timestamp](_, t => UTF8String(timestampToString(t))) - case _ => buildCast[Any](_, o => UTF8String(o.toString)) + case BinaryType => buildCast[Array[Byte]](_, UTF8String.fromBytes) + case DateType => buildCast[Int](_, d => UTF8String.fromString(DateTimeUtils.dateToString(d))) + case TimestampType => buildCast[Long](_, + t => UTF8String.fromString(DateTimeUtils.timestampToString(t))) + case _ => buildCast[Any](_, o => UTF8String.fromString(o.toString)) } // BinaryConverter @@ -128,7 +142,7 @@ case class Cast(child: Expression, dataType: DataType) extends UnaryExpression w case StringType => buildCast[UTF8String](_, _.length() != 0) case TimestampType => - buildCast[Timestamp](_, t => t.getTime() != 0 || t.getNanos() != 0) + buildCast[Long](_, t => t != 0) case DateType => // Hive would return null when cast from date to boolean buildCast[Int](_, d => null) @@ -141,7 +155,7 @@ case class Cast(child: Expression, dataType: DataType) extends UnaryExpression w case ByteType => buildCast[Byte](_, _ != 0) case DecimalType() => - buildCast[Decimal](_, _ != 0) + buildCast[Decimal](_, _ != Decimal(0)) case DoubleType => buildCast[Double](_, _ != 0) case FloatType => @@ -159,20 +173,21 @@ case class Cast(child: Expression, dataType: DataType) extends UnaryExpression w if (periodIdx != -1 && n.length() - periodIdx > 9) { n = n.substring(0, periodIdx + 10) } - try Timestamp.valueOf(n) catch { case _: java.lang.IllegalArgumentException => null } + try DateTimeUtils.fromJavaTimestamp(Timestamp.valueOf(n)) + catch { case _: java.lang.IllegalArgumentException => null } }) case BooleanType => - buildCast[Boolean](_, b => new Timestamp((if (b) 1 else 0))) + buildCast[Boolean](_, b => if (b) 1L else 0) case LongType => - buildCast[Long](_, l => new Timestamp(l)) + buildCast[Long](_, l => longToTimestamp(l)) case IntegerType => - buildCast[Int](_, i => new Timestamp(i)) + buildCast[Int](_, i => longToTimestamp(i.toLong)) case ShortType => - buildCast[Short](_, s => new Timestamp(s)) + buildCast[Short](_, s => longToTimestamp(s.toLong)) case ByteType => - buildCast[Byte](_, b => new Timestamp(b)) + buildCast[Byte](_, b => longToTimestamp(b.toLong)) case DateType => - buildCast[Int](_, d => new Timestamp(DateUtils.toJavaDate(d).getTime)) + buildCast[Int](_, d => DateTimeUtils.daysToMillis(d) * 10000) // TimestampWritable.decimalToTimestamp case DecimalType() => buildCast[Decimal](_, d => decimalToTimestamp(d)) @@ -192,50 +207,30 @@ case class Cast(child: Expression, dataType: DataType) extends UnaryExpression w }) } - private[this] def decimalToTimestamp(d: Decimal) = { - val seconds = Math.floor(d.toDouble).toLong - val bd = (d.toBigDecimal - seconds) * 1000000000 - val nanos = bd.intValue() - - val millis = seconds * 1000 - val t = new Timestamp(millis) - - // remaining fractional portion as nanos - t.setNanos(nanos) - t + private[this] def decimalToTimestamp(d: Decimal): Long = { + (d.toBigDecimal * 10000000L).longValue() } - // Timestamp to long, converting milliseconds to seconds - private[this] def timestampToLong(ts: Timestamp) = Math.floor(ts.getTime / 1000.0).toLong - - private[this] def timestampToDouble(ts: Timestamp) = { - // First part is the seconds since the beginning of time, followed by nanosecs. - Math.floor(ts.getTime / 1000.0).toLong + ts.getNanos.toDouble / 1000000000 - } - - // Converts Timestamp to string according to Hive TimestampWritable convention - private[this] def timestampToString(ts: Timestamp): String = { - val timestampString = ts.toString - val formatted = Cast.threadLocalTimestampFormat.get.format(ts) - - if (timestampString.length > 19 && timestampString.substring(19) != ".0") { - formatted + timestampString.substring(19) - } else { - formatted - } + // converting milliseconds to 100ns + private[this] def longToTimestamp(t: Long): Long = t * 10000L + // converting 100ns to seconds + private[this] def timestampToLong(ts: Long): Long = math.floor(ts.toDouble / 10000000L).toLong + // converting 100ns to seconds in double + private[this] def timestampToDouble(ts: Long): Double = { + ts / 10000000.0 } // DateConverter private[this] def castToDate(from: DataType): Any => Any = from match { case StringType => buildCast[UTF8String](_, s => - try DateUtils.fromJavaDate(Date.valueOf(s.toString)) + try DateTimeUtils.fromJavaDate(Date.valueOf(s.toString)) catch { case _: java.lang.IllegalArgumentException => null } ) case TimestampType => // throw valid precision more than seconds, according to Hive. // Timestamp.nanos is in 0 to 999,999,999, no more than a second. - buildCast[Timestamp](_, t => DateUtils.millisToDays(t.getTime)) + buildCast[Long](_, t => DateTimeUtils.millisToDays(t / 10000L)) // Hive throws this exception as a Semantic Exception // It is never possible to compare result when hive return with exception, // so we can return null @@ -254,7 +249,7 @@ case class Cast(child: Expression, dataType: DataType) extends UnaryExpression w case DateType => buildCast[Int](_, d => null) case TimestampType => - buildCast[Timestamp](_, t => timestampToLong(t)) + buildCast[Long](_, t => timestampToLong(t)) case x: NumericType => b => x.numeric.asInstanceOf[Numeric[Any]].toLong(b) } @@ -270,7 +265,7 @@ case class Cast(child: Expression, dataType: DataType) extends UnaryExpression w case DateType => buildCast[Int](_, d => null) case TimestampType => - buildCast[Timestamp](_, t => timestampToLong(t).toInt) + buildCast[Long](_, t => timestampToLong(t).toInt) case x: NumericType => b => x.numeric.asInstanceOf[Numeric[Any]].toInt(b) } @@ -286,7 +281,7 @@ case class Cast(child: Expression, dataType: DataType) extends UnaryExpression w case DateType => buildCast[Int](_, d => null) case TimestampType => - buildCast[Timestamp](_, t => timestampToLong(t).toShort) + buildCast[Long](_, t => timestampToLong(t).toShort) case x: NumericType => b => x.numeric.asInstanceOf[Numeric[Any]].toInt(b).toShort } @@ -302,7 +297,7 @@ case class Cast(child: Expression, dataType: DataType) extends UnaryExpression w case DateType => buildCast[Int](_, d => null) case TimestampType => - buildCast[Timestamp](_, t => timestampToLong(t).toByte) + buildCast[Long](_, t => timestampToLong(t).toByte) case x: NumericType => b => x.numeric.asInstanceOf[Numeric[Any]].toInt(b).toByte } @@ -325,7 +320,7 @@ case class Cast(child: Expression, dataType: DataType) extends UnaryExpression w private[this] def castToDecimal(from: DataType, target: DecimalType): Any => Any = from match { case StringType => buildCast[UTF8String](_, s => try { - changePrecision(Decimal(s.toString.toDouble), target) + changePrecision(Decimal(new JavaBigDecimal(s.toString)), target) } catch { case _: NumberFormatException => null }) @@ -335,7 +330,7 @@ case class Cast(child: Expression, dataType: DataType) extends UnaryExpression w buildCast[Int](_, d => null) // date can't cast to decimal in Hive case TimestampType => // Note that we lose precision here. - buildCast[Timestamp](_, t => changePrecision(Decimal(timestampToDouble(t)), target)) + buildCast[Long](_, t => changePrecision(Decimal(timestampToDouble(t)), target)) case DecimalType() => b => changePrecision(b.asInstanceOf[Decimal].clone(), target) case LongType => @@ -359,7 +354,7 @@ case class Cast(child: Expression, dataType: DataType) extends UnaryExpression w case DateType => buildCast[Int](_, d => null) case TimestampType => - buildCast[Timestamp](_, t => timestampToDouble(t)) + buildCast[Long](_, t => timestampToDouble(t)) case x: NumericType => b => x.numeric.asInstanceOf[Numeric[Any]].toDouble(b) } @@ -375,7 +370,7 @@ case class Cast(child: Expression, dataType: DataType) extends UnaryExpression w case DateType => buildCast[Int](_, d => null) case TimestampType => - buildCast[Timestamp](_, t => timestampToDouble(t).toFloat) + buildCast[Long](_, t => timestampToDouble(t).toFloat) case x: NumericType => b => x.numeric.asInstanceOf[Numeric[Any]].toFloat(b) } @@ -398,8 +393,8 @@ case class Cast(child: Expression, dataType: DataType) extends UnaryExpression w case (fromField, toField) => cast(fromField.dataType, toField.dataType) } // TODO: Could be faster? - val newRow = new GenericMutableRow(from.fields.size) - buildCast[Row](_, row => { + val newRow = new GenericMutableRow(from.fields.length) + buildCast[InternalRow](_, row => { var i = 0 while (i < row.length) { val v = row(i) @@ -412,43 +407,72 @@ case class Cast(child: Expression, dataType: DataType) extends UnaryExpression w private[this] def cast(from: DataType, to: DataType): Any => Any = to match { case dt if dt == child.dataType => identity[Any] - case StringType => castToString(from) - case BinaryType => castToBinary(from) - case DateType => castToDate(from) - case decimal: DecimalType => castToDecimal(from, decimal) - case TimestampType => castToTimestamp(from) - case BooleanType => castToBoolean(from) - case ByteType => castToByte(from) - case ShortType => castToShort(from) - case IntegerType => castToInt(from) - case FloatType => castToFloat(from) - case LongType => castToLong(from) - case DoubleType => castToDouble(from) - case array: ArrayType => castArray(from.asInstanceOf[ArrayType], array) - case map: MapType => castMap(from.asInstanceOf[MapType], map) - case struct: StructType => castStruct(from.asInstanceOf[StructType], struct) + case StringType => castToString(from) + case BinaryType => castToBinary(from) + case DateType => castToDate(from) + case decimal: DecimalType => castToDecimal(from, decimal) + case TimestampType => castToTimestamp(from) + case BooleanType => castToBoolean(from) + case ByteType => castToByte(from) + case ShortType => castToShort(from) + case IntegerType => castToInt(from) + case FloatType => castToFloat(from) + case LongType => castToLong(from) + case DoubleType => castToDouble(from) + case array: ArrayType => castArray(from.asInstanceOf[ArrayType], array) + case map: MapType => castMap(from.asInstanceOf[MapType], map) + case struct: StructType => castStruct(from.asInstanceOf[StructType], struct) } private[this] lazy val cast: Any => Any = cast(child.dataType, dataType) - override def eval(input: Row): Any = { + override def eval(input: InternalRow): Any = { val evaluated = child.eval(input) if (evaluated == null) null else cast(evaluated) } -} -object Cast { - // `SimpleDateFormat` is not thread-safe. - private[sql] val threadLocalTimestampFormat = new ThreadLocal[DateFormat] { - override def initialValue(): SimpleDateFormat = { - new SimpleDateFormat("yyyy-MM-dd HH:mm:ss") - } - } + override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { + // TODO: Add support for more data types. + (child.dataType, dataType) match { + + case (BinaryType, StringType) => + defineCodeGen (ctx, ev, c => + s"${ctx.stringType}.fromBytes($c)") + + case (DateType, StringType) => + defineCodeGen(ctx, ev, c => + s"""${ctx.stringType}.fromString( + org.apache.spark.sql.catalyst.util.DateTimeUtils.dateToString($c))""") + + case (TimestampType, StringType) => + defineCodeGen(ctx, ev, c => + s"""${ctx.stringType}.fromString( + org.apache.spark.sql.catalyst.util.DateTimeUtils.timestampToString($c))""") + + case (_, StringType) => + defineCodeGen(ctx, ev, c => s"${ctx.stringType}.fromString(String.valueOf($c))") + + // fallback for DecimalType, this must be before other numeric types + case (_, dt: DecimalType) => + super.genCode(ctx, ev) + + case (BooleanType, dt: NumericType) => + defineCodeGen(ctx, ev, c => s"(${ctx.javaType(dt)})($c ? 1 : 0)") + + case (dt: DecimalType, BooleanType) => + defineCodeGen(ctx, ev, c => s"!$c.isZero()") + + case (dt: NumericType, BooleanType) => + defineCodeGen(ctx, ev, c => s"$c != 0") + + case (_: DecimalType, dt: NumericType) => + defineCodeGen(ctx, ev, c => s"($c).to${ctx.primitiveTypeName(dt)}()") + + case (_: NumericType, dt: NumericType) => + defineCodeGen(ctx, ev, c => s"(${ctx.javaType(dt)})($c)") - // `SimpleDateFormat` is not thread-safe. - private[sql] val threadLocalDateFormat = new ThreadLocal[DateFormat] { - override def initialValue(): SimpleDateFormat = { - new SimpleDateFormat("yyyy-MM-dd") + case other => + super.genCode(ctx, ev) } } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ExpectsInputTypes.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ExpectsInputTypes.scala new file mode 100644 index 000000000000..916e30154d4f --- /dev/null +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ExpectsInputTypes.scala @@ -0,0 +1,43 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.expressions + +import org.apache.spark.sql.catalyst.analysis.TypeCheckResult +import org.apache.spark.sql.types.AbstractDataType + + +/** + * An trait that gets mixin to define the expected input types of an expression. + */ +trait ExpectsInputTypes { self: Expression => + + /** + * Expected input types from child expressions. The i-th position in the returned seq indicates + * the type requirement for the i-th child. + * + * The possible values at each position are: + * 1. a specific data type, e.g. LongType, StringType. + * 2. a non-leaf abstract data type, e.g. NumericType, IntegralType, FractionalType. + */ + def inputTypes: Seq[AbstractDataType] + + override def checkInputDataTypes(): TypeCheckResult = { + // TODO: implement proper type checking. + TypeCheckResult.TypeCheckSuccess + } +} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala index 0837a3179d89..cafbbafdca20 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala @@ -17,17 +17,23 @@ package org.apache.spark.sql.catalyst.expressions -import org.apache.spark.sql.catalyst.analysis.UnresolvedAttribute +import org.apache.spark.sql.catalyst.analysis.{TypeCheckResult, UnresolvedAttribute} +import org.apache.spark.sql.catalyst.expressions.codegen.{CodeGenContext, GeneratedExpressionCode} import org.apache.spark.sql.catalyst.trees import org.apache.spark.sql.catalyst.trees.TreeNode import org.apache.spark.sql.types._ + +/** + * If an expression wants to be exposed in the function registry (so users can call it with + * "name(arguments...)", the concrete implementation must be a case class whose constructor + * arguments are all Expressions types. + * + * See [[Substring]] for an example. + */ abstract class Expression extends TreeNode[Expression] { self: Product => - /** The narrowest possible type that is produced when this expression is evaluated. */ - type EvaluatedType <: Any - /** * Returns true when an expression is a candidate for static evaluation before the query is * executed. @@ -40,19 +46,66 @@ abstract class Expression extends TreeNode[Expression] { * - A [[Cast]] or [[UnaryMinus]] is foldable if its child is foldable */ def foldable: Boolean = false + + /** + * Returns true when the current expression always return the same result for fixed input values. + */ + // TODO: Need to define explicit input values vs implicit input values. + def deterministic: Boolean = true + def nullable: Boolean + def references: AttributeSet = AttributeSet(children.flatMap(_.references.iterator)) /** Returns the result of evaluating this expression on a given input Row */ - def eval(input: Row = null): EvaluatedType + def eval(input: InternalRow = null): Any + + /** + * Returns an [[GeneratedExpressionCode]], which contains Java source code that + * can be used to generate the result of evaluating the expression on an input row. + * + * @param ctx a [[CodeGenContext]] + * @return [[GeneratedExpressionCode]] + */ + def gen(ctx: CodeGenContext): GeneratedExpressionCode = { + val isNull = ctx.freshName("isNull") + val primitive = ctx.freshName("primitive") + val ve = GeneratedExpressionCode("", isNull, primitive) + ve.code = genCode(ctx, ve) + ve + } + + /** + * Returns Java source code that can be compiled to evaluate this expression. + * The default behavior is to call the eval method of the expression. Concrete expression + * implementations should override this to do actual code generation. + * + * @param ctx a [[CodeGenContext]] + * @param ev an [[GeneratedExpressionCode]] with unique terms. + * @return Java source code + */ + protected def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { + ctx.references += this + val objectTerm = ctx.freshName("obj") + s""" + /* expression: ${this} */ + Object $objectTerm = expressions[${ctx.references.size - 1}].eval(i); + boolean ${ev.isNull} = $objectTerm == null; + ${ctx.javaType(this.dataType)} ${ev.primitive} = ${ctx.defaultValue(this.dataType)}; + if (!${ev.isNull}) { + ${ev.primitive} = (${ctx.boxedType(this.dataType)}) $objectTerm; + } + """ + } /** * Returns `true` if this expression and all its children have been resolved to a specific schema - * and `false` if it still contains any unresolved placeholders. Implementations of expressions - * should override this if the resolution of this type of expression involves more than just - * the resolution of its children. + * and input data types checking passed, and `false` if it still contains any unresolved + * placeholders or has data types mismatch. + * Implementations of expressions should override this if the resolution of this type of + * expression involves more than just the resolution of its children and type checking. */ - lazy val resolved: Boolean = childrenResolved + lazy val resolved: Boolean = childrenResolved && checkInputDataTypes().isSuccess /** * Returns the [[DataType]] of the result of evaluating this expression. It is @@ -67,8 +120,39 @@ abstract class Expression extends TreeNode[Expression] { def childrenResolved: Boolean = children.forall(_.resolved) /** - * Returns a string representation of this expression that does not have developer centric - * debugging information like the expression id. + * Returns true when two expressions will always compute the same result, even if they differ + * cosmetically (i.e. capitalization of names in attributes may be different). + */ + def semanticEquals(other: Expression): Boolean = this.getClass == other.getClass && { + def checkSemantic(elements1: Seq[Any], elements2: Seq[Any]): Boolean = { + elements1.length == elements2.length && elements1.zip(elements2).forall { + case (e1: Expression, e2: Expression) => e1 semanticEquals e2 + case (Some(e1: Expression), Some(e2: Expression)) => e1 semanticEquals e2 + case (t1: Traversable[_], t2: Traversable[_]) => checkSemantic(t1.toSeq, t2.toSeq) + case (i1, i2) => i1 == i2 + } + } + val elements1 = this.productIterator.toSeq + val elements2 = other.asInstanceOf[Product].productIterator.toSeq + checkSemantic(elements1, elements2) + } + + /** + * Checks the input data types, returns `TypeCheckResult.success` if it's valid, + * or returns a `TypeCheckResult` with an error message if invalid. + * Note: it's not valid to call this method until `childrenResolved == true`. + */ + def checkInputDataTypes(): TypeCheckResult = TypeCheckResult.TypeCheckSuccess + + /** + * Returns a user-facing string representation of this expression's name. + * This should usually match the name of the function in SQL. + */ + def prettyName: String = getClass.getSimpleName.toLowerCase + + /** + * Returns a user-facing string representation of this expression, i.e. does not have developer + * centric debugging information like the expression id. */ def prettyString: String = { transform { @@ -76,47 +160,139 @@ abstract class Expression extends TreeNode[Expression] { case u: UnresolvedAttribute => PrettyAttribute(u.name) }.toString } -} -abstract class BinaryExpression extends Expression with trees.BinaryNode[Expression] { - self: Product => - - def symbol: String - - override def foldable: Boolean = left.foldable && right.foldable - - override def nullable: Boolean = left.nullable || right.nullable - - override def toString: String = s"($left $symbol $right)" + override def toString: String = prettyName + children.mkString("(", ",", ")") } + +/** + * A leaf expression, i.e. one without any child expressions. + */ abstract class LeafExpression extends Expression with trees.LeafNode[Expression] { self: Product => } + +/** + * An expression with one input and one output. The output is by default evaluated to null + * if the input is evaluated to null. + */ abstract class UnaryExpression extends Expression with trees.UnaryNode[Expression] { self: Product => + + override def foldable: Boolean = child.foldable + override def nullable: Boolean = child.nullable + + /** + * Called by unary expressions to generate a code block that returns null if its parent returns + * null, and if not not null, use `f` to generate the expression. + * + * As an example, the following does a boolean inversion (i.e. NOT). + * {{{ + * defineCodeGen(ctx, ev, c => s"!($c)") + * }}} + * + * @param f function that accepts a variable name and returns Java code to compute the output. + */ + protected def defineCodeGen( + ctx: CodeGenContext, + ev: GeneratedExpressionCode, + f: String => String): String = { + nullSafeCodeGen(ctx, ev, (result, eval) => { + s"$result = ${f(eval)};" + }) + } + + /** + * Called by unary expressions to generate a code block that returns null if its parent returns + * null, and if not not null, use `f` to generate the expression. + */ + protected def nullSafeCodeGen( + ctx: CodeGenContext, + ev: GeneratedExpressionCode, + f: (String, String) => String): String = { + val eval = child.gen(ctx) + val resultCode = f(ev.primitive, eval.primitive) + eval.code + s""" + boolean ${ev.isNull} = ${eval.isNull}; + ${ctx.javaType(dataType)} ${ev.primitive} = ${ctx.defaultValue(dataType)}; + if (!${ev.isNull}) { + $resultCode + } + """ + } } -// TODO Semantically we probably not need GroupExpression -// All we need is holding the Seq[Expression], and ONLY used in doing the -// expressions transformation correctly. Probably will be removed since it's -// not like a real expressions. -case class GroupExpression(children: Seq[Expression]) extends Expression { + +/** + * An expression with two inputs and one output. The output is by default evaluated to null + * if any input is evaluated to null. + */ +abstract class BinaryExpression extends Expression with trees.BinaryNode[Expression] { self: Product => - type EvaluatedType = Seq[Any] - override def eval(input: Row): EvaluatedType = throw new UnsupportedOperationException - override def nullable: Boolean = false - override def foldable: Boolean = false - override def dataType: DataType = throw new UnsupportedOperationException + + override def foldable: Boolean = left.foldable && right.foldable + + override def nullable: Boolean = left.nullable || right.nullable + + /** + * Short hand for generating binary evaluation code. + * If either of the sub-expressions is null, the result of this computation + * is assumed to be null. + * + * @param f accepts two variable names and returns Java code to compute the output. + */ + protected def defineCodeGen( + ctx: CodeGenContext, + ev: GeneratedExpressionCode, + f: (String, String) => String): String = { + nullSafeCodeGen(ctx, ev, (result, eval1, eval2) => { + s"$result = ${f(eval1, eval2)};" + }) + } + + /** + * Short hand for generating binary evaluation code. + * If either of the sub-expressions is null, the result of this computation + * is assumed to be null. + */ + protected def nullSafeCodeGen( + ctx: CodeGenContext, + ev: GeneratedExpressionCode, + f: (String, String, String) => String): String = { + val eval1 = left.gen(ctx) + val eval2 = right.gen(ctx) + val resultCode = f(ev.primitive, eval1.primitive, eval2.primitive) + s""" + ${eval1.code} + boolean ${ev.isNull} = ${eval1.isNull}; + ${ctx.javaType(dataType)} ${ev.primitive} = ${ctx.defaultValue(dataType)}; + if (!${ev.isNull}) { + ${eval2.code} + if (!${eval2.isNull}) { + $resultCode + } else { + ${ev.isNull} = true; + } + } + """ + } } + /** - * Expressions that require a specific `DataType` as input should implement this trait - * so that the proper type conversions can be performed in the analyzer. + * An expression that has two inputs that are expected to the be same type. If the two inputs have + * different types, the analyzer will find the tightest common type and do the proper type casting. */ -trait ExpectsInputTypes { +abstract class BinaryOperator extends BinaryExpression { + self: Product => + + def symbol: String + + override def toString: String = s"($left $symbol $right)" +} - def expectedChildTypes: Seq[DataType] +private[sql] object BinaryOperator { + def unapply(e: BinaryOperator): Option[(Expression, Expression)] = Some((e.left, e.right)) } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ExtractValue.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ExtractValue.scala index e05926cbfe74..3020e7fc967f 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ExtractValue.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ExtractValue.scala @@ -21,6 +21,7 @@ import scala.collection.Map import org.apache.spark.sql.AnalysisException import org.apache.spark.sql.catalyst.analysis._ +import org.apache.spark.sql.catalyst.expressions.codegen.{GeneratedExpressionCode, CodeGenContext} import org.apache.spark.sql.types._ object ExtractValue { @@ -38,19 +39,25 @@ object ExtractValue { def apply( child: Expression, extraction: Expression, - resolver: Resolver): ExtractValue = { + resolver: Resolver): Expression = { (child.dataType, extraction) match { - case (StructType(fields), Literal(fieldName, StringType)) => - val ordinal = findField(fields, fieldName.toString, resolver) - GetStructField(child, fields(ordinal), ordinal) - case (ArrayType(StructType(fields), containsNull), Literal(fieldName, StringType)) => - val ordinal = findField(fields, fieldName.toString, resolver) - GetArrayStructFields(child, fields(ordinal), ordinal, containsNull) - case (_: ArrayType, _) if extraction.dataType.isInstanceOf[IntegralType] => + case (StructType(fields), NonNullLiteral(v, StringType)) => + val fieldName = v.toString + val ordinal = findField(fields, fieldName, resolver) + GetStructField(child, fields(ordinal).copy(name = fieldName), ordinal) + + case (ArrayType(StructType(fields), containsNull), NonNullLiteral(v, StringType)) => + val fieldName = v.toString + val ordinal = findField(fields, fieldName, resolver) + GetArrayStructFields(child, fields(ordinal).copy(name = fieldName), ordinal, containsNull) + + case (_: ArrayType, _) if extraction.dataType.isInstanceOf[IntegralType] => GetArrayItem(child, extraction) + case (_: MapType, _) => GetMapValue(child, extraction) + case (otherType, _) => val errorMsg = otherType match { case StructType(_) | ArrayType(StructType(_), _) => @@ -67,7 +74,7 @@ object ExtractValue { def unapply(g: ExtractValue): Option[(Expression, Expression)] = { g match { case o: ExtractValueWithOrdinal => Some((o.child, o.ordinal)) - case _ => Some((g.child, null)) + case s: ExtractValueWithStruct => Some((s.child, null)) } } @@ -90,27 +97,47 @@ object ExtractValue { } } -trait ExtractValue extends UnaryExpression { +/** + * A common interface of all kinds of extract value expressions. + * Note: concrete extract value expressions are created only by `ExtractValue.apply`, + * we don't need to do type check for them. + */ +trait ExtractValue { + self: Expression => +} + +abstract class ExtractValueWithStruct extends UnaryExpression with ExtractValue { self: Product => - type EvaluatedType = Any + def field: StructField + override def toString: String = s"$child.${field.name}" } /** * Returns the value of fields in the Struct `child`. */ case class GetStructField(child: Expression, field: StructField, ordinal: Int) - extends ExtractValue { + extends ExtractValueWithStruct { override def dataType: DataType = field.dataType override def nullable: Boolean = child.nullable || field.nullable - override def foldable: Boolean = child.foldable - override def toString: String = s"$child.${field.name}" - override def eval(input: Row): Any = { - val baseValue = child.eval(input).asInstanceOf[Row] + override def eval(input: InternalRow): Any = { + val baseValue = child.eval(input).asInstanceOf[InternalRow] if (baseValue == null) null else baseValue(ordinal) } + + override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { + nullSafeCodeGen(ctx, ev, (result, eval) => { + s""" + if ($eval.isNullAt($ordinal)) { + ${ev.isNull} = true; + } else { + $result = ${ctx.getColumn(eval, dataType, ordinal)}; + } + """ + }) + } } /** @@ -120,35 +147,54 @@ case class GetArrayStructFields( child: Expression, field: StructField, ordinal: Int, - containsNull: Boolean) extends ExtractValue { + containsNull: Boolean) extends ExtractValueWithStruct { override def dataType: DataType = ArrayType(field.dataType, containsNull) - override def nullable: Boolean = child.nullable - override def foldable: Boolean = child.foldable - override def toString: String = s"$child.${field.name}" + override def nullable: Boolean = child.nullable || containsNull || field.nullable - override def eval(input: Row): Any = { - val baseValue = child.eval(input).asInstanceOf[Seq[Row]] + override def eval(input: InternalRow): Any = { + val baseValue = child.eval(input).asInstanceOf[Seq[InternalRow]] if (baseValue == null) null else { baseValue.map { row => if (row == null) null else row(ordinal) } } } + + override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { + val arraySeqClass = "scala.collection.mutable.ArraySeq" + // TODO: consider using Array[_] for ArrayType child to avoid + // boxing of primitives + nullSafeCodeGen(ctx, ev, (result, eval) => { + s""" + final int n = $eval.size(); + final $arraySeqClass values = new $arraySeqClass(n); + for (int j = 0; j < n; j++) { + InternalRow row = (InternalRow) $eval.apply(j); + if (row != null && !row.isNullAt($ordinal)) { + values.update(j, ${ctx.getColumn("row", field.dataType, ordinal)}); + } + } + $result = (${ctx.javaType(dataType)}) values; + """ + }) + } } -abstract class ExtractValueWithOrdinal extends ExtractValue { +abstract class ExtractValueWithOrdinal extends BinaryExpression with ExtractValue { self: Product => def ordinal: Expression + def child: Expression + + override def left: Expression = child + override def right: Expression = ordinal /** `Null` is returned for invalid ordinals. */ override def nullable: Boolean = true - override def foldable: Boolean = child.foldable && ordinal.foldable override def toString: String = s"$child[$ordinal]" - override def children: Seq[Expression] = child :: ordinal :: Nil - override def eval(input: Row): Any = { + override def eval(input: InternalRow): Any = { val value = child.eval(input) if (value == null) { null @@ -173,20 +219,30 @@ case class GetArrayItem(child: Expression, ordinal: Expression) override def dataType: DataType = child.dataType.asInstanceOf[ArrayType].elementType - override lazy val resolved = childrenResolved && - child.dataType.isInstanceOf[ArrayType] && ordinal.dataType.isInstanceOf[IntegralType] - protected def evalNotNull(value: Any, ordinal: Any) = { // TODO: consider using Array[_] for ArrayType child to avoid // boxing of primitives val baseValue = value.asInstanceOf[Seq[_]] - val index = ordinal.asInstanceOf[Int] + val index = ordinal.asInstanceOf[Number].intValue() if (index >= baseValue.size || index < 0) { null } else { baseValue(index) } } + + override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { + nullSafeCodeGen(ctx, ev, (result, eval1, eval2) => { + s""" + final int index = (int)$eval2; + if (index >= $eval1.size() || index < 0) { + ${ev.isNull} = true; + } else { + $result = (${ctx.boxedType(dataType)})$eval1.apply(index); + } + """ + }) + } } /** @@ -197,10 +253,20 @@ case class GetMapValue(child: Expression, ordinal: Expression) override def dataType: DataType = child.dataType.asInstanceOf[MapType].valueType - override lazy val resolved = childrenResolved && child.dataType.isInstanceOf[MapType] - protected def evalNotNull(value: Any, ordinal: Any) = { val baseValue = value.asInstanceOf[Map[Any, _]] baseValue.get(ordinal).orNull } + + override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { + nullSafeCodeGen(ctx, ev, (result, eval1, eval2) => { + s""" + if ($eval1.contains($eval2)) { + $result = (${ctx.boxedType(dataType)})$eval1.apply($eval2); + } else { + ${ev.isNull} = true; + } + """ + }) + } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Projection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Projection.scala index 8cae548279eb..fcfe83ceb863 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Projection.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Projection.scala @@ -17,7 +17,6 @@ package org.apache.spark.sql.catalyst.expressions - /** * A [[Projection]] that is calculated by calling the `eval` of each of the specified expressions. * @param expressions a sequence of expressions that determine the value of each column of the @@ -30,14 +29,14 @@ class InterpretedProjection(expressions: Seq[Expression]) extends Projection { // null check is required for when Kryo invokes the no-arg constructor. protected val exprArray = if (expressions != null) expressions.toArray else null - def apply(input: Row): Row = { + def apply(input: InternalRow): InternalRow = { val outputArray = new Array[Any](exprArray.length) var i = 0 while (i < exprArray.length) { outputArray(i) = exprArray(i).eval(input) i += 1 } - new GenericRow(outputArray) + new GenericInternalRow(outputArray) } override def toString: String = s"Row => [${exprArray.mkString(",")}]" @@ -55,14 +54,14 @@ case class InterpretedMutableProjection(expressions: Seq[Expression]) extends Mu private[this] val exprArray = expressions.toArray private[this] var mutableRow: MutableRow = new GenericMutableRow(exprArray.size) - def currentValue: Row = mutableRow + def currentValue: InternalRow = mutableRow override def target(row: MutableRow): MutableProjection = { mutableRow = row this } - override def apply(input: Row): Row = { + override def apply(input: InternalRow): InternalRow = { var i = 0 while (i < exprArray.length) { mutableRow(i) = exprArray(i).eval(input) @@ -76,31 +75,31 @@ case class InterpretedMutableProjection(expressions: Seq[Expression]) extends Mu * A mutable wrapper that makes two rows appear as a single concatenated row. Designed to * be instantiated once per thread and reused. */ -class JoinedRow extends Row { - private[this] var row1: Row = _ - private[this] var row2: Row = _ +class JoinedRow extends InternalRow { + private[this] var row1: InternalRow = _ + private[this] var row2: InternalRow = _ - def this(left: Row, right: Row) = { + def this(left: InternalRow, right: InternalRow) = { this() row1 = left row2 = right } /** Updates this JoinedRow to used point at two new base rows. Returns itself. */ - def apply(r1: Row, r2: Row): Row = { + def apply(r1: InternalRow, r2: InternalRow): InternalRow = { row1 = r1 row2 = r2 this } /** Updates this JoinedRow by updating its left base row. Returns itself. */ - def withLeft(newLeft: Row): Row = { + def withLeft(newLeft: InternalRow): InternalRow = { row1 = newLeft this } /** Updates this JoinedRow by updating its right base row. Returns itself. */ - def withRight(newRight: Row): Row = { + def withRight(newRight: InternalRow): InternalRow = { row2 = newRight this } @@ -136,13 +135,7 @@ class JoinedRow extends Row { override def getFloat(i: Int): Float = if (i < row1.length) row1.getFloat(i) else row2.getFloat(i - row1.length) - override def getString(i: Int): String = - if (i < row1.length) row1.getString(i) else row2.getString(i - row1.length) - - override def getAs[T](i: Int): T = - if (i < row1.length) row1.getAs[T](i) else row2.getAs[T](i - row1.length) - - override def copy(): Row = { + override def copy(): InternalRow = { val totalSize = row1.length + row2.length val copiedValues = new Array[Any](totalSize) var i = 0 @@ -150,7 +143,7 @@ class JoinedRow extends Row { copiedValues(i) = apply(i) i += 1 } - new GenericRow(copiedValues) + new GenericInternalRow(copiedValues) } override def toString: String = { @@ -176,31 +169,31 @@ class JoinedRow extends Row { * Row will be referenced, increasing the opportunity for the JIT to play tricks. This sounds * crazy but in benchmarks it had noticeable effects. */ -class JoinedRow2 extends Row { - private[this] var row1: Row = _ - private[this] var row2: Row = _ +class JoinedRow2 extends InternalRow { + private[this] var row1: InternalRow = _ + private[this] var row2: InternalRow = _ - def this(left: Row, right: Row) = { + def this(left: InternalRow, right: InternalRow) = { this() row1 = left row2 = right } /** Updates this JoinedRow to used point at two new base rows. Returns itself. */ - def apply(r1: Row, r2: Row): Row = { + def apply(r1: InternalRow, r2: InternalRow): InternalRow = { row1 = r1 row2 = r2 this } /** Updates this JoinedRow by updating its left base row. Returns itself. */ - def withLeft(newLeft: Row): Row = { + def withLeft(newLeft: InternalRow): InternalRow = { row1 = newLeft this } /** Updates this JoinedRow by updating its right base row. Returns itself. */ - def withRight(newRight: Row): Row = { + def withRight(newRight: InternalRow): InternalRow = { row2 = newRight this } @@ -236,13 +229,7 @@ class JoinedRow2 extends Row { override def getFloat(i: Int): Float = if (i < row1.length) row1.getFloat(i) else row2.getFloat(i - row1.length) - override def getString(i: Int): String = - if (i < row1.length) row1.getString(i) else row2.getString(i - row1.length) - - override def getAs[T](i: Int): T = - if (i < row1.length) row1.getAs[T](i) else row2.getAs[T](i - row1.length) - - override def copy(): Row = { + override def copy(): InternalRow = { val totalSize = row1.length + row2.length val copiedValues = new Array[Any](totalSize) var i = 0 @@ -250,7 +237,7 @@ class JoinedRow2 extends Row { copiedValues(i) = apply(i) i += 1 } - new GenericRow(copiedValues) + new GenericInternalRow(copiedValues) } override def toString: String = { @@ -270,31 +257,31 @@ class JoinedRow2 extends Row { /** * JIT HACK: Replace with macros */ -class JoinedRow3 extends Row { - private[this] var row1: Row = _ - private[this] var row2: Row = _ +class JoinedRow3 extends InternalRow { + private[this] var row1: InternalRow = _ + private[this] var row2: InternalRow = _ - def this(left: Row, right: Row) = { + def this(left: InternalRow, right: InternalRow) = { this() row1 = left row2 = right } /** Updates this JoinedRow to used point at two new base rows. Returns itself. */ - def apply(r1: Row, r2: Row): Row = { + def apply(r1: InternalRow, r2: InternalRow): InternalRow = { row1 = r1 row2 = r2 this } /** Updates this JoinedRow by updating its left base row. Returns itself. */ - def withLeft(newLeft: Row): Row = { + def withLeft(newLeft: InternalRow): InternalRow = { row1 = newLeft this } /** Updates this JoinedRow by updating its right base row. Returns itself. */ - def withRight(newRight: Row): Row = { + def withRight(newRight: InternalRow): InternalRow = { row2 = newRight this } @@ -330,13 +317,7 @@ class JoinedRow3 extends Row { override def getFloat(i: Int): Float = if (i < row1.length) row1.getFloat(i) else row2.getFloat(i - row1.length) - override def getString(i: Int): String = - if (i < row1.length) row1.getString(i) else row2.getString(i - row1.length) - - override def getAs[T](i: Int): T = - if (i < row1.length) row1.getAs[T](i) else row2.getAs[T](i - row1.length) - - override def copy(): Row = { + override def copy(): InternalRow = { val totalSize = row1.length + row2.length val copiedValues = new Array[Any](totalSize) var i = 0 @@ -344,7 +325,7 @@ class JoinedRow3 extends Row { copiedValues(i) = apply(i) i += 1 } - new GenericRow(copiedValues) + new GenericInternalRow(copiedValues) } override def toString: String = { @@ -364,31 +345,31 @@ class JoinedRow3 extends Row { /** * JIT HACK: Replace with macros */ -class JoinedRow4 extends Row { - private[this] var row1: Row = _ - private[this] var row2: Row = _ +class JoinedRow4 extends InternalRow { + private[this] var row1: InternalRow = _ + private[this] var row2: InternalRow = _ - def this(left: Row, right: Row) = { + def this(left: InternalRow, right: InternalRow) = { this() row1 = left row2 = right } /** Updates this JoinedRow to used point at two new base rows. Returns itself. */ - def apply(r1: Row, r2: Row): Row = { + def apply(r1: InternalRow, r2: InternalRow): InternalRow = { row1 = r1 row2 = r2 this } /** Updates this JoinedRow by updating its left base row. Returns itself. */ - def withLeft(newLeft: Row): Row = { + def withLeft(newLeft: InternalRow): InternalRow = { row1 = newLeft this } /** Updates this JoinedRow by updating its right base row. Returns itself. */ - def withRight(newRight: Row): Row = { + def withRight(newRight: InternalRow): InternalRow = { row2 = newRight this } @@ -424,13 +405,7 @@ class JoinedRow4 extends Row { override def getFloat(i: Int): Float = if (i < row1.length) row1.getFloat(i) else row2.getFloat(i - row1.length) - override def getString(i: Int): String = - if (i < row1.length) row1.getString(i) else row2.getString(i - row1.length) - - override def getAs[T](i: Int): T = - if (i < row1.length) row1.getAs[T](i) else row2.getAs[T](i - row1.length) - - override def copy(): Row = { + override def copy(): InternalRow = { val totalSize = row1.length + row2.length val copiedValues = new Array[Any](totalSize) var i = 0 @@ -438,7 +413,7 @@ class JoinedRow4 extends Row { copiedValues(i) = apply(i) i += 1 } - new GenericRow(copiedValues) + new GenericInternalRow(copiedValues) } override def toString: String = { @@ -458,31 +433,31 @@ class JoinedRow4 extends Row { /** * JIT HACK: Replace with macros */ -class JoinedRow5 extends Row { - private[this] var row1: Row = _ - private[this] var row2: Row = _ +class JoinedRow5 extends InternalRow { + private[this] var row1: InternalRow = _ + private[this] var row2: InternalRow = _ - def this(left: Row, right: Row) = { + def this(left: InternalRow, right: InternalRow) = { this() row1 = left row2 = right } /** Updates this JoinedRow to used point at two new base rows. Returns itself. */ - def apply(r1: Row, r2: Row): Row = { + def apply(r1: InternalRow, r2: InternalRow): InternalRow = { row1 = r1 row2 = r2 this } /** Updates this JoinedRow by updating its left base row. Returns itself. */ - def withLeft(newLeft: Row): Row = { + def withLeft(newLeft: InternalRow): InternalRow = { row1 = newLeft this } /** Updates this JoinedRow by updating its right base row. Returns itself. */ - def withRight(newRight: Row): Row = { + def withRight(newRight: InternalRow): InternalRow = { row2 = newRight this } @@ -518,13 +493,7 @@ class JoinedRow5 extends Row { override def getFloat(i: Int): Float = if (i < row1.length) row1.getFloat(i) else row2.getFloat(i - row1.length) - override def getString(i: Int): String = - if (i < row1.length) row1.getString(i) else row2.getString(i - row1.length) - - override def getAs[T](i: Int): T = - if (i < row1.length) row1.getAs[T](i) else row2.getAs[T](i - row1.length) - - override def copy(): Row = { + override def copy(): InternalRow = { val totalSize = row1.length + row2.length val copiedValues = new Array[Any](totalSize) var i = 0 @@ -532,7 +501,7 @@ class JoinedRow5 extends Row { copiedValues(i) = apply(i) i += 1 } - new GenericRow(copiedValues) + new GenericInternalRow(copiedValues) } override def toString: String = { @@ -552,31 +521,31 @@ class JoinedRow5 extends Row { /** * JIT HACK: Replace with macros */ -class JoinedRow6 extends Row { - private[this] var row1: Row = _ - private[this] var row2: Row = _ +class JoinedRow6 extends InternalRow { + private[this] var row1: InternalRow = _ + private[this] var row2: InternalRow = _ - def this(left: Row, right: Row) = { + def this(left: InternalRow, right: InternalRow) = { this() row1 = left row2 = right } /** Updates this JoinedRow to used point at two new base rows. Returns itself. */ - def apply(r1: Row, r2: Row): Row = { + def apply(r1: InternalRow, r2: InternalRow): InternalRow = { row1 = r1 row2 = r2 this } /** Updates this JoinedRow by updating its left base row. Returns itself. */ - def withLeft(newLeft: Row): Row = { + def withLeft(newLeft: InternalRow): InternalRow = { row1 = newLeft this } /** Updates this JoinedRow by updating its right base row. Returns itself. */ - def withRight(newRight: Row): Row = { + def withRight(newRight: InternalRow): InternalRow = { row2 = newRight this } @@ -612,13 +581,7 @@ class JoinedRow6 extends Row { override def getFloat(i: Int): Float = if (i < row1.length) row1.getFloat(i) else row2.getFloat(i - row1.length) - override def getString(i: Int): String = - if (i < row1.length) row1.getString(i) else row2.getString(i - row1.length) - - override def getAs[T](i: Int): T = - if (i < row1.length) row1.getAs[T](i) else row2.getAs[T](i - row1.length) - - override def copy(): Row = { + override def copy(): InternalRow = { val totalSize = row1.length + row2.length val copiedValues = new Array[Any](totalSize) var i = 0 @@ -626,7 +589,7 @@ class JoinedRow6 extends Row { copiedValues(i) = apply(i) i += 1 } - new GenericRow(copiedValues) + new GenericInternalRow(copiedValues) } override def toString: String = { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ScalaUdf.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ScalaUDF.scala similarity index 97% rename from sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ScalaUdf.scala rename to sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ScalaUDF.scala index 9a77ca624ebe..fc055c97a179 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ScalaUdf.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ScalaUDF.scala @@ -24,14 +24,15 @@ import org.apache.spark.sql.types.DataType * User-defined function. * @param dataType Return type of function. */ -case class ScalaUdf(function: AnyRef, dataType: DataType, children: Seq[Expression]) - extends Expression { - - type EvaluatedType = Any +case class ScalaUDF( + function: AnyRef, + dataType: DataType, + children: Seq[Expression], + inputTypes: Seq[DataType] = Nil) extends Expression with ExpectsInputTypes { override def nullable: Boolean = true - override def toString: String = s"scalaUDF(${children.mkString(",")})" + override def toString: String = s"UDF(${children.mkString(",")})" // scalastyle:off @@ -40,14 +41,14 @@ case class ScalaUdf(function: AnyRef, dataType: DataType, children: Seq[Expressi (1 to 22).map { x => val anys = (1 to x).map(x => "Any").reduce(_ + ", " + _) val childs = (0 to x - 1).map(x => s"val child$x = children($x)").reduce(_ + "\n " + _) - lazy val converters = (0 to x - 1).map(x => s"lazy val converter$x = CatalystTypeConverters.createToScalaConverter(child$x.dataType)").reduce(_ + "\n " + _) + val converters = (0 to x - 1).map(x => s"lazy val converter$x = CatalystTypeConverters.createToScalaConverter(child$x.dataType)").reduce(_ + "\n " + _) val evals = (0 to x - 1).map(x => s"converter$x(child$x.eval(input))").reduce(_ + ",\n " + _) s"""case $x => val func = function.asInstanceOf[($anys) => Any] $childs $converters - (input: Row) => { + (input: InternalRow) => { func( $evals) } @@ -55,11 +56,11 @@ case class ScalaUdf(function: AnyRef, dataType: DataType, children: Seq[Expressi }.foreach(println) */ - - val f = children.size match { - case 0 => + + private[this] val f = children.size match { + case 0 => val func = function.asInstanceOf[() => Any] - (input: Row) => { + (input: InternalRow) => { func() } @@ -67,7 +68,7 @@ case class ScalaUdf(function: AnyRef, dataType: DataType, children: Seq[Expressi val func = function.asInstanceOf[(Any) => Any] val child0 = children(0) lazy val converter0 = CatalystTypeConverters.createToScalaConverter(child0.dataType) - (input: Row) => { + (input: InternalRow) => { func( converter0(child0.eval(input))) } @@ -78,7 +79,7 @@ case class ScalaUdf(function: AnyRef, dataType: DataType, children: Seq[Expressi val child1 = children(1) lazy val converter0 = CatalystTypeConverters.createToScalaConverter(child0.dataType) lazy val converter1 = CatalystTypeConverters.createToScalaConverter(child1.dataType) - (input: Row) => { + (input: InternalRow) => { func( converter0(child0.eval(input)), converter1(child1.eval(input))) @@ -92,7 +93,7 @@ case class ScalaUdf(function: AnyRef, dataType: DataType, children: Seq[Expressi lazy val converter0 = CatalystTypeConverters.createToScalaConverter(child0.dataType) lazy val converter1 = CatalystTypeConverters.createToScalaConverter(child1.dataType) lazy val converter2 = CatalystTypeConverters.createToScalaConverter(child2.dataType) - (input: Row) => { + (input: InternalRow) => { func( converter0(child0.eval(input)), converter1(child1.eval(input)), @@ -109,7 +110,7 @@ case class ScalaUdf(function: AnyRef, dataType: DataType, children: Seq[Expressi lazy val converter1 = CatalystTypeConverters.createToScalaConverter(child1.dataType) lazy val converter2 = CatalystTypeConverters.createToScalaConverter(child2.dataType) lazy val converter3 = CatalystTypeConverters.createToScalaConverter(child3.dataType) - (input: Row) => { + (input: InternalRow) => { func( converter0(child0.eval(input)), converter1(child1.eval(input)), @@ -129,7 +130,7 @@ case class ScalaUdf(function: AnyRef, dataType: DataType, children: Seq[Expressi lazy val converter2 = CatalystTypeConverters.createToScalaConverter(child2.dataType) lazy val converter3 = CatalystTypeConverters.createToScalaConverter(child3.dataType) lazy val converter4 = CatalystTypeConverters.createToScalaConverter(child4.dataType) - (input: Row) => { + (input: InternalRow) => { func( converter0(child0.eval(input)), converter1(child1.eval(input)), @@ -152,7 +153,7 @@ case class ScalaUdf(function: AnyRef, dataType: DataType, children: Seq[Expressi lazy val converter3 = CatalystTypeConverters.createToScalaConverter(child3.dataType) lazy val converter4 = CatalystTypeConverters.createToScalaConverter(child4.dataType) lazy val converter5 = CatalystTypeConverters.createToScalaConverter(child5.dataType) - (input: Row) => { + (input: InternalRow) => { func( converter0(child0.eval(input)), converter1(child1.eval(input)), @@ -178,7 +179,7 @@ case class ScalaUdf(function: AnyRef, dataType: DataType, children: Seq[Expressi lazy val converter4 = CatalystTypeConverters.createToScalaConverter(child4.dataType) lazy val converter5 = CatalystTypeConverters.createToScalaConverter(child5.dataType) lazy val converter6 = CatalystTypeConverters.createToScalaConverter(child6.dataType) - (input: Row) => { + (input: InternalRow) => { func( converter0(child0.eval(input)), converter1(child1.eval(input)), @@ -207,7 +208,7 @@ case class ScalaUdf(function: AnyRef, dataType: DataType, children: Seq[Expressi lazy val converter5 = CatalystTypeConverters.createToScalaConverter(child5.dataType) lazy val converter6 = CatalystTypeConverters.createToScalaConverter(child6.dataType) lazy val converter7 = CatalystTypeConverters.createToScalaConverter(child7.dataType) - (input: Row) => { + (input: InternalRow) => { func( converter0(child0.eval(input)), converter1(child1.eval(input)), @@ -239,7 +240,7 @@ case class ScalaUdf(function: AnyRef, dataType: DataType, children: Seq[Expressi lazy val converter6 = CatalystTypeConverters.createToScalaConverter(child6.dataType) lazy val converter7 = CatalystTypeConverters.createToScalaConverter(child7.dataType) lazy val converter8 = CatalystTypeConverters.createToScalaConverter(child8.dataType) - (input: Row) => { + (input: InternalRow) => { func( converter0(child0.eval(input)), converter1(child1.eval(input)), @@ -274,7 +275,7 @@ case class ScalaUdf(function: AnyRef, dataType: DataType, children: Seq[Expressi lazy val converter7 = CatalystTypeConverters.createToScalaConverter(child7.dataType) lazy val converter8 = CatalystTypeConverters.createToScalaConverter(child8.dataType) lazy val converter9 = CatalystTypeConverters.createToScalaConverter(child9.dataType) - (input: Row) => { + (input: InternalRow) => { func( converter0(child0.eval(input)), converter1(child1.eval(input)), @@ -312,7 +313,7 @@ case class ScalaUdf(function: AnyRef, dataType: DataType, children: Seq[Expressi lazy val converter8 = CatalystTypeConverters.createToScalaConverter(child8.dataType) lazy val converter9 = CatalystTypeConverters.createToScalaConverter(child9.dataType) lazy val converter10 = CatalystTypeConverters.createToScalaConverter(child10.dataType) - (input: Row) => { + (input: InternalRow) => { func( converter0(child0.eval(input)), converter1(child1.eval(input)), @@ -353,7 +354,7 @@ case class ScalaUdf(function: AnyRef, dataType: DataType, children: Seq[Expressi lazy val converter9 = CatalystTypeConverters.createToScalaConverter(child9.dataType) lazy val converter10 = CatalystTypeConverters.createToScalaConverter(child10.dataType) lazy val converter11 = CatalystTypeConverters.createToScalaConverter(child11.dataType) - (input: Row) => { + (input: InternalRow) => { func( converter0(child0.eval(input)), converter1(child1.eval(input)), @@ -397,7 +398,7 @@ case class ScalaUdf(function: AnyRef, dataType: DataType, children: Seq[Expressi lazy val converter10 = CatalystTypeConverters.createToScalaConverter(child10.dataType) lazy val converter11 = CatalystTypeConverters.createToScalaConverter(child11.dataType) lazy val converter12 = CatalystTypeConverters.createToScalaConverter(child12.dataType) - (input: Row) => { + (input: InternalRow) => { func( converter0(child0.eval(input)), converter1(child1.eval(input)), @@ -444,7 +445,7 @@ case class ScalaUdf(function: AnyRef, dataType: DataType, children: Seq[Expressi lazy val converter11 = CatalystTypeConverters.createToScalaConverter(child11.dataType) lazy val converter12 = CatalystTypeConverters.createToScalaConverter(child12.dataType) lazy val converter13 = CatalystTypeConverters.createToScalaConverter(child13.dataType) - (input: Row) => { + (input: InternalRow) => { func( converter0(child0.eval(input)), converter1(child1.eval(input)), @@ -494,7 +495,7 @@ case class ScalaUdf(function: AnyRef, dataType: DataType, children: Seq[Expressi lazy val converter12 = CatalystTypeConverters.createToScalaConverter(child12.dataType) lazy val converter13 = CatalystTypeConverters.createToScalaConverter(child13.dataType) lazy val converter14 = CatalystTypeConverters.createToScalaConverter(child14.dataType) - (input: Row) => { + (input: InternalRow) => { func( converter0(child0.eval(input)), converter1(child1.eval(input)), @@ -547,7 +548,7 @@ case class ScalaUdf(function: AnyRef, dataType: DataType, children: Seq[Expressi lazy val converter13 = CatalystTypeConverters.createToScalaConverter(child13.dataType) lazy val converter14 = CatalystTypeConverters.createToScalaConverter(child14.dataType) lazy val converter15 = CatalystTypeConverters.createToScalaConverter(child15.dataType) - (input: Row) => { + (input: InternalRow) => { func( converter0(child0.eval(input)), converter1(child1.eval(input)), @@ -603,7 +604,7 @@ case class ScalaUdf(function: AnyRef, dataType: DataType, children: Seq[Expressi lazy val converter14 = CatalystTypeConverters.createToScalaConverter(child14.dataType) lazy val converter15 = CatalystTypeConverters.createToScalaConverter(child15.dataType) lazy val converter16 = CatalystTypeConverters.createToScalaConverter(child16.dataType) - (input: Row) => { + (input: InternalRow) => { func( converter0(child0.eval(input)), converter1(child1.eval(input)), @@ -662,7 +663,7 @@ case class ScalaUdf(function: AnyRef, dataType: DataType, children: Seq[Expressi lazy val converter15 = CatalystTypeConverters.createToScalaConverter(child15.dataType) lazy val converter16 = CatalystTypeConverters.createToScalaConverter(child16.dataType) lazy val converter17 = CatalystTypeConverters.createToScalaConverter(child17.dataType) - (input: Row) => { + (input: InternalRow) => { func( converter0(child0.eval(input)), converter1(child1.eval(input)), @@ -724,7 +725,7 @@ case class ScalaUdf(function: AnyRef, dataType: DataType, children: Seq[Expressi lazy val converter16 = CatalystTypeConverters.createToScalaConverter(child16.dataType) lazy val converter17 = CatalystTypeConverters.createToScalaConverter(child17.dataType) lazy val converter18 = CatalystTypeConverters.createToScalaConverter(child18.dataType) - (input: Row) => { + (input: InternalRow) => { func( converter0(child0.eval(input)), converter1(child1.eval(input)), @@ -789,7 +790,7 @@ case class ScalaUdf(function: AnyRef, dataType: DataType, children: Seq[Expressi lazy val converter17 = CatalystTypeConverters.createToScalaConverter(child17.dataType) lazy val converter18 = CatalystTypeConverters.createToScalaConverter(child18.dataType) lazy val converter19 = CatalystTypeConverters.createToScalaConverter(child19.dataType) - (input: Row) => { + (input: InternalRow) => { func( converter0(child0.eval(input)), converter1(child1.eval(input)), @@ -857,7 +858,7 @@ case class ScalaUdf(function: AnyRef, dataType: DataType, children: Seq[Expressi lazy val converter18 = CatalystTypeConverters.createToScalaConverter(child18.dataType) lazy val converter19 = CatalystTypeConverters.createToScalaConverter(child19.dataType) lazy val converter20 = CatalystTypeConverters.createToScalaConverter(child20.dataType) - (input: Row) => { + (input: InternalRow) => { func( converter0(child0.eval(input)), converter1(child1.eval(input)), @@ -928,7 +929,7 @@ case class ScalaUdf(function: AnyRef, dataType: DataType, children: Seq[Expressi lazy val converter19 = CatalystTypeConverters.createToScalaConverter(child19.dataType) lazy val converter20 = CatalystTypeConverters.createToScalaConverter(child20.dataType) lazy val converter21 = CatalystTypeConverters.createToScalaConverter(child21.dataType) - (input: Row) => { + (input: InternalRow) => { func( converter0(child0.eval(input)), converter1(child1.eval(input)), @@ -956,7 +957,6 @@ case class ScalaUdf(function: AnyRef, dataType: DataType, children: Seq[Expressi } // scalastyle:on - - override def eval(input: Row): Any = CatalystTypeConverters.convertToCatalyst(f(input), dataType) - + private[this] val converter = CatalystTypeConverters.createToCatalystConverter(dataType) + override def eval(input: InternalRow): Any = converter(f(input)) } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SortOrder.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SortOrder.scala index 83074eb1e631..4baae03b3a22 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SortOrder.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SortOrder.scala @@ -29,14 +29,14 @@ case object Descending extends SortDirection * An expression that can be used to sort a tuple. This class extends expression primarily so that * transformations over expression will descend into its child. */ -case class SortOrder(child: Expression, direction: SortDirection) extends Expression +case class SortOrder(child: Expression, direction: SortDirection) extends Expression with trees.UnaryNode[Expression] { override def dataType: DataType = child.dataType override def nullable: Boolean = child.nullable // SortOrder itself is never evaluated. - override def eval(input: Row = null): EvaluatedType = + override def eval(input: InternalRow = null): Any = throw new TreeNodeException(this, s"No function to evaluate expression. type: ${this.nodeName}") override def toString: String = s"$child ${if (direction == Ascending) "ASC" else "DESC"}" diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SpecificMutableRow.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SpecificMutableRow.scala index aa4099e4d7bf..3928c0f2ffda 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SpecificMutableRow.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SpecificMutableRow.scala @@ -18,6 +18,7 @@ package org.apache.spark.sql.catalyst.expressions import org.apache.spark.sql.types._ +import org.apache.spark.unsafe.types.UTF8String /** * A parent class for mutable container objects that are reused when the values are changed, @@ -195,14 +196,15 @@ final class SpecificMutableRow(val values: Array[MutableValue]) extends MutableR def this(dataTypes: Seq[DataType]) = this( dataTypes.map { - case IntegerType => new MutableInt + case BooleanType => new MutableBoolean case ByteType => new MutableByte - case FloatType => new MutableFloat case ShortType => new MutableShort + // We use INT for DATE internally + case IntegerType | DateType => new MutableInt + // We use Long for Timestamp internally + case LongType | TimestampType => new MutableLong + case FloatType => new MutableFloat case DoubleType => new MutableDouble - case BooleanType => new MutableBoolean - case LongType => new MutableLong - case DateType => new MutableInt // We use INT for DATE internally case _ => new MutableAny }.toArray) @@ -220,7 +222,7 @@ final class SpecificMutableRow(val values: Array[MutableValue]) extends MutableR override def isNullAt(i: Int): Boolean = values(i).isNull - override def copy(): Row = { + override def copy(): InternalRow = { val newValues = new Array[Any](values.length) var i = 0 while (i < values.length) { @@ -228,7 +230,7 @@ final class SpecificMutableRow(val values: Array[MutableValue]) extends MutableR i += 1 } - new GenericRow(newValues) + new GenericInternalRow(newValues) } override def update(ordinal: Int, value: Any) { @@ -239,7 +241,8 @@ final class SpecificMutableRow(val values: Array[MutableValue]) extends MutableR } } - override def setString(ordinal: Int, value: String): Unit = update(ordinal, UTF8String(value)) + override def setString(ordinal: Int, value: String): Unit = + update(ordinal, UTF8String.fromString(value)) override def getString(ordinal: Int): String = apply(ordinal).toString diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/UnsafeRowConverter.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/UnsafeRowConverter.scala index 5b2c8572784b..b11fc245c4af 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/UnsafeRowConverter.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/UnsafeRowConverter.scala @@ -17,9 +17,11 @@ package org.apache.spark.sql.catalyst.expressions +import org.apache.spark.sql.catalyst.util.ObjectPool import org.apache.spark.sql.types._ import org.apache.spark.unsafe.PlatformDependent import org.apache.spark.unsafe.array.ByteArrayMethods +import org.apache.spark.unsafe.types.UTF8String /** * Converts Rows into UnsafeRow format. This class is NOT thread-safe. @@ -32,6 +34,8 @@ class UnsafeRowConverter(fieldTypes: Array[DataType]) { this(schema.fields.map(_.dataType)) } + def numFields: Int = fieldTypes.length + /** Re-used pointer to the unsafe row being written */ private[this] val unsafeRow = new UnsafeRow() @@ -47,7 +51,7 @@ class UnsafeRowConverter(fieldTypes: Array[DataType]) { /** * Compute the amount of space, in bytes, required to encode the given row. */ - def getSizeRequirement(row: Row): Int = { + def getSizeRequirement(row: InternalRow): Int = { var fieldNumber = 0 var variableLengthFieldSize: Int = 0 while (fieldNumber < writers.length) { @@ -67,19 +71,32 @@ class UnsafeRowConverter(fieldTypes: Array[DataType]) { * @param baseOffset the base offset of the destination address * @return the number of bytes written. This should be equal to `getSizeRequirement(row)`. */ - def writeRow(row: Row, baseObject: Object, baseOffset: Long): Long = { - unsafeRow.pointTo(baseObject, baseOffset, writers.length, null) + def writeRow(row: InternalRow, baseObject: Object, baseOffset: Long, pool: ObjectPool): Int = { + unsafeRow.pointTo(baseObject, baseOffset, writers.length, pool) + + if (writers.length > 0) { + // zero-out the bitset + var n = writers.length / 64 + while (n >= 0) { + PlatformDependent.UNSAFE.putLong( + unsafeRow.getBaseObject, + unsafeRow.getBaseOffset + n * 8, + 0L) + n -= 1 + } + } + var fieldNumber = 0 - var appendCursor: Int = fixedLengthSize + var cursor: Int = fixedLengthSize while (fieldNumber < writers.length) { if (row.isNullAt(fieldNumber)) { unsafeRow.setNullAt(fieldNumber) } else { - appendCursor += writers(fieldNumber).write(row, unsafeRow, fieldNumber, appendCursor) + cursor += writers(fieldNumber).write(row, unsafeRow, fieldNumber, cursor) } fieldNumber += 1 } - appendCursor + cursor } } @@ -94,16 +111,16 @@ private abstract class UnsafeColumnWriter { * @param source the row being converted * @param target a pointer to the converted unsafe row * @param column the column to write - * @param appendCursor the offset from the start of the unsafe row to the end of the row; + * @param cursor the offset from the start of the unsafe row to the end of the row; * used for calculating where variable-length data should be written * @return the number of variable-length bytes written */ - def write(source: Row, target: UnsafeRow, column: Int, appendCursor: Int): Int + def write(source: InternalRow, target: UnsafeRow, column: Int, cursor: Int): Int /** * Return the number of bytes that are needed to write this variable-length value. */ - def getSize(source: Row, column: Int): Int + def getSize(source: InternalRow, column: Int): Int } private object UnsafeColumnWriter { @@ -114,13 +131,13 @@ private object UnsafeColumnWriter { case BooleanType => BooleanUnsafeColumnWriter case ByteType => ByteUnsafeColumnWriter case ShortType => ShortUnsafeColumnWriter - case IntegerType => IntUnsafeColumnWriter - case LongType => LongUnsafeColumnWriter + case IntegerType | DateType => IntUnsafeColumnWriter + case LongType | TimestampType => LongUnsafeColumnWriter case FloatType => FloatUnsafeColumnWriter case DoubleType => DoubleUnsafeColumnWriter case StringType => StringUnsafeColumnWriter - case t => - throw new UnsupportedOperationException(s"Do not know how to write columns of type $t") + case BinaryType => BinaryUnsafeColumnWriter + case t => ObjectUnsafeColumnWriter } } } @@ -136,88 +153,122 @@ private object LongUnsafeColumnWriter extends LongUnsafeColumnWriter private object FloatUnsafeColumnWriter extends FloatUnsafeColumnWriter private object DoubleUnsafeColumnWriter extends DoubleUnsafeColumnWriter private object StringUnsafeColumnWriter extends StringUnsafeColumnWriter +private object BinaryUnsafeColumnWriter extends BinaryUnsafeColumnWriter +private object ObjectUnsafeColumnWriter extends ObjectUnsafeColumnWriter private abstract class PrimitiveUnsafeColumnWriter extends UnsafeColumnWriter { // Primitives don't write to the variable-length region: - def getSize(sourceRow: Row, column: Int): Int = 0 + def getSize(sourceRow: InternalRow, column: Int): Int = 0 } private class NullUnsafeColumnWriter private() extends PrimitiveUnsafeColumnWriter { - override def write(source: Row, target: UnsafeRow, column: Int, appendCursor: Int): Int = { + override def write(source: InternalRow, target: UnsafeRow, column: Int, cursor: Int): Int = { target.setNullAt(column) 0 } } private class BooleanUnsafeColumnWriter private() extends PrimitiveUnsafeColumnWriter { - override def write(source: Row, target: UnsafeRow, column: Int, appendCursor: Int): Int = { + override def write(source: InternalRow, target: UnsafeRow, column: Int, cursor: Int): Int = { target.setBoolean(column, source.getBoolean(column)) 0 } } private class ByteUnsafeColumnWriter private() extends PrimitiveUnsafeColumnWriter { - override def write(source: Row, target: UnsafeRow, column: Int, appendCursor: Int): Int = { + override def write(source: InternalRow, target: UnsafeRow, column: Int, cursor: Int): Int = { target.setByte(column, source.getByte(column)) 0 } } private class ShortUnsafeColumnWriter private() extends PrimitiveUnsafeColumnWriter { - override def write(source: Row, target: UnsafeRow, column: Int, appendCursor: Int): Int = { + override def write(source: InternalRow, target: UnsafeRow, column: Int, cursor: Int): Int = { target.setShort(column, source.getShort(column)) 0 } } private class IntUnsafeColumnWriter private() extends PrimitiveUnsafeColumnWriter { - override def write(source: Row, target: UnsafeRow, column: Int, appendCursor: Int): Int = { + override def write(source: InternalRow, target: UnsafeRow, column: Int, cursor: Int): Int = { target.setInt(column, source.getInt(column)) 0 } } private class LongUnsafeColumnWriter private() extends PrimitiveUnsafeColumnWriter { - override def write(source: Row, target: UnsafeRow, column: Int, appendCursor: Int): Int = { + override def write(source: InternalRow, target: UnsafeRow, column: Int, cursor: Int): Int = { target.setLong(column, source.getLong(column)) 0 } } private class FloatUnsafeColumnWriter private() extends PrimitiveUnsafeColumnWriter { - override def write(source: Row, target: UnsafeRow, column: Int, appendCursor: Int): Int = { + override def write(source: InternalRow, target: UnsafeRow, column: Int, cursor: Int): Int = { target.setFloat(column, source.getFloat(column)) 0 } } private class DoubleUnsafeColumnWriter private() extends PrimitiveUnsafeColumnWriter { - override def write(source: Row, target: UnsafeRow, column: Int, appendCursor: Int): Int = { + override def write(source: InternalRow, target: UnsafeRow, column: Int, cursor: Int): Int = { target.setDouble(column, source.getDouble(column)) 0 } } -private class StringUnsafeColumnWriter private() extends UnsafeColumnWriter { - def getSize(source: Row, column: Int): Int = { - val numBytes = source.get(column).asInstanceOf[UTF8String].getBytes.length - 8 + ByteArrayMethods.roundNumberOfBytesToNearestWord(numBytes) +private abstract class BytesUnsafeColumnWriter extends UnsafeColumnWriter { + + def getBytes(source: InternalRow, column: Int): Array[Byte] + + def getSize(source: InternalRow, column: Int): Int = { + val numBytes = getBytes(source, column).length + ByteArrayMethods.roundNumberOfBytesToNearestWord(numBytes) } - override def write(source: Row, target: UnsafeRow, column: Int, appendCursor: Int): Int = { - val value = source.get(column).asInstanceOf[UTF8String] - val baseObject = target.getBaseObject - val baseOffset = target.getBaseOffset - val numBytes = value.getBytes.length - PlatformDependent.UNSAFE.putLong(baseObject, baseOffset + appendCursor, numBytes) + protected[this] def isString: Boolean + + override def write(source: InternalRow, target: UnsafeRow, column: Int, cursor: Int): Int = { + val offset = target.getBaseOffset + cursor + val bytes = getBytes(source, column) + val numBytes = bytes.length + if ((numBytes & 0x07) > 0) { + // zero-out the padding bytes + PlatformDependent.UNSAFE.putLong(target.getBaseObject, offset + ((numBytes >> 3) << 3), 0L) + } PlatformDependent.copyMemory( - value.getBytes, + bytes, PlatformDependent.BYTE_ARRAY_OFFSET, - baseObject, - baseOffset + appendCursor + 8, + target.getBaseObject, + offset, numBytes ) - target.setLong(column, appendCursor) - 8 + ByteArrayMethods.roundNumberOfBytesToNearestWord(numBytes) + val flag = if (isString) 1L << (UnsafeRow.OFFSET_BITS * 2) else 0 + target.setLong(column, flag | (cursor.toLong << UnsafeRow.OFFSET_BITS) | numBytes.toLong) + ByteArrayMethods.roundNumberOfBytesToNearestWord(numBytes) + } +} + +private class StringUnsafeColumnWriter private() extends BytesUnsafeColumnWriter { + protected[this] def isString: Boolean = true + def getBytes(source: InternalRow, column: Int): Array[Byte] = { + source.getAs[UTF8String](column).getBytes + } +} + +private class BinaryUnsafeColumnWriter private() extends BytesUnsafeColumnWriter { + protected[this] def isString: Boolean = false + def getBytes(source: InternalRow, column: Int): Array[Byte] = { + source.getAs[Array[Byte]](column) + } +} + +private class ObjectUnsafeColumnWriter private() extends UnsafeColumnWriter { + def getSize(sourceRow: InternalRow, column: Int): Int = 0 + override def write(source: InternalRow, target: UnsafeRow, column: Int, cursor: Int): Int = { + val obj = source.get(column) + val idx = target.getPool.put(obj) + target.setLong(column, - idx) + 0 } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregates.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregates.scala index f3830c6d3bcf..64e07bd2a17d 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregates.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregates.scala @@ -19,9 +19,11 @@ package org.apache.spark.sql.catalyst.expressions import com.clearspring.analytics.stream.cardinality.HyperLogLog -import org.apache.spark.sql.types._ import org.apache.spark.sql.catalyst.trees import org.apache.spark.sql.catalyst.errors.TreeNodeException +import org.apache.spark.sql.catalyst.analysis.TypeCheckResult +import org.apache.spark.sql.catalyst.util.TypeUtils +import org.apache.spark.sql.types._ import org.apache.spark.util.collection.OpenHashSet abstract class AggregateExpression extends Expression { @@ -37,7 +39,7 @@ abstract class AggregateExpression extends Expression { * [[AggregateExpression.eval]] should never be invoked because [[AggregateExpression]]'s are * replaced with a physical aggregate operator at runtime. */ - override def eval(input: Row = null): EvaluatedType = + override def eval(input: InternalRow = null): Any = throw new TreeNodeException(this, s"No function to evaluate expression. type: ${this.nodeName}") } @@ -74,15 +76,13 @@ abstract class AggregateFunction extends AggregateExpression with Serializable with trees.LeafNode[Expression] { self: Product => - override type EvaluatedType = Any - /** Base should return the generic aggregate expression that this function is computing */ val base: AggregateExpression override def nullable: Boolean = base.nullable override def dataType: DataType = base.dataType - def update(input: Row): Unit + def update(input: InternalRow): Unit // Do we really need this? override def newInstance(): AggregateFunction = { @@ -94,7 +94,6 @@ case class Min(child: Expression) extends PartialAggregate with trees.UnaryNode[ override def nullable: Boolean = true override def dataType: DataType = child.dataType - override def toString: String = s"MIN($child)" override def asPartial: SplitEvaluation = { val partialMin = Alias(Min(child), "PartialMin")() @@ -102,6 +101,9 @@ case class Min(child: Expression) extends PartialAggregate with trees.UnaryNode[ } override def newInstance(): MinFunction = new MinFunction(child, this) + + override def checkInputDataTypes(): TypeCheckResult = + TypeUtils.checkForOrderingExpr(child.dataType, "function min") } case class MinFunction(expr: Expression, base: AggregateExpression) extends AggregateFunction { @@ -110,22 +112,21 @@ case class MinFunction(expr: Expression, base: AggregateExpression) extends Aggr val currentMin: MutableLiteral = MutableLiteral(null, expr.dataType) val cmp = GreaterThan(currentMin, expr) - override def update(input: Row): Unit = { + override def update(input: InternalRow): Unit = { if (currentMin.value == null) { currentMin.value = expr.eval(input) - } else if(cmp.eval(input) == true) { + } else if (cmp.eval(input) == true) { currentMin.value = expr.eval(input) } } - override def eval(input: Row): Any = currentMin.value + override def eval(input: InternalRow): Any = currentMin.value } case class Max(child: Expression) extends PartialAggregate with trees.UnaryNode[Expression] { override def nullable: Boolean = true override def dataType: DataType = child.dataType - override def toString: String = s"MAX($child)" override def asPartial: SplitEvaluation = { val partialMax = Alias(Max(child), "PartialMax")() @@ -133,6 +134,9 @@ case class Max(child: Expression) extends PartialAggregate with trees.UnaryNode[ } override def newInstance(): MaxFunction = new MaxFunction(child, this) + + override def checkInputDataTypes(): TypeCheckResult = + TypeUtils.checkForOrderingExpr(child.dataType, "function max") } case class MaxFunction(expr: Expression, base: AggregateExpression) extends AggregateFunction { @@ -141,22 +145,21 @@ case class MaxFunction(expr: Expression, base: AggregateExpression) extends Aggr val currentMax: MutableLiteral = MutableLiteral(null, expr.dataType) val cmp = LessThan(currentMax, expr) - override def update(input: Row): Unit = { + override def update(input: InternalRow): Unit = { if (currentMax.value == null) { currentMax.value = expr.eval(input) - } else if(cmp.eval(input) == true) { + } else if (cmp.eval(input) == true) { currentMax.value = expr.eval(input) } } - override def eval(input: Row): Any = currentMax.value + override def eval(input: InternalRow): Any = currentMax.value } case class Count(child: Expression) extends PartialAggregate with trees.UnaryNode[Expression] { override def nullable: Boolean = false override def dataType: LongType.type = LongType - override def toString: String = s"COUNT($child)" override def asPartial: SplitEvaluation = { val partialCount = Alias(Count(child), "PartialCount")() @@ -166,6 +169,21 @@ case class Count(child: Expression) extends PartialAggregate with trees.UnaryNod override def newInstance(): CountFunction = new CountFunction(child, this) } +case class CountFunction(expr: Expression, base: AggregateExpression) extends AggregateFunction { + def this() = this(null, null) // Required for serialization. + + var count: Long = _ + + override def update(input: InternalRow): Unit = { + val evaluatedExpr = expr.eval(input) + if (evaluatedExpr != null) { + count += 1L + } + } + + override def eval(input: InternalRow): Any = count +} + case class CountDistinct(expressions: Seq[Expression]) extends PartialAggregate { def this() = this(null) @@ -184,6 +202,28 @@ case class CountDistinct(expressions: Seq[Expression]) extends PartialAggregate } } +case class CountDistinctFunction( + @transient expr: Seq[Expression], + @transient base: AggregateExpression) + extends AggregateFunction { + + def this() = this(null, null) // Required for serialization. + + val seen = new OpenHashSet[Any]() + + @transient + val distinctValue = new InterpretedProjection(expr) + + override def update(input: InternalRow): Unit = { + val evaluatedExpr = distinctValue(input) + if (!evaluatedExpr.anyNull) { + seen.add(evaluatedExpr) + } + } + + override def eval(input: InternalRow): Any = seen.size.toLong +} + case class CollectHashSet(expressions: Seq[Expression]) extends AggregateExpression { def this() = this(null) @@ -207,14 +247,14 @@ case class CollectHashSetFunction( @transient val distinctValue = new InterpretedProjection(expr) - override def update(input: Row): Unit = { + override def update(input: InternalRow): Unit = { val evaluatedExpr = distinctValue(input) if (!evaluatedExpr.anyNull) { seen.add(evaluatedExpr) } } - override def eval(input: Row): Any = { + override def eval(input: InternalRow): Any = { seen } } @@ -240,7 +280,7 @@ case class CombineSetsAndCountFunction( val seen = new OpenHashSet[Any]() - override def update(input: Row): Unit = { + override def update(input: InternalRow): Unit = { val inputSetEval = inputSet.eval(input).asInstanceOf[OpenHashSet[Any]] val inputIterator = inputSetEval.iterator while (inputIterator.hasNext) { @@ -248,7 +288,7 @@ case class CombineSetsAndCountFunction( } } - override def eval(input: Row): Any = seen.size.toLong + override def eval(input: InternalRow): Any = seen.size.toLong } /** The data type of ApproxCountDistinctPartition since its output is a HyperLogLog object. */ @@ -279,6 +319,25 @@ case class ApproxCountDistinctPartition(child: Expression, relativeSD: Double) } } +case class ApproxCountDistinctPartitionFunction( + expr: Expression, + base: AggregateExpression, + relativeSD: Double) + extends AggregateFunction { + def this() = this(null, null, 0) // Required for serialization. + + private val hyperLogLog = new HyperLogLog(relativeSD) + + override def update(input: InternalRow): Unit = { + val evaluatedExpr = expr.eval(input) + if (evaluatedExpr != null) { + hyperLogLog.offer(evaluatedExpr) + } + } + + override def eval(input: InternalRow): Any = hyperLogLog +} + case class ApproxCountDistinctMerge(child: Expression, relativeSD: Double) extends AggregateExpression with trees.UnaryNode[Expression] { @@ -290,6 +349,23 @@ case class ApproxCountDistinctMerge(child: Expression, relativeSD: Double) } } +case class ApproxCountDistinctMergeFunction( + expr: Expression, + base: AggregateExpression, + relativeSD: Double) + extends AggregateFunction { + def this() = this(null, null, 0) // Required for serialization. + + private val hyperLogLog = new HyperLogLog(relativeSD) + + override def update(input: InternalRow): Unit = { + val evaluatedExpr = expr.eval(input) + hyperLogLog.addAll(evaluatedExpr.asInstanceOf[HyperLogLog]) + } + + override def eval(input: InternalRow): Any = hyperLogLog.cardinality() +} + case class ApproxCountDistinct(child: Expression, relativeSD: Double = 0.05) extends PartialAggregate with trees.UnaryNode[Expression] { @@ -311,6 +387,8 @@ case class ApproxCountDistinct(child: Expression, relativeSD: Double = 0.05) case class Average(child: Expression) extends PartialAggregate with trees.UnaryNode[Expression] { + override def prettyName: String = "avg" + override def nullable: Boolean = true override def dataType: DataType = child.dataType match { @@ -322,8 +400,6 @@ case class Average(child: Expression) extends PartialAggregate with trees.UnaryN DoubleType } - override def toString: String = s"AVG($child)" - override def asPartial: SplitEvaluation = { child.dataType match { case DecimalType.Fixed(_, _) | DecimalType.Unlimited => @@ -350,159 +426,9 @@ case class Average(child: Expression) extends PartialAggregate with trees.UnaryN } override def newInstance(): AverageFunction = new AverageFunction(child, this) -} -case class Sum(child: Expression) extends PartialAggregate with trees.UnaryNode[Expression] { - - override def nullable: Boolean = true - - override def dataType: DataType = child.dataType match { - case DecimalType.Fixed(precision, scale) => - DecimalType(precision + 10, scale) // Add 10 digits left of decimal point, like Hive - case DecimalType.Unlimited => - DecimalType.Unlimited - case _ => - child.dataType - } - - override def toString: String = s"SUM($child)" - - override def asPartial: SplitEvaluation = { - child.dataType match { - case DecimalType.Fixed(_, _) => - val partialSum = Alias(Sum(Cast(child, DecimalType.Unlimited)), "PartialSum")() - SplitEvaluation( - Cast(CombineSum(partialSum.toAttribute), dataType), - partialSum :: Nil) - - case _ => - val partialSum = Alias(Sum(child), "PartialSum")() - SplitEvaluation( - CombineSum(partialSum.toAttribute), - partialSum :: Nil) - } - } - - override def newInstance(): SumFunction = new SumFunction(child, this) -} - -/** - * Sum should satisfy 3 cases: - * 1) sum of all null values = zero - * 2) sum for table column with no data = null - * 3) sum of column with null and not null values = sum of not null values - * Require separate CombineSum Expression and function as it has to distinguish "No data" case - * versus "data equals null" case, while aggregating results and at each partial expression.i.e., - * Combining PartitionLevel InputData - * <-- null - * Zero <-- Zero <-- null - * - * <-- null <-- no data - * null <-- null <-- no data - */ -case class CombineSum(child: Expression) extends AggregateExpression { - def this() = this(null) - - override def children: Seq[Expression] = child :: Nil - override def nullable: Boolean = true - override def dataType: DataType = child.dataType - override def toString: String = s"CombineSum($child)" - override def newInstance(): CombineSumFunction = new CombineSumFunction(child, this) -} - -case class SumDistinct(child: Expression) - extends PartialAggregate with trees.UnaryNode[Expression] { - - def this() = this(null) - override def nullable: Boolean = true - override def dataType: DataType = child.dataType match { - case DecimalType.Fixed(precision, scale) => - DecimalType(precision + 10, scale) // Add 10 digits left of decimal point, like Hive - case DecimalType.Unlimited => - DecimalType.Unlimited - case _ => - child.dataType - } - override def toString: String = s"SUM(DISTINCT $child)" - override def newInstance(): SumDistinctFunction = new SumDistinctFunction(child, this) - - override def asPartial: SplitEvaluation = { - val partialSet = Alias(CollectHashSet(child :: Nil), "partialSets")() - SplitEvaluation( - CombineSetsAndSum(partialSet.toAttribute, this), - partialSet :: Nil) - } -} - -case class CombineSetsAndSum(inputSet: Expression, base: Expression) extends AggregateExpression { - def this() = this(null, null) - - override def children: Seq[Expression] = inputSet :: Nil - override def nullable: Boolean = true - override def dataType: DataType = base.dataType - override def toString: String = s"CombineAndSum($inputSet)" - override def newInstance(): CombineSetsAndSumFunction = { - new CombineSetsAndSumFunction(inputSet, this) - } -} - -case class CombineSetsAndSumFunction( - @transient inputSet: Expression, - @transient base: AggregateExpression) - extends AggregateFunction { - - def this() = this(null, null) // Required for serialization. - - val seen = new OpenHashSet[Any]() - - override def update(input: Row): Unit = { - val inputSetEval = inputSet.eval(input).asInstanceOf[OpenHashSet[Any]] - val inputIterator = inputSetEval.iterator - while (inputIterator.hasNext) { - seen.add(inputIterator.next) - } - } - - override def eval(input: Row): Any = { - val casted = seen.asInstanceOf[OpenHashSet[Row]] - if (casted.size == 0) { - null - } else { - Cast(Literal( - casted.iterator.map(f => f.apply(0)).reduceLeft( - base.dataType.asInstanceOf[NumericType].numeric.asInstanceOf[Numeric[Any]].plus)), - base.dataType).eval(null) - } - } -} - -case class First(child: Expression) extends PartialAggregate with trees.UnaryNode[Expression] { - override def nullable: Boolean = true - override def dataType: DataType = child.dataType - override def toString: String = s"FIRST($child)" - - override def asPartial: SplitEvaluation = { - val partialFirst = Alias(First(child), "PartialFirst")() - SplitEvaluation( - First(partialFirst.toAttribute), - partialFirst :: Nil) - } - override def newInstance(): FirstFunction = new FirstFunction(child, this) -} - -case class Last(child: Expression) extends PartialAggregate with trees.UnaryNode[Expression] { - override def references: AttributeSet = child.references - override def nullable: Boolean = true - override def dataType: DataType = child.dataType - override def toString: String = s"LAST($child)" - - override def asPartial: SplitEvaluation = { - val partialLast = Alias(Last(child), "PartialLast")() - SplitEvaluation( - Last(partialLast.toAttribute), - partialLast :: Nil) - } - override def newInstance(): LastFunction = new LastFunction(child, this) + override def checkInputDataTypes(): TypeCheckResult = + TypeUtils.checkForNumericExpr(child.dataType, "function average") } case class AverageFunction(expr: Expression, base: AggregateExpression) @@ -526,7 +452,7 @@ case class AverageFunction(expr: Expression, base: AggregateExpression) private def addFunction(value: Any) = Add(sum, Cast(Literal.create(value, expr.dataType), calcType)) - override def eval(input: Row): Any = { + override def eval(input: InternalRow): Any = { if (count == 0L) { null } else { @@ -543,7 +469,7 @@ case class AverageFunction(expr: Expression, base: AggregateExpression) } } - override def update(input: Row): Unit = { + override def update(input: InternalRow): Unit = { val evaluatedExpr = expr.eval(input) if (evaluatedExpr != null) { count += 1 @@ -552,55 +478,39 @@ case class AverageFunction(expr: Expression, base: AggregateExpression) } } -case class CountFunction(expr: Expression, base: AggregateExpression) extends AggregateFunction { - def this() = this(null, null) // Required for serialization. +case class Sum(child: Expression) extends PartialAggregate with trees.UnaryNode[Expression] { - var count: Long = _ + override def nullable: Boolean = true - override def update(input: Row): Unit = { - val evaluatedExpr = expr.eval(input) - if (evaluatedExpr != null) { - count += 1L - } + override def dataType: DataType = child.dataType match { + case DecimalType.Fixed(precision, scale) => + DecimalType(precision + 10, scale) // Add 10 digits left of decimal point, like Hive + case DecimalType.Unlimited => + DecimalType.Unlimited + case _ => + child.dataType } - override def eval(input: Row): Any = count -} - -case class ApproxCountDistinctPartitionFunction( - expr: Expression, - base: AggregateExpression, - relativeSD: Double) - extends AggregateFunction { - def this() = this(null, null, 0) // Required for serialization. - - private val hyperLogLog = new HyperLogLog(relativeSD) + override def asPartial: SplitEvaluation = { + child.dataType match { + case DecimalType.Fixed(_, _) => + val partialSum = Alias(Sum(Cast(child, DecimalType.Unlimited)), "PartialSum")() + SplitEvaluation( + Cast(CombineSum(partialSum.toAttribute), dataType), + partialSum :: Nil) - override def update(input: Row): Unit = { - val evaluatedExpr = expr.eval(input) - if (evaluatedExpr != null) { - hyperLogLog.offer(evaluatedExpr) + case _ => + val partialSum = Alias(Sum(child), "PartialSum")() + SplitEvaluation( + CombineSum(partialSum.toAttribute), + partialSum :: Nil) } } - override def eval(input: Row): Any = hyperLogLog -} - -case class ApproxCountDistinctMergeFunction( - expr: Expression, - base: AggregateExpression, - relativeSD: Double) - extends AggregateFunction { - def this() = this(null, null, 0) // Required for serialization. - - private val hyperLogLog = new HyperLogLog(relativeSD) - - override def update(input: Row): Unit = { - val evaluatedExpr = expr.eval(input) - hyperLogLog.addAll(evaluatedExpr.asInstanceOf[HyperLogLog]) - } + override def newInstance(): SumFunction = new SumFunction(child, this) - override def eval(input: Row): Any = hyperLogLog.cardinality() + override def checkInputDataTypes(): TypeCheckResult = + TypeUtils.checkForNumericExpr(child.dataType, "function sum") } case class SumFunction(expr: Expression, base: AggregateExpression) extends AggregateFunction { @@ -618,14 +528,14 @@ case class SumFunction(expr: Expression, base: AggregateExpression) extends Aggr private val sum = MutableLiteral(null, calcType) - private val addFunction = + private val addFunction = Coalesce(Seq(Add(Coalesce(Seq(sum, zero)), Cast(expr, calcType)), sum, zero)) - override def update(input: Row): Unit = { + override def update(input: InternalRow): Unit = { sum.update(addFunction, input) } - override def eval(input: Row): Any = { + override def eval(input: InternalRow): Any = { expr.dataType match { case DecimalType.Fixed(_, _) => Cast(sum, dataType).eval(null) @@ -634,9 +544,33 @@ case class SumFunction(expr: Expression, base: AggregateExpression) extends Aggr } } +/** + * Sum should satisfy 3 cases: + * 1) sum of all null values = zero + * 2) sum for table column with no data = null + * 3) sum of column with null and not null values = sum of not null values + * Require separate CombineSum Expression and function as it has to distinguish "No data" case + * versus "data equals null" case, while aggregating results and at each partial expression.i.e., + * Combining PartitionLevel InputData + * <-- null + * Zero <-- Zero <-- null + * + * <-- null <-- no data + * null <-- null <-- no data + */ +case class CombineSum(child: Expression) extends AggregateExpression { + def this() = this(null) + + override def children: Seq[Expression] = child :: Nil + override def nullable: Boolean = true + override def dataType: DataType = child.dataType + override def toString: String = s"CombineSum($child)" + override def newInstance(): CombineSumFunction = new CombineSumFunction(child, this) +} + case class CombineSumFunction(expr: Expression, base: AggregateExpression) extends AggregateFunction { - + def this() = this(null, null) // Required for serialization. private val calcType = @@ -651,18 +585,18 @@ case class CombineSumFunction(expr: Expression, base: AggregateExpression) private val sum = MutableLiteral(null, calcType) - private val addFunction = + private val addFunction = Coalesce(Seq(Add(Coalesce(Seq(sum, zero)), Cast(expr, calcType)), sum, zero)) - - override def update(input: Row): Unit = { + + override def update(input: InternalRow): Unit = { val result = expr.eval(input) - // partial sum result can be null only when no input rows present + // partial sum result can be null only when no input rows present if(result != null) { sum.update(addFunction, input) } } - override def eval(input: Row): Any = { + override def eval(input: InternalRow): Any = { expr.dataType match { case DecimalType.Fixed(_, _) => Cast(sum, dataType).eval(null) @@ -671,6 +605,33 @@ case class CombineSumFunction(expr: Expression, base: AggregateExpression) } } +case class SumDistinct(child: Expression) + extends PartialAggregate with trees.UnaryNode[Expression] { + + def this() = this(null) + override def nullable: Boolean = true + override def dataType: DataType = child.dataType match { + case DecimalType.Fixed(precision, scale) => + DecimalType(precision + 10, scale) // Add 10 digits left of decimal point, like Hive + case DecimalType.Unlimited => + DecimalType.Unlimited + case _ => + child.dataType + } + override def toString: String = s"SUM(DISTINCT $child)" + override def newInstance(): SumDistinctFunction = new SumDistinctFunction(child, this) + + override def asPartial: SplitEvaluation = { + val partialSet = Alias(CollectHashSet(child :: Nil), "partialSets")() + SplitEvaluation( + CombineSetsAndSum(partialSet.toAttribute, this), + partialSet :: Nil) + } + + override def checkInputDataTypes(): TypeCheckResult = + TypeUtils.checkForNumericExpr(child.dataType, "function sumDistinct") +} + case class SumDistinctFunction(expr: Expression, base: AggregateExpression) extends AggregateFunction { @@ -678,14 +639,14 @@ case class SumDistinctFunction(expr: Expression, base: AggregateExpression) private val seen = new scala.collection.mutable.HashSet[Any]() - override def update(input: Row): Unit = { + override def update(input: InternalRow): Unit = { val evaluatedExpr = expr.eval(input) if (evaluatedExpr != null) { seen += evaluatedExpr } } - override def eval(input: Row): Any = { + override def eval(input: InternalRow): Any = { if (seen.size == 0) { null } else { @@ -697,8 +658,20 @@ case class SumDistinctFunction(expr: Expression, base: AggregateExpression) } } -case class CountDistinctFunction( - @transient expr: Seq[Expression], +case class CombineSetsAndSum(inputSet: Expression, base: Expression) extends AggregateExpression { + def this() = this(null, null) + + override def children: Seq[Expression] = inputSet :: Nil + override def nullable: Boolean = true + override def dataType: DataType = base.dataType + override def toString: String = s"CombineAndSum($inputSet)" + override def newInstance(): CombineSetsAndSumFunction = { + new CombineSetsAndSumFunction(inputSet, this) + } +} + +case class CombineSetsAndSumFunction( + @transient inputSet: Expression, @transient base: AggregateExpression) extends AggregateFunction { @@ -706,17 +679,39 @@ case class CountDistinctFunction( val seen = new OpenHashSet[Any]() - @transient - val distinctValue = new InterpretedProjection(expr) + override def update(input: InternalRow): Unit = { + val inputSetEval = inputSet.eval(input).asInstanceOf[OpenHashSet[Any]] + val inputIterator = inputSetEval.iterator + while (inputIterator.hasNext) { + seen.add(inputIterator.next) + } + } - override def update(input: Row): Unit = { - val evaluatedExpr = distinctValue(input) - if (!evaluatedExpr.anyNull) { - seen.add(evaluatedExpr) + override def eval(input: InternalRow): Any = { + val casted = seen.asInstanceOf[OpenHashSet[InternalRow]] + if (casted.size == 0) { + null + } else { + Cast(Literal( + casted.iterator.map(f => f.apply(0)).reduceLeft( + base.dataType.asInstanceOf[NumericType].numeric.asInstanceOf[Numeric[Any]].plus)), + base.dataType).eval(null) } } +} + +case class First(child: Expression) extends PartialAggregate with trees.UnaryNode[Expression] { + override def nullable: Boolean = true + override def dataType: DataType = child.dataType + override def toString: String = s"FIRST($child)" - override def eval(input: Row): Any = seen.size.toLong + override def asPartial: SplitEvaluation = { + val partialFirst = Alias(First(child), "PartialFirst")() + SplitEvaluation( + First(partialFirst.toAttribute), + partialFirst :: Nil) + } + override def newInstance(): FirstFunction = new FirstFunction(child, this) } case class FirstFunction(expr: Expression, base: AggregateExpression) extends AggregateFunction { @@ -724,13 +719,28 @@ case class FirstFunction(expr: Expression, base: AggregateExpression) extends Ag var result: Any = null - override def update(input: Row): Unit = { + override def update(input: InternalRow): Unit = { if (result == null) { result = expr.eval(input) } } - override def eval(input: Row): Any = result + override def eval(input: InternalRow): Any = result +} + +case class Last(child: Expression) extends PartialAggregate with trees.UnaryNode[Expression] { + override def references: AttributeSet = child.references + override def nullable: Boolean = true + override def dataType: DataType = child.dataType + override def toString: String = s"LAST($child)" + + override def asPartial: SplitEvaluation = { + val partialLast = Alias(Last(child), "PartialLast")() + SplitEvaluation( + Last(partialLast.toAttribute), + partialLast :: Nil) + } + override def newInstance(): LastFunction = new LastFunction(child, this) } case class LastFunction(expr: Expression, base: AggregateExpression) extends AggregateFunction { @@ -738,11 +748,11 @@ case class LastFunction(expr: Expression, base: AggregateExpression) extends Agg var result: Any = null - override def update(input: Row): Unit = { + override def update(input: InternalRow): Unit = { result = input } - override def eval(input: Row): Any = { - if (result != null) expr.eval(result.asInstanceOf[Row]) else null + override def eval(input: InternalRow): Any = { + if (result != null) expr.eval(result.asInstanceOf[InternalRow]) else null } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala index c7a37ad966df..4fbf4c87009c 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala @@ -17,77 +17,84 @@ package org.apache.spark.sql.catalyst.expressions -import org.apache.spark.sql.catalyst.analysis.UnresolvedException -import org.apache.spark.sql.catalyst.errors.TreeNodeException +import org.apache.spark.sql.catalyst.analysis.TypeCheckResult +import org.apache.spark.sql.catalyst.expressions.codegen.{CodeGenContext, GeneratedExpressionCode} +import org.apache.spark.sql.catalyst.util.TypeUtils import org.apache.spark.sql.types._ -case class UnaryMinus(child: Expression) extends UnaryExpression { - type EvaluatedType = Any +abstract class UnaryArithmetic extends UnaryExpression { + self: Product => override def dataType: DataType = child.dataType - override def foldable: Boolean = child.foldable - override def nullable: Boolean = child.nullable - override def toString: String = s"-$child" - lazy val numeric = dataType match { - case n: NumericType => n.numeric.asInstanceOf[Numeric[Any]] - case other => sys.error(s"Type $other does not support numeric operations") - } - - override def eval(input: Row): Any = { + override def eval(input: InternalRow): Any = { val evalE = child.eval(input) if (evalE == null) { null } else { - numeric.negate(evalE) + evalInternal(evalE) } } + + protected def evalInternal(evalE: Any): Any = + sys.error(s"UnaryArithmetics must override either eval or evalInternal") } -case class Sqrt(child: Expression) extends UnaryExpression { - type EvaluatedType = Any +case class UnaryMinus(child: Expression) extends UnaryArithmetic { + override def toString: String = s"-$child" - override def dataType: DataType = DoubleType - override def foldable: Boolean = child.foldable - override def nullable: Boolean = true - override def toString: String = s"SQRT($child)" + override def checkInputDataTypes(): TypeCheckResult = + TypeUtils.checkForNumericExpr(child.dataType, "operator -") - lazy val numeric = child.dataType match { - case n: NumericType => n.numeric.asInstanceOf[Numeric[Any]] - case other => sys.error(s"Type $other does not support non-negative numeric operations") - } + private lazy val numeric = TypeUtils.getNumeric(dataType) - override def eval(input: Row): Any = { - val evalE = child.eval(input) - if (evalE == null) { - null - } else { - val value = numeric.toDouble(evalE) - if (value < 0) null - else math.sqrt(value) - } + override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = dataType match { + case dt: DecimalType => defineCodeGen(ctx, ev, c => s"$c.unary_$$minus()") + case dt: NumericType => defineCodeGen(ctx, ev, c => s"(${ctx.javaType(dt)})(-($c))") } + + protected override def evalInternal(evalE: Any) = numeric.negate(evalE) +} + +case class UnaryPositive(child: Expression) extends UnaryArithmetic { + override def prettyName: String = "positive" + + override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = + defineCodeGen(ctx, ev, c => c) + + protected override def evalInternal(evalE: Any) = evalE +} + +/** + * A function that get the absolute value of the numeric value. + */ +case class Abs(child: Expression) extends UnaryArithmetic { + override def checkInputDataTypes(): TypeCheckResult = + TypeUtils.checkForNumericExpr(child.dataType, "function abs") + + private lazy val numeric = TypeUtils.getNumeric(dataType) + + protected override def evalInternal(evalE: Any) = numeric.abs(evalE) } -abstract class BinaryArithmetic extends BinaryExpression { +abstract class BinaryArithmetic extends BinaryOperator { self: Product => - type EvaluatedType = Any + override def dataType: DataType = left.dataType - override lazy val resolved = - left.resolved && right.resolved && - left.dataType == right.dataType && - !DecimalType.isFixed(left.dataType) - - override def dataType: DataType = { - if (!resolved) { - throw new UnresolvedException(this, - s"datatype. Can not resolve due to differing types ${left.dataType}, ${right.dataType}") + override def checkInputDataTypes(): TypeCheckResult = { + if (left.dataType != right.dataType) { + TypeCheckResult.TypeCheckFailure( + s"differing types in ${this.getClass.getSimpleName} " + + s"(${left.dataType} and ${right.dataType}).") + } else { + checkTypesInternal(dataType) } - left.dataType } - override def eval(input: Row): Any = { + protected def checkTypesInternal(t: DataType): TypeCheckResult + + override def eval(input: InternalRow): Any = { val evalE1 = left.eval(input) if(evalE1 == null) { null @@ -101,91 +108,92 @@ abstract class BinaryArithmetic extends BinaryExpression { } } - def evalInternal(evalE1: EvaluatedType, evalE2: EvaluatedType): Any = - sys.error(s"BinaryExpressions must either override eval or evalInternal") + /** Name of the function for this expression on a [[Decimal]] type. */ + def decimalMethod: String = + sys.error("BinaryArithmetics must override either decimalMethod or genCode") + + override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = dataType match { + case dt: DecimalType => + defineCodeGen(ctx, ev, (eval1, eval2) => s"$eval1.$decimalMethod($eval2)") + // byte and short are casted into int when add, minus, times or divide + case ByteType | ShortType => + defineCodeGen(ctx, ev, + (eval1, eval2) => s"(${ctx.javaType(dataType)})($eval1 $symbol $eval2)") + case _ => + defineCodeGen(ctx, ev, (eval1, eval2) => s"$eval1 $symbol $eval2") + } + + protected def evalInternal(evalE1: Any, evalE2: Any): Any = + sys.error(s"BinaryArithmetics must override either eval or evalInternal") +} + +private[sql] object BinaryArithmetic { + def unapply(e: BinaryArithmetic): Option[(Expression, Expression)] = Some((e.left, e.right)) } case class Add(left: Expression, right: Expression) extends BinaryArithmetic { override def symbol: String = "+" + override def decimalMethod: String = "$plus" - lazy val numeric = dataType match { - case n: NumericType => n.numeric.asInstanceOf[Numeric[Any]] - case other => sys.error(s"Type $other does not support numeric operations") - } + override lazy val resolved = + childrenResolved && checkInputDataTypes().isSuccess && !DecimalType.isFixed(dataType) - override def eval(input: Row): Any = { - val evalE1 = left.eval(input) - if(evalE1 == null) { - null - } else { - val evalE2 = right.eval(input) - if (evalE2 == null) { - null - } else { - numeric.plus(evalE1, evalE2) - } - } - } + protected def checkTypesInternal(t: DataType) = + TypeUtils.checkForNumericExpr(t, "operator " + symbol) + + private lazy val numeric = TypeUtils.getNumeric(dataType) + + protected override def evalInternal(evalE1: Any, evalE2: Any) = numeric.plus(evalE1, evalE2) } case class Subtract(left: Expression, right: Expression) extends BinaryArithmetic { override def symbol: String = "-" + override def decimalMethod: String = "$minus" - lazy val numeric = dataType match { - case n: NumericType => n.numeric.asInstanceOf[Numeric[Any]] - case other => sys.error(s"Type $other does not support numeric operations") - } + override lazy val resolved = + childrenResolved && checkInputDataTypes().isSuccess && !DecimalType.isFixed(dataType) - override def eval(input: Row): Any = { - val evalE1 = left.eval(input) - if(evalE1 == null) { - null - } else { - val evalE2 = right.eval(input) - if (evalE2 == null) { - null - } else { - numeric.minus(evalE1, evalE2) - } - } - } + protected def checkTypesInternal(t: DataType) = + TypeUtils.checkForNumericExpr(t, "operator " + symbol) + + private lazy val numeric = TypeUtils.getNumeric(dataType) + + protected override def evalInternal(evalE1: Any, evalE2: Any) = numeric.minus(evalE1, evalE2) } case class Multiply(left: Expression, right: Expression) extends BinaryArithmetic { override def symbol: String = "*" + override def decimalMethod: String = "$times" - lazy val numeric = dataType match { - case n: NumericType => n.numeric.asInstanceOf[Numeric[Any]] - case other => sys.error(s"Type $other does not support numeric operations") - } + override lazy val resolved = + childrenResolved && checkInputDataTypes().isSuccess && !DecimalType.isFixed(dataType) - override def eval(input: Row): Any = { - val evalE1 = left.eval(input) - if(evalE1 == null) { - null - } else { - val evalE2 = right.eval(input) - if (evalE2 == null) { - null - } else { - numeric.times(evalE1, evalE2) - } - } - } + protected def checkTypesInternal(t: DataType) = + TypeUtils.checkForNumericExpr(t, "operator " + symbol) + + private lazy val numeric = TypeUtils.getNumeric(dataType) + + protected override def evalInternal(evalE1: Any, evalE2: Any) = numeric.times(evalE1, evalE2) } case class Divide(left: Expression, right: Expression) extends BinaryArithmetic { override def symbol: String = "/" + override def decimalMethod: String = "$div" override def nullable: Boolean = true - lazy val div: (Any, Any) => Any = dataType match { + override lazy val resolved = + childrenResolved && checkInputDataTypes().isSuccess && !DecimalType.isFixed(dataType) + + protected def checkTypesInternal(t: DataType) = + TypeUtils.checkForNumericExpr(t, "operator " + symbol) + + private lazy val div: (Any, Any) => Any = dataType match { case ft: FractionalType => ft.fractional.asInstanceOf[Fractional[Any]].div case it: IntegralType => it.integral.asInstanceOf[Integral[Any]].quot - case other => sys.error(s"Type $other does not support numeric operations") } - - override def eval(input: Row): Any = { + + override def eval(input: InternalRow): Any = { val evalE2 = right.eval(input) if (evalE2 == null || evalE2 == 0) { null @@ -198,20 +206,60 @@ case class Divide(left: Expression, right: Expression) extends BinaryArithmetic } } } + + /** + * Special case handling due to division by 0 => null. + */ + override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { + val eval1 = left.gen(ctx) + val eval2 = right.gen(ctx) + val isZero = if (dataType.isInstanceOf[DecimalType]) { + s"${eval2.primitive}.isZero()" + } else { + s"${eval2.primitive} == 0" + } + val javaType = ctx.javaType(dataType) + val divide = if (dataType.isInstanceOf[DecimalType]) { + s"${eval1.primitive}.$decimalMethod(${eval2.primitive})" + } else { + s"($javaType)(${eval1.primitive} $symbol ${eval2.primitive})" + } + s""" + ${eval2.code} + boolean ${ev.isNull} = false; + $javaType ${ev.primitive} = ${ctx.defaultValue(javaType)}; + if (${eval2.isNull} || $isZero) { + ${ev.isNull} = true; + } else { + ${eval1.code} + if (${eval1.isNull}) { + ${ev.isNull} = true; + } else { + ${ev.primitive} = $divide; + } + } + """ + } } case class Remainder(left: Expression, right: Expression) extends BinaryArithmetic { override def symbol: String = "%" + override def decimalMethod: String = "remainder" override def nullable: Boolean = true - lazy val integral = dataType match { + override lazy val resolved = + childrenResolved && checkInputDataTypes().isSuccess && !DecimalType.isFixed(dataType) + + protected def checkTypesInternal(t: DataType) = + TypeUtils.checkForNumericExpr(t, "operator " + symbol) + + private lazy val integral = dataType match { case i: IntegralType => i.integral.asInstanceOf[Integral[Any]] case i: FractionalType => i.asIntegral.asInstanceOf[Integral[Any]] - case other => sys.error(s"Type $other does not support numeric operations") } - override def eval(input: Row): Any = { + override def eval(input: InternalRow): Any = { val evalE2 = right.eval(input) if (evalE2 == null || evalE2 == 0) { null @@ -224,131 +272,51 @@ case class Remainder(left: Expression, right: Expression) extends BinaryArithmet } } } -} - -/** - * A function that calculates bitwise and(&) of two numbers. - */ -case class BitwiseAnd(left: Expression, right: Expression) extends BinaryArithmetic { - override def symbol: String = "&" - - lazy val and: (Any, Any) => Any = dataType match { - case ByteType => - ((evalE1: Byte, evalE2: Byte) => (evalE1 & evalE2).toByte).asInstanceOf[(Any, Any) => Any] - case ShortType => - ((evalE1: Short, evalE2: Short) => (evalE1 & evalE2).toShort).asInstanceOf[(Any, Any) => Any] - case IntegerType => - ((evalE1: Int, evalE2: Int) => evalE1 & evalE2).asInstanceOf[(Any, Any) => Any] - case LongType => - ((evalE1: Long, evalE2: Long) => evalE1 & evalE2).asInstanceOf[(Any, Any) => Any] - case other => sys.error(s"Unsupported bitwise & operation on $other") - } - - override def evalInternal(evalE1: EvaluatedType, evalE2: EvaluatedType): Any = and(evalE1, evalE2) -} - -/** - * A function that calculates bitwise or(|) of two numbers. - */ -case class BitwiseOr(left: Expression, right: Expression) extends BinaryArithmetic { - override def symbol: String = "|" - - lazy val or: (Any, Any) => Any = dataType match { - case ByteType => - ((evalE1: Byte, evalE2: Byte) => (evalE1 | evalE2).toByte).asInstanceOf[(Any, Any) => Any] - case ShortType => - ((evalE1: Short, evalE2: Short) => (evalE1 | evalE2).toShort).asInstanceOf[(Any, Any) => Any] - case IntegerType => - ((evalE1: Int, evalE2: Int) => evalE1 | evalE2).asInstanceOf[(Any, Any) => Any] - case LongType => - ((evalE1: Long, evalE2: Long) => evalE1 | evalE2).asInstanceOf[(Any, Any) => Any] - case other => sys.error(s"Unsupported bitwise | operation on $other") - } - - override def evalInternal(evalE1: EvaluatedType, evalE2: EvaluatedType): Any = or(evalE1, evalE2) -} - -/** - * A function that calculates bitwise xor(^) of two numbers. - */ -case class BitwiseXor(left: Expression, right: Expression) extends BinaryArithmetic { - override def symbol: String = "^" - - lazy val xor: (Any, Any) => Any = dataType match { - case ByteType => - ((evalE1: Byte, evalE2: Byte) => (evalE1 ^ evalE2).toByte).asInstanceOf[(Any, Any) => Any] - case ShortType => - ((evalE1: Short, evalE2: Short) => (evalE1 ^ evalE2).toShort).asInstanceOf[(Any, Any) => Any] - case IntegerType => - ((evalE1: Int, evalE2: Int) => evalE1 ^ evalE2).asInstanceOf[(Any, Any) => Any] - case LongType => - ((evalE1: Long, evalE2: Long) => evalE1 ^ evalE2).asInstanceOf[(Any, Any) => Any] - case other => sys.error(s"Unsupported bitwise ^ operation on $other") - } - - override def evalInternal(evalE1: EvaluatedType, evalE2: EvaluatedType): Any = xor(evalE1, evalE2) -} - -/** - * A function that calculates bitwise not(~) of a number. - */ -case class BitwiseNot(child: Expression) extends UnaryExpression { - type EvaluatedType = Any - - override def dataType: DataType = child.dataType - override def foldable: Boolean = child.foldable - override def nullable: Boolean = child.nullable - override def toString: String = s"~$child" - - lazy val not: (Any) => Any = dataType match { - case ByteType => - ((evalE: Byte) => (~evalE).toByte).asInstanceOf[(Any) => Any] - case ShortType => - ((evalE: Short) => (~evalE).toShort).asInstanceOf[(Any) => Any] - case IntegerType => - ((evalE: Int) => ~evalE).asInstanceOf[(Any) => Any] - case LongType => - ((evalE: Long) => ~evalE).asInstanceOf[(Any) => Any] - case other => sys.error(s"Unsupported bitwise ~ operation on $other") - } - override def eval(input: Row): Any = { - val evalE = child.eval(input) - if (evalE == null) { - null + /** + * Special case handling for x % 0 ==> null. + */ + override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { + val eval1 = left.gen(ctx) + val eval2 = right.gen(ctx) + val isZero = if (dataType.isInstanceOf[DecimalType]) { + s"${eval2.primitive}.isZero()" } else { - not(evalE) + s"${eval2.primitive} == 0" } + val javaType = ctx.javaType(dataType) + val remainder = if (dataType.isInstanceOf[DecimalType]) { + s"${eval1.primitive}.$decimalMethod(${eval2.primitive})" + } else { + s"($javaType)(${eval1.primitive} $symbol ${eval2.primitive})" + } + s""" + ${eval2.code} + boolean ${ev.isNull} = false; + $javaType ${ev.primitive} = ${ctx.defaultValue(javaType)}; + if (${eval2.isNull} || $isZero) { + ${ev.isNull} = true; + } else { + ${eval1.code} + if (${eval1.isNull}) { + ${ev.isNull} = true; + } else { + ${ev.primitive} = $remainder; + } + } + """ } } -case class MaxOf(left: Expression, right: Expression) extends Expression { - type EvaluatedType = Any - - override def foldable: Boolean = left.foldable && right.foldable - +case class MaxOf(left: Expression, right: Expression) extends BinaryArithmetic { override def nullable: Boolean = left.nullable && right.nullable - override def children: Seq[Expression] = left :: right :: Nil - - override lazy val resolved = - left.resolved && right.resolved && - left.dataType == right.dataType - - override def dataType: DataType = { - if (!resolved) { - throw new UnresolvedException(this, - s"datatype. Can not resolve due to differing types ${left.dataType}, ${right.dataType}") - } - left.dataType - } + protected def checkTypesInternal(t: DataType) = + TypeUtils.checkForOrderingExpr(t, "function maxOf") - lazy val ordering = left.dataType match { - case i: AtomicType => i.ordering.asInstanceOf[Ordering[Any]] - case other => sys.error(s"Type $other does not support ordered operations") - } + private lazy val ordering = TypeUtils.getOrdering(dataType) - override def eval(input: Row): Any = { + override def eval(input: InternalRow): Any = { val evalE1 = left.eval(input) val evalE2 = right.eval(input) if (evalE1 == null) { @@ -364,36 +332,45 @@ case class MaxOf(left: Expression, right: Expression) extends Expression { } } - override def toString: String = s"MaxOf($left, $right)" -} - -case class MinOf(left: Expression, right: Expression) extends Expression { - type EvaluatedType = Any + override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { + val eval1 = left.gen(ctx) + val eval2 = right.gen(ctx) + val compCode = ctx.genComp(dataType, eval1.primitive, eval2.primitive) + + eval1.code + eval2.code + s""" + boolean ${ev.isNull} = false; + ${ctx.javaType(left.dataType)} ${ev.primitive} = + ${ctx.defaultValue(left.dataType)}; + + if (${eval1.isNull}) { + ${ev.isNull} = ${eval2.isNull}; + ${ev.primitive} = ${eval2.primitive}; + } else if (${eval2.isNull}) { + ${ev.isNull} = ${eval1.isNull}; + ${ev.primitive} = ${eval1.primitive}; + } else { + if ($compCode > 0) { + ${ev.primitive} = ${eval1.primitive}; + } else { + ${ev.primitive} = ${eval2.primitive}; + } + } + """ + } - override def foldable: Boolean = left.foldable && right.foldable + override def symbol: String = "max" + override def prettyName: String = symbol +} +case class MinOf(left: Expression, right: Expression) extends BinaryArithmetic { override def nullable: Boolean = left.nullable && right.nullable - override def children: Seq[Expression] = left :: right :: Nil - - override lazy val resolved = - left.resolved && right.resolved && - left.dataType == right.dataType - - override def dataType: DataType = { - if (!resolved) { - throw new UnresolvedException(this, - s"datatype. Can not resolve due to differing types ${left.dataType}, ${right.dataType}") - } - left.dataType - } + protected def checkTypesInternal(t: DataType) = + TypeUtils.checkForOrderingExpr(t, "function minOf") - lazy val ordering = left.dataType match { - case i: AtomicType => i.ordering.asInstanceOf[Ordering[Any]] - case other => sys.error(s"Type $other does not support ordered operations") - } + private lazy val ordering = TypeUtils.getOrdering(dataType) - override def eval(input: Row): Any = { + override def eval(input: InternalRow): Any = { val evalE1 = left.eval(input) val evalE2 = right.eval(input) if (evalE1 == null) { @@ -409,31 +386,32 @@ case class MinOf(left: Expression, right: Expression) extends Expression { } } - override def toString: String = s"MinOf($left, $right)" -} - -/** - * A function that get the absolute value of the numeric value. - */ -case class Abs(child: Expression) extends UnaryExpression { - type EvaluatedType = Any - - override def dataType: DataType = child.dataType - override def foldable: Boolean = child.foldable - override def nullable: Boolean = child.nullable - override def toString: String = s"Abs($child)" - - lazy val numeric = dataType match { - case n: NumericType => n.numeric.asInstanceOf[Numeric[Any]] - case other => sys.error(s"Type $other does not support numeric operations") + override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { + val eval1 = left.gen(ctx) + val eval2 = right.gen(ctx) + val compCode = ctx.genComp(dataType, eval1.primitive, eval2.primitive) + + eval1.code + eval2.code + s""" + boolean ${ev.isNull} = false; + ${ctx.javaType(left.dataType)} ${ev.primitive} = + ${ctx.defaultValue(left.dataType)}; + + if (${eval1.isNull}) { + ${ev.isNull} = ${eval2.isNull}; + ${ev.primitive} = ${eval2.primitive}; + } else if (${eval2.isNull}) { + ${ev.isNull} = ${eval1.isNull}; + ${ev.primitive} = ${eval1.primitive}; + } else { + if ($compCode < 0) { + ${ev.primitive} = ${eval1.primitive}; + } else { + ${ev.primitive} = ${eval2.primitive}; + } + } + """ } - override def eval(input: Row): Any = { - val evalE = child.eval(input) - if (evalE == null) { - null - } else { - numeric.abs(evalE) - } - } + override def symbol: String = "min" + override def prettyName: String = symbol } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/bitwise.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/bitwise.scala new file mode 100644 index 000000000000..9002dda7bf4d --- /dev/null +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/bitwise.scala @@ -0,0 +1,126 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.expressions + +import org.apache.spark.sql.catalyst.analysis.TypeCheckResult +import org.apache.spark.sql.catalyst.expressions.codegen._ +import org.apache.spark.sql.catalyst.util.TypeUtils +import org.apache.spark.sql.types._ + + +/** + * A function that calculates bitwise and(&) of two numbers. + * + * Code generation inherited from BinaryArithmetic. + */ +case class BitwiseAnd(left: Expression, right: Expression) extends BinaryArithmetic { + override def symbol: String = "&" + + protected def checkTypesInternal(t: DataType) = + TypeUtils.checkForBitwiseExpr(t, "operator " + symbol) + + private lazy val and: (Any, Any) => Any = dataType match { + case ByteType => + ((evalE1: Byte, evalE2: Byte) => (evalE1 & evalE2).toByte).asInstanceOf[(Any, Any) => Any] + case ShortType => + ((evalE1: Short, evalE2: Short) => (evalE1 & evalE2).toShort).asInstanceOf[(Any, Any) => Any] + case IntegerType => + ((evalE1: Int, evalE2: Int) => evalE1 & evalE2).asInstanceOf[(Any, Any) => Any] + case LongType => + ((evalE1: Long, evalE2: Long) => evalE1 & evalE2).asInstanceOf[(Any, Any) => Any] + } + + protected override def evalInternal(evalE1: Any, evalE2: Any) = and(evalE1, evalE2) +} + +/** + * A function that calculates bitwise or(|) of two numbers. + * + * Code generation inherited from BinaryArithmetic. + */ +case class BitwiseOr(left: Expression, right: Expression) extends BinaryArithmetic { + override def symbol: String = "|" + + protected def checkTypesInternal(t: DataType) = + TypeUtils.checkForBitwiseExpr(t, "operator " + symbol) + + private lazy val or: (Any, Any) => Any = dataType match { + case ByteType => + ((evalE1: Byte, evalE2: Byte) => (evalE1 | evalE2).toByte).asInstanceOf[(Any, Any) => Any] + case ShortType => + ((evalE1: Short, evalE2: Short) => (evalE1 | evalE2).toShort).asInstanceOf[(Any, Any) => Any] + case IntegerType => + ((evalE1: Int, evalE2: Int) => evalE1 | evalE2).asInstanceOf[(Any, Any) => Any] + case LongType => + ((evalE1: Long, evalE2: Long) => evalE1 | evalE2).asInstanceOf[(Any, Any) => Any] + } + + protected override def evalInternal(evalE1: Any, evalE2: Any) = or(evalE1, evalE2) +} + +/** + * A function that calculates bitwise xor of two numbers. + * + * Code generation inherited from BinaryArithmetic. + */ +case class BitwiseXor(left: Expression, right: Expression) extends BinaryArithmetic { + override def symbol: String = "^" + + protected def checkTypesInternal(t: DataType) = + TypeUtils.checkForBitwiseExpr(t, "operator " + symbol) + + private lazy val xor: (Any, Any) => Any = dataType match { + case ByteType => + ((evalE1: Byte, evalE2: Byte) => (evalE1 ^ evalE2).toByte).asInstanceOf[(Any, Any) => Any] + case ShortType => + ((evalE1: Short, evalE2: Short) => (evalE1 ^ evalE2).toShort).asInstanceOf[(Any, Any) => Any] + case IntegerType => + ((evalE1: Int, evalE2: Int) => evalE1 ^ evalE2).asInstanceOf[(Any, Any) => Any] + case LongType => + ((evalE1: Long, evalE2: Long) => evalE1 ^ evalE2).asInstanceOf[(Any, Any) => Any] + } + + protected override def evalInternal(evalE1: Any, evalE2: Any): Any = xor(evalE1, evalE2) +} + +/** + * A function that calculates bitwise not(~) of a number. + */ +case class BitwiseNot(child: Expression) extends UnaryArithmetic { + override def toString: String = s"~$child" + + override def checkInputDataTypes(): TypeCheckResult = + TypeUtils.checkForBitwiseExpr(child.dataType, "operator ~") + + private lazy val not: (Any) => Any = dataType match { + case ByteType => + ((evalE: Byte) => (~evalE).toByte).asInstanceOf[(Any) => Any] + case ShortType => + ((evalE: Short) => (~evalE).toShort).asInstanceOf[(Any) => Any] + case IntegerType => + ((evalE: Int) => ~evalE).asInstanceOf[(Any) => Any] + case LongType => + ((evalE: Long) => ~evalE).asInstanceOf[(Any) => Any] + } + + override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { + defineCodeGen(ctx, ev, c => s"(${ctx.javaType(dataType)}) ~($c)") + } + + protected override def evalInternal(evalE: Any) = not(evalE) +} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala index ecb4c4b68f90..9f6329bbda4e 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala @@ -17,678 +17,283 @@ package org.apache.spark.sql.catalyst.expressions.codegen -import com.google.common.cache.{CacheLoader, CacheBuilder} - +import scala.collection.mutable import scala.language.existentials +import com.google.common.cache.{CacheBuilder, CacheLoader} +import org.codehaus.janino.ClassBodyEvaluator + import org.apache.spark.Logging -import org.apache.spark.sql.catalyst.expressions import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.types._ +import org.apache.spark.unsafe.types.UTF8String + // These classes are here to avoid issues with serialization and integration with quasiquotes. class IntegerHashSet extends org.apache.spark.util.collection.OpenHashSet[Int] class LongHashSet extends org.apache.spark.util.collection.OpenHashSet[Long] /** - * A base class for generators of byte code to perform expression evaluation. Includes a set of - * helpers for referring to Catalyst types and building trees that perform evaluation of individual - * expressions. + * Java source for evaluating an [[Expression]] given a [[InternalRow]] of input. + * + * @param code The sequence of statements required to evaluate the expression. + * @param isNull A term that holds a boolean value representing whether the expression evaluated + * to null. + * @param primitive A term for a possible primitive value of the result of the evaluation. Not + * valid if `isNull` is set to `true`. */ -abstract class CodeGenerator[InType <: AnyRef, OutType <: AnyRef] extends Logging { - import scala.reflect.runtime.{universe => ru} - import scala.reflect.runtime.universe._ +case class GeneratedExpressionCode(var code: String, var isNull: String, var primitive: String) - import scala.tools.reflect.ToolBox +/** + * A context for codegen, which is used to bookkeeping the expressions those are not supported + * by codegen, then they are evaluated directly. The unsupported expression is appended at the + * end of `references`, the position of it is kept in the code, used to access and evaluate it. + */ +class CodeGenContext { - protected val toolBox = runtimeMirror(getClass.getClassLoader).mkToolBox() + /** + * Holding all the expressions those do not support codegen, will be evaluated directly. + */ + val references: mutable.ArrayBuffer[Expression] = new mutable.ArrayBuffer[Expression]() - protected val rowType = typeOf[Row] - protected val mutableRowType = typeOf[MutableRow] - protected val genericRowType = typeOf[GenericRow] - protected val genericMutableRowType = typeOf[GenericMutableRow] + val stringType: String = classOf[UTF8String].getName + val decimalType: String = classOf[Decimal].getName - protected val projectionType = typeOf[Projection] - protected val mutableProjectionType = typeOf[MutableProjection] + final val JAVA_BOOLEAN = "boolean" + final val JAVA_BYTE = "byte" + final val JAVA_SHORT = "short" + final val JAVA_INT = "int" + final val JAVA_LONG = "long" + final val JAVA_FLOAT = "float" + final val JAVA_DOUBLE = "double" private val curId = new java.util.concurrent.atomic.AtomicInteger() - private val javaSeparator = "$" /** - * Can be flipped on manually in the console to add (expensive) expression evaluation trace code. + * Returns a term name that is unique within this instance of a `CodeGenerator`. + * + * (Since we aren't in a macro context we do not seem to have access to the built in `freshName` + * function.) */ - var debugLogging = false + def freshName(prefix: String): String = { + s"$prefix${curId.getAndIncrement}" + } /** - * Generates a class for a given input expression. Called when there is not cached code - * already available. + * Returns the code to access a column in Row for a given DataType. */ - protected def create(in: InType): OutType + def getColumn(row: String, dataType: DataType, ordinal: Int): String = { + val jt = javaType(dataType) + if (isPrimitiveType(jt)) { + s"$row.get${primitiveTypeName(jt)}($ordinal)" + } else { + s"($jt)$row.apply($ordinal)" + } + } /** - * Canonicalizes an input expression. Used to avoid double caching expressions that differ only - * cosmetically. + * Returns the code to update a column in Row for a given DataType. */ - protected def canonicalize(in: InType): InType - - /** Binds an input expression to a given input schema */ - protected def bind(in: InType, inputSchema: Seq[Attribute]): InType + def setColumn(row: String, dataType: DataType, ordinal: Int, value: String): String = { + val jt = javaType(dataType) + if (isPrimitiveType(jt)) { + s"$row.set${primitiveTypeName(jt)}($ordinal, $value)" + } else { + s"$row.update($ordinal, $value)" + } + } /** - * A cache of generated classes. - * - * From the Guava Docs: A Cache is similar to ConcurrentMap, but not quite the same. The most - * fundamental difference is that a ConcurrentMap persists all elements that are added to it until - * they are explicitly removed. A Cache on the other hand is generally configured to evict entries - * automatically, in order to constrain its memory footprint. Note that this cache does not use - * weak keys/values and thus does not respond to memory pressure. + * Returns the name used in accessor and setter for a Java primitive type. */ - protected val cache = CacheBuilder.newBuilder() - .maximumSize(1000) - .build( - new CacheLoader[InType, OutType]() { - override def load(in: InType): OutType = globalLock.synchronized { - val startTime = System.nanoTime() - val result = create(in) - val endTime = System.nanoTime() - def timeMs: Double = (endTime - startTime).toDouble / 1000000 - logInfo(s"Code generated expression $in in $timeMs ms") - result - } - }) - - /** Generates the requested evaluator binding the given expression(s) to the inputSchema. */ - def generate(expressions: InType, inputSchema: Seq[Attribute]): OutType = - generate(bind(expressions, inputSchema)) + def primitiveTypeName(jt: String): String = jt match { + case JAVA_INT => "Int" + case _ => boxedType(jt) + } - /** Generates the requested evaluator given already bound expression(s). */ - def generate(expressions: InType): OutType = cache.get(canonicalize(expressions)) + def primitiveTypeName(dt: DataType): String = primitiveTypeName(javaType(dt)) /** - * Returns a term name that is unique within this instance of a `CodeGenerator`. - * - * (Since we aren't in a macro context we do not seem to have access to the built in `freshName` - * function.) + * Returns the Java type for a DataType. */ - protected def freshName(prefix: String): TermName = { - newTermName(s"$prefix$javaSeparator${curId.getAndIncrement}") + def javaType(dt: DataType): String = dt match { + case BooleanType => JAVA_BOOLEAN + case ByteType => JAVA_BYTE + case ShortType => JAVA_SHORT + case IntegerType | DateType => JAVA_INT + case LongType | TimestampType => JAVA_LONG + case FloatType => JAVA_FLOAT + case DoubleType => JAVA_DOUBLE + case dt: DecimalType => decimalType + case BinaryType => "byte[]" + case StringType => stringType + case _: StructType => "InternalRow" + case _: ArrayType => s"scala.collection.Seq" + case _: MapType => s"scala.collection.Map" + case dt: OpenHashSetUDT if dt.elementType == IntegerType => classOf[IntegerHashSet].getName + case dt: OpenHashSetUDT if dt.elementType == LongType => classOf[LongHashSet].getName + case _ => "Object" } /** - * Scala ASTs for evaluating an [[Expression]] given a [[Row]] of input. - * - * @param code The sequence of statements required to evaluate the expression. - * @param nullTerm A term that holds a boolean value representing whether the expression evaluated - * to null. - * @param primitiveTerm A term for a possible primitive value of the result of the evaluation. Not - * valid if `nullTerm` is set to `true`. - * @param objectTerm A possibly boxed version of the result of evaluating this expression. + * Returns the boxed type in Java. */ - protected case class EvaluatedExpression( - code: Seq[Tree], - nullTerm: TermName, - primitiveTerm: TermName, - objectTerm: TermName) + def boxedType(jt: String): String = jt match { + case JAVA_BOOLEAN => "Boolean" + case JAVA_BYTE => "Byte" + case JAVA_SHORT => "Short" + case JAVA_INT => "Integer" + case JAVA_LONG => "Long" + case JAVA_FLOAT => "Float" + case JAVA_DOUBLE => "Double" + case other => other + } + + def boxedType(dt: DataType): String = boxedType(javaType(dt)) /** - * Given an expression tree returns an [[EvaluatedExpression]], which contains Scala trees that - * can be used to determine the result of evaluating the expression on an input row. + * Returns the representation of default value for a given Java Type. */ - def expressionEvaluator(e: Expression): EvaluatedExpression = { - val primitiveTerm = freshName("primitiveTerm") - val nullTerm = freshName("nullTerm") - val objectTerm = freshName("objectTerm") - - implicit class Evaluate1(e: Expression) { - def castOrNull(f: TermName => Tree, dataType: DataType): Seq[Tree] = { - val eval = expressionEvaluator(e) - eval.code ++ - q""" - val $nullTerm = ${eval.nullTerm} - val $primitiveTerm = - if($nullTerm) - ${defaultPrimitive(dataType)} - else - ${f(eval.primitiveTerm)} - """.children - } - } - - implicit class Evaluate2(expressions: (Expression, Expression)) { - - /** - * Short hand for generating binary evaluation code, which depends on two sub-evaluations of - * the same type. If either of the sub-expressions is null, the result of this computation - * is assumed to be null. - * - * @param f a function from two primitive term names to a tree that evaluates them. - */ - def evaluate(f: (TermName, TermName) => Tree): Seq[Tree] = - evaluateAs(expressions._1.dataType)(f) - - def evaluateAs(resultType: DataType)(f: (TermName, TermName) => Tree): Seq[Tree] = { - // TODO: Right now some timestamp tests fail if we enforce this... - if (expressions._1.dataType != expressions._2.dataType) { - log.warn(s"${expressions._1.dataType} != ${expressions._2.dataType}") - } + def defaultValue(jt: String): String = jt match { + case JAVA_BOOLEAN => "false" + case JAVA_BYTE => "(byte)-1" + case JAVA_SHORT => "(short)-1" + case JAVA_INT => "-1" + case JAVA_LONG => "-1L" + case JAVA_FLOAT => "-1.0f" + case JAVA_DOUBLE => "-1.0" + case _ => "null" + } - val eval1 = expressionEvaluator(expressions._1) - val eval2 = expressionEvaluator(expressions._2) - val resultCode = f(eval1.primitiveTerm, eval2.primitiveTerm) - - eval1.code ++ eval2.code ++ - q""" - val $nullTerm = ${eval1.nullTerm} || ${eval2.nullTerm} - val $primitiveTerm: ${termForType(resultType)} = - if($nullTerm) { - ${defaultPrimitive(resultType)} - } else { - $resultCode.asInstanceOf[${termForType(resultType)}] - } - """.children : Seq[Tree] - } - } + def defaultValue(dt: DataType): String = defaultValue(javaType(dt)) - val inputTuple = newTermName(s"i") - - // TODO: Skip generation of null handling code when expression are not nullable. - val primitiveEvaluation: PartialFunction[Expression, Seq[Tree]] = { - case b @ BoundReference(ordinal, dataType, nullable) => - val nullValue = q"$inputTuple.isNullAt($ordinal)" - q""" - val $nullTerm: Boolean = $nullValue - val $primitiveTerm: ${termForType(dataType)} = - if($nullTerm) - ${defaultPrimitive(dataType)} - else - ${getColumn(inputTuple, dataType, ordinal)} - """.children - - case expressions.Literal(null, dataType) => - q""" - val $nullTerm = true - val $primitiveTerm: ${termForType(dataType)} = null.asInstanceOf[${termForType(dataType)}] - """.children - - case expressions.Literal(value: Boolean, dataType) => - q""" - val $nullTerm = ${value == null} - val $primitiveTerm: ${termForType(dataType)} = $value - """.children - - case expressions.Literal(value: UTF8String, dataType) => - q""" - val $nullTerm = ${value == null} - val $primitiveTerm: ${termForType(dataType)} = - org.apache.spark.sql.types.UTF8String(${value.getBytes}) - """.children - - case expressions.Literal(value: Int, dataType) => - q""" - val $nullTerm = ${value == null} - val $primitiveTerm: ${termForType(dataType)} = $value - """.children - - case expressions.Literal(value: Long, dataType) => - q""" - val $nullTerm = ${value == null} - val $primitiveTerm: ${termForType(dataType)} = $value - """.children - - case Cast(e @ BinaryType(), StringType) => - val eval = expressionEvaluator(e) - eval.code ++ - q""" - val $nullTerm = ${eval.nullTerm} - val $primitiveTerm = - if($nullTerm) - ${defaultPrimitive(StringType)} - else - org.apache.spark.sql.types.UTF8String(${eval.primitiveTerm}.asInstanceOf[Array[Byte]]) - """.children - - case Cast(child @ DateType(), StringType) => - child.castOrNull(c => - q"""org.apache.spark.sql.types.UTF8String( - org.apache.spark.sql.catalyst.util.DateUtils.toString($c))""", - StringType) - - case Cast(child @ NumericType(), IntegerType) => - child.castOrNull(c => q"$c.toInt", IntegerType) - - case Cast(child @ NumericType(), LongType) => - child.castOrNull(c => q"$c.toLong", LongType) - - case Cast(child @ NumericType(), DoubleType) => - child.castOrNull(c => q"$c.toDouble", DoubleType) - - case Cast(child @ NumericType(), FloatType) => - child.castOrNull(c => q"$c.toFloat", FloatType) - - // Special handling required for timestamps in hive test cases since the toString function - // does not match the expected output. - case Cast(e, StringType) if e.dataType != TimestampType => - val eval = expressionEvaluator(e) - eval.code ++ - q""" - val $nullTerm = ${eval.nullTerm} - val $primitiveTerm = - if($nullTerm) - ${defaultPrimitive(StringType)} - else - org.apache.spark.sql.types.UTF8String(${eval.primitiveTerm}.toString) - """.children - - case EqualTo(e1 @ BinaryType(), e2 @ BinaryType()) => - (e1, e2).evaluateAs (BooleanType) { - case (eval1, eval2) => - q""" - java.util.Arrays.equals($eval1.asInstanceOf[Array[Byte]], - $eval2.asInstanceOf[Array[Byte]]) - """ - } + /** + * Generates code for equal expression in Java. + */ + def genEqual(dataType: DataType, c1: String, c2: String): String = dataType match { + case BinaryType => s"java.util.Arrays.equals($c1, $c2)" + case dt: DataType if isPrimitiveType(dt) => s"$c1 == $c2" + case other => s"$c1.equals($c2)" + } - case EqualTo(e1, e2) => - (e1, e2).evaluateAs (BooleanType) { case (eval1, eval2) => q"$eval1 == $eval2" } + /** + * Generates code for compare expression in Java. + */ + def genComp(dataType: DataType, c1: String, c2: String): String = dataType match { + // java boolean doesn't support > or < operator + case BooleanType => s"($c1 == $c2 ? 0 : ($c1 ? 1 : -1))" + // use c1 - c2 may overflow + case dt: DataType if isPrimitiveType(dt) => s"($c1 > $c2 ? 1 : $c1 < $c2 ? -1 : 0)" + case BinaryType => s"org.apache.spark.sql.catalyst.util.TypeUtils.compareBinary($c1, $c2)" + case NullType => "0" + case other => s"$c1.compare($c2)" + } - /* TODO: Fix null semantics. - case In(e1, list) if !list.exists(!_.isInstanceOf[expressions.Literal]) => - val eval = expressionEvaluator(e1) + /** + * List of java data types that have special accessors and setters in [[InternalRow]]. + */ + val primitiveTypes = + Seq(JAVA_BOOLEAN, JAVA_BYTE, JAVA_SHORT, JAVA_INT, JAVA_LONG, JAVA_FLOAT, JAVA_DOUBLE) - val checks = list.map { - case expressions.Literal(v: String, dataType) => - q"if(${eval.primitiveTerm} == $v) return true" - case expressions.Literal(v: Int, dataType) => - q"if(${eval.primitiveTerm} == $v) return true" - } + /** + * Returns true if the Java type has a special accessor and setter in [[InternalRow]]. + */ + def isPrimitiveType(jt: String): Boolean = primitiveTypes.contains(jt) - val funcName = newTermName(s"isIn${curId.getAndIncrement()}") - - q""" - def $funcName: Boolean = { - ..${eval.code} - if(${eval.nullTerm}) return false - ..$checks - return false - } - val $nullTerm = false - val $primitiveTerm = $funcName - """.children - */ - - case GreaterThan(e1 @ NumericType(), e2 @ NumericType()) => - (e1, e2).evaluateAs (BooleanType) { case (eval1, eval2) => q"$eval1 > $eval2" } - case GreaterThanOrEqual(e1 @ NumericType(), e2 @ NumericType()) => - (e1, e2).evaluateAs (BooleanType) { case (eval1, eval2) => q"$eval1 >= $eval2" } - case LessThan(e1 @ NumericType(), e2 @ NumericType()) => - (e1, e2).evaluateAs (BooleanType) { case (eval1, eval2) => q"$eval1 < $eval2" } - case LessThanOrEqual(e1 @ NumericType(), e2 @ NumericType()) => - (e1, e2).evaluateAs (BooleanType) { case (eval1, eval2) => q"$eval1 <= $eval2" } - - case And(e1, e2) => - val eval1 = expressionEvaluator(e1) - val eval2 = expressionEvaluator(e2) - - q""" - ..${eval1.code} - var $nullTerm = false - var $primitiveTerm: ${termForType(BooleanType)} = false - - if (!${eval1.nullTerm} && ${eval1.primitiveTerm} == false) { - } else { - ..${eval2.code} - if (!${eval2.nullTerm} && ${eval2.primitiveTerm} == false) { - } else if (!${eval1.nullTerm} && !${eval2.nullTerm}) { - $primitiveTerm = true - } else { - $nullTerm = true - } - } - """.children - - case Or(e1, e2) => - val eval1 = expressionEvaluator(e1) - val eval2 = expressionEvaluator(e2) - - q""" - ..${eval1.code} - var $nullTerm = false - var $primitiveTerm: ${termForType(BooleanType)} = false - - if (!${eval1.nullTerm} && ${eval1.primitiveTerm}) { - $primitiveTerm = true - } else { - ..${eval2.code} - if (!${eval2.nullTerm} && ${eval2.primitiveTerm}) { - $primitiveTerm = true - } else if (!${eval1.nullTerm} && !${eval2.nullTerm}) { - $primitiveTerm = false - } else { - $nullTerm = true - } - } - """.children - - case Not(child) => - // Uh, bad function name... - child.castOrNull(c => q"!$c", BooleanType) - - case Add(e1, e2) => (e1, e2) evaluate { case (eval1, eval2) => q"$eval1 + $eval2" } - case Subtract(e1, e2) => (e1, e2) evaluate { case (eval1, eval2) => q"$eval1 - $eval2" } - case Multiply(e1, e2) => (e1, e2) evaluate { case (eval1, eval2) => q"$eval1 * $eval2" } - case Divide(e1, e2) => - val eval1 = expressionEvaluator(e1) - val eval2 = expressionEvaluator(e2) - - eval1.code ++ eval2.code ++ - q""" - var $nullTerm = false - var $primitiveTerm: ${termForType(e1.dataType)} = 0 - - if (${eval1.nullTerm} || ${eval2.nullTerm} ) { - $nullTerm = true - } else if (${eval2.primitiveTerm} == 0) - $nullTerm = true - else { - $primitiveTerm = ${eval1.primitiveTerm} / ${eval2.primitiveTerm} - } - """.children - - case Remainder(e1, e2) => - val eval1 = expressionEvaluator(e1) - val eval2 = expressionEvaluator(e2) - - eval1.code ++ eval2.code ++ - q""" - var $nullTerm = false - var $primitiveTerm: ${termForType(e1.dataType)} = 0 - - if (${eval1.nullTerm} || ${eval2.nullTerm} ) { - $nullTerm = true - } else if (${eval2.primitiveTerm} == 0) - $nullTerm = true - else { - $nullTerm = false - $primitiveTerm = ${eval1.primitiveTerm} % ${eval2.primitiveTerm} - } - """.children - - case IsNotNull(e) => - val eval = expressionEvaluator(e) - q""" - ..${eval.code} - var $nullTerm = false - var $primitiveTerm: ${termForType(BooleanType)} = !${eval.nullTerm} - """.children - - case IsNull(e) => - val eval = expressionEvaluator(e) - q""" - ..${eval.code} - var $nullTerm = false - var $primitiveTerm: ${termForType(BooleanType)} = ${eval.nullTerm} - """.children - - case c @ Coalesce(children) => - q""" - var $nullTerm = true - var $primitiveTerm: ${termForType(c.dataType)} = ${defaultPrimitive(c.dataType)} - """.children ++ - children.map { c => - val eval = expressionEvaluator(c) - q""" - if($nullTerm) { - ..${eval.code} - if(!${eval.nullTerm}) { - $nullTerm = false - $primitiveTerm = ${eval.primitiveTerm} - } - } - """ - } + def isPrimitiveType(dt: DataType): Boolean = isPrimitiveType(javaType(dt)) +} - case i @ expressions.If(condition, trueValue, falseValue) => - val condEval = expressionEvaluator(condition) - val trueEval = expressionEvaluator(trueValue) - val falseEval = expressionEvaluator(falseValue) - - q""" - var $nullTerm = false - var $primitiveTerm: ${termForType(i.dataType)} = ${defaultPrimitive(i.dataType)} - ..${condEval.code} - if(!${condEval.nullTerm} && ${condEval.primitiveTerm}) { - ..${trueEval.code} - $nullTerm = ${trueEval.nullTerm} - $primitiveTerm = ${trueEval.primitiveTerm} - } else { - ..${falseEval.code} - $nullTerm = ${falseEval.nullTerm} - $primitiveTerm = ${falseEval.primitiveTerm} - } - """.children - - case NewSet(elementType) => - q""" - val $nullTerm = false - val $primitiveTerm = new ${hashSetForType(elementType)}() - """.children - - case AddItemToSet(item, set) => - val itemEval = expressionEvaluator(item) - val setEval = expressionEvaluator(set) - - val elementType = set.dataType.asInstanceOf[OpenHashSetUDT].elementType - - itemEval.code ++ setEval.code ++ - q""" - if (!${itemEval.nullTerm}) { - ${setEval.primitiveTerm} - .asInstanceOf[${hashSetForType(elementType)}] - .add(${itemEval.primitiveTerm}) - } - - val $nullTerm = false - val $primitiveTerm = ${setEval.primitiveTerm} - """.children - - case CombineSets(left, right) => - val leftEval = expressionEvaluator(left) - val rightEval = expressionEvaluator(right) - - val elementType = left.dataType.asInstanceOf[OpenHashSetUDT].elementType - - leftEval.code ++ rightEval.code ++ - q""" - val $nullTerm = false - var $primitiveTerm: ${hashSetForType(elementType)} = null - - { - val leftSet = ${leftEval.primitiveTerm}.asInstanceOf[${hashSetForType(elementType)}] - val rightSet = ${rightEval.primitiveTerm}.asInstanceOf[${hashSetForType(elementType)}] - val iterator = rightSet.iterator - while (iterator.hasNext) { - leftSet.add(iterator.next()) - } - $primitiveTerm = leftSet - } - """.children - - case MaxOf(e1, e2) => - val eval1 = expressionEvaluator(e1) - val eval2 = expressionEvaluator(e2) - - eval1.code ++ eval2.code ++ - q""" - var $nullTerm = false - var $primitiveTerm: ${termForType(e1.dataType)} = ${defaultPrimitive(e1.dataType)} - - if (${eval1.nullTerm}) { - $nullTerm = ${eval2.nullTerm} - $primitiveTerm = ${eval2.primitiveTerm} - } else if (${eval2.nullTerm}) { - $nullTerm = ${eval1.nullTerm} - $primitiveTerm = ${eval1.primitiveTerm} - } else { - if (${eval1.primitiveTerm} > ${eval2.primitiveTerm}) { - $primitiveTerm = ${eval1.primitiveTerm} - } else { - $primitiveTerm = ${eval2.primitiveTerm} - } - } - """.children - - case MinOf(e1, e2) => - val eval1 = expressionEvaluator(e1) - val eval2 = expressionEvaluator(e2) - - eval1.code ++ eval2.code ++ - q""" - var $nullTerm = false - var $primitiveTerm: ${termForType(e1.dataType)} = ${defaultPrimitive(e1.dataType)} - - if (${eval1.nullTerm}) { - $nullTerm = ${eval2.nullTerm} - $primitiveTerm = ${eval2.primitiveTerm} - } else if (${eval2.nullTerm}) { - $nullTerm = ${eval1.nullTerm} - $primitiveTerm = ${eval1.primitiveTerm} - } else { - if (${eval1.primitiveTerm} < ${eval2.primitiveTerm}) { - $primitiveTerm = ${eval1.primitiveTerm} - } else { - $primitiveTerm = ${eval2.primitiveTerm} - } - } - """.children - - case UnscaledValue(child) => - val childEval = expressionEvaluator(child) - - childEval.code ++ - q""" - var $nullTerm = ${childEval.nullTerm} - var $primitiveTerm: Long = if (!$nullTerm) { - ${childEval.primitiveTerm}.toUnscaledLong - } else { - ${defaultPrimitive(LongType)} - } - """.children - - case MakeDecimal(child, precision, scale) => - val childEval = expressionEvaluator(child) - - childEval.code ++ - q""" - var $nullTerm = ${childEval.nullTerm} - var $primitiveTerm: org.apache.spark.sql.types.Decimal = - ${defaultPrimitive(DecimalType())} - - if (!$nullTerm) { - $primitiveTerm = new org.apache.spark.sql.types.Decimal() - $primitiveTerm = $primitiveTerm.setOrNull(${childEval.primitiveTerm}, $precision, $scale) - $nullTerm = $primitiveTerm == null - } - """.children - } - // If there was no match in the partial function above, we fall back on calling the interpreted - // expression evaluator. - val code: Seq[Tree] = - primitiveEvaluation.lift.apply(e).getOrElse { - log.debug(s"No rules to generate $e") - val tree = reify { e } - q""" - val $objectTerm = $tree.eval(i) - val $nullTerm = $objectTerm == null - val $primitiveTerm = $objectTerm.asInstanceOf[${termForType(e.dataType)}] - """.children - } - - // Only inject debugging code if debugging is turned on. - val debugCode = - if (debugLogging) { - val localLogger = log - val localLoggerTree = reify { localLogger } - q""" - $localLoggerTree.debug( - ${e.toString} + ": " + (if ($nullTerm) "null" else $primitiveTerm.toString)) - """ :: Nil - } else { - Nil - } - - EvaluatedExpression(code ++ debugCode, nullTerm, primitiveTerm, objectTerm) - } +abstract class GeneratedClass { + def generate(expressions: Array[Expression]): Any +} - protected def getColumn(inputRow: TermName, dataType: DataType, ordinal: Int) = { - dataType match { - case StringType => q"$inputRow($ordinal).asInstanceOf[org.apache.spark.sql.types.UTF8String]" - case dt: DataType if isNativeType(dt) => q"$inputRow.${accessorForType(dt)}($ordinal)" - case _ => q"$inputRow.apply($ordinal).asInstanceOf[${termForType(dataType)}]" - } - } +/** + * A base class for generators of byte code to perform expression evaluation. Includes a set of + * helpers for referring to Catalyst types and building trees that perform evaluation of individual + * expressions. + */ +abstract class CodeGenerator[InType <: AnyRef, OutType <: AnyRef] extends Logging { - protected def setColumn( - destinationRow: TermName, - dataType: DataType, - ordinal: Int, - value: TermName) = { - dataType match { - case StringType => q"$destinationRow.update($ordinal, $value)" - case dt: DataType if isNativeType(dt) => - q"$destinationRow.${mutatorForType(dt)}($ordinal, $value)" - case _ => q"$destinationRow.update($ordinal, $value)" - } - } + protected val exprType: String = classOf[Expression].getName + protected val mutableRowType: String = classOf[MutableRow].getName + protected val genericMutableRowType: String = classOf[GenericMutableRow].getName - protected def accessorForType(dt: DataType) = newTermName(s"get${primitiveForType(dt)}") - protected def mutatorForType(dt: DataType) = newTermName(s"set${primitiveForType(dt)}") + /** + * Generates a class for a given input expression. Called when there is not cached code + * already available. + */ + protected def create(in: InType): OutType - protected def hashSetForType(dt: DataType) = dt match { - case IntegerType => typeOf[IntegerHashSet] - case LongType => typeOf[LongHashSet] - case unsupportedType => - sys.error(s"Code generation not support for hashset of type $unsupportedType") - } + /** + * Canonicalizes an input expression. Used to avoid double caching expressions that differ only + * cosmetically. + */ + protected def canonicalize(in: InType): InType - protected def primitiveForType(dt: DataType) = dt match { - case IntegerType => "Int" - case LongType => "Long" - case ShortType => "Short" - case ByteType => "Byte" - case DoubleType => "Double" - case FloatType => "Float" - case BooleanType => "Boolean" - case StringType => "org.apache.spark.sql.types.UTF8String" - } + /** Binds an input expression to a given input schema */ + protected def bind(in: InType, inputSchema: Seq[Attribute]): InType - protected def defaultPrimitive(dt: DataType) = dt match { - case BooleanType => ru.Literal(Constant(false)) - case FloatType => ru.Literal(Constant(-1.0.toFloat)) - case StringType => q"""org.apache.spark.sql.types.UTF8String("")""" - case ShortType => ru.Literal(Constant(-1.toShort)) - case LongType => ru.Literal(Constant(-1L)) - case ByteType => ru.Literal(Constant(-1.toByte)) - case DoubleType => ru.Literal(Constant(-1.toDouble)) - case DecimalType() => q"org.apache.spark.sql.types.Decimal(-1)" - case IntegerType => ru.Literal(Constant(-1)) - case DateType => ru.Literal(Constant(-1)) - case _ => ru.Literal(Constant(null)) + /** + * Compile the Java source code into a Java class, using Janino. + */ + protected def compile(code: String): GeneratedClass = { + cache.get(code) } - protected def termForType(dt: DataType) = dt match { - case n: AtomicType => n.tag - case _ => typeTag[Any] + /** + * Compile the Java source code into a Java class, using Janino. + */ + private[this] def doCompile(code: String): GeneratedClass = { + val evaluator = new ClassBodyEvaluator() + evaluator.setParentClassLoader(getClass.getClassLoader) + evaluator.setDefaultImports(Array("org.apache.spark.sql.catalyst.InternalRow")) + evaluator.setExtendedClass(classOf[GeneratedClass]) + try { + evaluator.cook(code) + } catch { + case e: Exception => + logError(s"failed to compile:\n $code", e) + throw e + } + evaluator.getClazz().newInstance().asInstanceOf[GeneratedClass] } /** - * List of data types that have special accessors and setters in [[Row]]. + * A cache of generated classes. + * + * From the Guava Docs: A Cache is similar to ConcurrentMap, but not quite the same. The most + * fundamental difference is that a ConcurrentMap persists all elements that are added to it until + * they are explicitly removed. A Cache on the other hand is generally configured to evict entries + * automatically, in order to constrain its memory footprint. Note that this cache does not use + * weak keys/values and thus does not respond to memory pressure. */ - protected val nativeTypes = - Seq(IntegerType, BooleanType, LongType, DoubleType, FloatType, ShortType, ByteType, StringType) + private val cache = CacheBuilder.newBuilder() + .maximumSize(100) + .build( + new CacheLoader[String, GeneratedClass]() { + override def load(code: String): GeneratedClass = { + val startTime = System.nanoTime() + val result = doCompile(code) + val endTime = System.nanoTime() + def timeMs: Double = (endTime - startTime).toDouble / 1000000 + logInfo(s"Code generated in $timeMs ms") + result + } + }) + + /** Generates the requested evaluator binding the given expression(s) to the inputSchema. */ + def generate(expressions: InType, inputSchema: Seq[Attribute]): OutType = + generate(bind(expressions, inputSchema)) + + /** Generates the requested evaluator given already bound expression(s). */ + def generate(expressions: InType): OutType = create(canonicalize(expressions)) /** - * Returns true if the data type has a special accessor and setter in [[Row]]. + * Create a new codegen context for expression evaluator, used to store those + * expressions that don't support codegen */ - protected def isNativeType(dt: DataType) = nativeTypes.contains(dt) + def newCodeGenContext(): CodeGenContext = { + new CodeGenContext + } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateMutableProjection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateMutableProjection.scala index 840260703ab7..addb8023d9c0 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateMutableProjection.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateMutableProjection.scala @@ -19,15 +19,14 @@ package org.apache.spark.sql.catalyst.expressions.codegen import org.apache.spark.sql.catalyst.expressions._ +// MutableProjection is not accessible in Java +abstract class BaseMutableProjection extends MutableProjection + /** * Generates byte code that produces a [[MutableRow]] object that can update itself based on a new - * input [[Row]] for a fixed set of [[Expression Expressions]]. + * input [[InternalRow]] for a fixed set of [[Expression Expressions]]. */ object GenerateMutableProjection extends CodeGenerator[Seq[Expression], () => MutableProjection] { - import scala.reflect.runtime.{universe => ru} - import scala.reflect.runtime.universe._ - - val mutableRowName = newTermName("mutableRow") protected def canonicalize(in: Seq[Expression]): Seq[Expression] = in.map(ExpressionCanonicalizer.execute) @@ -36,41 +35,56 @@ object GenerateMutableProjection extends CodeGenerator[Seq[Expression], () => Mu in.map(BindReferences.bindReference(_, inputSchema)) protected def create(expressions: Seq[Expression]): (() => MutableProjection) = { - val projectionCode = expressions.zipWithIndex.flatMap { case (e, i) => - val evaluationCode = expressionEvaluator(e) - - evaluationCode.code :+ - q""" - if(${evaluationCode.nullTerm}) - mutableRow.setNullAt($i) - else - ${setColumn(mutableRowName, e.dataType, i, evaluationCode.primitiveTerm)} - """ - } + val ctx = newCodeGenContext() + val projectionCode = expressions.zipWithIndex.map { case (e, i) => + val evaluationCode = e.gen(ctx) + evaluationCode.code + + s""" + if(${evaluationCode.isNull}) + mutableRow.setNullAt($i); + else + ${ctx.setColumn("mutableRow", e.dataType, i, evaluationCode.primitive)}; + """ + }.mkString("\n") + val code = s""" + public Object generate($exprType[] expr) { + return new SpecificProjection(expr); + } + + class SpecificProjection extends ${classOf[BaseMutableProjection].getName} { - val code = - q""" - () => { new $mutableProjectionType { + private $exprType[] expressions = null; + private $mutableRowType mutableRow = null; - private[this] var $mutableRowName: $mutableRowType = - new $genericMutableRowType(${expressions.size}) + public SpecificProjection($exprType[] expr) { + expressions = expr; + mutableRow = new $genericMutableRowType(${expressions.size}); + } - def target(row: $mutableRowType): $mutableProjectionType = { - $mutableRowName = row - this - } + public ${classOf[BaseMutableProjection].getName} target($mutableRowType row) { + mutableRow = row; + return this; + } - /* Provide immutable access to the last projected row. */ - def currentValue: $rowType = mutableRow + /* Provide immutable access to the last projected row. */ + public InternalRow currentValue() { + return (InternalRow) mutableRow; + } - def apply(i: $rowType): $rowType = { - ..$projectionCode - mutableRow - } - } } - """ + public Object apply(Object _i) { + InternalRow i = (InternalRow) _i; + $projectionCode - log.debug(s"code for ${expressions.mkString(",")}:\n$code") - toolBox.eval(code).asInstanceOf[() => MutableProjection] + return mutableRow; + } + } + """ + + logDebug(s"code for ${expressions.mkString(",")}:\n$code") + + val c = compile(code) + () => { + c.generate(ctx.references.toArray).asInstanceOf[MutableProjection] + } } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateOrdering.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateOrdering.scala index b129c0d898bb..97cb16045ae4 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateOrdering.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateOrdering.scala @@ -18,93 +18,83 @@ package org.apache.spark.sql.catalyst.expressions.codegen import org.apache.spark.Logging +import org.apache.spark.annotation.Private +import org.apache.spark.sql.Row import org.apache.spark.sql.catalyst.expressions._ -import org.apache.spark.sql.types.{BinaryType, StringType, NumericType} + +/** + * Inherits some default implementation for Java from `Ordering[Row]` + */ +@Private +class BaseOrdering extends Ordering[InternalRow] { + def compare(a: InternalRow, b: InternalRow): Int = { + throw new UnsupportedOperationException + } +} /** * Generates bytecode for an [[Ordering]] of [[Row Rows]] for a given set of * [[Expression Expressions]]. */ -object GenerateOrdering extends CodeGenerator[Seq[SortOrder], Ordering[Row]] with Logging { - import scala.reflect.runtime.{universe => ru} - import scala.reflect.runtime.universe._ +object GenerateOrdering + extends CodeGenerator[Seq[SortOrder], Ordering[InternalRow]] with Logging { - protected def canonicalize(in: Seq[SortOrder]): Seq[SortOrder] = + protected def canonicalize(in: Seq[SortOrder]): Seq[SortOrder] = in.map(ExpressionCanonicalizer.execute(_).asInstanceOf[SortOrder]) protected def bind(in: Seq[SortOrder], inputSchema: Seq[Attribute]): Seq[SortOrder] = in.map(BindReferences.bindReference(_, inputSchema)) - protected def create(ordering: Seq[SortOrder]): Ordering[Row] = { - val a = newTermName("a") - val b = newTermName("b") - val comparisons = ordering.zipWithIndex.map { case (order, i) => - val evalA = expressionEvaluator(order.child) - val evalB = expressionEvaluator(order.child) + protected def create(ordering: Seq[SortOrder]): Ordering[InternalRow] = { + val ctx = newCodeGenContext() - val compare = order.child.dataType match { - case BinaryType => - q""" - val x = ${if (order.direction == Ascending) evalA.primitiveTerm else evalB.primitiveTerm} - val y = ${if (order.direction != Ascending) evalB.primitiveTerm else evalA.primitiveTerm} - var i = 0 - while (i < x.length && i < y.length) { - val res = x(i).compareTo(y(i)) - if (res != 0) return res - i = i+1 - } - return x.length - y.length - """ - case _: NumericType => - q""" - val comp = ${evalA.primitiveTerm} - ${evalB.primitiveTerm} - if(comp != 0) { - return ${if (order.direction == Ascending) q"comp.toInt" else q"-comp.toInt"} - } - """ - case StringType => - if (order.direction == Ascending) { - q"""return ${evalA.primitiveTerm}.compare(${evalB.primitiveTerm})""" + val comparisons = ordering.zipWithIndex.map { case (order, i) => + val evalA = order.child.gen(ctx) + val evalB = order.child.gen(ctx) + val asc = order.direction == Ascending + s""" + i = a; + ${evalA.code} + i = b; + ${evalB.code} + if (${evalA.isNull} && ${evalB.isNull}) { + // Nothing + } else if (${evalA.isNull}) { + return ${if (order.direction == Ascending) "-1" else "1"}; + } else if (${evalB.isNull}) { + return ${if (order.direction == Ascending) "1" else "-1"}; } else { - q"""return ${evalB.primitiveTerm}.compare(${evalA.primitiveTerm})""" + int comp = ${ctx.genComp(order.child.dataType, evalA.primitive, evalB.primitive)}; + if (comp != 0) { + return ${if (asc) "comp" else "-comp"}; + } } - } - - q""" - i = $a - ..${evalA.code} - i = $b - ..${evalB.code} - if (${evalA.nullTerm} && ${evalB.nullTerm}) { - // Nothing - } else if (${evalA.nullTerm}) { - return ${if (order.direction == Ascending) q"-1" else q"1"} - } else if (${evalB.nullTerm}) { - return ${if (order.direction == Ascending) q"1" else q"-1"} - } else { - $compare - } """ - } + }.mkString("\n") - val q"class $orderingName extends $orderingType { ..$body }" = reify { - class SpecificOrdering extends Ordering[Row] { - val o = ordering + val code = s""" + public SpecificOrdering generate($exprType[] expr) { + return new SpecificOrdering(expr); } - }.tree.children.head - val code = q""" - class $orderingName extends $orderingType { - ..$body - def compare(a: $rowType, b: $rowType): Int = { - var i: $rowType = null // Holds current row being evaluated. - ..$comparisons - return 0 + class SpecificOrdering extends ${classOf[BaseOrdering].getName} { + + private $exprType[] expressions = null; + + public SpecificOrdering($exprType[] expr) { + expressions = expr; } - } - new $orderingName() - """ + + @Override + public int compare(InternalRow a, InternalRow b) { + InternalRow i = null; // Holds current row being evaluated. + $comparisons + return 0; + } + }""" + logDebug(s"Generated Ordering: $code") - toolBox.eval(code).asInstanceOf[Ordering[Row]] + + compile(code).generate(ctx.references.toArray).asInstanceOf[BaseOrdering] } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GeneratePredicate.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GeneratePredicate.scala index 40e163024360..3ebc2c147579 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GeneratePredicate.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GeneratePredicate.scala @@ -20,29 +20,46 @@ package org.apache.spark.sql.catalyst.expressions.codegen import org.apache.spark.sql.catalyst.expressions._ /** - * Generates bytecode that evaluates a boolean [[Expression]] on a given input [[Row]]. + * Interface for generated predicate */ -object GeneratePredicate extends CodeGenerator[Expression, (Row) => Boolean] { - import scala.reflect.runtime.{universe => ru} - import scala.reflect.runtime.universe._ +abstract class Predicate { + def eval(r: InternalRow): Boolean +} + +/** + * Generates bytecode that evaluates a boolean [[Expression]] on a given input [[InternalRow]]. + */ +object GeneratePredicate extends CodeGenerator[Expression, (InternalRow) => Boolean] { protected def canonicalize(in: Expression): Expression = ExpressionCanonicalizer.execute(in) protected def bind(in: Expression, inputSchema: Seq[Attribute]): Expression = BindReferences.bindReference(in, inputSchema) - protected def create(predicate: Expression): ((Row) => Boolean) = { - val cEval = expressionEvaluator(predicate) + protected def create(predicate: Expression): ((InternalRow) => Boolean) = { + val ctx = newCodeGenContext() + val eval = predicate.gen(ctx) + val code = s""" + public SpecificPredicate generate($exprType[] expr) { + return new SpecificPredicate(expr); + } - val code = - q""" - (i: $rowType) => { - ..${cEval.code} - if (${cEval.nullTerm}) false else ${cEval.primitiveTerm} + class SpecificPredicate extends ${classOf[Predicate].getName} { + private final $exprType[] expressions; + public SpecificPredicate($exprType[] expr) { + expressions = expr; } - """ - log.debug(s"Generated predicate '$predicate':\n$code") - toolBox.eval(code).asInstanceOf[Row => Boolean] + @Override + public boolean eval(InternalRow i) { + ${eval.code} + return !${eval.isNull} && ${eval.primitive}; + } + }""" + + logDebug(s"Generated predicate '$predicate':\n$code") + + val p = compile(code).generate(ctx.references.toArray).asInstanceOf[Predicate] + (r: InternalRow) => p.eval(r) } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateProjection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateProjection.scala index 584f938445c8..3c7ee9cc1659 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateProjection.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateProjection.scala @@ -20,15 +20,18 @@ package org.apache.spark.sql.catalyst.expressions.codegen import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.types._ +/** + * Java can not access Projection (in package object) + */ +abstract class BaseProject extends Projection {} /** - * Generates bytecode that produces a new [[Row]] object based on a fixed set of input - * [[Expression Expressions]] and a given input [[Row]]. The returned [[Row]] object is custom - * generated based on the output types of the [[Expression]] to avoid boxing of primitive values. + * Generates bytecode that produces a new [[InternalRow]] object based on a fixed set of input + * [[Expression Expressions]] and a given input [[InternalRow]]. The returned [[InternalRow]] + * object is custom generated based on the output types of the [[Expression]] to avoid boxing of + * primitive values. */ object GenerateProjection extends CodeGenerator[Seq[Expression], Projection] { - import scala.reflect.runtime.{universe => ru} - import scala.reflect.runtime.universe._ protected def canonicalize(in: Seq[Expression]): Seq[Expression] = in.map(ExpressionCanonicalizer.execute) @@ -38,201 +41,195 @@ object GenerateProjection extends CodeGenerator[Seq[Expression], Projection] { // Make Mutablility optional... protected def create(expressions: Seq[Expression]): Projection = { - val tupleLength = ru.Literal(Constant(expressions.length)) - val lengthDef = q"final val length = $tupleLength" - - /* TODO: Configurable... - val nullFunctions = - q""" - private final val nullSet = new org.apache.spark.util.collection.BitSet(length) - final def setNullAt(i: Int) = nullSet.set(i) - final def isNullAt(i: Int) = nullSet.get(i) - """ - */ - - val nullFunctions = - q""" - private[this] var nullBits = new Array[Boolean](${expressions.size}) - override def setNullAt(i: Int) = { nullBits(i) = true } - override def isNullAt(i: Int) = nullBits(i) - """.children - - val tupleElements = expressions.zipWithIndex.flatMap { + val ctx = newCodeGenContext() + val columns = expressions.zipWithIndex.map { case (e, i) => - val elementName = newTermName(s"c$i") - val evaluatedExpression = expressionEvaluator(e) - val iLit = ru.Literal(Constant(i)) + s"private ${ctx.javaType(e.dataType)} c$i = ${ctx.defaultValue(e.dataType)};\n" + }.mkString("\n ") - q""" - var ${newTermName(s"c$i")}: ${termForType(e.dataType)} = _ + val initColumns = expressions.zipWithIndex.map { + case (e, i) => + val eval = e.gen(ctx) + s""" { - ..${evaluatedExpression.code} - if(${evaluatedExpression.nullTerm}) - setNullAt($iLit) - else { - nullBits($iLit) = false - $elementName = ${evaluatedExpression.primitiveTerm} + // column$i + ${eval.code} + nullBits[$i] = ${eval.isNull}; + if (!${eval.isNull}) { + c$i = ${eval.primitive}; } } - """.children : Seq[Tree] - } - - val accessorFailure = q"""scala.sys.error("Invalid ordinal:" + i)""" - val applyFunction = { - val cases = (0 until expressions.size).map { i => - val ordinal = ru.Literal(Constant(i)) - val elementName = newTermName(s"c$i") - val iLit = ru.Literal(Constant(i)) - - q"if(i == $ordinal) { if(isNullAt($i)) return null else return $elementName }" - } - q"override def apply(i: Int): Any = { ..$cases; $accessorFailure }" - } - - val updateFunction = { - val cases = expressions.zipWithIndex.map {case (e, i) => - val ordinal = ru.Literal(Constant(i)) - val elementName = newTermName(s"c$i") - val iLit = ru.Literal(Constant(i)) - - q""" - if(i == $ordinal) { - if(value == null) { - setNullAt(i) - } else { - nullBits(i) = false - $elementName = value.asInstanceOf[${termForType(e.dataType)}] - } - return - }""" - } - q"override def update(i: Int, value: Any): Unit = { ..$cases; $accessorFailure }" - } - - val specificAccessorFunctions = nativeTypes.map { dataType => - val ifStatements = expressions.zipWithIndex.flatMap { - // getString() is not used by expressions - case (e, i) if e.dataType == dataType && dataType != StringType => - val elementName = newTermName(s"c$i") - // TODO: The string of ifs gets pretty inefficient as the row grows in size. - // TODO: Optional null checks? - q"if(i == $i) return $elementName" :: Nil - case _ => Nil - } - dataType match { - // Row() need this interface to compile - case StringType => - q""" - override def getString(i: Int): String = { - $accessorFailure - }""" - case other => - q""" - override def ${accessorForType(dataType)}(i: Int): ${termForType(dataType)} = { - ..$ifStatements; - $accessorFailure - }""" - } - } - - val specificMutatorFunctions = nativeTypes.map { dataType => - val ifStatements = expressions.zipWithIndex.flatMap { - // setString() is not used by expressions - case (e, i) if e.dataType == dataType && dataType != StringType => - val elementName = newTermName(s"c$i") - // TODO: The string of ifs gets pretty inefficient as the row grows in size. - // TODO: Optional null checks? - q"if(i == $i) { nullBits($i) = false; $elementName = value; return }" :: Nil - case _ => Nil + """ + }.mkString("\n") + + val getCases = (0 until expressions.size).map { i => + s"case $i: return c$i;" + }.mkString("\n ") + + val updateCases = expressions.zipWithIndex.map { case (e, i) => + s"case $i: { c$i = (${ctx.boxedType(e.dataType)})value; return;}" + }.mkString("\n ") + + val specificAccessorFunctions = ctx.primitiveTypes.map { jt => + val cases = expressions.zipWithIndex.flatMap { + case (e, i) if ctx.javaType(e.dataType) == jt => + Some(s"case $i: return c$i;") + case _ => None + }.mkString("\n ") + if (cases.length > 0) { + val getter = "get" + ctx.primitiveTypeName(jt) + s""" + @Override + public $jt $getter(int i) { + if (isNullAt(i)) { + return ${ctx.defaultValue(jt)}; + } + switch (i) { + $cases + } + throw new IllegalArgumentException("Invalid index: " + i + + " in $getter"); + }""" + } else { + "" } - dataType match { - case StringType => - // MutableRow() need this interface to compile - q""" - override def setString(i: Int, value: String) { - $accessorFailure - }""" - case other => - q""" - override def ${mutatorForType(dataType)}(i: Int, value: ${termForType(dataType)}) { - ..$ifStatements; - $accessorFailure - }""" + }.filter(_.length > 0).mkString("\n") + + val specificMutatorFunctions = ctx.primitiveTypes.map { jt => + val cases = expressions.zipWithIndex.flatMap { + case (e, i) if ctx.javaType(e.dataType) == jt => + Some(s"case $i: { c$i = value; return; }") + case _ => None + }.mkString("\n ") + if (cases.length > 0) { + val setter = "set" + ctx.primitiveTypeName(jt) + s""" + @Override + public void $setter(int i, $jt value) { + nullBits[i] = false; + switch (i) { + $cases + } + throw new IllegalArgumentException("Invalid index: " + i + + " in $setter}"); + }""" + } else { + "" } - } + }.filter(_.length > 0).mkString("\n") - val hashValues = expressions.zipWithIndex.map { case (e,i) => - val elementName = newTermName(s"c$i") + val hashValues = expressions.zipWithIndex.map { case (e, i) => + val col = s"c$i" val nonNull = e.dataType match { - case BooleanType => q"if ($elementName) 0 else 1" - case ByteType | ShortType | IntegerType => q"$elementName.toInt" - case LongType => q"($elementName ^ ($elementName >>> 32)).toInt" - case FloatType => q"java.lang.Float.floatToIntBits($elementName)" + case BooleanType => s"$col ? 0 : 1" + case ByteType | ShortType | IntegerType | DateType => s"$col" + case LongType | TimestampType => s"$col ^ ($col >>> 32)" + case FloatType => s"Float.floatToIntBits($col)" case DoubleType => - q"{ val b = java.lang.Double.doubleToLongBits($elementName); (b ^ (b >>>32)).toInt }" - case _ => q"$elementName.hashCode" + s"(int)(Double.doubleToLongBits($col) ^ (Double.doubleToLongBits($col) >>> 32))" + case BinaryType => s"java.util.Arrays.hashCode($col)" + case _ => s"$col.hashCode()" } - q"if (isNullAt($i)) 0 else $nonNull" + s"isNullAt($i) ? 0 : ($nonNull)" } - val hashUpdates: Seq[Tree] = hashValues.map(v => q"""result = 37 * result + $v""": Tree) + val hashUpdates: String = hashValues.map( v => + s""" + result *= 37; result += $v;""" + ).mkString("\n") - val hashCodeFunction = - q""" - override def hashCode(): Int = { - var result: Int = 37 - ..$hashUpdates - result + val columnChecks = expressions.zipWithIndex.map { case (e, i) => + s""" + if (nullBits[$i] != row.nullBits[$i] || + (!nullBits[$i] && !(${ctx.genEqual(e.dataType, s"c$i", s"row.c$i")}))) { + return false; } """ + }.mkString("\n") - val columnChecks = (0 until expressions.size).map { i => - val elementName = newTermName(s"c$i") - q"if (this.$elementName != specificType.$elementName) return false" + val copyColumns = expressions.zipWithIndex.map { case (e, i) => + s"""if (!nullBits[$i]) arr[$i] = c$i;""" + }.mkString("\n ") + + val code = s""" + public SpecificProjection generate($exprType[] expr) { + return new SpecificProjection(expr); } - val equalsFunction = - q""" - override def equals(other: Any): Boolean = other match { - case specificType: SpecificRow => - ..$columnChecks - return true - case other => super.equals(other) - } - """ + class SpecificProjection extends ${classOf[BaseProject].getName} { + private $exprType[] expressions = null; - val allColumns = (0 until expressions.size).map { i => - val iLit = ru.Literal(Constant(i)) - q"if(isNullAt($iLit)) { null } else { ${newTermName(s"c$i")} }" + public SpecificProjection($exprType[] expr) { + expressions = expr; + } + + @Override + public Object apply(Object r) { + return new SpecificRow(expressions, (InternalRow) r); + } } - val copyFunction = - q"override def copy() = new $genericRowType(Array[Any](..$allColumns))" - - val toSeqFunction = - q"override def toSeq: Seq[Any] = Seq(..$allColumns)" - - val classBody = - nullFunctions ++ ( - lengthDef +: - applyFunction +: - updateFunction +: - equalsFunction +: - hashCodeFunction +: - copyFunction +: - toSeqFunction +: - (tupleElements ++ specificAccessorFunctions ++ specificMutatorFunctions)) - - val code = q""" - final class SpecificRow(i: $rowType) extends $mutableRowType { - ..$classBody + final class SpecificRow extends ${classOf[MutableRow].getName} { + + $columns + + public SpecificRow($exprType[] expressions, InternalRow i) { + $initColumns + } + + public int length() { return ${expressions.length};} + protected boolean[] nullBits = new boolean[${expressions.length}]; + public void setNullAt(int i) { nullBits[i] = true; } + public boolean isNullAt(int i) { return nullBits[i]; } + + public Object get(int i) { + if (isNullAt(i)) return null; + switch (i) { + $getCases + } + return null; + } + public void update(int i, Object value) { + if (value == null) { + setNullAt(i); + return; + } + nullBits[i] = false; + switch (i) { + $updateCases + } + } + $specificAccessorFunctions + $specificMutatorFunctions + + @Override + public int hashCode() { + int result = 37; + $hashUpdates + return result; } - new $projectionType { def apply(r: $rowType) = new SpecificRow(r) } + @Override + public boolean equals(Object other) { + if (other instanceof SpecificRow) { + SpecificRow row = (SpecificRow) other; + $columnChecks + return true; + } + return super.equals(other); + } + + @Override + public InternalRow copy() { + Object[] arr = new Object[${expressions.length}]; + ${copyColumns} + return new ${classOf[GenericInternalRow].getName}(arr); + } + } """ - log.debug( - s"MutableRow, initExprs: ${expressions.mkString(",")} code:\n${toolBox.typeCheck(code)}") - toolBox.eval(code).asInstanceOf[Projection] + logDebug(s"MutableRow, initExprs: ${expressions.mkString(",")} code:\n${code}") + + compile(code).generate(ctx.references.toArray).asInstanceOf[Projection] } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/package.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/package.scala index 528e38a50a74..7f1b12cdd580 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/package.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/package.scala @@ -27,12 +27,6 @@ import org.apache.spark.util.Utils */ package object codegen { - /** - * A lock to protect invoking the scala compiler at runtime, since it is not thread safe in Scala - * 2.10. - */ - protected[codegen] val globalLock = org.apache.spark.sql.catalyst.ScalaReflectionLock - /** Canonicalizes an expression so those that differ only by names can reuse the same code. */ object ExpressionCanonicalizer extends rules.RuleExecutor[Expression] { val batches = diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeCreator.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeCreator.scala new file mode 100644 index 000000000000..fa70409353e7 --- /dev/null +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeCreator.scala @@ -0,0 +1,125 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.expressions + +import org.apache.spark.sql.catalyst +import org.apache.spark.sql.catalyst.analysis.TypeCheckResult +import org.apache.spark.sql.catalyst.util.TypeUtils +import org.apache.spark.sql.catalyst.analysis.TypeCheckResult.TypeCheckFailure +import org.apache.spark.sql.types._ +import org.apache.spark.unsafe.types.UTF8String + +/** + * Returns an Array containing the evaluation of all children expressions. + */ +case class CreateArray(children: Seq[Expression]) extends Expression { + + override def foldable: Boolean = children.forall(_.foldable) + + override def checkInputDataTypes(): TypeCheckResult = + TypeUtils.checkForSameTypeInputExpr(children.map(_.dataType), "function array") + + override def dataType: DataType = { + ArrayType( + children.headOption.map(_.dataType).getOrElse(NullType), + containsNull = children.exists(_.nullable)) + } + + override def nullable: Boolean = false + + override def eval(input: InternalRow): Any = { + children.map(_.eval(input)) + } + + override def prettyName: String = "array" +} + +/** + * Returns a Row containing the evaluation of all children expressions. + * TODO: [[CreateStruct]] does not support codegen. + */ +case class CreateStruct(children: Seq[Expression]) extends Expression { + + override def foldable: Boolean = children.forall(_.foldable) + + override lazy val resolved: Boolean = childrenResolved + + override lazy val dataType: StructType = { + val fields = children.zipWithIndex.map { case (child, idx) => + child match { + case ne: NamedExpression => + StructField(ne.name, ne.dataType, ne.nullable, ne.metadata) + case _ => + StructField(s"col${idx + 1}", child.dataType, child.nullable, Metadata.empty) + } + } + StructType(fields) + } + + override def nullable: Boolean = false + + override def eval(input: InternalRow): Any = { + InternalRow(children.map(_.eval(input)): _*) + } + + override def prettyName: String = "struct" +} + +/** + * Creates a struct with the given field names and values + * + * @param children Seq(name1, val1, name2, val2, ...) + */ +case class CreateNamedStruct(children: Seq[Expression]) extends Expression { + + private lazy val (nameExprs, valExprs) = + children.grouped(2).map { case Seq(name, value) => (name, value) }.toList.unzip + + private lazy val names = nameExprs.map(_.eval(EmptyRow).toString) + + override lazy val dataType: StructType = { + val fields = names.zip(valExprs).map { case (name, valExpr) => + StructField(name, valExpr.dataType, valExpr.nullable, Metadata.empty) + } + StructType(fields) + } + + override def foldable: Boolean = valExprs.forall(_.foldable) + + override def nullable: Boolean = false + + override def checkInputDataTypes(): TypeCheckResult = { + if (children.size % 2 != 0) { + TypeCheckResult.TypeCheckFailure("CreateNamedStruct expects an even number of arguments.") + } else { + val invalidNames = + nameExprs.filterNot(e => e.foldable && e.dataType == StringType && !nullable) + if (invalidNames.size != 0) { + TypeCheckResult.TypeCheckFailure( + s"Odd position only allow foldable and not-null StringType expressions, got :" + + s" ${invalidNames.mkString(",")}") + } else { + TypeCheckResult.TypeCheckSuccess + } + } + } + + override def eval(input: InternalRow): Any = { + InternalRow(valExprs.map(_.eval(input)): _*) + } +} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypes.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypes.scala deleted file mode 100644 index 956a2429b0b6..000000000000 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypes.scala +++ /dev/null @@ -1,77 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.sql.catalyst.expressions - -import org.apache.spark.sql.types._ - - -/** - * Returns an Array containing the evaluation of all children expressions. - */ -case class CreateArray(children: Seq[Expression]) extends Expression { - override type EvaluatedType = Any - - override def foldable: Boolean = children.forall(_.foldable) - - lazy val childTypes = children.map(_.dataType).distinct - - override lazy val resolved = - childrenResolved && childTypes.size <= 1 - - override def dataType: DataType = { - assert(resolved, s"Invalid dataType of mixed ArrayType ${childTypes.mkString(",")}") - ArrayType( - childTypes.headOption.getOrElse(NullType), - containsNull = children.exists(_.nullable)) - } - - override def nullable: Boolean = false - - override def eval(input: Row): Any = { - children.map(_.eval(input)) - } - - override def toString: String = s"Array(${children.mkString(",")})" -} - -/** - * Returns a Row containing the evaluation of all children expressions. - * TODO: [[CreateStruct]] does not support codegen. - */ -case class CreateStruct(children: Seq[NamedExpression]) extends Expression { - override type EvaluatedType = Row - - override def foldable: Boolean = children.forall(_.foldable) - - override lazy val resolved: Boolean = childrenResolved - - override lazy val dataType: StructType = { - assert(resolved, - s"CreateStruct contains unresolvable children: ${children.filterNot(_.resolved)}.") - val fields = children.map { child => - StructField(child.name, child.dataType, child.nullable, child.metadata) - } - StructType(fields) - } - - override def nullable: Boolean = false - - override def eval(input: Row): EvaluatedType = { - Row(children.map(_.eval(input)): _*) - } -} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/conditionals.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/conditionals.scala new file mode 100644 index 000000000000..1d7393d3d91f --- /dev/null +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/conditionals.scala @@ -0,0 +1,314 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.expressions + +import org.apache.spark.sql.catalyst +import org.apache.spark.sql.catalyst.analysis.TypeCheckResult +import org.apache.spark.sql.catalyst.expressions.codegen._ +import org.apache.spark.sql.types.{BooleanType, DataType} + + +case class If(predicate: Expression, trueValue: Expression, falseValue: Expression) + extends Expression { + + override def children: Seq[Expression] = predicate :: trueValue :: falseValue :: Nil + override def nullable: Boolean = trueValue.nullable || falseValue.nullable + + override def checkInputDataTypes(): TypeCheckResult = { + if (predicate.dataType != BooleanType) { + TypeCheckResult.TypeCheckFailure( + s"type of predicate expression in If should be boolean, not ${predicate.dataType}") + } else if (trueValue.dataType != falseValue.dataType) { + TypeCheckResult.TypeCheckFailure( + s"differing types in If (${trueValue.dataType} and ${falseValue.dataType}).") + } else { + TypeCheckResult.TypeCheckSuccess + } + } + + override def dataType: DataType = trueValue.dataType + + override def eval(input: InternalRow): Any = { + if (true == predicate.eval(input)) { + trueValue.eval(input) + } else { + falseValue.eval(input) + } + } + + override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { + val condEval = predicate.gen(ctx) + val trueEval = trueValue.gen(ctx) + val falseEval = falseValue.gen(ctx) + + s""" + ${condEval.code} + boolean ${ev.isNull} = false; + ${ctx.javaType(dataType)} ${ev.primitive} = ${ctx.defaultValue(dataType)}; + if (!${condEval.isNull} && ${condEval.primitive}) { + ${trueEval.code} + ${ev.isNull} = ${trueEval.isNull}; + ${ev.primitive} = ${trueEval.primitive}; + } else { + ${falseEval.code} + ${ev.isNull} = ${falseEval.isNull}; + ${ev.primitive} = ${falseEval.primitive}; + } + """ + } + + override def toString: String = s"if ($predicate) $trueValue else $falseValue" +} + +trait CaseWhenLike extends Expression { + self: Product => + + // Note that `branches` are considered in consecutive pairs (cond, val), and the optional last + // element is the value for the default catch-all case (if provided). + // Hence, `branches` consists of at least two elements, and can have an odd or even length. + def branches: Seq[Expression] + + @transient lazy val whenList = + branches.sliding(2, 2).collect { case Seq(whenExpr, _) => whenExpr }.toSeq + @transient lazy val thenList = + branches.sliding(2, 2).collect { case Seq(_, thenExpr) => thenExpr }.toSeq + val elseValue = if (branches.length % 2 == 0) None else Option(branches.last) + + // both then and else expressions should be considered. + def valueTypes: Seq[DataType] = (thenList ++ elseValue).map(_.dataType) + def valueTypesEqual: Boolean = valueTypes.distinct.size == 1 + + override def checkInputDataTypes(): TypeCheckResult = { + if (valueTypesEqual) { + checkTypesInternal() + } else { + TypeCheckResult.TypeCheckFailure( + "THEN and ELSE expressions should all be same type or coercible to a common type") + } + } + + protected def checkTypesInternal(): TypeCheckResult + + override def dataType: DataType = thenList.head.dataType + + override def nullable: Boolean = { + // If no value is nullable and no elseValue is provided, the whole statement defaults to null. + thenList.exists(_.nullable) || (elseValue.map(_.nullable).getOrElse(true)) + } +} + +// scalastyle:off +/** + * Case statements of the form "CASE WHEN a THEN b [WHEN c THEN d]* [ELSE e] END". + * Refer to this link for the corresponding semantics: + * https://cwiki.apache.org/confluence/display/Hive/LanguageManual+UDF#LanguageManualUDF-ConditionalFunctions + */ +// scalastyle:on +case class CaseWhen(branches: Seq[Expression]) extends CaseWhenLike { + + // Use private[this] Array to speed up evaluation. + @transient private[this] lazy val branchesArr = branches.toArray + + override def children: Seq[Expression] = branches + + override protected def checkTypesInternal(): TypeCheckResult = { + if (whenList.forall(_.dataType == BooleanType)) { + TypeCheckResult.TypeCheckSuccess + } else { + val index = whenList.indexWhere(_.dataType != BooleanType) + TypeCheckResult.TypeCheckFailure( + s"WHEN expressions in CaseWhen should all be boolean type, " + + s"but the ${index + 1}th when expression's type is ${whenList(index)}") + } + } + + /** Written in imperative fashion for performance considerations. */ + override def eval(input: InternalRow): Any = { + val len = branchesArr.length + var i = 0 + // If all branches fail and an elseVal is not provided, the whole statement + // defaults to null, according to Hive's semantics. + while (i < len - 1) { + if (branchesArr(i).eval(input) == true) { + return branchesArr(i + 1).eval(input) + } + i += 2 + } + var res: Any = null + if (i == len - 1) { + res = branchesArr(i).eval(input) + } + return res + } + + override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { + val len = branchesArr.length + val got = ctx.freshName("got") + + val cases = (0 until len/2).map { i => + val cond = branchesArr(i * 2).gen(ctx) + val res = branchesArr(i * 2 + 1).gen(ctx) + s""" + if (!$got) { + ${cond.code} + if (!${cond.isNull} && ${cond.primitive}) { + $got = true; + ${res.code} + ${ev.isNull} = ${res.isNull}; + ${ev.primitive} = ${res.primitive}; + } + } + """ + }.mkString("\n") + + val other = if (len % 2 == 1) { + val res = branchesArr(len - 1).gen(ctx) + s""" + if (!$got) { + ${res.code} + ${ev.isNull} = ${res.isNull}; + ${ev.primitive} = ${res.primitive}; + } + """ + } else { + "" + } + + s""" + boolean $got = false; + boolean ${ev.isNull} = true; + ${ctx.javaType(dataType)} ${ev.primitive} = ${ctx.defaultValue(dataType)}; + $cases + $other + """ + } + + override def toString: String = { + "CASE" + branches.sliding(2, 2).map { + case Seq(cond, value) => s" WHEN $cond THEN $value" + case Seq(elseValue) => s" ELSE $elseValue" + }.mkString + } +} + +// scalastyle:off +/** + * Case statements of the form "CASE a WHEN b THEN c [WHEN d THEN e]* [ELSE f] END". + * Refer to this link for the corresponding semantics: + * https://cwiki.apache.org/confluence/display/Hive/LanguageManual+UDF#LanguageManualUDF-ConditionalFunctions + */ +// scalastyle:on +case class CaseKeyWhen(key: Expression, branches: Seq[Expression]) extends CaseWhenLike { + + // Use private[this] Array to speed up evaluation. + @transient private[this] lazy val branchesArr = branches.toArray + + override def children: Seq[Expression] = key +: branches + + override protected def checkTypesInternal(): TypeCheckResult = { + if ((key +: whenList).map(_.dataType).distinct.size > 1) { + TypeCheckResult.TypeCheckFailure( + "key and WHEN expressions should all be same type or coercible to a common type") + } else { + TypeCheckResult.TypeCheckSuccess + } + } + + /** Written in imperative fashion for performance considerations. */ + override def eval(input: InternalRow): Any = { + val evaluatedKey = key.eval(input) + val len = branchesArr.length + var i = 0 + // If all branches fail and an elseVal is not provided, the whole statement + // defaults to null, according to Hive's semantics. + while (i < len - 1) { + if (equalNullSafe(evaluatedKey, branchesArr(i).eval(input))) { + return branchesArr(i + 1).eval(input) + } + i += 2 + } + var res: Any = null + if (i == len - 1) { + res = branchesArr(i).eval(input) + } + return res + } + + override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { + val keyEval = key.gen(ctx) + val len = branchesArr.length + val got = ctx.freshName("got") + + val cases = (0 until len/2).map { i => + val cond = branchesArr(i * 2).gen(ctx) + val res = branchesArr(i * 2 + 1).gen(ctx) + s""" + if (!$got) { + ${cond.code} + if (${keyEval.isNull} && ${cond.isNull} || + !${keyEval.isNull} && !${cond.isNull} + && ${ctx.genEqual(key.dataType, keyEval.primitive, cond.primitive)}) { + $got = true; + ${res.code} + ${ev.isNull} = ${res.isNull}; + ${ev.primitive} = ${res.primitive}; + } + } + """ + }.mkString("\n") + + val other = if (len % 2 == 1) { + val res = branchesArr(len - 1).gen(ctx) + s""" + if (!$got) { + ${res.code} + ${ev.isNull} = ${res.isNull}; + ${ev.primitive} = ${res.primitive}; + } + """ + } else { + "" + } + + s""" + boolean $got = false; + boolean ${ev.isNull} = true; + ${ctx.javaType(dataType)} ${ev.primitive} = ${ctx.defaultValue(dataType)}; + ${keyEval.code} + $cases + $other + """ + } + + private def equalNullSafe(l: Any, r: Any) = { + if (l == null && r == null) { + true + } else if (l == null || r == null) { + false + } else { + l == r + } + } + + override def toString: String = { + s"CASE $key" + branches.sliding(2, 2).map { + case Seq(cond, value) => s" WHEN $cond THEN $value" + case Seq(elseValue) => s" ELSE $elseValue" + }.mkString + } +} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/datetimeFunctions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/datetimeFunctions.scala new file mode 100644 index 000000000000..13ba2f2e5d62 --- /dev/null +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/datetimeFunctions.scala @@ -0,0 +1,52 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.expressions + +import org.apache.spark.sql.catalyst.expressions.codegen.{CodeGenContext, GeneratedExpressionCode} +import org.apache.spark.sql.catalyst.util.DateTimeUtils +import org.apache.spark.sql.types._ + +/** + * Returns the current date at the start of query evaluation. + * All calls of current_date within the same query return the same value. + */ +case class CurrentDate() extends LeafExpression { + override def foldable: Boolean = true + override def nullable: Boolean = false + + override def dataType: DataType = DateType + + override def eval(input: InternalRow): Any = { + DateTimeUtils.millisToDays(System.currentTimeMillis()) + } +} + +/** + * Returns the current timestamp at the start of query evaluation. + * All calls of current_timestamp within the same query return the same value. + */ +case class CurrentTimestamp() extends LeafExpression { + override def foldable: Boolean = true + override def nullable: Boolean = false + + override def dataType: DataType = TimestampType + + override def eval(input: InternalRow): Any = { + System.currentTimeMillis() * 10000L + } +} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/decimalFunctions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/decimalFunctions.scala index adb94df7d1c7..f5c2dde191cf 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/decimalFunctions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/decimalFunctions.scala @@ -17,18 +17,20 @@ package org.apache.spark.sql.catalyst.expressions +import org.apache.spark.sql.catalyst.expressions.codegen.{CodeGenContext, GeneratedExpressionCode} import org.apache.spark.sql.types._ -/** Return the unscaled Long value of a Decimal, assuming it fits in a Long */ +/** + * Return the unscaled Long value of a Decimal, assuming it fits in a Long. + * Note: this expression is internal and created only by the optimizer, + * we don't need to do type check for it. + */ case class UnscaledValue(child: Expression) extends UnaryExpression { - override type EvaluatedType = Any override def dataType: DataType = LongType - override def foldable: Boolean = child.foldable - override def nullable: Boolean = child.nullable override def toString: String = s"UnscaledValue($child)" - override def eval(input: Row): Any = { + override def eval(input: InternalRow): Any = { val childResult = child.eval(input) if (childResult == null) { null @@ -36,18 +38,23 @@ case class UnscaledValue(child: Expression) extends UnaryExpression { childResult.asInstanceOf[Decimal].toUnscaledLong } } + + override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { + defineCodeGen(ctx, ev, c => s"$c.toUnscaledLong()") + } } -/** Create a Decimal from an unscaled Long value */ +/** + * Create a Decimal from an unscaled Long value. + * Note: this expression is internal and created only by the optimizer, + * we don't need to do type check for it. + */ case class MakeDecimal(child: Expression, precision: Int, scale: Int) extends UnaryExpression { - override type EvaluatedType = Decimal override def dataType: DataType = DecimalType(precision, scale) - override def foldable: Boolean = child.foldable - override def nullable: Boolean = child.nullable override def toString: String = s"MakeDecimal($child,$precision,$scale)" - override def eval(input: Row): Decimal = { + override def eval(input: InternalRow): Decimal = { val childResult = child.eval(input) if (childResult == null) { null @@ -55,4 +62,18 @@ case class MakeDecimal(child: Expression, precision: Int, scale: Int) extends Un new Decimal().setOrNull(childResult.asInstanceOf[Long], precision, scale) } } + + override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { + val eval = child.gen(ctx) + eval.code + s""" + boolean ${ev.isNull} = ${eval.isNull}; + ${ctx.decimalType} ${ev.primitive} = null; + + if (!${ev.isNull}) { + ${ev.primitive} = (new ${ctx.decimalType}()).setOrNull( + ${eval.primitive}, $precision, $scale); + ${ev.isNull} = ${ev.primitive} == null; + } + """ + } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/generators.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/generators.scala index 747a47bdde95..7a42a1d31058 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/generators.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/generators.scala @@ -19,6 +19,8 @@ package org.apache.spark.sql.catalyst.expressions import scala.collection.Map +import org.apache.spark.sql.Row +import org.apache.spark.sql.catalyst.analysis.TypeCheckResult import org.apache.spark.sql.catalyst.{CatalystTypeConverters, trees} import org.apache.spark.sql.types._ @@ -40,8 +42,6 @@ import org.apache.spark.sql.types._ abstract class Generator extends Expression { self: Product => - override type EvaluatedType = TraversableOnce[Row] - // TODO ideally we should return the type of ArrayType(StructType), // however, we don't keep the output field names in the Generator. override def dataType: DataType = throw new UnsupportedOperationException @@ -55,13 +55,13 @@ abstract class Generator extends Expression { def elementTypes: Seq[(DataType, Boolean)] /** Should be implemented by child classes to perform specific Generators. */ - override def eval(input: Row): TraversableOnce[Row] + override def eval(input: InternalRow): TraversableOnce[InternalRow] /** * Notifies that there are no more rows to process, clean up code, and additional * rows can be made here. */ - def terminate(): TraversableOnce[Row] = Nil + def terminate(): TraversableOnce[InternalRow] = Nil } /** @@ -69,16 +69,27 @@ abstract class Generator extends Expression { */ case class UserDefinedGenerator( elementTypes: Seq[(DataType, Boolean)], - function: Row => TraversableOnce[Row], + function: Row => TraversableOnce[InternalRow], children: Seq[Expression]) extends Generator { - override def eval(input: Row): TraversableOnce[Row] = { - // TODO(davies): improve this + @transient private[this] var inputRow: InterpretedProjection = _ + @transient private[this] var convertToScala: (InternalRow) => Row = _ + + private def initializeConverters(): Unit = { + inputRow = new InterpretedProjection(children) + convertToScala = { + val inputSchema = StructType(children.map(e => StructField(e.simpleString, e.dataType, true))) + CatalystTypeConverters.createToScalaConverter(inputSchema) + }.asInstanceOf[InternalRow => Row] + } + + override def eval(input: InternalRow): TraversableOnce[InternalRow] = { + if (inputRow == null) { + initializeConverters() + } // Convert the objects into Scala Type before calling function, we need schema to support UDT - val inputSchema = StructType(children.map(e => StructField(e.simpleString, e.dataType, true))) - val inputRow = new InterpretedProjection(children) - function(CatalystTypeConverters.convertToScala(inputRow(input), inputSchema).asInstanceOf[Row]) + function(convertToScala(inputRow(input))) } override def toString: String = s"UserDefinedGenerator(${children.mkString(",")})" @@ -90,23 +101,29 @@ case class UserDefinedGenerator( case class Explode(child: Expression) extends Generator with trees.UnaryNode[Expression] { - override lazy val resolved = - child.resolved && - (child.dataType.isInstanceOf[ArrayType] || child.dataType.isInstanceOf[MapType]) + override def checkInputDataTypes(): TypeCheckResult = { + if (child.dataType.isInstanceOf[ArrayType] || child.dataType.isInstanceOf[MapType]) { + TypeCheckResult.TypeCheckSuccess + } else { + TypeCheckResult.TypeCheckFailure( + s"input to function explode should be array or map type, not ${child.dataType}") + } + } override def elementTypes: Seq[(DataType, Boolean)] = child.dataType match { case ArrayType(et, containsNull) => (et, containsNull) :: Nil case MapType(kt, vt, valueContainsNull) => (kt, false) :: (vt, valueContainsNull) :: Nil } - override def eval(input: Row): TraversableOnce[Row] = { + override def eval(input: InternalRow): TraversableOnce[InternalRow] = { child.dataType match { case ArrayType(_, _) => val inputArray = child.eval(input).asInstanceOf[Seq[Any]] - if (inputArray == null) Nil else inputArray.map(v => new GenericRow(Array(v))) + if (inputArray == null) Nil else inputArray.map(v => InternalRow(v)) case MapType(_, _, _) => - val inputMap = child.eval(input).asInstanceOf[Map[Any,Any]] - if (inputMap == null) Nil else inputMap.map { case (k,v) => new GenericRow(Array(k,v)) } + val inputMap = child.eval(input).asInstanceOf[Map[Any, Any]] + if (inputMap == null) Nil + else inputMap.map { case (k, v) => InternalRow(k, v) } } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/literals.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/literals.scala index 5f8c7354aede..479224af5627 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/literals.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/literals.scala @@ -19,9 +19,12 @@ package org.apache.spark.sql.catalyst.expressions import java.sql.{Date, Timestamp} +import org.apache.spark.sql.catalyst import org.apache.spark.sql.catalyst.CatalystTypeConverters -import org.apache.spark.sql.catalyst.util.DateUtils +import org.apache.spark.sql.catalyst.expressions.codegen.{CodeGenContext, GeneratedExpressionCode} +import org.apache.spark.sql.catalyst.util.DateTimeUtils import org.apache.spark.sql.types._ +import org.apache.spark.unsafe.types.UTF8String object Literal { def apply(v: Any): Literal = v match { @@ -31,13 +34,13 @@ object Literal { case f: Float => Literal(f, FloatType) case b: Byte => Literal(b, ByteType) case s: Short => Literal(s, ShortType) - case s: String => Literal(UTF8String(s), StringType) + case s: String => Literal(UTF8String.fromString(s), StringType) case b: Boolean => Literal(b, BooleanType) case d: BigDecimal => Literal(Decimal(d), DecimalType.Unlimited) case d: java.math.BigDecimal => Literal(Decimal(d), DecimalType.Unlimited) case d: Decimal => Literal(d, DecimalType.Unlimited) - case t: Timestamp => Literal(t, TimestampType) - case d: Date => Literal(DateUtils.fromJavaDate(d), DateType) + case t: Timestamp => Literal(DateTimeUtils.fromJavaTimestamp(t), TimestampType) + case d: Date => Literal(DateTimeUtils.fromJavaDate(d), DateType) case a: Array[Byte] => Literal(a, BinaryType) case null => Literal(null, NullType) case _ => @@ -78,18 +81,71 @@ case class Literal protected (value: Any, dataType: DataType) extends LeafExpres override def toString: String = if (value != null) value.toString else "null" - type EvaluatedType = Any - override def eval(input: Row): Any = value + override def equals(other: Any): Boolean = other match { + case o: Literal => + dataType.equals(o.dataType) && + (value == null && null == o.value || value != null && value.equals(o.value)) + case _ => false + } + + override def eval(input: InternalRow): Any = value + + override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { + // change the isNull and primitive to consts, to inline them + if (value == null) { + ev.isNull = "true" + s"final ${ctx.javaType(dataType)} ${ev.primitive} = ${ctx.defaultValue(dataType)};" + } else { + dataType match { + case BooleanType => + ev.isNull = "false" + ev.primitive = value.toString + "" + case FloatType => + val v = value.asInstanceOf[Float] + if (v.isNaN || v.isInfinite) { + super.genCode(ctx, ev) + } else { + ev.isNull = "false" + ev.primitive = s"${value}f" + "" + } + case DoubleType => + val v = value.asInstanceOf[Double] + if (v.isNaN || v.isInfinite) { + super.genCode(ctx, ev) + } else { + ev.isNull = "false" + ev.primitive = s"${value}" + "" + } + case ByteType | ShortType => + ev.isNull = "false" + ev.primitive = s"(${ctx.javaType(dataType)})$value" + "" + case IntegerType | DateType => + ev.isNull = "false" + ev.primitive = value.toString + "" + case TimestampType | LongType => + ev.isNull = "false" + ev.primitive = s"${value}L" + "" + // eval() version may be faster for non-primitive types + case other => + super.genCode(ctx, ev) + } + } + } } // TODO: Specialize case class MutableLiteral(var value: Any, dataType: DataType, nullable: Boolean = true) extends LeafExpression { - type EvaluatedType = Any - def update(expression: Expression, input: Row): Unit = { + def update(expression: Expression, input: InternalRow): Unit = { value = expression.eval(input) } - override def eval(input: Row): Any = value + override def eval(input: InternalRow): Any = value } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/math.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/math.scala new file mode 100644 index 000000000000..92500453980f --- /dev/null +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/math.scala @@ -0,0 +1,609 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.expressions + +import java.{lang => jl} + +import org.apache.spark.sql.catalyst.expressions.codegen._ +import org.apache.spark.sql.types._ +import org.apache.spark.unsafe.types.UTF8String + +/** + * A leaf expression specifically for math constants. Math constants expect no input. + * @param c The math constant. + * @param name The short name of the function + */ +abstract class LeafMathExpression(c: Double, name: String) + extends LeafExpression with Serializable { + self: Product => + + override def dataType: DataType = DoubleType + override def foldable: Boolean = true + override def nullable: Boolean = false + override def toString: String = s"$name()" + + override def eval(input: InternalRow): Any = c + + override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { + s""" + boolean ${ev.isNull} = false; + ${ctx.javaType(dataType)} ${ev.primitive} = java.lang.Math.$name; + """ + } +} + +/** + * A unary expression specifically for math functions. Math Functions expect a specific type of + * input format, therefore these functions extend `ExpectsInputTypes`. + * @param f The math function. + * @param name The short name of the function + */ +abstract class UnaryMathExpression(f: Double => Double, name: String) + extends UnaryExpression with Serializable with ExpectsInputTypes { self: Product => + + override def inputTypes: Seq[DataType] = Seq(DoubleType) + override def dataType: DataType = DoubleType + override def nullable: Boolean = true + override def toString: String = s"$name($child)" + + override def eval(input: InternalRow): Any = { + val evalE = child.eval(input) + if (evalE == null) { + null + } else { + val result = f(evalE.asInstanceOf[Double]) + if (result.isNaN) null else result + } + } + + // name of function in java.lang.Math + def funcName: String = name.toLowerCase + + override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { + nullSafeCodeGen(ctx, ev, (result, eval) => { + s""" + ${ev.primitive} = java.lang.Math.${funcName}($eval); + if (Double.valueOf(${ev.primitive}).isNaN()) { + ${ev.isNull} = true; + } + """ + }) + } +} + +/** + * A binary expression specifically for math functions that take two `Double`s as input and returns + * a `Double`. + * @param f The math function. + * @param name The short name of the function + */ +abstract class BinaryMathExpression(f: (Double, Double) => Double, name: String) + extends BinaryExpression with Serializable with ExpectsInputTypes { self: Product => + + override def inputTypes: Seq[DataType] = Seq(DoubleType, DoubleType) + + override def toString: String = s"$name($left, $right)" + + override def dataType: DataType = DoubleType + + override def eval(input: InternalRow): Any = { + val evalE1 = left.eval(input) + if (evalE1 == null) { + null + } else { + val evalE2 = right.eval(input) + if (evalE2 == null) { + null + } else { + val result = f(evalE1.asInstanceOf[Double], evalE2.asInstanceOf[Double]) + if (result.isNaN) null else result + } + } + } + + override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { + defineCodeGen(ctx, ev, (c1, c2) => s"java.lang.Math.${name.toLowerCase}($c1, $c2)") + } +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// +//////////////////////////////////////////////////////////////////////////////////////////////////// +// Leaf math functions +//////////////////////////////////////////////////////////////////////////////////////////////////// +//////////////////////////////////////////////////////////////////////////////////////////////////// + +case class EulerNumber() extends LeafMathExpression(math.E, "E") + +case class Pi() extends LeafMathExpression(math.Pi, "PI") + +//////////////////////////////////////////////////////////////////////////////////////////////////// +//////////////////////////////////////////////////////////////////////////////////////////////////// +// Unary math functions +//////////////////////////////////////////////////////////////////////////////////////////////////// +//////////////////////////////////////////////////////////////////////////////////////////////////// + +case class Acos(child: Expression) extends UnaryMathExpression(math.acos, "ACOS") + +case class Asin(child: Expression) extends UnaryMathExpression(math.asin, "ASIN") + +case class Atan(child: Expression) extends UnaryMathExpression(math.atan, "ATAN") + +case class Cbrt(child: Expression) extends UnaryMathExpression(math.cbrt, "CBRT") + +case class Ceil(child: Expression) extends UnaryMathExpression(math.ceil, "CEIL") + +case class Cos(child: Expression) extends UnaryMathExpression(math.cos, "COS") + +case class Cosh(child: Expression) extends UnaryMathExpression(math.cosh, "COSH") + +case class Exp(child: Expression) extends UnaryMathExpression(math.exp, "EXP") + +case class Expm1(child: Expression) extends UnaryMathExpression(math.expm1, "EXPM1") + +case class Floor(child: Expression) extends UnaryMathExpression(math.floor, "FLOOR") + +object Factorial { + + def factorial(n: Int): Long = { + if (n < factorials.length) factorials(n) else Long.MaxValue + } + + private val factorials: Array[Long] = Array[Long]( + 1, + 1, + 2, + 6, + 24, + 120, + 720, + 5040, + 40320, + 362880, + 3628800, + 39916800, + 479001600, + 6227020800L, + 87178291200L, + 1307674368000L, + 20922789888000L, + 355687428096000L, + 6402373705728000L, + 121645100408832000L, + 2432902008176640000L + ) +} + +case class Factorial(child: Expression) extends UnaryExpression with ExpectsInputTypes { + + override def inputTypes: Seq[DataType] = Seq(IntegerType) + + override def dataType: DataType = LongType + + override def foldable: Boolean = child.foldable + + // If the value not in the range of [0, 20], it still will be null, so set it to be true here. + override def nullable: Boolean = true + + override def eval(input: InternalRow): Any = { + val evalE = child.eval(input) + if (evalE == null) { + null + } else { + val input = evalE.asInstanceOf[jl.Integer] + if (input > 20 || input < 0) { + null + } else { + Factorial.factorial(input) + } + } + } + + override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { + val eval = child.gen(ctx) + eval.code + s""" + boolean ${ev.isNull} = ${eval.isNull}; + ${ctx.javaType(dataType)} ${ev.primitive} = ${ctx.defaultValue(dataType)}; + if (!${ev.isNull}) { + if (${eval.primitive} > 20 || ${eval.primitive} < 0) { + ${ev.isNull} = true; + } else { + ${ev.primitive} = + org.apache.spark.sql.catalyst.expressions.Factorial.factorial(${eval.primitive}); + } + } + """ + } +} + +case class Log(child: Expression) extends UnaryMathExpression(math.log, "LOG") + +case class Log2(child: Expression) + extends UnaryMathExpression((x: Double) => math.log(x) / math.log(2), "LOG2") { + override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { + val eval = child.gen(ctx) + eval.code + s""" + boolean ${ev.isNull} = ${eval.isNull}; + ${ctx.javaType(dataType)} ${ev.primitive} = ${ctx.defaultValue(dataType)}; + if (!${ev.isNull}) { + ${ev.primitive} = java.lang.Math.log(${eval.primitive}) / java.lang.Math.log(2); + if (Double.valueOf(${ev.primitive}).isNaN()) { + ${ev.isNull} = true; + } + } + """ + } +} + +case class Log10(child: Expression) extends UnaryMathExpression(math.log10, "LOG10") + +case class Log1p(child: Expression) extends UnaryMathExpression(math.log1p, "LOG1P") + +case class Rint(child: Expression) extends UnaryMathExpression(math.rint, "ROUND") { + override def funcName: String = "rint" +} + +case class Signum(child: Expression) extends UnaryMathExpression(math.signum, "SIGNUM") + +case class Sin(child: Expression) extends UnaryMathExpression(math.sin, "SIN") + +case class Sinh(child: Expression) extends UnaryMathExpression(math.sinh, "SINH") + +case class Sqrt(child: Expression) extends UnaryMathExpression(math.sqrt, "SQRT") + +case class Tan(child: Expression) extends UnaryMathExpression(math.tan, "TAN") + +case class Tanh(child: Expression) extends UnaryMathExpression(math.tanh, "TANH") + +case class ToDegrees(child: Expression) extends UnaryMathExpression(math.toDegrees, "DEGREES") { + override def funcName: String = "toDegrees" +} + +case class ToRadians(child: Expression) extends UnaryMathExpression(math.toRadians, "RADIANS") { + override def funcName: String = "toRadians" +} + +case class Bin(child: Expression) + extends UnaryExpression with Serializable with ExpectsInputTypes { + + override def inputTypes: Seq[DataType] = Seq(LongType) + override def dataType: DataType = StringType + + override def eval(input: InternalRow): Any = { + val evalE = child.eval(input) + if (evalE == null) { + null + } else { + UTF8String.fromString(jl.Long.toBinaryString(evalE.asInstanceOf[Long])) + } + } + + override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { + defineCodeGen(ctx, ev, (c) => + s"${ctx.stringType}.fromString(java.lang.Long.toBinaryString($c))") + } +} + +object Hex { + val hexDigits = Array[Char]( + '0', '1', '2', '3', '4', '5', '6', '7', '8', '9', 'A', 'B', 'C', 'D', 'E', 'F' + ).map(_.toByte) + + // lookup table to translate '0' -> 0 ... 'F'/'f' -> 15 + val unhexDigits = { + val array = Array.fill[Byte](128)(-1) + (0 to 9).foreach(i => array('0' + i) = i.toByte) + (0 to 5).foreach(i => array('A' + i) = (i + 10).toByte) + (0 to 5).foreach(i => array('a' + i) = (i + 10).toByte) + array + } +} + +/** + * If the argument is an INT or binary, hex returns the number as a STRING in hexadecimal format. + * Otherwise if the number is a STRING, it converts each character into its hex representation + * and returns the resulting STRING. Negative numbers would be treated as two's complement. + */ +case class Hex(child: Expression) extends UnaryExpression with ExpectsInputTypes { + // TODO: Create code-gen version. + + override def inputTypes: Seq[AbstractDataType] = + Seq(TypeCollection(LongType, BinaryType, StringType)) + + override def dataType: DataType = StringType + + override def eval(input: InternalRow): Any = { + val num = child.eval(input) + if (num == null) { + null + } else { + child.dataType match { + case LongType => hex(num.asInstanceOf[Long]) + case BinaryType => hex(num.asInstanceOf[Array[Byte]]) + case StringType => hex(num.asInstanceOf[UTF8String].getBytes) + } + } + } + + private[this] def hex(bytes: Array[Byte]): UTF8String = { + val length = bytes.length + val value = new Array[Byte](length * 2) + var i = 0 + while (i < length) { + value(i * 2) = Hex.hexDigits((bytes(i) & 0xF0) >> 4) + value(i * 2 + 1) = Hex.hexDigits(bytes(i) & 0x0F) + i += 1 + } + UTF8String.fromBytes(value) + } + + private def hex(num: Long): UTF8String = { + // Extract the hex digits of num into value[] from right to left + val value = new Array[Byte](16) + var numBuf = num + var len = 0 + do { + len += 1 + value(value.length - len) = Hex.hexDigits((numBuf & 0xF).toInt) + numBuf >>>= 4 + } while (numBuf != 0) + UTF8String.fromBytes(java.util.Arrays.copyOfRange(value, value.length - len, value.length)) + } +} + +/** + * Performs the inverse operation of HEX. + * Resulting characters are returned as a byte array. + */ +case class Unhex(child: Expression) extends UnaryExpression with ExpectsInputTypes { + // TODO: Create code-gen version. + + override def inputTypes: Seq[AbstractDataType] = Seq(StringType) + + override def nullable: Boolean = true + override def dataType: DataType = BinaryType + + override def eval(input: InternalRow): Any = { + val num = child.eval(input) + if (num == null) { + null + } else { + unhex(num.asInstanceOf[UTF8String].getBytes) + } + } + + private[this] def unhex(bytes: Array[Byte]): Array[Byte] = { + val out = new Array[Byte]((bytes.length + 1) >> 1) + var i = 0 + if ((bytes.length & 0x01) != 0) { + // padding with '0' + if (bytes(0) < 0) { + return null + } + val v = Hex.unhexDigits(bytes(0)) + if (v == -1) { + return null + } + out(0) = v + i += 1 + } + // two characters form the hex value. + while (i < bytes.length) { + if (bytes(i) < 0 || bytes(i + 1) < 0) { + return null + } + val first = Hex.unhexDigits(bytes(i)) + val second = Hex.unhexDigits(bytes(i + 1)) + if (first == -1 || second == -1) { + return null + } + out(i / 2) = (((first << 4) | second) & 0xFF).toByte + i += 2 + } + out + } +} + + +//////////////////////////////////////////////////////////////////////////////////////////////////// +//////////////////////////////////////////////////////////////////////////////////////////////////// +// Binary math functions +//////////////////////////////////////////////////////////////////////////////////////////////////// +//////////////////////////////////////////////////////////////////////////////////////////////////// + + +case class Atan2(left: Expression, right: Expression) + extends BinaryMathExpression(math.atan2, "ATAN2") { + + override def eval(input: InternalRow): Any = { + val evalE1 = left.eval(input) + if (evalE1 == null) { + null + } else { + val evalE2 = right.eval(input) + if (evalE2 == null) { + null + } else { + // With codegen, the values returned by -0.0 and 0.0 are different. Handled with +0.0 + val result = math.atan2(evalE1.asInstanceOf[Double] + 0.0, + evalE2.asInstanceOf[Double] + 0.0) + if (result.isNaN) null else result + } + } + } + + override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { + defineCodeGen(ctx, ev, (c1, c2) => s"java.lang.Math.atan2($c1 + 0.0, $c2 + 0.0)") + s""" + if (Double.valueOf(${ev.primitive}).isNaN()) { + ${ev.isNull} = true; + } + """ + } +} + +case class Pow(left: Expression, right: Expression) + extends BinaryMathExpression(math.pow, "POWER") { + override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { + defineCodeGen(ctx, ev, (c1, c2) => s"java.lang.Math.pow($c1, $c2)") + s""" + if (Double.valueOf(${ev.primitive}).isNaN()) { + ${ev.isNull} = true; + } + """ + } +} + + +/** + * Bitwise unsigned left shift. + * @param left the base number to shift. + * @param right number of bits to left shift. + */ +case class ShiftLeft(left: Expression, right: Expression) + extends BinaryExpression with ExpectsInputTypes { + + override def inputTypes: Seq[AbstractDataType] = + Seq(TypeCollection(IntegerType, LongType), IntegerType) + + override def dataType: DataType = left.dataType + + override def eval(input: InternalRow): Any = { + val valueLeft = left.eval(input) + if (valueLeft != null) { + val valueRight = right.eval(input) + if (valueRight != null) { + valueLeft match { + case l: jl.Long => l << valueRight.asInstanceOf[jl.Integer] + case i: jl.Integer => i << valueRight.asInstanceOf[jl.Integer] + } + } else { + null + } + } else { + null + } + } + + override protected def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { + nullSafeCodeGen(ctx, ev, (result, left, right) => s"$result = $left << $right;") + } +} + + +/** + * Bitwise unsigned left shift. + * @param left the base number to shift. + * @param right number of bits to left shift. + */ +case class ShiftRight(left: Expression, right: Expression) + extends BinaryExpression with ExpectsInputTypes { + + override def inputTypes: Seq[AbstractDataType] = + Seq(TypeCollection(IntegerType, LongType), IntegerType) + + override def dataType: DataType = left.dataType + + override def eval(input: InternalRow): Any = { + val valueLeft = left.eval(input) + if (valueLeft != null) { + val valueRight = right.eval(input) + if (valueRight != null) { + valueLeft match { + case l: jl.Long => l >> valueRight.asInstanceOf[jl.Integer] + case i: jl.Integer => i >> valueRight.asInstanceOf[jl.Integer] + } + } else { + null + } + } else { + null + } + } + + override protected def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { + nullSafeCodeGen(ctx, ev, (result, left, right) => s"$result = $left >> $right;") + } +} + + +/** + * Bitwise unsigned right shift, for integer and long data type. + * @param left the base number. + * @param right the number of bits to right shift. + */ +case class ShiftRightUnsigned(left: Expression, right: Expression) + extends BinaryExpression with ExpectsInputTypes { + + override def inputTypes: Seq[AbstractDataType] = + Seq(TypeCollection(IntegerType, LongType), IntegerType) + + override def dataType: DataType = left.dataType + + override def eval(input: InternalRow): Any = { + val valueLeft = left.eval(input) + if (valueLeft != null) { + val valueRight = right.eval(input) + if (valueRight != null) { + valueLeft match { + case l: jl.Long => l >>> valueRight.asInstanceOf[jl.Integer] + case i: jl.Integer => i >>> valueRight.asInstanceOf[jl.Integer] + } + } else { + null + } + } else { + null + } + } + + override protected def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { + nullSafeCodeGen(ctx, ev, (result, left, right) => s"$result = $left >>> $right;") + } +} + + +case class Hypot(left: Expression, right: Expression) + extends BinaryMathExpression(math.hypot, "HYPOT") + + +/** + * Computes the logarithm of a number. + * @param left the logarithm base, default to e. + * @param right the number to compute the logarithm of. + */ +case class Logarithm(left: Expression, right: Expression) + extends BinaryMathExpression((c1, c2) => math.log(c2) / math.log(c1), "LOG") { + + /** + * Natural log, i.e. using e as the base. + */ + def this(child: Expression) = { + this(EulerNumber(), child) + } + + override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { + val logCode = if (left.isInstanceOf[EulerNumber]) { + defineCodeGen(ctx, ev, (c1, c2) => s"java.lang.Math.log($c2)") + } else { + defineCodeGen(ctx, ev, (c1, c2) => s"java.lang.Math.log($c2) / java.lang.Math.log($c1)") + } + logCode + s""" + if (Double.isNaN(${ev.primitive})) { + ${ev.isNull} = true; + } + """ + } +} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/mathfuncs/binary.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/mathfuncs/binary.scala deleted file mode 100644 index fcc06d3aa103..000000000000 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/mathfuncs/binary.scala +++ /dev/null @@ -1,93 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.sql.catalyst.expressions.mathfuncs - -import org.apache.spark.sql.catalyst.analysis.UnresolvedException -import org.apache.spark.sql.catalyst.expressions.{ExpectsInputTypes, BinaryExpression, Expression, Row} -import org.apache.spark.sql.types._ - -/** - * A binary expression specifically for math functions that take two `Double`s as input and returns - * a `Double`. - * @param f The math function. - * @param name The short name of the function - */ -abstract class BinaryMathExpression(f: (Double, Double) => Double, name: String) - extends BinaryExpression with Serializable with ExpectsInputTypes { self: Product => - type EvaluatedType = Any - override def symbol: String = null - override def expectedChildTypes: Seq[DataType] = Seq(DoubleType, DoubleType) - - override def nullable: Boolean = left.nullable || right.nullable - override def toString: String = s"$name($left, $right)" - - override lazy val resolved = - left.resolved && right.resolved && - left.dataType == right.dataType && - !DecimalType.isFixed(left.dataType) - - override def dataType: DataType = { - if (!resolved) { - throw new UnresolvedException(this, - s"datatype. Can not resolve due to differing types ${left.dataType}, ${right.dataType}") - } - left.dataType - } - - override def eval(input: Row): Any = { - val evalE1 = left.eval(input) - if (evalE1 == null) { - null - } else { - val evalE2 = right.eval(input) - if (evalE2 == null) { - null - } else { - val result = f(evalE1.asInstanceOf[Double], evalE2.asInstanceOf[Double]) - if (result.isNaN) null else result - } - } - } -} - -case class Atan2( - left: Expression, - right: Expression) extends BinaryMathExpression(math.atan2, "ATAN2") { - override def eval(input: Row): Any = { - val evalE1 = left.eval(input) - if (evalE1 == null) { - null - } else { - val evalE2 = right.eval(input) - if (evalE2 == null) { - null - } else { - // With codegen, the values returned by -0.0 and 0.0 are different. Handled with +0.0 - val result = math.atan2(evalE1.asInstanceOf[Double] + 0.0, - evalE2.asInstanceOf[Double] + 0.0) - if (result.isNaN) null else result - } - } - } -} - -case class Hypot( - left: Expression, - right: Expression) extends BinaryMathExpression(math.hypot, "HYPOT") - -case class Pow(left: Expression, right: Expression) extends BinaryMathExpression(math.pow, "POWER") diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/mathfuncs/unary.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/mathfuncs/unary.scala deleted file mode 100644 index dc68469e060c..000000000000 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/mathfuncs/unary.scala +++ /dev/null @@ -1,92 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.sql.catalyst.expressions.mathfuncs - -import org.apache.spark.sql.catalyst.expressions.{ExpectsInputTypes, Expression, Row, UnaryExpression} -import org.apache.spark.sql.types._ - -/** - * A unary expression specifically for math functions. Math Functions expect a specific type of - * input format, therefore these functions extend `ExpectsInputTypes`. - * @param name The short name of the function - */ -abstract class MathematicalExpression(f: Double => Double, name: String) - extends UnaryExpression with Serializable with ExpectsInputTypes { - self: Product => - type EvaluatedType = Any - - override def expectedChildTypes: Seq[DataType] = Seq(DoubleType) - override def dataType: DataType = DoubleType - override def foldable: Boolean = child.foldable - override def nullable: Boolean = true - override def toString: String = s"$name($child)" - - override def eval(input: Row): Any = { - val evalE = child.eval(input) - if (evalE == null) { - null - } else { - val result = f(evalE.asInstanceOf[Double]) - if (result.isNaN) null else result - } - } -} - -case class Acos(child: Expression) extends MathematicalExpression(math.acos, "ACOS") - -case class Asin(child: Expression) extends MathematicalExpression(math.asin, "ASIN") - -case class Atan(child: Expression) extends MathematicalExpression(math.atan, "ATAN") - -case class Cbrt(child: Expression) extends MathematicalExpression(math.cbrt, "CBRT") - -case class Ceil(child: Expression) extends MathematicalExpression(math.ceil, "CEIL") - -case class Cos(child: Expression) extends MathematicalExpression(math.cos, "COS") - -case class Cosh(child: Expression) extends MathematicalExpression(math.cosh, "COSH") - -case class Exp(child: Expression) extends MathematicalExpression(math.exp, "EXP") - -case class Expm1(child: Expression) extends MathematicalExpression(math.expm1, "EXPM1") - -case class Floor(child: Expression) extends MathematicalExpression(math.floor, "FLOOR") - -case class Log(child: Expression) extends MathematicalExpression(math.log, "LOG") - -case class Log10(child: Expression) extends MathematicalExpression(math.log10, "LOG10") - -case class Log1p(child: Expression) extends MathematicalExpression(math.log1p, "LOG1P") - -case class Rint(child: Expression) extends MathematicalExpression(math.rint, "ROUND") - -case class Signum(child: Expression) extends MathematicalExpression(math.signum, "SIGNUM") - -case class Sin(child: Expression) extends MathematicalExpression(math.sin, "SIN") - -case class Sinh(child: Expression) extends MathematicalExpression(math.sinh, "SINH") - -case class Tan(child: Expression) extends MathematicalExpression(math.tan, "TAN") - -case class Tanh(child: Expression) extends MathematicalExpression(math.tanh, "TANH") - -case class ToDegrees(child: Expression) - extends MathematicalExpression(math.toDegrees, "DEGREES") - -case class ToRadians(child: Expression) - extends MathematicalExpression(math.toRadians, "RADIANS") diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/misc.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/misc.scala new file mode 100644 index 000000000000..e008af396694 --- /dev/null +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/misc.scala @@ -0,0 +1,206 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.expressions + +import java.security.MessageDigest +import java.security.NoSuchAlgorithmException +import java.util.zip.CRC32 + +import org.apache.commons.codec.digest.DigestUtils +import org.apache.spark.sql.catalyst.analysis.TypeCheckResult +import org.apache.spark.sql.catalyst.expressions.codegen._ +import org.apache.spark.sql.types._ +import org.apache.spark.unsafe.types.UTF8String + +/** + * A function that calculates an MD5 128-bit checksum and returns it as a hex string + * For input of type [[BinaryType]] + */ +case class Md5(child: Expression) extends UnaryExpression with ExpectsInputTypes { + + override def dataType: DataType = StringType + + override def inputTypes: Seq[DataType] = Seq(BinaryType) + + override def eval(input: InternalRow): Any = { + val value = child.eval(input) + if (value == null) { + null + } else { + UTF8String.fromString(DigestUtils.md5Hex(value.asInstanceOf[Array[Byte]])) + } + } + + override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { + defineCodeGen(ctx, ev, c => + s"${ctx.stringType}.fromString(org.apache.commons.codec.digest.DigestUtils.md5Hex($c))") + } +} + +/** + * A function that calculates the SHA-2 family of functions (SHA-224, SHA-256, SHA-384, and SHA-512) + * and returns it as a hex string. The first argument is the string or binary to be hashed. The + * second argument indicates the desired bit length of the result, which must have a value of 224, + * 256, 384, 512, or 0 (which is equivalent to 256). SHA-224 is supported starting from Java 8. If + * asking for an unsupported SHA function, the return value is NULL. If either argument is NULL or + * the hash length is not one of the permitted values, the return value is NULL. + */ +case class Sha2(left: Expression, right: Expression) + extends BinaryExpression with Serializable with ExpectsInputTypes { + + override def dataType: DataType = StringType + + override def inputTypes: Seq[DataType] = Seq(BinaryType, IntegerType) + + override def eval(input: InternalRow): Any = { + val evalE1 = left.eval(input) + if (evalE1 == null) { + null + } else { + val evalE2 = right.eval(input) + if (evalE2 == null) { + null + } else { + val bitLength = evalE2.asInstanceOf[Int] + val input = evalE1.asInstanceOf[Array[Byte]] + bitLength match { + case 224 => + // DigestUtils doesn't support SHA-224 now + try { + val md = MessageDigest.getInstance("SHA-224") + md.update(input) + UTF8String.fromBytes(md.digest()) + } catch { + // SHA-224 is not supported on the system, return null + case noa: NoSuchAlgorithmException => null + } + case 256 | 0 => + UTF8String.fromString(DigestUtils.sha256Hex(input)) + case 384 => + UTF8String.fromString(DigestUtils.sha384Hex(input)) + case 512 => + UTF8String.fromString(DigestUtils.sha512Hex(input)) + case _ => null + } + } + } + } + override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { + val eval1 = left.gen(ctx) + val eval2 = right.gen(ctx) + val digestUtils = "org.apache.commons.codec.digest.DigestUtils" + + s""" + ${eval1.code} + boolean ${ev.isNull} = ${eval1.isNull}; + ${ctx.javaType(dataType)} ${ev.primitive} = ${ctx.defaultValue(dataType)}; + if (!${ev.isNull}) { + ${eval2.code} + if (!${eval2.isNull}) { + if (${eval2.primitive} == 224) { + try { + java.security.MessageDigest md = java.security.MessageDigest.getInstance("SHA-224"); + md.update(${eval1.primitive}); + ${ev.primitive} = ${ctx.stringType}.fromBytes(md.digest()); + } catch (java.security.NoSuchAlgorithmException e) { + ${ev.isNull} = true; + } + } else if (${eval2.primitive} == 256 || ${eval2.primitive} == 0) { + ${ev.primitive} = + ${ctx.stringType}.fromString(${digestUtils}.sha256Hex(${eval1.primitive})); + } else if (${eval2.primitive} == 384) { + ${ev.primitive} = + ${ctx.stringType}.fromString(${digestUtils}.sha384Hex(${eval1.primitive})); + } else if (${eval2.primitive} == 512) { + ${ev.primitive} = + ${ctx.stringType}.fromString(${digestUtils}.sha512Hex(${eval1.primitive})); + } else { + ${ev.isNull} = true; + } + } else { + ${ev.isNull} = true; + } + } + """ + } +} + +/** + * A function that calculates a sha1 hash value and returns it as a hex string + * For input of type [[BinaryType]] or [[StringType]] + */ +case class Sha1(child: Expression) extends UnaryExpression with ExpectsInputTypes { + + override def dataType: DataType = StringType + + override def inputTypes: Seq[DataType] = Seq(BinaryType) + + override def eval(input: InternalRow): Any = { + val value = child.eval(input) + if (value == null) { + null + } else { + UTF8String.fromString(DigestUtils.shaHex(value.asInstanceOf[Array[Byte]])) + } + } + + override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { + defineCodeGen(ctx, ev, c => + "org.apache.spark.unsafe.types.UTF8String.fromString" + + s"(org.apache.commons.codec.digest.DigestUtils.shaHex($c))" + ) + } +} + +/** + * A function that computes a cyclic redundancy check value and returns it as a bigint + * For input of type [[BinaryType]] + */ +case class Crc32(child: Expression) extends UnaryExpression with ExpectsInputTypes { + + override def dataType: DataType = LongType + + override def inputTypes: Seq[DataType] = Seq(BinaryType) + + override def eval(input: InternalRow): Any = { + val value = child.eval(input) + if (value == null) { + null + } else { + val checksum = new CRC32 + checksum.update(value.asInstanceOf[Array[Byte]], 0, value.asInstanceOf[Array[Byte]].length) + checksum.getValue + } + } + + override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { + val value = child.gen(ctx) + val CRC32 = "java.util.zip.CRC32" + s""" + ${value.code} + boolean ${ev.isNull} = ${value.isNull}; + long ${ev.primitive} = ${ctx.defaultValue(dataType)}; + if (!${ev.isNull}) { + ${CRC32} checksum = new ${CRC32}(); + checksum.update(${value.primitive}, 0, ${value.primitive}.length); + ${ev.primitive} = checksum.getValue(); + } + """ + } + +} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/namedExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/namedExpressions.scala index a9170589f8c6..81ebda3060c5 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/namedExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/namedExpressions.scala @@ -17,10 +17,10 @@ package org.apache.spark.sql.catalyst.expressions -import org.apache.spark.sql.catalyst.trees import org.apache.spark.sql.catalyst.analysis.UnresolvedAttribute import org.apache.spark.sql.catalyst.errors.TreeNodeException -import org.apache.spark.sql.catalyst.trees.LeafNode +import org.apache.spark.sql.catalyst.expressions.codegen.{CodeGenContext, GeneratedExpressionCode} +import org.apache.spark.sql.catalyst.trees import org.apache.spark.sql.types._ object NamedExpression { @@ -111,11 +111,13 @@ case class Alias(child: Expression, name: String)( val explicitMetadata: Option[Metadata] = None) extends NamedExpression with trees.UnaryNode[Expression] { - override type EvaluatedType = Any // Alias(Generator, xx) need to be transformed into Generate(generator, ...) - override lazy val resolved = childrenResolved && !child.isInstanceOf[Generator] + override lazy val resolved = + childrenResolved && checkInputDataTypes().isSuccess && !child.isInstanceOf[Generator] + + override def eval(input: InternalRow): Any = child.eval(input) - override def eval(input: Row): Any = child.eval(input) + override def gen(ctx: CodeGenContext): GeneratedExpressionCode = child.gen(ctx) override def dataType: DataType = child.dataType override def nullable: Boolean = child.nullable @@ -181,6 +183,11 @@ case class AttributeReference( case _ => false } + override def semanticEquals(other: Expression): Boolean = other match { + case ar: AttributeReference => sameRef(ar) + case _ => false + } + override def hashCode: Int = { // See http://stackoverflow.com/questions/113511/hash-code-implementation var h = 17 @@ -224,7 +231,7 @@ case class AttributeReference( } // Unresolved attributes are transient at compile time and don't get evaluated during execution. - override def eval(input: Row = null): EvaluatedType = + override def eval(input: InternalRow = null): Any = throw new TreeNodeException(this, s"No function to evaluate expression. type: ${this.nodeName}") override def toString: String = s"$name#${exprId.id}$typeSuffix" @@ -235,7 +242,6 @@ case class AttributeReference( * expression id or the unresolved indicator. */ case class PrettyAttribute(name: String) extends Attribute with trees.LeafNode[Expression] { - type EvaluatedType = Any override def toString: String = name @@ -247,12 +253,12 @@ case class PrettyAttribute(name: String) extends Attribute with trees.LeafNode[E override def withName(newName: String): Attribute = throw new UnsupportedOperationException override def qualifiers: Seq[String] = throw new UnsupportedOperationException override def exprId: ExprId = throw new UnsupportedOperationException - override def eval(input: Row): EvaluatedType = throw new UnsupportedOperationException + override def eval(input: InternalRow): Any = throw new UnsupportedOperationException override def nullable: Boolean = throw new UnsupportedOperationException override def dataType: DataType = NullType } object VirtualColumn { val groupingIdName: String = "grouping__id" - def newGroupingId: AttributeReference = AttributeReference(groupingIdName, IntegerType, false)() + val groupingIdAttribute: UnresolvedAttribute = UnresolvedAttribute(groupingIdName) } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/nullFunctions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/nullFunctions.scala index f9161cf34f0c..145d323a9f0b 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/nullFunctions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/nullFunctions.scala @@ -17,34 +17,30 @@ package org.apache.spark.sql.catalyst.expressions -import org.apache.spark.sql.catalyst.trees -import org.apache.spark.sql.catalyst.analysis.UnresolvedException +import org.apache.spark.sql.catalyst.analysis.TypeCheckResult +import org.apache.spark.sql.catalyst.expressions.codegen.{CodeGenContext, GeneratedExpressionCode} +import org.apache.spark.sql.catalyst.util.TypeUtils import org.apache.spark.sql.types.DataType case class Coalesce(children: Seq[Expression]) extends Expression { - type EvaluatedType = Any /** Coalesce is nullable if all of its children are nullable, or if it has no children. */ - override def nullable: Boolean = !children.exists(!_.nullable) + override def nullable: Boolean = children.forall(_.nullable) // Coalesce is foldable if all children are foldable. - override def foldable: Boolean = !children.exists(!_.foldable) + override def foldable: Boolean = children.forall(_.foldable) - // Only resolved if all the children are of the same type. - override lazy val resolved = childrenResolved && (children.map(_.dataType).distinct.size == 1) - - override def toString: String = s"Coalesce(${children.mkString(",")})" - - override def dataType: DataType = if (resolved) { - children.head.dataType - } else { - val childTypes = children.map(c => s"$c: ${c.dataType}").mkString(", ") - throw new UnresolvedException( - this, s"Coalesce cannot have children of different types. $childTypes") + override def checkInputDataTypes(): TypeCheckResult = { + if (children == Nil) { + TypeCheckResult.TypeCheckFailure("input to function coalesce cannot be empty") + } else { + TypeUtils.checkForSameTypeInputExpr(children.map(_.dataType), "function coalesce") + } } - override def eval(input: Row): Any = { - var i = 0 + override def dataType: DataType = children.head.dataType + + override def eval(input: InternalRow): Any = { var result: Any = null val childIterator = children.iterator while (childIterator.hasNext && result == null) { @@ -52,27 +48,58 @@ case class Coalesce(children: Seq[Expression]) extends Expression { } result } + + override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { + s""" + boolean ${ev.isNull} = true; + ${ctx.javaType(dataType)} ${ev.primitive} = ${ctx.defaultValue(dataType)}; + """ + + children.map { e => + val eval = e.gen(ctx) + s""" + if (${ev.isNull}) { + ${eval.code} + if (!${eval.isNull}) { + ${ev.isNull} = false; + ${ev.primitive} = ${eval.primitive}; + } + } + """ + }.mkString("\n") + } } -case class IsNull(child: Expression) extends Predicate with trees.UnaryNode[Expression] { - override def foldable: Boolean = child.foldable +case class IsNull(child: Expression) extends UnaryExpression with Predicate { override def nullable: Boolean = false - override def eval(input: Row): Any = { + override def eval(input: InternalRow): Any = { child.eval(input) == null } + override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { + val eval = child.gen(ctx) + ev.isNull = "false" + ev.primitive = eval.isNull + eval.code + } + override def toString: String = s"IS NULL $child" } -case class IsNotNull(child: Expression) extends Predicate with trees.UnaryNode[Expression] { - override def foldable: Boolean = child.foldable +case class IsNotNull(child: Expression) extends UnaryExpression with Predicate { override def nullable: Boolean = false override def toString: String = s"IS NOT NULL $child" - override def eval(input: Row): Any = { + override def eval(input: InternalRow): Any = { child.eval(input) != null } + + override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { + val eval = child.gen(ctx) + ev.isNull = "false" + ev.primitive = s"(!(${eval.isNull}))" + eval.code + } } /** @@ -85,7 +112,7 @@ case class AtLeastNNonNulls(n: Int, children: Seq[Expression]) extends Predicate private[this] val childrenArray = children.toArray - override def eval(input: Row): Boolean = { + override def eval(input: InternalRow): Boolean = { var numNonNulls = 0 var i = 0 while (i < childrenArray.length && numNonNulls < n) { @@ -96,4 +123,25 @@ case class AtLeastNNonNulls(n: Int, children: Seq[Expression]) extends Predicate } numNonNulls >= n } + + override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { + val nonnull = ctx.freshName("nonnull") + val code = children.map { e => + val eval = e.gen(ctx) + s""" + if ($nonnull < $n) { + ${eval.code} + if (!${eval.isNull}) { + $nonnull += 1; + } + } + """ + }.mkString("\n") + s""" + int $nonnull = 0; + $code + boolean ${ev.isNull} = false; + boolean ${ev.primitive} = $nonnull >= $n; + """ + } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/package.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/package.scala index fbc97b2e7531..d24d74e7b82a 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/package.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/package.scala @@ -49,30 +49,30 @@ package org.apache.spark.sql.catalyst */ package object expressions { - type Row = org.apache.spark.sql.Row + type InternalRow = org.apache.spark.sql.catalyst.InternalRow - val Row = org.apache.spark.sql.Row + val InternalRow = org.apache.spark.sql.catalyst.InternalRow /** - * Converts a [[Row]] to another Row given a sequence of expression that define each column of the - * new row. If the schema of the input row is specified, then the given expression will be bound - * to that schema. + * Converts a [[InternalRow]] to another Row given a sequence of expression that define each + * column of the new row. If the schema of the input row is specified, then the given expression + * will be bound to that schema. */ - abstract class Projection extends (Row => Row) + abstract class Projection extends (InternalRow => InternalRow) /** - * Converts a [[Row]] to another Row given a sequence of expression that define each column of the - * new row. If the schema of the input row is specified, then the given expression will be bound - * to that schema. + * Converts a [[InternalRow]] to another Row given a sequence of expression that define each + * column of the new row. If the schema of the input row is specified, then the given expression + * will be bound to that schema. * * In contrast to a normal projection, a MutableProjection reuses the same underlying row object * each time an input row is added. This significantly reduces the cost of calculating the - * projection, but means that it is not safe to hold on to a reference to a [[Row]] after `next()` - * has been called on the [[Iterator]] that produced it. Instead, the user must call `Row.copy()` - * and hold on to the returned [[Row]] before calling `next()`. + * projection, but means that it is not safe to hold on to a reference to a [[InternalRow]] after + * `next()` has been called on the [[Iterator]] that produced it. Instead, the user must call + * `InternalRow.copy()` and hold on to the returned [[InternalRow]] before calling `next()`. */ abstract class MutableProjection extends Projection { - def currentValue: Row + def currentValue: InternalRow /** Uses the given row to store the output of the projection. */ def target(row: MutableRow): MutableProjection diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala index 1d72a9eb834b..0b479f466c63 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala @@ -17,17 +17,18 @@ package org.apache.spark.sql.catalyst.expressions -import org.apache.spark.sql.catalyst.analysis.UnresolvedException -import org.apache.spark.sql.catalyst.errors.TreeNodeException +import org.apache.spark.sql.catalyst.analysis.TypeCheckResult +import org.apache.spark.sql.catalyst.util.TypeUtils +import org.apache.spark.sql.catalyst.expressions.codegen.{GeneratedExpressionCode, CodeGenContext} import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan -import org.apache.spark.sql.types.{DataType, BinaryType, BooleanType, AtomicType} +import org.apache.spark.sql.types._ object InterpretedPredicate { - def create(expression: Expression, inputSchema: Seq[Attribute]): (Row => Boolean) = + def create(expression: Expression, inputSchema: Seq[Attribute]): (InternalRow => Boolean) = create(BindReferences.bindReference(expression, inputSchema)) - def create(expression: Expression): (Row => Boolean) = { - (r: Row) => expression.eval(r).asInstanceOf[Boolean] + def create(expression: Expression): (InternalRow => Boolean) = { + (r: InternalRow) => expression.eval(r).asInstanceOf[Boolean] } } @@ -35,8 +36,6 @@ trait Predicate extends Expression { self: Product => override def dataType: DataType = BooleanType - - type EvaluatedType = Any } trait PredicateHelper { @@ -70,20 +69,21 @@ trait PredicateHelper { expr.references.subsetOf(plan.outputSet) } - case class Not(child: Expression) extends UnaryExpression with Predicate with ExpectsInputTypes { - override def foldable: Boolean = child.foldable - override def nullable: Boolean = child.nullable override def toString: String = s"NOT $child" - override def expectedChildTypes: Seq[DataType] = Seq(BooleanType) + override def inputTypes: Seq[DataType] = Seq(BooleanType) - override def eval(input: Row): Any = { + override def eval(input: InternalRow): Any = { child.eval(input) match { case null => null case b: Boolean => !b } } + + override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { + defineCodeGen(ctx, ev, c => s"!($c)") + } } /** @@ -95,7 +95,7 @@ case class In(value: Expression, list: Seq[Expression]) extends Predicate { override def nullable: Boolean = true // TODO: Figure out correct nullability semantics of IN. override def toString: String = s"$value IN ${list.mkString("(", ",", ")")}" - override def eval(input: Row): Any = { + override def eval(input: InternalRow): Any = { val evaluatedValue = value.eval(input) list.exists(e => e.eval(input) == evaluatedValue) } @@ -114,7 +114,7 @@ case class InSet(value: Expression, hset: Set[Any]) override def nullable: Boolean = true // TODO: Figure out correct nullability semantics of IN. override def toString: String = s"$value INSET ${hset.mkString("(", ",", ")")}" - override def eval(input: Row): Any = { + override def eval(input: InternalRow): Any = { hset.contains(value.eval(input)) } } @@ -122,11 +122,11 @@ case class InSet(value: Expression, hset: Set[Any]) case class And(left: Expression, right: Expression) extends BinaryExpression with Predicate with ExpectsInputTypes { - override def expectedChildTypes: Seq[DataType] = Seq(BooleanType, BooleanType) + override def toString: String = s"($left && $right)" - override def symbol: String = "&&" + override def inputTypes: Seq[DataType] = Seq(BooleanType, BooleanType) - override def eval(input: Row): Any = { + override def eval(input: InternalRow): Any = { val l = left.eval(input) if (l == false) { false @@ -143,16 +143,39 @@ case class And(left: Expression, right: Expression) } } } + + override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { + val eval1 = left.gen(ctx) + val eval2 = right.gen(ctx) + + // The result should be `false`, if any of them is `false` whenever the other is null or not. + s""" + ${eval1.code} + boolean ${ev.isNull} = false; + boolean ${ev.primitive} = false; + + if (!${eval1.isNull} && !${eval1.primitive}) { + } else { + ${eval2.code} + if (!${eval2.isNull} && !${eval2.primitive}) { + } else if (!${eval1.isNull} && !${eval2.isNull}) { + ${ev.primitive} = true; + } else { + ${ev.isNull} = true; + } + } + """ + } } case class Or(left: Expression, right: Expression) extends BinaryExpression with Predicate with ExpectsInputTypes { - override def expectedChildTypes: Seq[DataType] = Seq(BooleanType, BooleanType) + override def toString: String = s"($left || $right)" - override def symbol: String = "||" + override def inputTypes: Seq[DataType] = Seq(BooleanType, BooleanType) - override def eval(input: Row): Any = { + override def eval(input: InternalRow): Any = { val l = left.eval(input) if (l == true) { true @@ -169,61 +192,47 @@ case class Or(left: Expression, right: Expression) } } } -} -abstract class BinaryComparison extends BinaryExpression with Predicate { - self: Product => -} + override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { + val eval1 = left.gen(ctx) + val eval2 = right.gen(ctx) -case class EqualTo(left: Expression, right: Expression) extends BinaryComparison { - override def symbol: String = "=" + // The result should be `true`, if any of them is `true` whenever the other is null or not. + s""" + ${eval1.code} + boolean ${ev.isNull} = false; + boolean ${ev.primitive} = true; - override def eval(input: Row): Any = { - val l = left.eval(input) - if (l == null) { - null - } else { - val r = right.eval(input) - if (r == null) null - else if (left.dataType != BinaryType) l == r - else java.util.Arrays.equals(l.asInstanceOf[Array[Byte]], r.asInstanceOf[Array[Byte]]) - } + if (!${eval1.isNull} && ${eval1.primitive}) { + } else { + ${eval2.code} + if (!${eval2.isNull} && ${eval2.primitive}) { + } else if (!${eval1.isNull} && !${eval2.isNull}) { + ${ev.primitive} = false; + } else { + ${ev.isNull} = true; + } + } + """ } } -case class EqualNullSafe(left: Expression, right: Expression) extends BinaryComparison { - override def symbol: String = "<=>" - - override def nullable: Boolean = false +abstract class BinaryComparison extends BinaryOperator with Predicate { + self: Product => - override def eval(input: Row): Any = { - val l = left.eval(input) - val r = right.eval(input) - if (l == null && r == null) { - true - } else if (l == null || r == null) { - false + override def checkInputDataTypes(): TypeCheckResult = { + if (left.dataType != right.dataType) { + TypeCheckResult.TypeCheckFailure( + s"differing types in ${this.getClass.getSimpleName} " + + s"(${left.dataType} and ${right.dataType}).") } else { - l == r + checkTypesInternal(dataType) } } -} - -case class LessThan(left: Expression, right: Expression) extends BinaryComparison { - override def symbol: String = "<" - lazy val ordering: Ordering[Any] = { - if (left.dataType != right.dataType) { - throw new TreeNodeException(this, - s"Types do not match ${left.dataType} != ${right.dataType}") - } - left.dataType match { - case i: AtomicType => i.ordering.asInstanceOf[Ordering[Any]] - case other => sys.error(s"Type $other does not support ordered operations") - } - } + protected def checkTypesInternal(t: DataType): TypeCheckResult - override def eval(input: Row): Any = { + override def eval(input: InternalRow): Any = { val evalE1 = left.eval(input) if (evalE1 == null) { null @@ -232,258 +241,124 @@ case class LessThan(left: Expression, right: Expression) extends BinaryCompariso if (evalE2 == null) { null } else { - ordering.lt(evalE1, evalE2) + evalInternal(evalE1, evalE2) } } } -} -case class LessThanOrEqual(left: Expression, right: Expression) extends BinaryComparison { - override def symbol: String = "<=" - - lazy val ordering: Ordering[Any] = { - if (left.dataType != right.dataType) { - throw new TreeNodeException(this, - s"Types do not match ${left.dataType} != ${right.dataType}") - } - left.dataType match { - case i: AtomicType => i.ordering.asInstanceOf[Ordering[Any]] - case other => sys.error(s"Type $other does not support ordered operations") - } - } - - override def eval(input: Row): Any = { - val evalE1 = left.eval(input) - if (evalE1 == null) { - null + override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { + if (ctx.isPrimitiveType(left.dataType)) { + // faster version + defineCodeGen(ctx, ev, (c1, c2) => s"$c1 $symbol $c2") } else { - val evalE2 = right.eval(input) - if (evalE2 == null) { - null - } else { - ordering.lteq(evalE1, evalE2) - } + defineCodeGen(ctx, ev, (c1, c2) => s"${ctx.genComp(left.dataType, c1, c2)} $symbol 0") } } -} -case class GreaterThan(left: Expression, right: Expression) extends BinaryComparison { - override def symbol: String = ">" + protected def evalInternal(evalE1: Any, evalE2: Any): Any = + sys.error(s"BinaryComparisons must override either eval or evalInternal") +} - lazy val ordering: Ordering[Any] = { - if (left.dataType != right.dataType) { - throw new TreeNodeException(this, - s"Types do not match ${left.dataType} != ${right.dataType}") - } - left.dataType match { - case i: AtomicType => i.ordering.asInstanceOf[Ordering[Any]] - case other => sys.error(s"Type $other does not support ordered operations") - } - } +private[sql] object BinaryComparison { + def unapply(e: BinaryComparison): Option[(Expression, Expression)] = Some((e.left, e.right)) +} - override def eval(input: Row): Any = { - val evalE1 = left.eval(input) - if(evalE1 == null) { - null - } else { - val evalE2 = right.eval(input) - if (evalE2 == null) { - null - } else { - ordering.gt(evalE1, evalE2) - } - } +/** An extractor that matches both standard 3VL equality and null-safe equality. */ +private[sql] object Equality { + def unapply(e: BinaryComparison): Option[(Expression, Expression)] = e match { + case EqualTo(l, r) => Some((l, r)) + case EqualNullSafe(l, r) => Some((l, r)) + case _ => None } } -case class GreaterThanOrEqual(left: Expression, right: Expression) extends BinaryComparison { - override def symbol: String = ">=" +case class EqualTo(left: Expression, right: Expression) extends BinaryComparison { + override def symbol: String = "=" - lazy val ordering: Ordering[Any] = { - if (left.dataType != right.dataType) { - throw new TreeNodeException(this, - s"Types do not match ${left.dataType} != ${right.dataType}") - } - left.dataType match { - case i: AtomicType => i.ordering.asInstanceOf[Ordering[Any]] - case other => sys.error(s"Type $other does not support ordered operations") - } + override protected def checkTypesInternal(t: DataType) = TypeCheckResult.TypeCheckSuccess + + protected override def evalInternal(l: Any, r: Any) = { + if (left.dataType != BinaryType) l == r + else java.util.Arrays.equals(l.asInstanceOf[Array[Byte]], r.asInstanceOf[Array[Byte]]) } - override def eval(input: Row): Any = { - val evalE1 = left.eval(input) - if (evalE1 == null) { - null - } else { - val evalE2 = right.eval(input) - if (evalE2 == null) { - null - } else { - ordering.gteq(evalE1, evalE2) - } - } + override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { + defineCodeGen(ctx, ev, (c1, c2) => ctx.genEqual(left.dataType, c1, c2)) } } -case class If(predicate: Expression, trueValue: Expression, falseValue: Expression) - extends Expression { - - override def children: Seq[Expression] = predicate :: trueValue :: falseValue :: Nil - override def nullable: Boolean = trueValue.nullable || falseValue.nullable +case class EqualNullSafe(left: Expression, right: Expression) extends BinaryComparison { + override def symbol: String = "<=>" - override lazy val resolved = childrenResolved && trueValue.dataType == falseValue.dataType - override def dataType: DataType = { - if (!resolved) { - throw new UnresolvedException( - this, - s"Can not resolve due to differing types ${trueValue.dataType}, ${falseValue.dataType}") - } - trueValue.dataType - } + override def nullable: Boolean = false - type EvaluatedType = Any + override protected def checkTypesInternal(t: DataType) = TypeCheckResult.TypeCheckSuccess - override def eval(input: Row): Any = { - if (true == predicate.eval(input)) { - trueValue.eval(input) + override def eval(input: InternalRow): Any = { + val l = left.eval(input) + val r = right.eval(input) + if (l == null && r == null) { + true + } else if (l == null || r == null) { + false } else { - falseValue.eval(input) + if (left.dataType != BinaryType) l == r + else java.util.Arrays.equals(l.asInstanceOf[Array[Byte]], r.asInstanceOf[Array[Byte]]) } } - override def toString: String = s"if ($predicate) $trueValue else $falseValue" + override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { + val eval1 = left.gen(ctx) + val eval2 = right.gen(ctx) + val equalCode = ctx.genEqual(left.dataType, eval1.primitive, eval2.primitive) + ev.isNull = "false" + eval1.code + eval2.code + s""" + boolean ${ev.primitive} = (${eval1.isNull} && ${eval2.isNull}) || + (!${eval1.isNull} && $equalCode); + """ + } } -trait CaseWhenLike extends Expression { - self: Product => +case class LessThan(left: Expression, right: Expression) extends BinaryComparison { + override def symbol: String = "<" - type EvaluatedType = Any + override protected def checkTypesInternal(t: DataType) = + TypeUtils.checkForOrderingExpr(left.dataType, "operator " + symbol) - // Note that `branches` are considered in consecutive pairs (cond, val), and the optional last - // element is the value for the default catch-all case (if provided). - // Hence, `branches` consists of at least two elements, and can have an odd or even length. - def branches: Seq[Expression] + private lazy val ordering = TypeUtils.getOrdering(left.dataType) + + protected override def evalInternal(evalE1: Any, evalE2: Any) = ordering.lt(evalE1, evalE2) +} - @transient lazy val whenList = - branches.sliding(2, 2).collect { case Seq(whenExpr, _) => whenExpr }.toSeq - @transient lazy val thenList = - branches.sliding(2, 2).collect { case Seq(_, thenExpr) => thenExpr }.toSeq - val elseValue = if (branches.length % 2 == 0) None else Option(branches.last) +case class LessThanOrEqual(left: Expression, right: Expression) extends BinaryComparison { + override def symbol: String = "<=" - // both then and else val should be considered. - def valueTypes: Seq[DataType] = (thenList ++ elseValue).map(_.dataType) - def valueTypesEqual: Boolean = valueTypes.distinct.size <= 1 + override protected def checkTypesInternal(t: DataType) = + TypeUtils.checkForOrderingExpr(left.dataType, "operator " + symbol) - override def dataType: DataType = { - if (!resolved) { - throw new UnresolvedException(this, "cannot resolve due to differing types in some branches") - } - valueTypes.head - } + private lazy val ordering = TypeUtils.getOrdering(left.dataType) - override def nullable: Boolean = { - // If no value is nullable and no elseValue is provided, the whole statement defaults to null. - thenList.exists(_.nullable) || (elseValue.map(_.nullable).getOrElse(true)) - } + protected override def evalInternal(evalE1: Any, evalE2: Any) = ordering.lteq(evalE1, evalE2) } -// scalastyle:off -/** - * Case statements of the form "CASE WHEN a THEN b [WHEN c THEN d]* [ELSE e] END". - * Refer to this link for the corresponding semantics: - * https://cwiki.apache.org/confluence/display/Hive/LanguageManual+UDF#LanguageManualUDF-ConditionalFunctions - */ -// scalastyle:on -case class CaseWhen(branches: Seq[Expression]) extends CaseWhenLike { - - // Use private[this] Array to speed up evaluation. - @transient private[this] lazy val branchesArr = branches.toArray - - override def children: Seq[Expression] = branches - - override lazy val resolved: Boolean = - childrenResolved && - whenList.forall(_.dataType == BooleanType) && - valueTypesEqual - - /** Written in imperative fashion for performance considerations. */ - override def eval(input: Row): Any = { - val len = branchesArr.length - var i = 0 - // If all branches fail and an elseVal is not provided, the whole statement - // defaults to null, according to Hive's semantics. - while (i < len - 1) { - if (branchesArr(i).eval(input) == true) { - return branchesArr(i + 1).eval(input) - } - i += 2 - } - var res: Any = null - if (i == len - 1) { - res = branchesArr(i).eval(input) - } - return res - } +case class GreaterThan(left: Expression, right: Expression) extends BinaryComparison { + override def symbol: String = ">" - override def toString: String = { - "CASE" + branches.sliding(2, 2).map { - case Seq(cond, value) => s" WHEN $cond THEN $value" - case Seq(elseValue) => s" ELSE $elseValue" - }.mkString - } + override protected def checkTypesInternal(t: DataType) = + TypeUtils.checkForOrderingExpr(left.dataType, "operator " + symbol) + + private lazy val ordering = TypeUtils.getOrdering(left.dataType) + + protected override def evalInternal(evalE1: Any, evalE2: Any) = ordering.gt(evalE1, evalE2) } -// scalastyle:off -/** - * Case statements of the form "CASE a WHEN b THEN c [WHEN d THEN e]* [ELSE f] END". - * Refer to this link for the corresponding semantics: - * https://cwiki.apache.org/confluence/display/Hive/LanguageManual+UDF#LanguageManualUDF-ConditionalFunctions - */ -// scalastyle:on -case class CaseKeyWhen(key: Expression, branches: Seq[Expression]) extends CaseWhenLike { - - // Use private[this] Array to speed up evaluation. - @transient private[this] lazy val branchesArr = branches.toArray - - override def children: Seq[Expression] = key +: branches - - override lazy val resolved: Boolean = - childrenResolved && valueTypesEqual - - /** Written in imperative fashion for performance considerations. */ - override def eval(input: Row): Any = { - val evaluatedKey = key.eval(input) - val len = branchesArr.length - var i = 0 - // If all branches fail and an elseVal is not provided, the whole statement - // defaults to null, according to Hive's semantics. - while (i < len - 1) { - if (equalNullSafe(evaluatedKey, branchesArr(i).eval(input))) { - return branchesArr(i + 1).eval(input) - } - i += 2 - } - var res: Any = null - if (i == len - 1) { - res = branchesArr(i).eval(input) - } - return res - } +case class GreaterThanOrEqual(left: Expression, right: Expression) extends BinaryComparison { + override def symbol: String = ">=" - private def equalNullSafe(l: Any, r: Any) = { - if (l == null && r == null) { - true - } else if (l == null || r == null) { - false - } else { - l == r - } - } + override protected def checkTypesInternal(t: DataType) = + TypeUtils.checkForOrderingExpr(left.dataType, "operator " + symbol) - override def toString: String = { - s"CASE $key" + branches.sliding(2, 2).map { - case Seq(cond, value) => s" WHEN $cond THEN $value" - case Seq(elseValue) => s" ELSE $elseValue" - }.mkString - } + private lazy val ordering = TypeUtils.getOrdering(left.dataType) + + protected override def evalInternal(evalE1: Any, evalE2: Any) = ordering.gteq(evalE1, evalE2) } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/random.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/random.scala index 66d7c8b07cce..45588bacd2e4 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/random.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/random.scala @@ -18,13 +18,14 @@ package org.apache.spark.sql.catalyst.expressions import org.apache.spark.TaskContext +import org.apache.spark.sql.AnalysisException import org.apache.spark.sql.types.{DataType, DoubleType} import org.apache.spark.util.Utils import org.apache.spark.util.random.XORShiftRandom /** * A Random distribution generating expression. - * TODO: This can be made generic to generate any type of random distribution, or any type of + * TODO: This can be made generic to generate any type of random distribution, or any type of * StructType. * * Since this expression is stateful, it cannot be a case object. @@ -36,9 +37,13 @@ abstract class RDG(seed: Long) extends LeafExpression with Serializable { * Record ID within each partition. By being transient, the Random Number Generator is * reset every time we serialize and deserialize it. */ - @transient protected lazy val rng = new XORShiftRandom(seed + TaskContext.get().partitionId()) + @transient protected lazy val partitionId = TaskContext.get() match { + case null => 0 + case _ => TaskContext.get().partitionId() + } + @transient protected lazy val rng = new XORShiftRandom(seed + partitionId) - override type EvaluatedType = Double + override def deterministic: Boolean = false override def nullable: Boolean = false @@ -46,11 +51,25 @@ abstract class RDG(seed: Long) extends LeafExpression with Serializable { } /** Generate a random column with i.i.d. uniformly distributed values in [0, 1). */ -case class Rand(seed: Long = Utils.random.nextLong()) extends RDG(seed) { - override def eval(input: Row): Double = rng.nextDouble() +case class Rand(seed: Long) extends RDG(seed) { + override def eval(input: InternalRow): Double = rng.nextDouble() + + def this() = this(Utils.random.nextLong()) + + def this(seed: Expression) = this(seed match { + case IntegerLiteral(s) => s + case _ => throw new AnalysisException("Input argument to rand must be an integer literal.") + }) } /** Generate a random column with i.i.d. gaussian random distribution. */ -case class Randn(seed: Long = Utils.random.nextLong()) extends RDG(seed) { - override def eval(input: Row): Double = rng.nextGaussian() +case class Randn(seed: Long) extends RDG(seed) { + override def eval(input: InternalRow): Double = rng.nextGaussian() + + def this() = this(Utils.random.nextLong()) + + def this(seed: Expression) = this(seed match { + case IntegerLiteral(s) => s + case _ => throw new AnalysisException("Input argument to rand must be an integer literal.") + }) } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/rows.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/rows.scala index 5fd892c42e69..dd5f2ed2d382 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/rows.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/rows.scala @@ -17,32 +17,46 @@ package org.apache.spark.sql.catalyst.expressions -import org.apache.spark.sql.types.{UTF8String, DataType, StructType, AtomicType} +import org.apache.spark.sql.Row +import org.apache.spark.sql.types.{DataType, StructType, AtomicType} +import org.apache.spark.unsafe.types.UTF8String /** - * An extended interface to [[Row]] that allows the values for each column to be updated. Setting - * a value through a primitive function implicitly marks that column as not null. + * An extended interface to [[InternalRow]] that allows the values for each column to be updated. + * Setting a value through a primitive function implicitly marks that column as not null. */ -trait MutableRow extends Row { +abstract class MutableRow extends InternalRow { def setNullAt(i: Int): Unit - def update(ordinal: Int, value: Any) - - def setInt(ordinal: Int, value: Int) - def setLong(ordinal: Int, value: Long) - def setDouble(ordinal: Int, value: Double) - def setBoolean(ordinal: Int, value: Boolean) - def setShort(ordinal: Int, value: Short) - def setByte(ordinal: Int, value: Byte) - def setFloat(ordinal: Int, value: Float) - def setString(ordinal: Int, value: String) - // TODO(davies): add setDate() and setDecimal() + def update(i: Int, value: Any) + + // default implementation (slow) + def setInt(i: Int, value: Int): Unit = { update(i, value) } + def setLong(i: Int, value: Long): Unit = { update(i, value) } + def setDouble(i: Int, value: Double): Unit = { update(i, value) } + def setBoolean(i: Int, value: Boolean): Unit = { update(i, value) } + def setShort(i: Int, value: Short): Unit = { update(i, value) } + def setByte(i: Int, value: Byte): Unit = { update(i, value) } + def setFloat(i: Int, value: Float): Unit = { update(i, value) } + def setString(i: Int, value: String): Unit = { + update(i, UTF8String.fromString(value)) + } + + override def copy(): InternalRow = { + val arr = new Array[Any](length) + var i = 0 + while (i < length) { + arr(i) = get(i) + i += 1 + } + new GenericInternalRow(arr) + } } /** * A row with no data. Calling any methods will result in an error. Can be used as a placeholder. */ -object EmptyRow extends Row { +object EmptyRow extends InternalRow { override def apply(i: Int): Any = throw new UnsupportedOperationException override def toSeq: Seq[Any] = Seq.empty override def length: Int = 0 @@ -56,120 +70,57 @@ object EmptyRow extends Row { override def getByte(i: Int): Byte = throw new UnsupportedOperationException override def getString(i: Int): String = throw new UnsupportedOperationException override def getAs[T](i: Int): T = throw new UnsupportedOperationException - override def copy(): Row = this + override def copy(): InternalRow = this } /** - * A row implementation that uses an array of objects as the underlying storage. Note that, while - * the array is not copied, and thus could technically be mutated after creation, this is not - * allowed. + * A row implementation that uses an array of objects as the underlying storage. */ -class GenericRow(protected[sql] val values: Array[Any]) extends Row { - /** No-arg constructor for serialization. */ - protected def this() = this(null) +trait ArrayBackedRow { + self: Row => - def this(size: Int) = this(new Array[Any](size)) + protected val values: Array[Any] override def toSeq: Seq[Any] = values.toSeq - override def length: Int = values.length + def length: Int = values.length override def apply(i: Int): Any = values(i) - override def isNullAt(i: Int): Boolean = values(i) == null - - override def getInt(i: Int): Int = { - if (values(i) == null) sys.error("Failed to check null bit for primitive int value.") - values(i).asInstanceOf[Int] - } - - override def getLong(i: Int): Long = { - if (values(i) == null) sys.error("Failed to check null bit for primitive long value.") - values(i).asInstanceOf[Long] - } - - override def getDouble(i: Int): Double = { - if (values(i) == null) sys.error("Failed to check null bit for primitive double value.") - values(i).asInstanceOf[Double] - } - - override def getFloat(i: Int): Float = { - if (values(i) == null) sys.error("Failed to check null bit for primitive float value.") - values(i).asInstanceOf[Float] - } + def setNullAt(i: Int): Unit = { values(i) = null} - override def getBoolean(i: Int): Boolean = { - if (values(i) == null) sys.error("Failed to check null bit for primitive boolean value.") - values(i).asInstanceOf[Boolean] - } - - override def getShort(i: Int): Short = { - if (values(i) == null) sys.error("Failed to check null bit for primitive short value.") - values(i).asInstanceOf[Short] - } - - override def getByte(i: Int): Byte = { - if (values(i) == null) sys.error("Failed to check null bit for primitive byte value.") - values(i).asInstanceOf[Byte] - } - - override def getString(i: Int): String = { - values(i) match { - case null => null - case s: String => s - case utf8: UTF8String => utf8.toString - } - } - - // TODO(davies): add getDate and getDecimal + def update(i: Int, value: Any): Unit = { values(i) = value } +} - // Custom hashCode function that matches the efficient code generated version. - override def hashCode: Int = { - var result: Int = 37 +/** + * A row implementation that uses an array of objects as the underlying storage. Note that, while + * the array is not copied, and thus could technically be mutated after creation, this is not + * allowed. + */ +class GenericRow(protected[sql] val values: Array[Any]) extends Row with ArrayBackedRow { + /** No-arg constructor for serialization. */ + protected def this() = this(null) - var i = 0 - while (i < values.length) { - val update: Int = - if (isNullAt(i)) { - 0 - } else { - apply(i) match { - case b: Boolean => if (b) 0 else 1 - case b: Byte => b.toInt - case s: Short => s.toInt - case i: Int => i - case l: Long => (l ^ (l >>> 32)).toInt - case f: Float => java.lang.Float.floatToIntBits(f) - case d: Double => - val b = java.lang.Double.doubleToLongBits(d) - (b ^ (b >>> 32)).toInt - case other => other.hashCode() - } - } - result = 37 * result + update - i += 1 - } - result - } + def this(size: Int) = this(new Array[Any](size)) + // This is used by test or outside override def equals(o: Any): Boolean = o match { - case other: Row => - if (values.length != other.length) { - return false - } - + case other: Row if other.length == length => var i = 0 - while (i < values.length) { + while (i < length) { if (isNullAt(i) != other.isNullAt(i)) { return false } - if (apply(i) != other.apply(i)) { + val equal = (apply(i), other.apply(i)) match { + case (a: Array[Byte], b: Array[Byte]) => java.util.Arrays.equals(a, b) + case (a, b) => a == b + } + if (!equal) { return false } i += 1 } true - case _ => false } @@ -185,34 +136,35 @@ class GenericRowWithSchema(values: Array[Any], override val schema: StructType) override def fieldIndex(name: String): Int = schema.fieldIndex(name) } -class GenericMutableRow(v: Array[Any]) extends GenericRow(v) with MutableRow { +/** + * A internal row implementation that uses an array of objects as the underlying storage. + * Note that, while the array is not copied, and thus could technically be mutated after creation, + * this is not allowed. + */ +class GenericInternalRow(protected[sql] val values: Array[Any]) + extends InternalRow with ArrayBackedRow { /** No-arg constructor for serialization. */ protected def this() = this(null) def this(size: Int) = this(new Array[Any](size)) - override def setBoolean(ordinal: Int, value: Boolean): Unit = { values(ordinal) = value } - override def setByte(ordinal: Int, value: Byte): Unit = { values(ordinal) = value } - override def setDouble(ordinal: Int, value: Double): Unit = { values(ordinal) = value } - override def setFloat(ordinal: Int, value: Float): Unit = { values(ordinal) = value } - override def setInt(ordinal: Int, value: Int): Unit = { values(ordinal) = value } - override def setLong(ordinal: Int, value: Long): Unit = { values(ordinal) = value } - override def setString(ordinal: Int, value: String) { values(ordinal) = UTF8String(value)} - override def setNullAt(i: Int): Unit = { values(i) = null } + override def copy(): InternalRow = this +} - override def setShort(ordinal: Int, value: Short): Unit = { values(ordinal) = value } +class GenericMutableRow(val values: Array[Any]) extends MutableRow with ArrayBackedRow { + /** No-arg constructor for serialization. */ + protected def this() = this(null) - override def update(ordinal: Int, value: Any): Unit = { values(ordinal) = value } + def this(size: Int) = this(new Array[Any](size)) - override def copy(): Row = new GenericRow(values.clone()) + override def copy(): InternalRow = new GenericInternalRow(values.clone()) } - -class RowOrdering(ordering: Seq[SortOrder]) extends Ordering[Row] { +class RowOrdering(ordering: Seq[SortOrder]) extends Ordering[InternalRow] { def this(ordering: Seq[SortOrder], inputSchema: Seq[Attribute]) = this(ordering.map(BindReferences.bindReference(_, inputSchema))) - def compare(a: Row, b: Row): Int = { + def compare(a: InternalRow, b: InternalRow): Int = { var i = 0 while (i < ordering.size) { val order = ordering(i) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/sets.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/sets.scala index 4c4418227820..5d51a4ca6533 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/sets.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/sets.scala @@ -17,6 +17,7 @@ package org.apache.spark.sql.catalyst.expressions +import org.apache.spark.sql.catalyst.expressions.codegen.{CodeGenContext, GeneratedExpressionCode} import org.apache.spark.sql.types._ import org.apache.spark.util.collection.OpenHashSet @@ -51,33 +52,44 @@ private[sql] class OpenHashSetUDT( * Creates a new set of the specified type */ case class NewSet(elementType: DataType) extends LeafExpression { - type EvaluatedType = Any override def nullable: Boolean = false override def dataType: OpenHashSetUDT = new OpenHashSetUDT(elementType) - override def eval(input: Row): Any = { + override def eval(input: InternalRow): Any = { new OpenHashSet[Any]() } + override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { + elementType match { + case IntegerType | LongType => + ev.isNull = "false" + s""" + ${ctx.javaType(dataType)} ${ev.primitive} = new ${ctx.javaType(dataType)}(); + """ + case _ => super.genCode(ctx, ev) + } + } + override def toString: String = s"new Set($dataType)" } /** * Adds an item to a set. * For performance, this expression mutates its input during evaluation. + * Note: this expression is internal and created only by the GeneratedAggregate, + * we don't need to do type check for it. */ case class AddItemToSet(item: Expression, set: Expression) extends Expression { - type EvaluatedType = Any override def children: Seq[Expression] = item :: set :: Nil override def nullable: Boolean = set.nullable - override def dataType: OpenHashSetUDT = set.dataType.asInstanceOf[OpenHashSetUDT] + override def dataType: DataType = set.dataType - override def eval(input: Row): Any = { + override def eval(input: InternalRow): Any = { val itemEval = item.eval(input) val setEval = set.eval(input).asInstanceOf[OpenHashSet[Any]] @@ -93,23 +105,39 @@ case class AddItemToSet(item: Expression, set: Expression) extends Expression { } } + override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { + val elementType = set.dataType.asInstanceOf[OpenHashSetUDT].elementType + elementType match { + case IntegerType | LongType => + val itemEval = item.gen(ctx) + val setEval = set.gen(ctx) + val htype = ctx.javaType(dataType) + + ev.isNull = "false" + ev.primitive = setEval.primitive + itemEval.code + setEval.code + s""" + if (!${itemEval.isNull} && !${setEval.isNull}) { + (($htype)${setEval.primitive}).add(${itemEval.primitive}); + } + """ + case _ => super.genCode(ctx, ev) + } + } + override def toString: String = s"$set += $item" } /** * Combines the elements of two sets. * For performance, this expression mutates its left input set during evaluation. + * Note: this expression is internal and created only by the GeneratedAggregate, + * we don't need to do type check for it. */ case class CombineSets(left: Expression, right: Expression) extends BinaryExpression { - type EvaluatedType = Any - - override def nullable: Boolean = left.nullable || right.nullable - override def dataType: OpenHashSetUDT = left.dataType.asInstanceOf[OpenHashSetUDT] + override def dataType: DataType = left.dataType - override def symbol: String = "++=" - - override def eval(input: Row): Any = { + override def eval(input: InternalRow): Any = { val leftEval = left.eval(input).asInstanceOf[OpenHashSet[Any]] if(leftEval != null) { val rightEval = right.eval(input).asInstanceOf[OpenHashSet[Any]] @@ -119,27 +147,43 @@ case class CombineSets(left: Expression, right: Expression) extends BinaryExpres val rightValue = iterator.next() leftEval.add(rightValue) } - leftEval - } else { - null } + leftEval } else { null } } + + override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { + val elementType = left.dataType.asInstanceOf[OpenHashSetUDT].elementType + elementType match { + case IntegerType | LongType => + val leftEval = left.gen(ctx) + val rightEval = right.gen(ctx) + val htype = ctx.javaType(dataType) + + ev.isNull = leftEval.isNull + ev.primitive = leftEval.primitive + leftEval.code + rightEval.code + s""" + if (!${leftEval.isNull} && !${rightEval.isNull}) { + ${leftEval.primitive}.union((${htype})${rightEval.primitive}); + } + """ + case _ => super.genCode(ctx, ev) + } + } } /** * Returns the number of elements in the input set. + * Note: this expression is internal and created only by the GeneratedAggregate, + * we don't need to do type check for it. */ case class CountSet(child: Expression) extends UnaryExpression { - type EvaluatedType = Any - - override def nullable: Boolean = child.nullable override def dataType: DataType = LongType - override def eval(input: Row): Any = { + override def eval(input: InternalRow): Any = { val childEval = child.eval(input).asInstanceOf[OpenHashSet[Any]] if (childEval != null) { childEval.size.toLong diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringOperations.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringOperations.scala index 7683e0990ce8..1a14a7a44934 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringOperations.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringOperations.scala @@ -19,20 +19,21 @@ package org.apache.spark.sql.catalyst.expressions import java.util.regex.Pattern +import org.apache.commons.lang3.StringUtils import org.apache.spark.sql.catalyst.analysis.UnresolvedException +import org.apache.spark.sql.catalyst.expressions.codegen._ import org.apache.spark.sql.types._ +import org.apache.spark.unsafe.types.UTF8String trait StringRegexExpression extends ExpectsInputTypes { self: BinaryExpression => - type EvaluatedType = Any - def escape(v: String): String def matches(regex: Pattern, str: String): Boolean override def nullable: Boolean = left.nullable || right.nullable override def dataType: DataType = BooleanType - override def expectedChildTypes: Seq[DataType] = Seq(StringType, StringType) + override def inputTypes: Seq[DataType] = Seq(StringType, StringType) // try cache the pattern for Literal private lazy val cache: Pattern = right match { @@ -40,16 +41,16 @@ trait StringRegexExpression extends ExpectsInputTypes { case _ => null } - protected def compile(str: String): Pattern = if(str == null) { + protected def compile(str: String): Pattern = if (str == null) { null } else { // Let it raise exception if couldn't compile the regex string Pattern.compile(escape(str)) } - protected def pattern(str: String) = if(cache == null) compile(str) else cache + protected def pattern(str: String) = if (cache == null) compile(str) else cache - override def eval(input: Row): Any = { + override def eval(input: InternalRow): Any = { val l = left.eval(input) if (l == null) { null @@ -75,8 +76,6 @@ trait StringRegexExpression extends ExpectsInputTypes { case class Like(left: Expression, right: Expression) extends BinaryExpression with StringRegexExpression { - override def symbol: String = "LIKE" - // replace the _ with .{1} exactly match 1 time of any character // replace the % with .*, match 0 or more times with any character override def escape(v: String): String = @@ -101,29 +100,27 @@ case class Like(left: Expression, right: Expression) } override def matches(regex: Pattern, str: String): Boolean = regex.matcher(str).matches() + + override def toString: String = s"$left LIKE $right" } case class RLike(left: Expression, right: Expression) extends BinaryExpression with StringRegexExpression { - override def symbol: String = "RLIKE" override def escape(v: String): String = v override def matches(regex: Pattern, str: String): Boolean = regex.matcher(str).find(0) + override def toString: String = s"$left RLIKE $right" } trait CaseConversionExpression extends ExpectsInputTypes { self: UnaryExpression => - type EvaluatedType = Any - def convert(v: UTF8String): UTF8String - override def foldable: Boolean = child.foldable - override def nullable: Boolean = child.nullable override def dataType: DataType = StringType - override def expectedChildTypes: Seq[DataType] = Seq(StringType) + override def inputTypes: Seq[DataType] = Seq(StringType) - override def eval(input: Row): Any = { + override def eval(input: InternalRow): Any = { val evaluated = child.eval(input) if (evaluated == null) { null @@ -137,20 +134,24 @@ trait CaseConversionExpression extends ExpectsInputTypes { * A function that converts the characters of a string to uppercase. */ case class Upper(child: Expression) extends UnaryExpression with CaseConversionExpression { - - override def convert(v: UTF8String): UTF8String = v.toUpperCase() - override def toString: String = s"Upper($child)" + override def convert(v: UTF8String): UTF8String = v.toUpperCase + + override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { + defineCodeGen(ctx, ev, c => s"($c).toUpperCase()") + } } /** * A function that converts the characters of a string to lowercase. */ case class Lower(child: Expression) extends UnaryExpression with CaseConversionExpression { - - override def convert(v: UTF8String): UTF8String = v.toLowerCase() - override def toString: String = s"Lower($child)" + override def convert(v: UTF8String): UTF8String = v.toLowerCase + + override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { + defineCodeGen(ctx, ev, c => s"($c).toLowerCase()") + } } /** A base trait for functions that compare two strings, returning a boolean. */ @@ -159,13 +160,11 @@ trait StringComparison extends ExpectsInputTypes { def compare(l: UTF8String, r: UTF8String): Boolean - override type EvaluatedType = Any - override def nullable: Boolean = left.nullable || right.nullable - override def expectedChildTypes: Seq[DataType] = Seq(StringType, StringType) + override def inputTypes: Seq[DataType] = Seq(StringType, StringType) - override def eval(input: Row): Any = { + override def eval(input: InternalRow): Any = { val leftEval = left.eval(input) if(leftEval == null) { null @@ -176,8 +175,6 @@ trait StringComparison extends ExpectsInputTypes { } } - override def symbol: String = nodeName - override def toString: String = s"$nodeName($left, $right)" } @@ -187,6 +184,9 @@ trait StringComparison extends ExpectsInputTypes { case class Contains(left: Expression, right: Expression) extends BinaryExpression with Predicate with StringComparison { override def compare(l: UTF8String, r: UTF8String): Boolean = l.contains(r) + override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { + defineCodeGen(ctx, ev, (c1, c2) => s"($c1).contains($c2)") + } } /** @@ -195,6 +195,9 @@ case class Contains(left: Expression, right: Expression) case class StartsWith(left: Expression, right: Expression) extends BinaryExpression with Predicate with StringComparison { override def compare(l: UTF8String, r: UTF8String): Boolean = l.startsWith(r) + override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { + defineCodeGen(ctx, ev, (c1, c2) => s"($c1).startsWith($c2)") + } } /** @@ -203,6 +206,9 @@ case class StartsWith(left: Expression, right: Expression) case class EndsWith(left: Expression, right: Expression) extends BinaryExpression with Predicate with StringComparison { override def compare(l: UTF8String, r: UTF8String): Boolean = l.endsWith(r) + override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { + defineCodeGen(ctx, ev, (c1, c2) => s"($c1).endsWith($c2)") + } } /** @@ -211,12 +217,15 @@ case class EndsWith(left: Expression, right: Expression) */ case class Substring(str: Expression, pos: Expression, len: Expression) extends Expression with ExpectsInputTypes { - - type EvaluatedType = Any + + def this(str: Expression, pos: Expression) = { + this(str, pos, Literal(Integer.MAX_VALUE)) + } override def foldable: Boolean = str.foldable && pos.foldable && len.foldable override def nullable: Boolean = str.nullable || pos.nullable || len.nullable + override def dataType: DataType = { if (!resolved) { throw new UnresolvedException(this, s"Cannot resolve since $children are not resolved") @@ -224,14 +233,14 @@ case class Substring(str: Expression, pos: Expression, len: Expression) if (str.dataType == BinaryType) str.dataType else StringType } - override def expectedChildTypes: Seq[DataType] = Seq(StringType, IntegerType, IntegerType) + override def inputTypes: Seq[DataType] = Seq(StringType, IntegerType, IntegerType) override def children: Seq[Expression] = str :: pos :: len :: Nil @inline def slicePos(startPos: Int, sliceLen: Int, length: () => Int): (Int, Int) = { // Hive and SQL use one-based indexing for SUBSTR arguments but also accept zero and - // negative indices for start positions. If a start index i is greater than 0, it + // negative indices for start positions. If a start index i is greater than 0, it // refers to element i-1 in the sequence. If a start index i is less than 0, it refers // to the -ith element before the end of the sequence. If a start index i is 0, it // refers to the first element. @@ -250,7 +259,7 @@ case class Substring(str: Expression, pos: Expression, len: Expression) (start, end) } - override def eval(input: Row): Any = { + override def eval(input: InternalRow): Any = { val string = str.eval(input) val po = pos.eval(input) val ln = len.eval(input) @@ -266,14 +275,176 @@ case class Substring(str: Expression, pos: Expression, len: Expression) ba.slice(st, end) case s: UTF8String => val (st, end) = slicePos(start, length, () => s.length()) - s.slice(st, end) + s.substring(st, end) } } } +} + +/** + * A function that return the length of the given string expression. + */ +case class StringLength(child: Expression) extends UnaryExpression with ExpectsInputTypes { + override def dataType: DataType = IntegerType + override def inputTypes: Seq[DataType] = Seq(StringType) + + override def eval(input: InternalRow): Any = { + val string = child.eval(input) + if (string == null) null else string.asInstanceOf[UTF8String].length + } - override def toString: String = len match { - // TODO: This is broken because max is not an integer value. - case max if max == Integer.MAX_VALUE => s"SUBSTR($str, $pos)" - case _ => s"SUBSTR($str, $pos, $len)" + override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { + defineCodeGen(ctx, ev, c => s"($c).length()") } + + override def prettyName: String = "length" } + +/** + * A function that return the Levenshtein distance between the two given strings. + */ +case class Levenshtein(left: Expression, right: Expression) extends BinaryExpression + with ExpectsInputTypes { + + override def inputTypes: Seq[AbstractDataType] = Seq(StringType, StringType) + + override def dataType: DataType = IntegerType + + override def eval(input: InternalRow): Any = { + val leftValue = left.eval(input) + if (leftValue == null) { + null + } else { + val rightValue = right.eval(input) + if(rightValue == null) { + null + } else { + StringUtils.getLevenshteinDistance(leftValue.toString, rightValue.toString) + } + } + } + + override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { + val stringUtils = classOf[StringUtils].getName + nullSafeCodeGen(ctx, ev, (res, left, right) => + s"$res = $stringUtils.getLevenshteinDistance($left.toString(), $right.toString());") + } +} + +/** + * Returns the numeric value of the first character of str. + */ +case class Ascii(child: Expression) extends UnaryExpression with ExpectsInputTypes { + override def dataType: DataType = IntegerType + override def inputTypes: Seq[DataType] = Seq(StringType) + + override def eval(input: InternalRow): Any = { + val string = child.eval(input) + if (string == null) { + null + } else { + val bytes = string.asInstanceOf[UTF8String].getBytes + if (bytes.length > 0) { + bytes(0).asInstanceOf[Int] + } else { + 0 + } + } + } +} + +/** + * Converts the argument from binary to a base 64 string. + */ +case class Base64(child: Expression) extends UnaryExpression with ExpectsInputTypes { + override def dataType: DataType = StringType + override def inputTypes: Seq[DataType] = Seq(BinaryType) + + override def eval(input: InternalRow): Any = { + val bytes = child.eval(input) + if (bytes == null) { + null + } else { + UTF8String.fromBytes( + org.apache.commons.codec.binary.Base64.encodeBase64( + bytes.asInstanceOf[Array[Byte]])) + } + } +} + +/** + * Converts the argument from a base 64 string to BINARY. + */ +case class UnBase64(child: Expression) extends UnaryExpression with ExpectsInputTypes { + override def dataType: DataType = BinaryType + override def inputTypes: Seq[DataType] = Seq(StringType) + + override def eval(input: InternalRow): Any = { + val string = child.eval(input) + if (string == null) { + null + } else { + org.apache.commons.codec.binary.Base64.decodeBase64(string.asInstanceOf[UTF8String].toString) + } + } +} + +/** + * Decodes the first argument into a String using the provided character set + * (one of 'US-ASCII', 'ISO-8859-1', 'UTF-8', 'UTF-16BE', 'UTF-16LE', 'UTF-16'). + * If either argument is null, the result will also be null. + */ +case class Decode(bin: Expression, charset: Expression) + extends BinaryExpression with ExpectsInputTypes { + + override def left: Expression = bin + override def right: Expression = charset + override def dataType: DataType = StringType + override def inputTypes: Seq[DataType] = Seq(BinaryType, StringType) + + override def eval(input: InternalRow): Any = { + val l = bin.eval(input) + if (l == null) { + null + } else { + val r = charset.eval(input) + if (r == null) { + null + } else { + val fromCharset = r.asInstanceOf[UTF8String].toString + UTF8String.fromString(new String(l.asInstanceOf[Array[Byte]], fromCharset)) + } + } + } +} + +/** + * Encodes the first argument into a BINARY using the provided character set + * (one of 'US-ASCII', 'ISO-8859-1', 'UTF-8', 'UTF-16BE', 'UTF-16LE', 'UTF-16'). + * If either argument is null, the result will also be null. +*/ +case class Encode(value: Expression, charset: Expression) + extends BinaryExpression with ExpectsInputTypes { + + override def left: Expression = value + override def right: Expression = charset + override def dataType: DataType = BinaryType + override def inputTypes: Seq[DataType] = Seq(StringType, StringType) + + override def eval(input: InternalRow): Any = { + val l = value.eval(input) + if (l == null) { + null + } else { + val r = charset.eval(input) + if (r == null) { + null + } else { + val toCharset = r.asInstanceOf[UTF8String].toString + l.asInstanceOf[UTF8String].toString.getBytes(toCharset) + } + } + } +} + + diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/windowExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/windowExpressions.scala index 099d67ca7fee..12023ad311dc 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/windowExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/windowExpressions.scala @@ -19,8 +19,7 @@ package org.apache.spark.sql.catalyst.expressions import org.apache.spark.sql.catalyst.analysis.UnresolvedException import org.apache.spark.sql.catalyst.errors.TreeNodeException -import org.apache.spark.sql.catalyst.expressions._ -import org.apache.spark.sql.types.{NumericType, DataType} +import org.apache.spark.sql.types.{DataType, NumericType} /** * The trait of the Window Specification (specified in the OVER clause or WINDOW clause) for @@ -66,17 +65,16 @@ case class WindowSpecDefinition( } } - type EvaluatedType = Any - - override def children: Seq[Expression] = partitionSpec ++ orderSpec + override def children: Seq[Expression] = partitionSpec ++ orderSpec override lazy val resolved: Boolean = - childrenResolved && frameSpecification.isInstanceOf[SpecifiedWindowFrame] + childrenResolved && checkInputDataTypes().isSuccess && + frameSpecification.isInstanceOf[SpecifiedWindowFrame] override def toString: String = simpleString - override def eval(input: Row): EvaluatedType = throw new UnsupportedOperationException + override def eval(input: InternalRow): Any = throw new UnsupportedOperationException override def nullable: Boolean = true override def foldable: Boolean = false override def dataType: DataType = throw new UnsupportedOperationException @@ -261,7 +259,7 @@ trait WindowFunction extends Expression { def reset(): Unit - def prepareInputParameters(input: Row): AnyRef + def prepareInputParameters(input: InternalRow): AnyRef def update(input: AnyRef): Unit @@ -288,7 +286,7 @@ case class UnresolvedWindowFunction( throw new UnresolvedException(this, "init") override def reset(): Unit = throw new UnresolvedException(this, "reset") - override def prepareInputParameters(input: Row): AnyRef = + override def prepareInputParameters(input: InternalRow): AnyRef = throw new UnresolvedException(this, "prepareInputParameters") override def update(input: AnyRef): Unit = throw new UnresolvedException(this, "update") @@ -299,7 +297,7 @@ case class UnresolvedWindowFunction( override def get(index: Int): Any = throw new UnresolvedException(this, "get") // Unresolved functions are transient at compile time and don't get evaluated during execution. - override def eval(input: Row = null): EvaluatedType = + override def eval(input: InternalRow = null): Any = throw new TreeNodeException(this, s"No function to evaluate expression. type: ${this.nodeName}") override def toString: String = s"'$name(${children.mkString(",")})" @@ -311,25 +309,25 @@ case class UnresolvedWindowFunction( case class UnresolvedWindowExpression( child: UnresolvedWindowFunction, windowSpec: WindowSpecReference) extends UnaryExpression { + override def dataType: DataType = throw new UnresolvedException(this, "dataType") override def foldable: Boolean = throw new UnresolvedException(this, "foldable") override def nullable: Boolean = throw new UnresolvedException(this, "nullable") override lazy val resolved = false // Unresolved functions are transient at compile time and don't get evaluated during execution. - override def eval(input: Row = null): EvaluatedType = + override def eval(input: InternalRow = null): Any = throw new TreeNodeException(this, s"No function to evaluate expression. type: ${this.nodeName}") } case class WindowExpression( windowFunction: WindowFunction, windowSpec: WindowSpecDefinition) extends Expression { - override type EvaluatedType = Any override def children: Seq[Expression] = windowFunction :: windowSpec :: Nil - override def eval(input: Row): EvaluatedType = + override def eval(input: InternalRow): Any = throw new TreeNodeException(this, s"No function to evaluate expression. type: ${this.nodeName}") override def dataType: DataType = windowFunction.dataType diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala index c2818d957cc7..bfd24287c964 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala @@ -36,21 +36,26 @@ object DefaultOptimizer extends Optimizer { // SubQueries are only needed for analysis and can be removed before execution. Batch("Remove SubQueries", FixedPoint(100), EliminateSubQueries) :: - Batch("Operator Reordering", FixedPoint(100), - UnionPushdown, - CombineFilters, - PushPredicateThroughProject, + Batch("Distinct", FixedPoint(100), + ReplaceDistinctWithAggregate) :: + Batch("Operator Optimizations", FixedPoint(100), + // Operator push down + UnionPushDown, PushPredicateThroughJoin, + PushPredicateThroughProject, PushPredicateThroughGenerate, ColumnPruning, + // Operator combine ProjectCollapsing, - CombineLimits) :: - Batch("ConstantFolding", FixedPoint(100), + CombineFilters, + CombineLimits, + // Constant folding NullPropagation, OptimizeIn, ConstantFolding, LikeSimplification, BooleanSimplification, + RemovePositive, SimplifyFilters, SimplifyCasts, SimplifyCaseConversionExpressions) :: @@ -61,25 +66,25 @@ object DefaultOptimizer extends Optimizer { } /** - * Pushes operations to either side of a Union. - */ -object UnionPushdown extends Rule[LogicalPlan] { + * Pushes operations to either side of a Union. + */ +object UnionPushDown extends Rule[LogicalPlan] { /** - * Maps Attributes from the left side to the corresponding Attribute on the right side. - */ - def buildRewrites(union: Union): AttributeMap[Attribute] = { + * Maps Attributes from the left side to the corresponding Attribute on the right side. + */ + private def buildRewrites(union: Union): AttributeMap[Attribute] = { assert(union.left.output.size == union.right.output.size) AttributeMap(union.left.output.zip(union.right.output)) } /** - * Rewrites an expression so that it can be pushed to the right side of a Union operator. - * This method relies on the fact that the output attributes of a union are always equal - * to the left child's output. - */ - def pushToRight[A <: Expression](e: A, rewrites: AttributeMap[Attribute]): A = { + * Rewrites an expression so that it can be pushed to the right side of a Union operator. + * This method relies on the fact that the output attributes of a union are always equal + * to the left child's output. + */ + private def pushToRight[A <: Expression](e: A, rewrites: AttributeMap[Attribute]) = { val result = e transform { case a: Attribute => rewrites(a) } @@ -106,7 +111,6 @@ object UnionPushdown extends Rule[LogicalPlan] { } } - /** * Attempts to eliminate the reading of unneeded columns from the query plan using the following * transformations: @@ -115,10 +119,13 @@ object UnionPushdown extends Rule[LogicalPlan] { * - Aggregate * - Project <- Join * - LeftSemiJoin - * - Performing alias substitution. */ object ColumnPruning extends Rule[LogicalPlan] { def apply(plan: LogicalPlan): LogicalPlan = plan transform { + case a @ Aggregate(_, _, e @ Expand(_, groupByExprs, _, child)) + if (child.outputSet -- AttributeSet(groupByExprs) -- a.references).nonEmpty => + a.copy(child = e.copy(child = prunedChild(child, AttributeSet(groupByExprs) ++ a.references))) + // Eliminate attributes that are not needed to calculate the specified aggregates. case a @ Aggregate(_, _, child) if (child.outputSet -- a.references).nonEmpty => a.copy(child = Project(a.references.toSeq, child)) @@ -153,10 +160,11 @@ object ColumnPruning extends Rule[LogicalPlan] { Join(left, prunedChild(right, allReferences), LeftSemi, condition) + // Push down project through limit, so that we may have chance to push it further. case Project(projectList, Limit(exp, child)) => Limit(exp, Project(projectList, child)) - // push down project if possible when the child is sort + // Push down project if possible when the child is sort case p @ Project(projectList, s @ Sort(_, _, grandChild)) if s.references.subsetOf(p.outputSet) => s.copy(child = Project(projectList, grandChild)) @@ -175,12 +183,21 @@ object ColumnPruning extends Rule[LogicalPlan] { } /** - * Combines two adjacent [[Project]] operators into one, merging the - * expressions into one single expression. + * Combines two adjacent [[Project]] operators into one and perform alias substitution, + * merging the expressions into one single expression. */ object ProjectCollapsing extends Rule[LogicalPlan] { + + /** Returns true if any expression in projectList is non-deterministic. */ + private def hasNondeterministic(projectList: Seq[NamedExpression]): Boolean = { + projectList.exists(expr => expr.find(!_.deterministic).isDefined) + } + def apply(plan: LogicalPlan): LogicalPlan = plan transformUp { - case Project(projectList1, Project(projectList2, child)) => + // We only collapse these two Projects if the child Project's expressions are all + // deterministic. + case Project(projectList1, Project(projectList2, child)) + if !hasNondeterministic(projectList2) => // Create a map of Aliases to their values from the child projection. // e.g., 'SELECT ... FROM (SELECT a + b AS c, d ...)' produces Map(c -> Alias(a + b, c)). val aliasMap = AttributeMap(projectList2.collect { @@ -207,10 +224,10 @@ object ProjectCollapsing extends Rule[LogicalPlan] { object LikeSimplification extends Rule[LogicalPlan] { // if guards below protect from escapes on trailing %. // Cases like "something\%" are not optimized, but this does not affect correctness. - val startsWith = "([^_%]+)%".r - val endsWith = "%([^_%]+)".r - val contains = "%([^_%]+)%".r - val equalTo = "([^_%]*)".r + private val startsWith = "([^_%]+)%".r + private val endsWith = "%([^_%]+)".r + private val contains = "%([^_%]+)%".r + private val equalTo = "([^_%]*)".r def apply(plan: LogicalPlan): LogicalPlan = plan transformAllExpressions { case Like(l, Literal(utf, StringType)) => @@ -255,7 +272,7 @@ object NullPropagation extends Rule[LogicalPlan] { if (newChildren.length == 0) { Literal.create(null, e.dataType) } else if (newChildren.length == 1) { - newChildren(0) + newChildren.head } else { Coalesce(newChildren) } @@ -264,22 +281,23 @@ object NullPropagation extends Rule[LogicalPlan] { case e @ Substring(_, Literal(null, _), _) => Literal.create(null, e.dataType) case e @ Substring(_, _, Literal(null, _)) => Literal.create(null, e.dataType) + // MaxOf and MinOf can't do null propagation + case e: MaxOf => e + case e: MinOf => e + // Put exceptional cases above if any - case e: BinaryArithmetic => e.children match { - case Literal(null, _) :: right :: Nil => Literal.create(null, e.dataType) - case left :: Literal(null, _) :: Nil => Literal.create(null, e.dataType) - case _ => e - } - case e: BinaryComparison => e.children match { - case Literal(null, _) :: right :: Nil => Literal.create(null, e.dataType) - case left :: Literal(null, _) :: Nil => Literal.create(null, e.dataType) - case _ => e - } + case e @ BinaryArithmetic(Literal(null, _), _) => Literal.create(null, e.dataType) + case e @ BinaryArithmetic(_, Literal(null, _)) => Literal.create(null, e.dataType) + + case e @ BinaryComparison(Literal(null, _), _) => Literal.create(null, e.dataType) + case e @ BinaryComparison(_, Literal(null, _)) => Literal.create(null, e.dataType) + case e: StringRegexExpression => e.children match { case Literal(null, _) :: right :: Nil => Literal.create(null, e.dataType) case left :: Literal(null, _) :: Nil => Literal.create(null, e.dataType) case _ => e } + case e: StringComparison => e.children match { case Literal(null, _) :: right :: Nil => Literal.create(null, e.dataType) case left :: Literal(null, _) :: Nil => Literal.create(null, e.dataType) @@ -481,7 +499,7 @@ object PushPredicateThroughProject extends Rule[LogicalPlan] { grandChild)) } - def replaceAlias(condition: Expression, sourceAliases: Map[Attribute, Expression]): Expression = { + private def replaceAlias(condition: Expression, sourceAliases: Map[Attribute, Expression]) = { condition transform { case a: AttributeReference => sourceAliases.getOrElse(a, a) } @@ -621,6 +639,15 @@ object SimplifyCasts extends Rule[LogicalPlan] { } } +/** + * Removes [[UnaryPositive]] identify function + */ +object RemovePositive extends Rule[LogicalPlan] { + def apply(plan: LogicalPlan): LogicalPlan = plan transformAllExpressions { + case UnaryPositive(child) => child + } +} + /** * Combines two adjacent [[Limit]] operators into one, merging the * expressions into one single expression. @@ -657,7 +684,7 @@ object DecimalAggregates extends Rule[LogicalPlan] { import Decimal.MAX_LONG_DIGITS /** Maximum number of decimal digits representable precisely in a Double */ - val MAX_DOUBLE_DIGITS = 15 + private val MAX_DOUBLE_DIGITS = 15 def apply(plan: LogicalPlan): LogicalPlan = plan transformAllExpressions { case Sum(e @ DecimalType.Expression(prec, scale)) if prec + 10 <= MAX_LONG_DIGITS => @@ -683,3 +710,15 @@ object ConvertToLocalRelation extends Rule[LogicalPlan] { LocalRelation(projectList.map(_.toAttribute), data.map(projection)) } } + +/** + * Replaces logical [[Distinct]] operator with an [[Aggregate]] operator. + * {{{ + * SELECT DISTINCT f1, f2 FROM t ==> SELECT f1, f2 FROM t GROUP BY f1, f2 + * }}} + */ +object ReplaceDistinctWithAggregate extends Rule[LogicalPlan] { + def apply(plan: LogicalPlan): LogicalPlan = plan transform { + case Distinct(child) => Aggregate(child.output, child.output, child) + } +} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/patterns.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/patterns.scala index cd54d04814ea..179a348d5baa 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/patterns.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/patterns.scala @@ -143,11 +143,11 @@ object PartialAggregation { // We need to pass all grouping expressions though so the grouping can happen a second // time. However some of them might be unnamed so we alias them allowing them to be // referenced in the second aggregation. - val namedGroupingExpressions: Map[Expression, NamedExpression] = + val namedGroupingExpressions: Seq[(Expression, NamedExpression)] = groupingExpressions.filter(!_.isInstanceOf[Literal]).map { case n: NamedExpression => (n, n) case other => (other, Alias(other, "PartialGroup")()) - }.toMap + } // Replace aggregations with a new expression that computes the result from the already // computed partial evaluations and grouping values. @@ -156,20 +156,15 @@ object PartialAggregation { partialEvaluations(new TreeNodeRef(e)).finalEvaluation case e: Expression => - // Should trim aliases around `GetField`s. These aliases are introduced while - // resolving struct field accesses, because `GetField` is not a `NamedExpression`. - // (Should we just turn `GetField` into a `NamedExpression`?) - namedGroupingExpressions - .get(e.transform { case Alias(g: ExtractValue, _) => g }) - .map(_.toAttribute) - .getOrElse(e) + namedGroupingExpressions.collectFirst { + case (expr, ne) if expr semanticEquals e => ne.toAttribute + }.getOrElse(e) }).asInstanceOf[Seq[NamedExpression]] - val partialComputation = - (namedGroupingExpressions.values ++ - partialEvaluations.values.flatMap(_.partialEvaluations)).toSeq + val partialComputation = namedGroupingExpressions.map(_._2) ++ + partialEvaluations.values.flatMap(_.partialEvaluations) - val namedGroupingAttributes = namedGroupingExpressions.values.map(_.toAttribute).toSeq + val namedGroupingAttributes = namedGroupingExpressions.map(_._2.toAttribute) Some( (namedGroupingAttributes, diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/QueryPlan.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/QueryPlan.scala index 7967189cacb2..2f545bb43216 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/QueryPlan.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/QueryPlan.scala @@ -81,17 +81,16 @@ abstract class QueryPlan[PlanType <: TreeNode[PlanType]] extends TreeNode[PlanTy } } - val newArgs = productIterator.map { + def recursiveTransform(arg: Any): AnyRef = arg match { case e: Expression => transformExpressionDown(e) case Some(e: Expression) => Some(transformExpressionDown(e)) - case m: Map[_,_] => m + case m: Map[_, _] => m case d: DataType => d // Avoid unpacking Structs - case seq: Traversable[_] => seq.map { - case e: Expression => transformExpressionDown(e) - case other => other - } + case seq: Traversable[_] => seq.map(recursiveTransform) case other: AnyRef => other - }.toArray + } + + val newArgs = productIterator.map(recursiveTransform).toArray if (changed) makeCopy(newArgs) else this } @@ -114,17 +113,16 @@ abstract class QueryPlan[PlanType <: TreeNode[PlanType]] extends TreeNode[PlanTy } } - val newArgs = productIterator.map { + def recursiveTransform(arg: Any): AnyRef = arg match { case e: Expression => transformExpressionUp(e) case Some(e: Expression) => Some(transformExpressionUp(e)) - case m: Map[_,_] => m + case m: Map[_, _] => m case d: DataType => d // Avoid unpacking Structs - case seq: Traversable[_] => seq.map { - case e: Expression => transformExpressionUp(e) - case other => other - } + case seq: Traversable[_] => seq.map(recursiveTransform) case other: AnyRef => other - }.toArray + } + + val newArgs = productIterator.map(recursiveTransform).toArray if (changed) makeCopy(newArgs) else this } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LocalRelation.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LocalRelation.scala index e3e070f0ff30..1868f119f0e9 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LocalRelation.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LocalRelation.scala @@ -17,10 +17,9 @@ package org.apache.spark.sql.catalyst.plans.logical -import org.apache.spark.sql.Row -import org.apache.spark.sql.catalyst.{CatalystTypeConverters, analysis} import org.apache.spark.sql.catalyst.expressions.Attribute -import org.apache.spark.sql.types.{StructType, StructField} +import org.apache.spark.sql.catalyst.{CatalystTypeConverters, InternalRow, analysis} +import org.apache.spark.sql.types.{StructField, StructType} object LocalRelation { def apply(output: Attribute*): LocalRelation = new LocalRelation(output) @@ -32,11 +31,11 @@ object LocalRelation { def fromProduct(output: Seq[Attribute], data: Seq[Product]): LocalRelation = { val schema = StructType.fromAttributes(output) val converter = CatalystTypeConverters.createToCatalystConverter(schema) - LocalRelation(output, data.map(converter(_).asInstanceOf[Row])) + LocalRelation(output, data.map(converter(_).asInstanceOf[InternalRow])) } } -case class LocalRelation(output: Seq[Attribute], data: Seq[Row] = Nil) +case class LocalRelation(output: Seq[Attribute], data: Seq[InternalRow] = Nil) extends LeafNode with analysis.MultiInstanceRelation { /** diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlan.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlan.scala index dba69659afc8..e911b907e853 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlan.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlan.scala @@ -19,7 +19,7 @@ package org.apache.spark.sql.catalyst.plans.logical import org.apache.spark.Logging import org.apache.spark.sql.AnalysisException -import org.apache.spark.sql.catalyst.analysis.{UnresolvedAttribute, EliminateSubQueries, Resolver} +import org.apache.spark.sql.catalyst.analysis._ import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.plans.QueryPlan import org.apache.spark.sql.catalyst.trees.TreeNode @@ -50,19 +50,19 @@ abstract class LogicalPlan extends QueryPlan[LogicalPlan] with Logging { * [[org.apache.spark.sql.catalyst.analysis.UnresolvedRelation UnresolvedRelation]] * should return `false`). */ - lazy val resolved: Boolean = !expressions.exists(!_.resolved) && childrenResolved + lazy val resolved: Boolean = expressions.forall(_.resolved) && childrenResolved override protected def statePrefix = if (!resolved) "'" else super.statePrefix /** * Returns true if all its children of this query plan have been resolved. */ - def childrenResolved: Boolean = !children.exists(!_.resolved) + def childrenResolved: Boolean = children.forall(_.resolved) /** * Returns true when the given logical plan will return the same results as this logical plan. * - * Since its likely undecideable to generally determine if two given plans will produce the same + * Since its likely undecidable to generally determine if two given plans will produce the same * results, it is okay for this function to return false, even if the results are actually * the same. Such behavior will not affect correctness, only the application of performance * enhancements like caching. However, it is not acceptable to return true if the results could @@ -90,7 +90,7 @@ abstract class LogicalPlan extends QueryPlan[LogicalPlan] with Logging { val input = children.flatMap(_.output) productIterator.map { // Children are checked using sameResult above. - case tn: TreeNode[_] if children contains tn => null + case tn: TreeNode[_] if containsChild(tn) => null case e: Expression => BindReferences.bindReference(e, input, allowFailures = true) case s: Option[_] => s.map { case e: Expression => BindReferences.bindReference(e, input, allowFailures = true) @@ -111,9 +111,8 @@ abstract class LogicalPlan extends QueryPlan[LogicalPlan] with Logging { */ def resolveChildren( nameParts: Seq[String], - resolver: Resolver, - throwErrors: Boolean = false): Option[NamedExpression] = - resolve(nameParts, children.flatMap(_.output), resolver, throwErrors) + resolver: Resolver): Option[NamedExpression] = + resolve(nameParts, children.flatMap(_.output), resolver) /** * Optionally resolves the given strings to a [[NamedExpression]] based on the output of this @@ -122,9 +121,8 @@ abstract class LogicalPlan extends QueryPlan[LogicalPlan] with Logging { */ def resolve( nameParts: Seq[String], - resolver: Resolver, - throwErrors: Boolean = false): Option[NamedExpression] = - resolve(nameParts, output, resolver, throwErrors) + resolver: Resolver): Option[NamedExpression] = + resolve(nameParts, output, resolver) /** * Given an attribute name, split it to name parts by dot, but @@ -134,7 +132,7 @@ abstract class LogicalPlan extends QueryPlan[LogicalPlan] with Logging { def resolveQuoted( name: String, resolver: Resolver): Option[NamedExpression] = { - resolve(parseAttributeName(name), resolver, true) + resolve(parseAttributeName(name), output, resolver) } /** @@ -163,7 +161,7 @@ abstract class LogicalPlan extends QueryPlan[LogicalPlan] with Logging { if (tmp.nonEmpty) throw e inBacktick = true } else if (char == '.') { - if (tmp.isEmpty) throw e + if (name(i - 1) == '.' || i == name.length - 1) throw e nameParts += tmp.mkString tmp.clear() } else { @@ -172,7 +170,7 @@ abstract class LogicalPlan extends QueryPlan[LogicalPlan] with Logging { } i += 1 } - if (tmp.isEmpty || inBacktick) throw e + if (inBacktick) throw e nameParts += tmp.mkString nameParts.toSeq } @@ -219,8 +217,7 @@ abstract class LogicalPlan extends QueryPlan[LogicalPlan] with Logging { protected def resolve( nameParts: Seq[String], input: Seq[Attribute], - resolver: Resolver, - throwErrors: Boolean): Option[NamedExpression] = { + resolver: Resolver): Option[NamedExpression] = { // A sequence of possible candidate matches. // Each candidate is a tuple. The first element is a resolved attribute, followed by a list @@ -254,19 +251,14 @@ abstract class LogicalPlan extends QueryPlan[LogicalPlan] with Logging { // One match, but we also need to extract the requested nested field. case Seq((a, nestedFields)) => - try { - // The foldLeft adds GetFields for every remaining parts of the identifier, - // and aliases it with the last part of the identifier. - // For example, consider "a.b.c", where "a" is resolved to an existing attribute. - // Then this will add GetField("c", GetField("b", a)), and alias - // the final expression as "c". - val fieldExprs = nestedFields.foldLeft(a: Expression)((expr, fieldName) => - ExtractValue(expr, Literal(fieldName), resolver)) - val aliasName = nestedFields.last - Some(Alias(fieldExprs, aliasName)()) - } catch { - case a: AnalysisException if !throwErrors => None - } + // The foldLeft adds ExtractValues for every remaining parts of the identifier, + // and wrap it with UnresolvedAlias which will be removed later. + // For example, consider "a.b.c", where "a" is resolved to an existing attribute. + // Then this will add ExtractValue("c", ExtractValue("b", a)), and wrap it as + // UnresolvedAlias(ExtractValue("c", ExtractValue("b", a))). + val fieldExprs = nestedFields.foldLeft(a: Expression)((expr, fieldName) => + ExtractValue(expr, Literal(fieldName), resolver)) + Some(UnresolvedAlias(fieldExprs)) // No matches. case Seq() => diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala index 01f4b6e9bb77..fae339808c23 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala @@ -20,6 +20,7 @@ package org.apache.spark.sql.catalyst.plans.logical import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.plans._ import org.apache.spark.sql.types._ +import org.apache.spark.util.collection.OpenHashSet case class Project(projectList: Seq[NamedExpression], child: LogicalPlan) extends UnaryNode { override def output: Seq[Attribute] = projectList.map(_.toAttribute) @@ -93,7 +94,7 @@ case class Union(left: LogicalPlan, right: LogicalPlan) extends BinaryNode { override lazy val resolved: Boolean = childrenResolved && - left.output.zip(right.output).forall { case (l,r) => l.dataType == r.dataType } + left.output.zip(right.output).forall { case (l, r) => l.dataType == r.dataType } override def statistics: Statistics = { val sizeInBytes = left.statistics.sizeInBytes + right.statistics.sizeInBytes @@ -130,6 +131,14 @@ case class Join( } } +/** + * A hint for the optimizer that we should broadcast the `child` if used in a join operator. + */ +case class BroadcastHint(child: LogicalPlan) extends UnaryNode { + override def output: Seq[Attribute] = child.output +} + + case class Except(left: LogicalPlan, right: LogicalPlan) extends BinaryNode { override def output: Seq[Attribute] = left.output } @@ -220,28 +229,82 @@ case class Window( /** * Apply the all of the GroupExpressions to every input row, hence we will get * multiple output rows for a input row. - * @param projections The group of expressions, all of the group expressions should - * output the same schema specified by the parameter `output` - * @param output The output Schema + * @param bitmasks The bitmask set represents the grouping sets + * @param groupByExprs The grouping by expressions * @param child Child operator */ case class Expand( - projections: Seq[GroupExpression], - output: Seq[Attribute], + bitmasks: Seq[Int], + groupByExprs: Seq[Expression], + gid: Attribute, child: LogicalPlan) extends UnaryNode { override def statistics: Statistics = { val sizeInBytes = child.statistics.sizeInBytes * projections.length Statistics(sizeInBytes = sizeInBytes) } + + val projections: Seq[Seq[Expression]] = expand() + + /** + * Extract attribute set according to the grouping id + * @param bitmask bitmask to represent the selected of the attribute sequence + * @param exprs the attributes in sequence + * @return the attributes of non selected specified via bitmask (with the bit set to 1) + */ + private def buildNonSelectExprSet(bitmask: Int, exprs: Seq[Expression]) + : OpenHashSet[Expression] = { + val set = new OpenHashSet[Expression](2) + + var bit = exprs.length - 1 + while (bit >= 0) { + if (((bitmask >> bit) & 1) == 0) set.add(exprs(bit)) + bit -= 1 + } + + set + } + + /** + * Create an array of Projections for the child projection, and replace the projections' + * expressions which equal GroupBy expressions with Literal(null), if those expressions + * are not set for this grouping set (according to the bit mask). + */ + private[this] def expand(): Seq[Seq[Expression]] = { + val result = new scala.collection.mutable.ArrayBuffer[Seq[Expression]] + + bitmasks.foreach { bitmask => + // get the non selected grouping attributes according to the bit mask + val nonSelectedGroupExprSet = buildNonSelectExprSet(bitmask, groupByExprs) + + val substitution = (child.output :+ gid).map(expr => expr transformDown { + case x: Expression if nonSelectedGroupExprSet.contains(x) => + // if the input attribute in the Invalid Grouping Expression set of for this group + // replace it with constant null + Literal.create(null, expr.dataType) + case x if x == gid => + // replace the groupingId with concrete value (the bit mask) + Literal.create(bitmask, IntegerType) + }) + + result += substitution + } + + result.toSeq + } + + override def output: Seq[Attribute] = { + child.output :+ gid + } } trait GroupingAnalytics extends UnaryNode { self: Product => - def gid: AttributeReference def groupByExprs: Seq[Expression] def aggregations: Seq[NamedExpression] override def output: Seq[Attribute] = aggregations.map(_.toAttribute) + + def withNewAggs(aggs: Seq[NamedExpression]): GroupingAnalytics } /** @@ -256,17 +319,16 @@ trait GroupingAnalytics extends UnaryNode { * @param child Child operator * @param aggregations The Aggregation expressions, those non selected group by expressions * will be considered as constant null if it appears in the expressions - * @param gid The attribute represents the virtual column GROUPING__ID, and it's also - * the bitmask indicates the selected GroupBy Expressions for each - * aggregating output row. - * The associated output will be one of the value in `bitmasks` */ case class GroupingSets( bitmasks: Seq[Int], groupByExprs: Seq[Expression], child: LogicalPlan, - aggregations: Seq[NamedExpression], - gid: AttributeReference = VirtualColumn.newGroupingId) extends GroupingAnalytics + aggregations: Seq[NamedExpression]) extends GroupingAnalytics { + + def withNewAggs(aggs: Seq[NamedExpression]): GroupingAnalytics = + this.copy(aggregations = aggs) +} /** * Cube is a syntactic sugar for GROUPING SETS, and will be transformed to GroupingSets, @@ -276,15 +338,15 @@ case class GroupingSets( * @param child Child operator * @param aggregations The Aggregation expressions, those non selected group by expressions * will be considered as constant null if it appears in the expressions - * @param gid The attribute represents the virtual column GROUPING__ID, and it's also - * the bitmask indicates the selected GroupBy Expressions for each - * aggregating output row. */ case class Cube( groupByExprs: Seq[Expression], child: LogicalPlan, - aggregations: Seq[NamedExpression], - gid: AttributeReference = VirtualColumn.newGroupingId) extends GroupingAnalytics + aggregations: Seq[NamedExpression]) extends GroupingAnalytics { + + def withNewAggs(aggs: Seq[NamedExpression]): GroupingAnalytics = + this.copy(aggregations = aggs) +} /** * Rollup is a syntactic sugar for GROUPING SETS, and will be transformed to GroupingSets, @@ -295,15 +357,15 @@ case class Cube( * @param child Child operator * @param aggregations The Aggregation expressions, those non selected group by expressions * will be considered as constant null if it appears in the expressions - * @param gid The attribute represents the virtual column GROUPING__ID, and it's also - * the bitmask indicates the selected GroupBy Expressions for each - * aggregating output row. */ case class Rollup( groupByExprs: Seq[Expression], child: LogicalPlan, - aggregations: Seq[NamedExpression], - gid: AttributeReference = VirtualColumn.newGroupingId) extends GroupingAnalytics + aggregations: Seq[NamedExpression]) extends GroupingAnalytics { + + def withNewAggs(aggs: Seq[NamedExpression]): GroupingAnalytics = + this.copy(aggregations = aggs) +} case class Limit(limitExpr: Expression, child: LogicalPlan) extends UnaryNode { override def output: Seq[Attribute] = child.output @@ -339,6 +401,9 @@ case class Sample( override def output: Seq[Attribute] = child.output } +/** + * Returns a new logical plan that dedups input rows. + */ case class Distinct(child: LogicalPlan) extends UnaryNode { override def output: Seq[Attribute] = child.output } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/physical/partitioning.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/physical/partitioning.scala index fb4217a44807..42dead7c2842 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/physical/partitioning.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/physical/partitioning.scala @@ -17,8 +17,9 @@ package org.apache.spark.sql.catalyst.plans.physical +import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.errors.TreeNodeException -import org.apache.spark.sql.catalyst.expressions.{Expression, Row, SortOrder} +import org.apache.spark.sql.catalyst.expressions.{Expression, SortOrder} import org.apache.spark.sql.types.{DataType, IntegerType} /** @@ -169,7 +170,7 @@ case class HashPartitioning(expressions: Seq[Expression], numPartitions: Int) override def keyExpressions: Seq[Expression] = expressions - override def eval(input: Row = null): EvaluatedType = + override def eval(input: InternalRow = null): Any = throw new TreeNodeException(this, s"No function to evaluate expression. type: ${this.nodeName}") } @@ -213,6 +214,6 @@ case class RangePartitioning(ordering: Seq[SortOrder], numPartitions: Int) override def keyExpressions: Seq[Expression] = ordering.map(_.child) - override def eval(input: Row): EvaluatedType = + override def eval(input: InternalRow): Any = throw new TreeNodeException(this, s"No function to evaluate expression. type: ${this.nodeName}") } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/TreeNode.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/TreeNode.scala index bc2ad34523d2..09f6c6b0ec42 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/TreeNode.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/TreeNode.scala @@ -59,9 +59,14 @@ abstract class TreeNode[BaseType <: TreeNode[BaseType]] { val origin: Origin = CurrentOrigin.get - /** Returns a Seq of the children of this node */ + /** + * Returns a Seq of the children of this node. + * Children should not change. Immutability required for containsChild optimization + */ def children: Seq[BaseType] + lazy val containsChild: Set[TreeNode[_]] = children.toSet + /** * Faster version of equality which short-circuits when two treeNodes are the same instance. * We don't just override Object.equals, as doing so prevents the scala compiler from @@ -147,7 +152,7 @@ abstract class TreeNode[BaseType <: TreeNode[BaseType]] { def mapChildren(f: BaseType => BaseType): this.type = { var changed = false val newArgs = productIterator.map { - case arg: TreeNode[_] if children contains arg => + case arg: TreeNode[_] if containsChild(arg) => val newChild = f(arg.asInstanceOf[BaseType]) if (newChild fastEquals arg) { arg @@ -173,7 +178,7 @@ abstract class TreeNode[BaseType <: TreeNode[BaseType]] { val newArgs = productIterator.map { // Handle Seq[TreeNode] in TreeNode parameters. case s: Seq[_] => s.map { - case arg: TreeNode[_] if children contains arg => + case arg: TreeNode[_] if containsChild(arg) => val newChild = remainingNewChildren.remove(0) val oldChild = remainingOldChildren.remove(0) if (newChild fastEquals oldChild) { @@ -185,7 +190,7 @@ abstract class TreeNode[BaseType <: TreeNode[BaseType]] { case nonChild: AnyRef => nonChild case null => null } - case arg: TreeNode[_] if children contains arg => + case arg: TreeNode[_] if containsChild(arg) => val newChild = remainingNewChildren.remove(0) val oldChild = remainingOldChildren.remove(0) if (newChild fastEquals oldChild) { @@ -238,7 +243,7 @@ abstract class TreeNode[BaseType <: TreeNode[BaseType]] { def transformChildrenDown(rule: PartialFunction[BaseType, BaseType]): this.type = { var changed = false val newArgs = productIterator.map { - case arg: TreeNode[_] if children contains arg => + case arg: TreeNode[_] if containsChild(arg) => val newChild = arg.asInstanceOf[BaseType].transformDown(rule) if (!(newChild fastEquals arg)) { changed = true @@ -246,7 +251,7 @@ abstract class TreeNode[BaseType <: TreeNode[BaseType]] { } else { arg } - case Some(arg: TreeNode[_]) if children contains arg => + case Some(arg: TreeNode[_]) if containsChild(arg) => val newChild = arg.asInstanceOf[BaseType].transformDown(rule) if (!(newChild fastEquals arg)) { changed = true @@ -254,10 +259,10 @@ abstract class TreeNode[BaseType <: TreeNode[BaseType]] { } else { Some(arg) } - case m: Map[_,_] => m + case m: Map[_, _] => m case d: DataType => d // Avoid unpacking Structs case args: Traversable[_] => args.map { - case arg: TreeNode[_] if children contains arg => + case arg: TreeNode[_] if containsChild(arg) => val newChild = arg.asInstanceOf[BaseType].transformDown(rule) if (!(newChild fastEquals arg)) { changed = true @@ -280,7 +285,7 @@ abstract class TreeNode[BaseType <: TreeNode[BaseType]] { * @param rule the function use to transform this nodes children */ def transformUp(rule: PartialFunction[BaseType, BaseType]): BaseType = { - val afterRuleOnChildren = transformChildrenUp(rule); + val afterRuleOnChildren = transformChildrenUp(rule) if (this fastEquals afterRuleOnChildren) { CurrentOrigin.withOrigin(origin) { rule.applyOrElse(this, identity[BaseType]) @@ -295,7 +300,7 @@ abstract class TreeNode[BaseType <: TreeNode[BaseType]] { def transformChildrenUp(rule: PartialFunction[BaseType, BaseType]): this.type = { var changed = false val newArgs = productIterator.map { - case arg: TreeNode[_] if children contains arg => + case arg: TreeNode[_] if containsChild(arg) => val newChild = arg.asInstanceOf[BaseType].transformUp(rule) if (!(newChild fastEquals arg)) { changed = true @@ -303,7 +308,7 @@ abstract class TreeNode[BaseType <: TreeNode[BaseType]] { } else { arg } - case Some(arg: TreeNode[_]) if children contains arg => + case Some(arg: TreeNode[_]) if containsChild(arg) => val newChild = arg.asInstanceOf[BaseType].transformUp(rule) if (!(newChild fastEquals arg)) { changed = true @@ -311,10 +316,10 @@ abstract class TreeNode[BaseType <: TreeNode[BaseType]] { } else { Some(arg) } - case m: Map[_,_] => m + case m: Map[_, _] => m case d: DataType => d // Avoid unpacking Structs case args: Traversable[_] => args.map { - case arg: TreeNode[_] if children contains arg => + case arg: TreeNode[_] if containsChild(arg) => val newChild = arg.asInstanceOf[BaseType].transformUp(rule) if (!(newChild fastEquals arg)) { changed = true @@ -344,11 +349,11 @@ abstract class TreeNode[BaseType <: TreeNode[BaseType]] { * @param newArgs the new product arguments. */ def makeCopy(newArgs: Array[AnyRef]): this.type = attachTree(this, "makeCopy") { - val defaultCtor = - getClass.getConstructors - .find(_.getParameterTypes.size != 0) - .headOption - .getOrElse(sys.error(s"No valid constructor for $nodeName")) + val ctors = getClass.getConstructors.filter(_.getParameterTypes.size != 0) + if (ctors.isEmpty) { + sys.error(s"No valid constructor for $nodeName") + } + val defaultCtor = ctors.maxBy(_.getParameterTypes.size) try { CurrentOrigin.withOrigin(origin) { @@ -383,8 +388,9 @@ abstract class TreeNode[BaseType <: TreeNode[BaseType]] { /** Returns a string representing the arguments to this node, minus any children */ def argString: String = productIterator.flatMap { - case tn: TreeNode[_] if children contains tn => Nil + case tn: TreeNode[_] if containsChild(tn) => Nil case tn: TreeNode[_] if tn.toString contains "\n" => s"(${tn.simpleString})" :: Nil + case seq: Seq[BaseType] if seq.toSet.subsetOf(children.toSet) => Nil case seq: Seq[_] => seq.mkString("[", ",", "]") :: Nil case set: Set[_] => set.mkString("{", ",", "}") :: Nil case other => other :: Nil diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/DateTimeUtils.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/DateTimeUtils.scala new file mode 100644 index 000000000000..4269ad5d5673 --- /dev/null +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/DateTimeUtils.scala @@ -0,0 +1,183 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.util + +import java.sql.{Date, Timestamp} +import java.text.{DateFormat, SimpleDateFormat} +import java.util.{Calendar, TimeZone} + +/** + * Helper functions for converting between internal and external date and time representations. + * Dates are exposed externally as java.sql.Date and are represented internally as the number of + * dates since the Unix epoch (1970-01-01). Timestamps are exposed externally as java.sql.Timestamp + * and are stored internally as longs, which are capable of storing timestamps with 100 nanosecond + * precision. + */ +object DateTimeUtils { + final val MILLIS_PER_DAY = SECONDS_PER_DAY * 1000L + + // see http://stackoverflow.com/questions/466321/convert-unix-timestamp-to-julian + final val JULIAN_DAY_OF_EPOCH = 2440587 // and .5 + final val SECONDS_PER_DAY = 60 * 60 * 24L + final val HUNDRED_NANOS_PER_SECOND = 1000L * 1000L * 10L + final val NANOS_PER_SECOND = HUNDRED_NANOS_PER_SECOND * 100 + + + // Java TimeZone has no mention of thread safety. Use thread local instance to be safe. + private val threadLocalLocalTimeZone = new ThreadLocal[TimeZone] { + override protected def initialValue: TimeZone = { + Calendar.getInstance.getTimeZone + } + } + + // `SimpleDateFormat` is not thread-safe. + private val threadLocalTimestampFormat = new ThreadLocal[DateFormat] { + override def initialValue(): SimpleDateFormat = { + new SimpleDateFormat("yyyy-MM-dd HH:mm:ss") + } + } + + // `SimpleDateFormat` is not thread-safe. + private val threadLocalDateFormat = new ThreadLocal[DateFormat] { + override def initialValue(): SimpleDateFormat = { + new SimpleDateFormat("yyyy-MM-dd") + } + } + + // we should use the exact day as Int, for example, (year, month, day) -> day + def millisToDays(millisUtc: Long): Int = { + // SPARK-6785: use Math.floor so negative number of days (dates before 1970) + // will correctly work as input for function toJavaDate(Int) + val millisLocal = millisUtc.toDouble + threadLocalLocalTimeZone.get().getOffset(millisUtc) + Math.floor(millisLocal / MILLIS_PER_DAY).toInt + } + + // reverse of millisToDays + def daysToMillis(days: Int): Long = { + val millisUtc = days.toLong * MILLIS_PER_DAY + millisUtc - threadLocalLocalTimeZone.get().getOffset(millisUtc) + } + + def dateToString(days: Int): String = + threadLocalDateFormat.get.format(toJavaDate(days)) + + // Converts Timestamp to string according to Hive TimestampWritable convention. + def timestampToString(num100ns: Long): String = { + val ts = toJavaTimestamp(num100ns) + val timestampString = ts.toString + val formatted = threadLocalTimestampFormat.get.format(ts) + + if (timestampString.length > 19 && timestampString.substring(19) != ".0") { + formatted + timestampString.substring(19) + } else { + formatted + } + } + + def stringToTime(s: String): java.util.Date = { + if (!s.contains('T')) { + // JDBC escape string + if (s.contains(' ')) { + Timestamp.valueOf(s) + } else { + Date.valueOf(s) + } + } else if (s.endsWith("Z")) { + // this is zero timezone of ISO8601 + stringToTime(s.substring(0, s.length - 1) + "GMT-00:00") + } else if (s.indexOf("GMT") == -1) { + // timezone with ISO8601 + val inset = "+00.00".length + val s0 = s.substring(0, s.length - inset) + val s1 = s.substring(s.length - inset, s.length) + if (s0.substring(s0.lastIndexOf(':')).contains('.')) { + stringToTime(s0 + "GMT" + s1) + } else { + stringToTime(s0 + ".0GMT" + s1) + } + } else { + // ISO8601 with GMT insert + val ISO8601GMT: SimpleDateFormat = new SimpleDateFormat( "yyyy-MM-dd'T'HH:mm:ss.SSSz" ) + ISO8601GMT.parse(s) + } + } + + /** + * Returns the number of days since epoch from from java.sql.Date. + */ + def fromJavaDate(date: Date): Int = { + millisToDays(date.getTime) + } + + /** + * Returns a java.sql.Date from number of days since epoch. + */ + def toJavaDate(daysSinceEpoch: Int): Date = { + new Date(daysToMillis(daysSinceEpoch)) + } + + /** + * Returns a java.sql.Timestamp from number of 100ns since epoch. + */ + def toJavaTimestamp(num100ns: Long): Timestamp = { + // setNanos() will overwrite the millisecond part, so the milliseconds should be + // cut off at seconds + var seconds = num100ns / HUNDRED_NANOS_PER_SECOND + var nanos = num100ns % HUNDRED_NANOS_PER_SECOND + // setNanos() can not accept negative value + if (nanos < 0) { + nanos += HUNDRED_NANOS_PER_SECOND + seconds -= 1 + } + val t = new Timestamp(seconds * 1000) + t.setNanos(nanos.toInt * 100) + t + } + + /** + * Returns the number of 100ns since epoch from java.sql.Timestamp. + */ + def fromJavaTimestamp(t: Timestamp): Long = { + if (t != null) { + t.getTime() * 10000L + (t.getNanos().toLong / 100) % 10000L + } else { + 0L + } + } + + /** + * Returns the number of 100ns (hundred of nanoseconds) since epoch from Julian day + * and nanoseconds in a day + */ + def fromJulianDay(day: Int, nanoseconds: Long): Long = { + // use Long to avoid rounding errors + val seconds = (day - JULIAN_DAY_OF_EPOCH).toLong * SECONDS_PER_DAY - SECONDS_PER_DAY / 2 + seconds * HUNDRED_NANOS_PER_SECOND + nanoseconds / 100L + } + + /** + * Returns Julian day and nanoseconds in a day from the number of 100ns (hundred of nanoseconds) + */ + def toJulianDay(num100ns: Long): (Int, Long) = { + val seconds = num100ns / HUNDRED_NANOS_PER_SECOND + SECONDS_PER_DAY / 2 + val day = seconds / SECONDS_PER_DAY + JULIAN_DAY_OF_EPOCH + val secondsInDay = seconds % SECONDS_PER_DAY + val nanos = (num100ns % HUNDRED_NANOS_PER_SECOND) * 100L + (day.toInt, secondsInDay * NANOS_PER_SECOND + nanos) + } +} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/DateUtils.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/DateUtils.scala deleted file mode 100644 index 3f92be4a55d7..000000000000 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/DateUtils.scala +++ /dev/null @@ -1,90 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.sql.catalyst.util - -import java.sql.Date -import java.text.SimpleDateFormat -import java.util.{Calendar, TimeZone} - -import org.apache.spark.sql.catalyst.expressions.Cast - -/** - * helper function to convert between Int value of days since 1970-01-01 and java.sql.Date - */ -object DateUtils { - private val MILLIS_PER_DAY = 86400000 - - // Java TimeZone has no mention of thread safety. Use thread local instance to be safe. - private val LOCAL_TIMEZONE = new ThreadLocal[TimeZone] { - override protected def initialValue: TimeZone = { - Calendar.getInstance.getTimeZone - } - } - - private def javaDateToDays(d: Date): Int = { - millisToDays(d.getTime) - } - - // we should use the exact day as Int, for example, (year, month, day) -> day - def millisToDays(millisLocal: Long): Int = { - ((millisLocal + LOCAL_TIMEZONE.get().getOffset(millisLocal)) / MILLIS_PER_DAY).toInt - } - - private def toMillisSinceEpoch(days: Int): Long = { - val millisUtc = days.toLong * MILLIS_PER_DAY - millisUtc - LOCAL_TIMEZONE.get().getOffset(millisUtc) - } - - def fromJavaDate(date: java.sql.Date): Int = { - javaDateToDays(date) - } - - def toJavaDate(daysSinceEpoch: Int): java.sql.Date = { - new java.sql.Date(toMillisSinceEpoch(daysSinceEpoch)) - } - - def toString(days: Int): String = Cast.threadLocalDateFormat.get.format(toJavaDate(days)) - - def stringToTime(s: String): java.util.Date = { - if (!s.contains('T')) { - // JDBC escape string - if (s.contains(' ')) { - java.sql.Timestamp.valueOf(s) - } else { - java.sql.Date.valueOf(s) - } - } else if (s.endsWith("Z")) { - // this is zero timezone of ISO8601 - stringToTime(s.substring(0, s.length - 1) + "GMT-00:00") - } else if (s.indexOf("GMT") == -1) { - // timezone with ISO8601 - val inset = "+00.00".length - val s0 = s.substring(0, s.length - inset) - val s1 = s.substring(s.length - inset, s.length) - if (s0.substring(s0.lastIndexOf(':')).contains('.')) { - stringToTime(s0 + "GMT" + s1) - } else { - stringToTime(s0 + ".0GMT" + s1) - } - } else { - // ISO8601 with GMT insert - val ISO8601GMT: SimpleDateFormat = new SimpleDateFormat( "yyyy-MM-dd'T'HH:mm:ss.SSSz" ) - ISO8601GMT.parse(s) - } - } -} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/StringKeyHashMap.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/StringKeyHashMap.scala new file mode 100644 index 000000000000..191d5e6399fc --- /dev/null +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/StringKeyHashMap.scala @@ -0,0 +1,44 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.util + +/** + * Build a map with String type of key, and it also supports either key case + * sensitive or insensitive. + */ +object StringKeyHashMap { + def apply[T](caseSensitive: Boolean): StringKeyHashMap[T] = caseSensitive match { + case false => new StringKeyHashMap[T](_.toLowerCase) + case true => new StringKeyHashMap[T](identity) + } +} + + +class StringKeyHashMap[T](normalizer: (String) => String) { + private val base = new collection.mutable.HashMap[String, T]() + + def apply(key: String): T = base(normalizer(key)) + + def get(key: String): Option[T] = base.get(normalizer(key)) + + def put(key: String, value: T): Option[T] = base.put(normalizer(key), value) + + def remove(key: String): Option[T] = base.remove(normalizer(key)) + + def iterator: Iterator[(String, T)] = base.toIterator +} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/TypeUtils.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/TypeUtils.scala new file mode 100644 index 000000000000..3148309a2166 --- /dev/null +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/TypeUtils.scala @@ -0,0 +1,73 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.util + +import org.apache.spark.sql.catalyst.analysis.TypeCheckResult +import org.apache.spark.sql.types._ + +/** + * Helper functions to check for valid data types. + */ +object TypeUtils { + def checkForNumericExpr(t: DataType, caller: String): TypeCheckResult = { + if (t.isInstanceOf[NumericType] || t == NullType) { + TypeCheckResult.TypeCheckSuccess + } else { + TypeCheckResult.TypeCheckFailure(s"$caller accepts numeric types, not $t") + } + } + + def checkForBitwiseExpr(t: DataType, caller: String): TypeCheckResult = { + if (t.isInstanceOf[IntegralType] || t == NullType) { + TypeCheckResult.TypeCheckSuccess + } else { + TypeCheckResult.TypeCheckFailure(s"$caller accepts integral types, not $t") + } + } + + def checkForOrderingExpr(t: DataType, caller: String): TypeCheckResult = { + if (t.isInstanceOf[AtomicType] || t == NullType) { + TypeCheckResult.TypeCheckSuccess + } else { + TypeCheckResult.TypeCheckFailure(s"$caller accepts non-complex types, not $t") + } + } + + def checkForSameTypeInputExpr(types: Seq[DataType], caller: String): TypeCheckResult = { + if (types.distinct.size > 1) { + TypeCheckResult.TypeCheckFailure( + s"input to $caller should all be the same type, but it's ${types.mkString("[", ", ", "]")}") + } else { + TypeCheckResult.TypeCheckSuccess + } + } + + def getNumeric(t: DataType): Numeric[Any] = + t.asInstanceOf[NumericType].numeric.asInstanceOf[Numeric[Any]] + + def getOrdering(t: DataType): Ordering[Any] = + t.asInstanceOf[AtomicType].ordering.asInstanceOf[Ordering[Any]] + + def compareBinary(x: Array[Byte], y: Array[Byte]): Int = { + for (i <- 0 until x.length; if i < y.length) { + val res = x(i).compareTo(y(i)) + if (res != 0) return res + } + x.length - y.length + } +} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/package.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/package.scala index 9d613a940ee8..07054166a5e8 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/package.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/package.scala @@ -83,7 +83,7 @@ package object util { } def resourceToString( - resource:String, + resource: String, encoding: String = "UTF-8", classLoader: ClassLoader = Utils.getSparkClassLoader): String = { new String(resourceToBytes(resource, classLoader), encoding) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/AbstractDataType.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/AbstractDataType.scala new file mode 100644 index 000000000000..fb1b47e94621 --- /dev/null +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/AbstractDataType.scala @@ -0,0 +1,155 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.types + +import scala.reflect.ClassTag +import scala.reflect.runtime.universe.{TypeTag, runtimeMirror} + +import org.apache.spark.sql.catalyst.ScalaReflectionLock +import org.apache.spark.sql.catalyst.expressions.Expression +import org.apache.spark.util.Utils + +/** + * A non-concrete data type, reserved for internal uses. + */ +private[sql] abstract class AbstractDataType { + /** + * The default concrete type to use if we want to cast a null literal into this type. + */ + private[sql] def defaultConcreteType: DataType + + /** + * Returns true if this data type is a parent of the `childCandidate`. + */ + private[sql] def isParentOf(childCandidate: DataType): Boolean + + /** Readable string representation for the type. */ + private[sql] def simpleString: String +} + + +/** + * A collection of types that can be used to specify type constraints. The sequence also specifies + * precedence: an earlier type takes precedence over a latter type. + * + * {{{ + * TypeCollection(StringType, BinaryType) + * }}} + * + * This means that we prefer StringType over BinaryType if it is possible to cast to StringType. + */ +private[sql] class TypeCollection(private val types: Seq[AbstractDataType]) + extends AbstractDataType { + + require(types.nonEmpty, s"TypeCollection ($types) cannot be empty") + + private[sql] override def defaultConcreteType: DataType = types.head.defaultConcreteType + + private[sql] override def isParentOf(childCandidate: DataType): Boolean = false + + private[sql] override def simpleString: String = { + types.map(_.simpleString).mkString("(", " or ", ")") + } +} + + +private[sql] object TypeCollection { + + def apply(types: AbstractDataType*): TypeCollection = new TypeCollection(types) + + def unapply(typ: AbstractDataType): Option[Seq[AbstractDataType]] = typ match { + case typ: TypeCollection => Some(typ.types) + case _ => None + } +} + + +/** + * An internal type used to represent everything that is not null, UDTs, arrays, structs, and maps. + */ +protected[sql] abstract class AtomicType extends DataType { + private[sql] type InternalType + @transient private[sql] val tag: TypeTag[InternalType] + private[sql] val ordering: Ordering[InternalType] + + @transient private[sql] val classTag = ScalaReflectionLock.synchronized { + val mirror = runtimeMirror(Utils.getSparkClassLoader) + ClassTag[InternalType](mirror.runtimeClass(tag.tpe)) + } +} + + +/** + * :: DeveloperApi :: + * Numeric data types. + */ +abstract class NumericType extends AtomicType { + // Unfortunately we can't get this implicitly as that breaks Spark Serialization. In order for + // implicitly[Numeric[JvmType]] to be valid, we have to change JvmType from a type variable to a + // type parameter and add a numeric annotation (i.e., [JvmType : Numeric]). This gets + // desugared by the compiler into an argument to the objects constructor. This means there is no + // longer an no argument constructor and thus the JVM cannot serialize the object anymore. + private[sql] val numeric: Numeric[InternalType] +} + + +private[sql] object NumericType { + /** + * Enables matching against NumericType for expressions: + * {{{ + * case Cast(child @ NumericType(), StringType) => + * ... + * }}} + */ + def unapply(e: Expression): Boolean = e.dataType.isInstanceOf[NumericType] +} + + +private[sql] object IntegralType { + /** + * Enables matching against IntegralType for expressions: + * {{{ + * case Cast(child @ IntegralType(), StringType) => + * ... + * }}} + */ + def unapply(e: Expression): Boolean = e.dataType.isInstanceOf[IntegralType] +} + + +private[sql] abstract class IntegralType extends NumericType { + private[sql] val integral: Integral[InternalType] +} + + +private[sql] object FractionalType { + /** + * Enables matching against FractionalType for expressions: + * {{{ + * case Cast(child @ FractionalType(), StringType) => + * ... + * }}} + */ + def unapply(e: Expression): Boolean = e.dataType.isInstanceOf[FractionalType] +} + + +private[sql] abstract class FractionalType extends NumericType { + private[sql] val fractional: Fractional[InternalType] + private[sql] val asIntegral: Integral[InternalType] +} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/ArrayType.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/ArrayType.scala index b116163facca..43413ec761e6 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/ArrayType.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/ArrayType.scala @@ -22,9 +22,17 @@ import org.json4s.JsonDSL._ import org.apache.spark.annotation.DeveloperApi -object ArrayType { +object ArrayType extends AbstractDataType { /** Construct a [[ArrayType]] object with the given element type. The `containsNull` is true. */ def apply(elementType: DataType): ArrayType = ArrayType(elementType, containsNull = true) + + private[sql] override def defaultConcreteType: DataType = ArrayType(NullType, containsNull = true) + + private[sql] override def isParentOf(childCandidate: DataType): Boolean = { + childCandidate.isInstanceOf[ArrayType] + } + + private[sql] override def simpleString: String = "array" } @@ -41,8 +49,6 @@ object ArrayType { * * @param elementType The data type of values. * @param containsNull Indicates if values have `null` values - * - * @group dataType */ @DeveloperApi case class ArrayType(elementType: DataType, containsNull: Boolean) extends DataType { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/BinaryType.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/BinaryType.scala index a581a9e9468e..f2c6f34ea51c 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/BinaryType.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/BinaryType.scala @@ -22,14 +22,13 @@ import scala.reflect.runtime.universe.typeTag import org.apache.spark.annotation.DeveloperApi import org.apache.spark.sql.catalyst.ScalaReflectionLock +import org.apache.spark.sql.catalyst.util.TypeUtils /** * :: DeveloperApi :: * The data type representing `Array[Byte]` values. * Please use the singleton [[DataTypes.BinaryType]]. - * - * @group dataType */ @DeveloperApi class BinaryType private() extends AtomicType { @@ -43,11 +42,7 @@ class BinaryType private() extends AtomicType { private[sql] val ordering = new Ordering[InternalType] { def compare(x: Array[Byte], y: Array[Byte]): Int = { - for (i <- 0 until x.length; if i < y.length) { - val res = x(i).compareTo(y(i)) - if (res != 0) return res - } - x.length - y.length + TypeUtils.compareBinary(x, y) } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/BooleanType.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/BooleanType.scala index a7f228cefa57..2d8ee3d9bc28 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/BooleanType.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/BooleanType.scala @@ -27,8 +27,6 @@ import org.apache.spark.sql.catalyst.ScalaReflectionLock /** * :: DeveloperApi :: * The data type representing `Boolean` values. Please use the singleton [[DataTypes.BooleanType]]. - * - *@group dataType */ @DeveloperApi class BooleanType private() extends AtomicType { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/ByteType.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/ByteType.scala index 4d8685796ec7..2ca427975a1c 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/ByteType.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/ByteType.scala @@ -27,8 +27,6 @@ import org.apache.spark.sql.catalyst.ScalaReflectionLock /** * :: DeveloperApi :: * The data type representing `Byte` values. Please use the singleton [[DataTypes.ByteType]]. - * - * @group dataType */ @DeveloperApi class ByteType private() extends IntegralType { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DataType.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DataType.scala index a0b261649f66..7d00047d08d7 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DataType.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DataType.scala @@ -17,8 +17,6 @@ package org.apache.spark.sql.types -import scala.reflect.ClassTag -import scala.reflect.runtime.universe.{TypeTag, runtimeMirror} import scala.util.parsing.combinator.RegexParsers import org.json4s._ @@ -27,19 +25,15 @@ import org.json4s.JsonDSL._ import org.json4s.jackson.JsonMethods._ import org.apache.spark.annotation.DeveloperApi -import org.apache.spark.sql.catalyst.ScalaReflectionLock import org.apache.spark.sql.catalyst.expressions.Expression -import org.apache.spark.util.Utils /** * :: DeveloperApi :: * The base type of all Spark SQL data types. - * - * @group dataType */ @DeveloperApi -abstract class DataType { +abstract class DataType extends AbstractDataType { /** * Enables matching against DataType for expressions: * {{{ @@ -80,84 +74,10 @@ abstract class DataType { * (`StructField.nullable`, `ArrayType.containsNull`, and `MapType.valueContainsNull`). */ private[spark] def asNullable: DataType -} - - -/** - * An internal type used to represent everything that is not null, UDTs, arrays, structs, and maps. - */ -protected[sql] abstract class AtomicType extends DataType { - private[sql] type InternalType - @transient private[sql] val tag: TypeTag[InternalType] - private[sql] val ordering: Ordering[InternalType] - - @transient private[sql] val classTag = ScalaReflectionLock.synchronized { - val mirror = runtimeMirror(Utils.getSparkClassLoader) - ClassTag[InternalType](mirror.runtimeClass(tag.tpe)) - } -} - - -/** - * :: DeveloperApi :: - * Numeric data types. - * - * @group dataType - */ -abstract class NumericType extends AtomicType { - // Unfortunately we can't get this implicitly as that breaks Spark Serialization. In order for - // implicitly[Numeric[JvmType]] to be valid, we have to change JvmType from a type variable to a - // type parameter and and add a numeric annotation (i.e., [JvmType : Numeric]). This gets - // desugared by the compiler into an argument to the objects constructor. This means there is no - // longer an no argument constructor and thus the JVM cannot serialize the object anymore. - private[sql] val numeric: Numeric[InternalType] -} - - -private[sql] object NumericType { - /** - * Enables matching against NumericType for expressions: - * {{{ - * case Cast(child @ NumericType(), StringType) => - * ... - * }}} - */ - def unapply(e: Expression): Boolean = e.dataType.isInstanceOf[NumericType] -} - - -private[sql] object IntegralType { - /** - * Enables matching against IntegralType for expressions: - * {{{ - * case Cast(child @ IntegralType(), StringType) => - * ... - * }}} - */ - def unapply(e: Expression): Boolean = e.dataType.isInstanceOf[IntegralType] -} - - -private[sql] abstract class IntegralType extends NumericType { - private[sql] val integral: Integral[InternalType] -} + private[sql] override def defaultConcreteType: DataType = this -private[sql] object FractionalType { - /** - * Enables matching against FractionalType for expressions: - * {{{ - * case Cast(child @ FractionalType(), StringType) => - * ... - * }}} - */ - def unapply(e: Expression): Boolean = e.dataType.isInstanceOf[FractionalType] -} - - -private[sql] abstract class FractionalType extends NumericType { - private[sql] val fractional: Fractional[InternalType] - private[sql] val asIntegral: Integral[InternalType] + private[sql] override def isParentOf(childCandidate: DataType): Boolean = this == childCandidate } @@ -165,6 +85,9 @@ object DataType { def fromJson(json: String): DataType = parseDataType(parse(json)) + /** + * @deprecated As of 1.2.0, replaced by `DataType.fromJson()` + */ @deprecated("Use DataType.fromJson instead", "1.2.0") def fromCaseClassString(string: String): DataType = CaseClassStringParser(string) @@ -271,7 +194,7 @@ object DataType { protected lazy val structField: Parser[StructField] = ("StructField(" ~> "[a-zA-Z0-9_]*".r) ~ ("," ~> dataType) ~ ("," ~> boolVal <~ ")") ^^ { - case name ~ tpe ~ nullable => + case name ~ tpe ~ nullable => StructField(name, tpe, nullable = nullable) } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DataTypeParser.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DataTypeParser.scala index 04f3379afb38..6b43224feb1f 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DataTypeParser.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DataTypeParser.scala @@ -44,7 +44,7 @@ private[sql] trait DataTypeParser extends StandardTokenParsers { "(?i)tinyint".r ^^^ ByteType | "(?i)smallint".r ^^^ ShortType | "(?i)double".r ^^^ DoubleType | - "(?i)bigint".r ^^^ LongType | + "(?i)(?:bigint|long)".r ^^^ LongType | "(?i)binary".r ^^^ BinaryType | "(?i)boolean".r ^^^ BooleanType | fixedDecimalType | diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DateType.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DateType.scala index 03f0644bc784..1d73e40ffcd3 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DateType.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DateType.scala @@ -26,10 +26,11 @@ import org.apache.spark.sql.catalyst.ScalaReflectionLock /** * :: DeveloperApi :: - * The data type representing `java.sql.Date` values. + * A date type, supporting "0001-01-01" through "9999-12-31". + * * Please use the singleton [[DataTypes.DateType]]. * - * @group dataType + * Internally, this is represented as the number of days from epoch (1970-01-01 00:00:00 UTC). */ @DeveloperApi class DateType private() extends AtomicType { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/Decimal.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/Decimal.scala index 994c5202c15d..5a169488c97e 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/Decimal.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/Decimal.scala @@ -17,6 +17,8 @@ package org.apache.spark.sql.types +import java.math.{MathContext, RoundingMode} + import org.apache.spark.annotation.DeveloperApi /** @@ -86,7 +88,7 @@ final class Decimal extends Ordered[Decimal] with Serializable { if (precision < 19) { return null // Requested precision is too low to represent this value } - this.decimalVal = BigDecimal(longVal) + this.decimalVal = BigDecimal(unscaled) this.longVal = 0L } else { val p = POW_10(math.min(precision, MAX_LONG_DIGITS)) @@ -137,9 +139,9 @@ final class Decimal extends Ordered[Decimal] with Serializable { def toBigDecimal: BigDecimal = { if (decimalVal.ne(null)) { - decimalVal + decimalVal(MathContext.UNLIMITED) } else { - BigDecimal(longVal, _scale) + BigDecimal(longVal, _scale)(MathContext.UNLIMITED) } } @@ -263,8 +265,15 @@ final class Decimal extends Ordered[Decimal] with Serializable { def * (that: Decimal): Decimal = Decimal(toBigDecimal * that.toBigDecimal) - def / (that: Decimal): Decimal = - if (that.isZero) null else Decimal(toBigDecimal / that.toBigDecimal) + def / (that: Decimal): Decimal = { + if (that.isZero) { + null + } else { + // To avoid non-terminating decimal expansion problem, we turn to Java BigDecimal's divide + // with specified ROUNDING_MODE. + Decimal(toJavaBigDecimal.divide(that.toJavaBigDecimal, ROUNDING_MODE.id)) + } + } def % (that: Decimal): Decimal = if (that.isZero) null else Decimal(toBigDecimal % that.toBigDecimal) @@ -313,7 +322,7 @@ object Decimal { // See scala.math's Numeric.scala for examples for Scala's built-in types. /** Common methods for Decimal evidence parameters */ - trait DecimalIsConflicted extends Numeric[Decimal] { + private[sql] trait DecimalIsConflicted extends Numeric[Decimal] { override def plus(x: Decimal, y: Decimal): Decimal = x + y override def times(x: Decimal, y: Decimal): Decimal = x * y override def minus(x: Decimal, y: Decimal): Decimal = x - y @@ -327,12 +336,12 @@ object Decimal { } /** A [[scala.math.Fractional]] evidence parameter for Decimals. */ - object DecimalIsFractional extends DecimalIsConflicted with Fractional[Decimal] { + private[sql] object DecimalIsFractional extends DecimalIsConflicted with Fractional[Decimal] { override def div(x: Decimal, y: Decimal): Decimal = x / y } /** A [[scala.math.Integral]] evidence parameter for Decimals. */ - object DecimalAsIfIntegral extends DecimalIsConflicted with Integral[Decimal] { + private[sql] object DecimalAsIfIntegral extends DecimalIsConflicted with Integral[Decimal] { override def quot(x: Decimal, y: Decimal): Decimal = x / y override def rem(x: Decimal, y: Decimal): Decimal = x % y } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DecimalType.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DecimalType.scala index 0f8cecd28f7d..127b16ff85be 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DecimalType.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DecimalType.scala @@ -20,13 +20,18 @@ package org.apache.spark.sql.types import scala.reflect.runtime.universe.typeTag import org.apache.spark.annotation.DeveloperApi +import org.apache.spark.sql.AnalysisException import org.apache.spark.sql.catalyst.ScalaReflectionLock import org.apache.spark.sql.catalyst.expressions.Expression /** Precision parameters for a Decimal */ -case class PrecisionInfo(precision: Int, scale: Int) - +case class PrecisionInfo(precision: Int, scale: Int) { + if (scale > precision) { + throw new AnalysisException( + s"Decimal scale ($scale) cannot be greater than precision ($precision).") + } +} /** * :: DeveloperApi :: @@ -34,8 +39,6 @@ case class PrecisionInfo(precision: Int, scale: Int) * A Decimal that might have fixed precision and scale, or unlimited values for these. * * Please use [[DataTypes.createDecimalType()]] to create a specific instance. - * - * @group dataType */ @DeveloperApi case class DecimalType(precisionInfo: Option[PrecisionInfo]) extends FractionalType { @@ -79,15 +82,24 @@ case class DecimalType(precisionInfo: Option[PrecisionInfo]) extends FractionalT /** Extra factory methods and pattern matchers for Decimals */ -object DecimalType { +object DecimalType extends AbstractDataType { + + private[sql] override def defaultConcreteType: DataType = Unlimited + + private[sql] override def isParentOf(childCandidate: DataType): Boolean = { + childCandidate.isInstanceOf[DecimalType] + } + + private[sql] override def simpleString: String = "decimal" + val Unlimited: DecimalType = DecimalType(None) - object Fixed { + private[sql] object Fixed { def unapply(t: DecimalType): Option[(Int, Int)] = t.precisionInfo.map(p => (p.precision, p.scale)) } - object Expression { + private[sql] object Expression { def unapply(e: Expression): Option[(Int, Int)] = e.dataType match { case t: DecimalType => t.precisionInfo.map(p => (p.precision, p.scale)) case _ => None diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DoubleType.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DoubleType.scala index 66766623213c..986c2ab05538 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DoubleType.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DoubleType.scala @@ -27,8 +27,6 @@ import org.apache.spark.sql.catalyst.ScalaReflectionLock /** * :: DeveloperApi :: * The data type representing `Double` values. Please use the singleton [[DataTypes.DoubleType]]. - * - * @group dataType */ @DeveloperApi class DoubleType private() extends FractionalType { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/FloatType.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/FloatType.scala index 1d5a2f4f6f86..9bd48ece83a1 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/FloatType.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/FloatType.scala @@ -27,8 +27,6 @@ import org.apache.spark.sql.catalyst.ScalaReflectionLock /** * :: DeveloperApi :: * The data type representing `Float` values. Please use the singleton [[DataTypes.FloatType]]. - * - * @group dataType */ @DeveloperApi class FloatType private() extends FractionalType { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/IntegerType.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/IntegerType.scala index 74e464c08287..a2c6e19b05b3 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/IntegerType.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/IntegerType.scala @@ -27,8 +27,6 @@ import org.apache.spark.sql.catalyst.ScalaReflectionLock /** * :: DeveloperApi :: * The data type representing `Int` values. Please use the singleton [[DataTypes.IntegerType]]. - * - * @group dataType */ @DeveloperApi class IntegerType private() extends IntegralType { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/LongType.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/LongType.scala index 390675782e5f..2b3adf6ade83 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/LongType.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/LongType.scala @@ -26,8 +26,6 @@ import org.apache.spark.sql.catalyst.ScalaReflectionLock /** * :: DeveloperApi :: * The data type representing `Long` values. Please use the singleton [[DataTypes.LongType]]. - * - * @group dataType */ @DeveloperApi class LongType private() extends IntegralType { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/MapType.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/MapType.scala index cfdf49307441..868dea13d971 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/MapType.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/MapType.scala @@ -30,8 +30,6 @@ import org.json4s.JsonDSL._ * @param keyType The data type of map keys. * @param valueType The data type of map values. * @param valueContainsNull Indicates if map values have `null` values. - * - * @group dataType */ case class MapType( keyType: DataType, @@ -69,7 +67,16 @@ case class MapType( } -object MapType { +object MapType extends AbstractDataType { + + private[sql] override def defaultConcreteType: DataType = apply(NullType, NullType) + + private[sql] override def isParentOf(childCandidate: DataType): Boolean = { + childCandidate.isInstanceOf[MapType] + } + + private[sql] override def simpleString: String = "map" + /** * Construct a [[MapType]] object with the given key type and value type. * The `valueContainsNull` is true. diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/NullType.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/NullType.scala index b64b07431fa9..aa84115c2e42 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/NullType.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/NullType.scala @@ -23,8 +23,6 @@ import org.apache.spark.annotation.DeveloperApi /** * :: DeveloperApi :: * The data type representing `NULL` values. Please use the singleton [[DataTypes.NullType]]. - * - * @group dataType */ @DeveloperApi class NullType private() extends DataType { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/SQLUserDefinedType.java b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/SQLUserDefinedType.java index a64d2bb7cde3..df64a878b6b3 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/SQLUserDefinedType.java +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/SQLUserDefinedType.java @@ -24,11 +24,11 @@ /** * ::DeveloperApi:: * A user-defined type which can be automatically recognized by a SQLContext and registered. - * + *

    * WARNING: This annotation will only work if both Java and Scala reflection return the same class * names (after erasure) for the UDT. This will NOT be the case when, e.g., the UDT class * is enclosed in an object (a singleton). - * + *

    * WARNING: UDTs are currently only supported from Scala. */ // TODO: Should I used @Documented ? diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/ShortType.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/ShortType.scala index 73e9ec780b0a..a13119e65906 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/ShortType.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/ShortType.scala @@ -26,8 +26,6 @@ import org.apache.spark.sql.catalyst.ScalaReflectionLock /** * :: DeveloperApi :: * The data type representing `Short` values. Please use the singleton [[DataTypes.ShortType]]. - * - * @group dataType */ @DeveloperApi class ShortType private() extends IntegralType { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/StringType.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/StringType.scala index 134ab0af4e0d..a7627a2de161 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/StringType.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/StringType.scala @@ -22,12 +22,11 @@ import scala.reflect.runtime.universe.typeTag import org.apache.spark.annotation.DeveloperApi import org.apache.spark.sql.catalyst.ScalaReflectionLock +import org.apache.spark.unsafe.types.UTF8String /** * :: DeveloperApi :: * The data type representing `String` values. Please use the singleton [[DataTypes.StringType]]. - * - * @group dataType */ @DeveloperApi class StringType private() extends AtomicType { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/StructType.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/StructType.scala index 7e00a27dfe72..3b17566d54d9 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/StructType.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/StructType.scala @@ -87,14 +87,12 @@ import org.apache.spark.sql.catalyst.expressions.{AttributeReference, Attribute} * val row = Row(Row(1, 2, true)) * // row: Row = [[1,2,true]] * }}} - * - * @group dataType */ @DeveloperApi case class StructType(fields: Array[StructField]) extends DataType with Seq[StructField] { /** No-arg constructor for kryo. */ - protected def this() = this(null) + def this() = this(Array.empty[StructField]) /** Returns all field names in an array. */ def fieldNames: Array[String] = fields.map(_.name) @@ -103,6 +101,108 @@ case class StructType(fields: Array[StructField]) extends DataType with Seq[Stru private lazy val nameToField: Map[String, StructField] = fields.map(f => f.name -> f).toMap private lazy val nameToIndex: Map[String, Int] = fieldNames.zipWithIndex.toMap + /** + * Creates a new [[StructType]] by adding a new field. + * {{{ + * val struct = (new StructType) + * .add(StructField("a", IntegerType, true)) + * .add(StructField("b", LongType, false)) + * .add(StructField("c", StringType, true)) + *}}} + */ + def add(field: StructField): StructType = { + StructType(fields :+ field) + } + + /** + * Creates a new [[StructType]] by adding a new nullable field with no metadata. + * + * val struct = (new StructType) + * .add("a", IntegerType) + * .add("b", LongType) + * .add("c", StringType) + */ + def add(name: String, dataType: DataType): StructType = { + StructType(fields :+ new StructField(name, dataType, nullable = true, Metadata.empty)) + } + + /** + * Creates a new [[StructType]] by adding a new field with no metadata. + * + * val struct = (new StructType) + * .add("a", IntegerType, true) + * .add("b", LongType, false) + * .add("c", StringType, true) + */ + def add(name: String, dataType: DataType, nullable: Boolean): StructType = { + StructType(fields :+ new StructField(name, dataType, nullable, Metadata.empty)) + } + + /** + * Creates a new [[StructType]] by adding a new field and specifying metadata. + * {{{ + * val struct = (new StructType) + * .add("a", IntegerType, true, Metadata.empty) + * .add("b", LongType, false, Metadata.empty) + * .add("c", StringType, true, Metadata.empty) + * }}} + */ + def add( + name: String, + dataType: DataType, + nullable: Boolean, + metadata: Metadata): StructType = { + StructType(fields :+ new StructField(name, dataType, nullable, metadata)) + } + + /** + * Creates a new [[StructType]] by adding a new nullable field with no metadata where the + * dataType is specified as a String. + * + * {{{ + * val struct = (new StructType) + * .add("a", "int") + * .add("b", "long") + * .add("c", "string") + * }}} + */ + def add(name: String, dataType: String): StructType = { + add(name, DataTypeParser.parse(dataType), nullable = true, Metadata.empty) + } + + /** + * Creates a new [[StructType]] by adding a new field with no metadata where the + * dataType is specified as a String. + * + * {{{ + * val struct = (new StructType) + * .add("a", "int", true) + * .add("b", "long", false) + * .add("c", "string", true) + * }}} + */ + def add(name: String, dataType: String, nullable: Boolean): StructType = { + add(name, DataTypeParser.parse(dataType), nullable, Metadata.empty) + } + + /** + * Creates a new [[StructType]] by adding a new field and specifying metadata where the + * dataType is specified as a String. + * {{{ + * val struct = (new StructType) + * .add("a", "int", true, Metadata.empty) + * .add("b", "long", false, Metadata.empty) + * .add("c", "string", true, Metadata.empty) + * }}} + */ + def add( + name: String, + dataType: String, + nullable: Boolean, + metadata: Metadata): StructType = { + add(name, DataTypeParser.parse(dataType), nullable, metadata) + } + /** * Extracts a [[StructField]] of the given name. If the [[StructType]] object does not * have a name matching the given name, `null` will be returned. @@ -201,7 +301,15 @@ case class StructType(fields: Array[StructField]) extends DataType with Seq[Stru } -object StructType { +object StructType extends AbstractDataType { + + private[sql] override def defaultConcreteType: DataType = new StructType + + private[sql] override def isParentOf(childCandidate: DataType): Boolean = { + childCandidate.isInstanceOf[StructType] + } + + private[sql] override def simpleString: String = "struct" def apply(fields: Seq[StructField]): StructType = StructType(fields.toArray) @@ -230,10 +338,10 @@ object StructType { case (StructType(leftFields), StructType(rightFields)) => val newFields = ArrayBuffer.empty[StructField] + val rightMapped = fieldsMap(rightFields) leftFields.foreach { case leftField @ StructField(leftName, leftType, leftNullable, _) => - rightFields - .find(_.name == leftName) + rightMapped.get(leftName) .map { case rightField @ StructField(_, rightType, rightNullable, _) => leftField.copy( dataType = merge(leftType, rightType), @@ -243,8 +351,9 @@ object StructType { .foreach(newFields += _) } + val leftMapped = fieldsMap(leftFields) rightFields - .filterNot(f => leftFields.map(_.name).contains(f.name)) + .filterNot(f => leftMapped.get(f.name).nonEmpty) .foreach(newFields += _) StructType(newFields) @@ -264,4 +373,9 @@ object StructType { case _ => throw new SparkException(s"Failed to merge incompatible data types $left and $right") } + + private[sql] def fieldsMap(fields: Array[StructField]): Map[String, StructField] = { + import scala.collection.breakOut + fields.map(s => (s.name, s))(breakOut) + } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/TimestampType.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/TimestampType.scala index aebabfc47592..de4b511edccd 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/TimestampType.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/TimestampType.scala @@ -17,8 +17,6 @@ package org.apache.spark.sql.types -import java.sql.Timestamp - import scala.math.Ordering import scala.reflect.runtime.universe.typeTag @@ -30,26 +28,22 @@ import org.apache.spark.sql.catalyst.ScalaReflectionLock * :: DeveloperApi :: * The data type representing `java.sql.Timestamp` values. * Please use the singleton [[DataTypes.TimestampType]]. - * - * @group dataType */ @DeveloperApi class TimestampType private() extends AtomicType { // The companion object and this class is separated so the companion object also subclasses // this type. Otherwise, the companion object would be of type "TimestampType$" in byte code. // Defined with a private constructor so the companion object is the only possible instantiation. - private[sql] type InternalType = Timestamp + private[sql] type InternalType = Long @transient private[sql] lazy val tag = ScalaReflectionLock.synchronized { typeTag[InternalType] } - private[sql] val ordering = new Ordering[InternalType] { - def compare(x: Timestamp, y: Timestamp): Int = x.compareTo(y) - } + private[sql] val ordering = implicitly[Ordering[InternalType]] /** * The default size of a value of the TimestampType is 12 bytes. */ - override def defaultSize: Int = 12 + override def defaultSize: Int = 8 private[spark] override def asNullable: TimestampType = this } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/UTF8String.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/UTF8String.scala deleted file mode 100644 index bc9c37bf2d5d..000000000000 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/UTF8String.scala +++ /dev/null @@ -1,221 +0,0 @@ -/* -* Licensed to the Apache Software Foundation (ASF) under one or more -* contributor license agreements. See the NOTICE file distributed with -* this work for additional information regarding copyright ownership. -* The ASF licenses this file to You under the Apache License, Version 2.0 -* (the "License"); you may not use this file except in compliance with -* the License. You may obtain a copy of the License at -* -* http://www.apache.org/licenses/LICENSE-2.0 -* -* Unless required by applicable law or agreed to in writing, software -* distributed under the License is distributed on an "AS IS" BASIS, -* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -* See the License for the specific language governing permissions and -* limitations under the License. -*/ - -package org.apache.spark.sql.types - -import java.util.Arrays - -import org.apache.spark.annotation.DeveloperApi - -/** - * :: DeveloperApi :: - * A UTF-8 String, as internal representation of StringType in SparkSQL - * - * A String encoded in UTF-8 as an Array[Byte], which can be used for comparison, - * search, see http://en.wikipedia.org/wiki/UTF-8 for details. - * - * Note: This is not designed for general use cases, should not be used outside SQL. - */ -@DeveloperApi -final class UTF8String extends Ordered[UTF8String] with Serializable { - - private[this] var bytes: Array[Byte] = _ - - /** - * Update the UTF8String with String. - */ - def set(str: String): UTF8String = { - bytes = str.getBytes("utf-8") - this - } - - /** - * Update the UTF8String with Array[Byte], which should be encoded in UTF-8 - */ - def set(bytes: Array[Byte]): UTF8String = { - this.bytes = bytes - this - } - - /** - * Return the number of bytes for a code point with the first byte as `b` - * @param b The first byte of a code point - */ - @inline - private[this] def numOfBytes(b: Byte): Int = { - val offset = (b & 0xFF) - 192 - if (offset >= 0) UTF8String.bytesOfCodePointInUTF8(offset) else 1 - } - - /** - * Return the number of code points in it. - * - * This is only used by Substring() when `start` is negative. - */ - def length(): Int = { - var len = 0 - var i: Int = 0 - while (i < bytes.length) { - i += numOfBytes(bytes(i)) - len += 1 - } - len - } - - def getBytes: Array[Byte] = { - bytes - } - - /** - * Return a substring of this, - * @param start the position of first code point - * @param until the position after last code point - */ - def slice(start: Int, until: Int): UTF8String = { - if (until <= start || start >= bytes.length || bytes == null) { - new UTF8String - } - - var c = 0 - var i: Int = 0 - while (c < start && i < bytes.length) { - i += numOfBytes(bytes(i)) - c += 1 - } - var j = i - while (c < until && j < bytes.length) { - j += numOfBytes(bytes(j)) - c += 1 - } - UTF8String(Arrays.copyOfRange(bytes, i, j)) - } - - def contains(sub: UTF8String): Boolean = { - val b = sub.getBytes - if (b.length == 0) { - return true - } - var i: Int = 0 - while (i <= bytes.length - b.length) { - // In worst case, it's O(N*K), but should works fine with SQL - if (bytes(i) == b(0) && Arrays.equals(Arrays.copyOfRange(bytes, i, i + b.length), b)) { - return true - } - i += 1 - } - false - } - - def startsWith(prefix: UTF8String): Boolean = { - val b = prefix.getBytes - if (b.length > bytes.length) { - return false - } - Arrays.equals(Arrays.copyOfRange(bytes, 0, b.length), b) - } - - def endsWith(suffix: UTF8String): Boolean = { - val b = suffix.getBytes - if (b.length > bytes.length) { - return false - } - Arrays.equals(Arrays.copyOfRange(bytes, bytes.length - b.length, bytes.length), b) - } - - def toUpperCase(): UTF8String = { - // upper case depends on locale, fallback to String. - UTF8String(toString().toUpperCase) - } - - def toLowerCase(): UTF8String = { - // lower case depends on locale, fallback to String. - UTF8String(toString().toLowerCase) - } - - override def toString(): String = { - new String(bytes, "utf-8") - } - - override def clone(): UTF8String = new UTF8String().set(this.bytes) - - override def compare(other: UTF8String): Int = { - var i: Int = 0 - val b = other.getBytes - while (i < bytes.length && i < b.length) { - val res = bytes(i).compareTo(b(i)) - if (res != 0) return res - i += 1 - } - bytes.length - b.length - } - - override def compareTo(other: UTF8String): Int = { - compare(other) - } - - override def equals(other: Any): Boolean = other match { - case s: UTF8String => - Arrays.equals(bytes, s.getBytes) - case s: String => - // This is only used for Catalyst unit tests - // fail fast - bytes.length >= s.length && length() == s.length && toString() == s - case _ => - false - } - - override def hashCode(): Int = { - Arrays.hashCode(bytes) - } -} - -/** - * :: DeveloperApi :: - */ -@DeveloperApi -object UTF8String { - // number of tailing bytes in a UTF8 sequence for a code point - // see http://en.wikipedia.org/wiki/UTF-8, 192-256 of Byte 1 - private[types] val bytesOfCodePointInUTF8: Array[Int] = Array(2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, - 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, - 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, - 4, 4, 4, 4, 4, 4, 4, 4, - 5, 5, 5, 5, - 6, 6, 6, 6) - - /** - * Create a UTF-8 String from String - */ - def apply(s: String): UTF8String = { - if (s != null) { - new UTF8String().set(s) - } else{ - null - } - } - - /** - * Create a UTF-8 String from Array[Byte], which should be encoded in UTF-8 - */ - def apply(bytes: Array[Byte]): UTF8String = { - if (bytes != null) { - new UTF8String().set(bytes) - } else { - null - } - } -} diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/RandomDataGenerator.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/RandomDataGenerator.scala new file mode 100644 index 000000000000..13aad467fa57 --- /dev/null +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/RandomDataGenerator.scala @@ -0,0 +1,158 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql + +import java.lang.Double.longBitsToDouble +import java.lang.Float.intBitsToFloat +import java.math.MathContext + +import scala.util.Random + +import org.apache.spark.sql.types._ + +/** + * Random data generators for Spark SQL DataTypes. These generators do not generate uniformly random + * values; instead, they're biased to return "interesting" values (such as maximum / minimum values) + * with higher probability. + */ +object RandomDataGenerator { + + /** + * The conditional probability of a non-null value being drawn from a set of "interesting" values + * instead of being chosen uniformly at random. + */ + private val PROBABILITY_OF_INTERESTING_VALUE: Float = 0.5f + + /** + * The probability of the generated value being null + */ + private val PROBABILITY_OF_NULL: Float = 0.1f + + private val MAX_STR_LEN: Int = 1024 + private val MAX_ARR_SIZE: Int = 128 + private val MAX_MAP_SIZE: Int = 128 + + /** + * Helper function for constructing a biased random number generator which returns "interesting" + * values with a higher probability. + */ + private def randomNumeric[T]( + rand: Random, + uniformRand: Random => T, + interestingValues: Seq[T]): Some[() => T] = { + val f = () => { + if (rand.nextFloat() <= PROBABILITY_OF_INTERESTING_VALUE) { + interestingValues(rand.nextInt(interestingValues.length)) + } else { + uniformRand(rand) + } + } + Some(f) + } + + /** + * Returns a function which generates random values for the given [[DataType]], or `None` if no + * random data generator is defined for that data type. The generated values will use an external + * representation of the data type; for example, the random generator for [[DateType]] will return + * instances of [[java.sql.Date]] and the generator for [[StructType]] will return a + * [[org.apache.spark.Row]]. + * + * @param dataType the type to generate values for + * @param nullable whether null values should be generated + * @param seed an optional seed for the random number generator + * @return a function which can be called to generate random values. + */ + def forType( + dataType: DataType, + nullable: Boolean = true, + seed: Option[Long] = None): Option[() => Any] = { + val rand = new Random() + seed.foreach(rand.setSeed) + + val valueGenerator: Option[() => Any] = dataType match { + case StringType => Some(() => rand.nextString(rand.nextInt(MAX_STR_LEN))) + case BinaryType => Some(() => { + val arr = new Array[Byte](rand.nextInt(MAX_STR_LEN)) + rand.nextBytes(arr) + arr + }) + case BooleanType => Some(() => rand.nextBoolean()) + case DateType => Some(() => new java.sql.Date(rand.nextInt())) + case TimestampType => Some(() => new java.sql.Timestamp(rand.nextLong())) + case DecimalType.Unlimited => Some( + () => BigDecimal.apply(rand.nextLong, rand.nextInt, MathContext.UNLIMITED)) + case DoubleType => randomNumeric[Double]( + rand, r => longBitsToDouble(r.nextLong()), Seq(Double.MinValue, Double.MinPositiveValue, + Double.MaxValue, Double.PositiveInfinity, Double.NegativeInfinity, Double.NaN, 0.0)) + case FloatType => randomNumeric[Float]( + rand, r => intBitsToFloat(r.nextInt()), Seq(Float.MinValue, Float.MinPositiveValue, + Float.MaxValue, Float.PositiveInfinity, Float.NegativeInfinity, Float.NaN, 0.0f)) + case ByteType => randomNumeric[Byte]( + rand, _.nextInt().toByte, Seq(Byte.MinValue, Byte.MaxValue, 0.toByte)) + case IntegerType => randomNumeric[Int]( + rand, _.nextInt(), Seq(Int.MinValue, Int.MaxValue, 0)) + case LongType => randomNumeric[Long]( + rand, _.nextLong(), Seq(Long.MinValue, Long.MaxValue, 0L)) + case ShortType => randomNumeric[Short]( + rand, _.nextInt().toShort, Seq(Short.MinValue, Short.MaxValue, 0.toShort)) + case NullType => Some(() => null) + case ArrayType(elementType, containsNull) => { + forType(elementType, nullable = containsNull, seed = Some(rand.nextLong())).map { + elementGenerator => () => Array.fill(rand.nextInt(MAX_ARR_SIZE))(elementGenerator()) + } + } + case MapType(keyType, valueType, valueContainsNull) => { + for ( + keyGenerator <- forType(keyType, nullable = false, seed = Some(rand.nextLong())); + valueGenerator <- + forType(valueType, nullable = valueContainsNull, seed = Some(rand.nextLong())) + ) yield { + () => { + Seq.fill(rand.nextInt(MAX_MAP_SIZE))((keyGenerator(), valueGenerator())).toMap + } + } + } + case StructType(fields) => { + val maybeFieldGenerators: Seq[Option[() => Any]] = fields.map { field => + forType(field.dataType, nullable = field.nullable, seed = Some(rand.nextLong())) + } + if (maybeFieldGenerators.forall(_.isDefined)) { + val fieldGenerators: Seq[() => Any] = maybeFieldGenerators.map(_.get) + Some(() => Row.fromSeq(fieldGenerators.map(_.apply()))) + } else { + None + } + } + case unsupportedType => None + } + // Handle nullability by wrapping the non-null value generator: + valueGenerator.map { valueGenerator => + if (nullable) { + () => { + if (rand.nextFloat() <= PROBABILITY_OF_NULL) { + null + } else { + valueGenerator() + } + } + } else { + valueGenerator + } + } + } +} diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/RandomDataGeneratorSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/RandomDataGeneratorSuite.scala new file mode 100644 index 000000000000..dbba93dba668 --- /dev/null +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/RandomDataGeneratorSuite.scala @@ -0,0 +1,98 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql + +import org.apache.spark.SparkFunSuite +import org.apache.spark.sql.catalyst.CatalystTypeConverters +import org.apache.spark.sql.types._ + +/** + * Tests of [[RandomDataGenerator]]. + */ +class RandomDataGeneratorSuite extends SparkFunSuite { + + /** + * Tests random data generation for the given type by using it to generate random values then + * converting those values into their Catalyst equivalents using CatalystTypeConverters. + */ + def testRandomDataGeneration(dataType: DataType, nullable: Boolean = true): Unit = { + val toCatalyst = CatalystTypeConverters.createToCatalystConverter(dataType) + val generator = RandomDataGenerator.forType(dataType, nullable).getOrElse { + fail(s"Random data generator was not defined for $dataType") + } + if (nullable) { + assert(Iterator.fill(100)(generator()).contains(null)) + } else { + assert(Iterator.fill(100)(generator()).forall(_ != null)) + } + for (_ <- 1 to 10) { + val generatedValue = generator() + toCatalyst(generatedValue) + } + } + + // Basic types: + for ( + dataType <- DataTypeTestUtils.atomicTypes; + nullable <- Seq(true, false) + if !dataType.isInstanceOf[DecimalType] || + dataType.asInstanceOf[DecimalType].precisionInfo.isEmpty + ) { + test(s"$dataType (nullable=$nullable)") { + testRandomDataGeneration(dataType) + } + } + + for ( + arrayType <- DataTypeTestUtils.atomicArrayTypes + if RandomDataGenerator.forType(arrayType.elementType, arrayType.containsNull).isDefined + ) { + test(s"$arrayType") { + testRandomDataGeneration(arrayType) + } + } + + val atomicTypesWithDataGenerators = + DataTypeTestUtils.atomicTypes.filter(RandomDataGenerator.forType(_).isDefined) + + // Complex types: + for ( + keyType <- atomicTypesWithDataGenerators; + valueType <- atomicTypesWithDataGenerators + // Scala's BigDecimal.hashCode can lead to OutOfMemoryError on Scala 2.10 (see SI-6173) and + // Spark can hit NumberFormatException errors when converting certain BigDecimals (SPARK-8802). + // For these reasons, we don't support generation of maps with decimal keys. + if !keyType.isInstanceOf[DecimalType] + ) { + val mapType = MapType(keyType, valueType) + test(s"$mapType") { + testRandomDataGeneration(mapType) + } + } + + for ( + colOneType <- atomicTypesWithDataGenerators; + colTwoType <- atomicTypesWithDataGenerators + ) { + val structType = StructType(StructField("a", colOneType) :: StructField("b", colTwoType) :: Nil) + test(s"$structType") { + testRandomDataGeneration(structType) + } + } + +} diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/CatalystTypeConvertersSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/CatalystTypeConvertersSuite.scala new file mode 100644 index 000000000000..df0f04563edc --- /dev/null +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/CatalystTypeConvertersSuite.scala @@ -0,0 +1,62 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst + +import org.apache.spark.SparkFunSuite +import org.apache.spark.sql.Row +import org.apache.spark.sql.types._ + +class CatalystTypeConvertersSuite extends SparkFunSuite { + + private val simpleTypes: Seq[DataType] = Seq( + StringType, + DateType, + BooleanType, + ByteType, + ShortType, + IntegerType, + LongType, + FloatType, + DoubleType) + + test("null handling in rows") { + val schema = StructType(simpleTypes.map(t => StructField(t.getClass.getName, t))) + val convertToCatalyst = CatalystTypeConverters.createToCatalystConverter(schema) + val convertToScala = CatalystTypeConverters.createToScalaConverter(schema) + + val scalaRow = Row.fromSeq(Seq.fill(simpleTypes.length)(null)) + assert(convertToScala(convertToCatalyst(scalaRow)) === scalaRow) + } + + test("null handling for individual values") { + for (dataType <- simpleTypes) { + assert(CatalystTypeConverters.createToScalaConverter(dataType)(null) === null) + } + } + + test("option handling in convertToCatalyst") { + // convertToCatalyst doesn't handle unboxing from Options. This is inconsistent with + // createToCatalystConverter but it may not actually matter as this is only called internally + // in a handful of places where we don't expect to receive Options. + assert(CatalystTypeConverters.convertToCatalyst(Some(123)) === Some(123)) + } + + test("option handling in createToCatalystConverter") { + assert(CatalystTypeConverters.createToCatalystConverter(IntegerType)(Some(123)) === 123) + } +} diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/DistributionSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/DistributionSuite.scala index ea82cd2622de..c046dbf4dc2c 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/DistributionSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/DistributionSuite.scala @@ -17,14 +17,13 @@ package org.apache.spark.sql.catalyst -import org.scalatest.FunSuite - +import org.apache.spark.SparkFunSuite import org.apache.spark.sql.catalyst.plans.physical._ /* Implicit conversions */ import org.apache.spark.sql.catalyst.dsl.expressions._ -class DistributionSuite extends FunSuite { +class DistributionSuite extends SparkFunSuite { protected def checkSatisfied( inputPartitioning: Partitioning, diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/ScalaReflectionSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/ScalaReflectionSuite.scala index bbc0b661a0c0..b4b00f558463 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/ScalaReflectionSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/ScalaReflectionSuite.scala @@ -20,9 +20,7 @@ package org.apache.spark.sql.catalyst import java.math.BigInteger import java.sql.{Date, Timestamp} -import org.scalatest.FunSuite - -import org.apache.spark.sql.catalyst.expressions.Row +import org.apache.spark.SparkFunSuite import org.apache.spark.sql.types._ case class PrimitiveData( @@ -75,8 +73,8 @@ case class MultipleConstructorsData(a: Int, b: String, c: Double) { def this(b: String, a: Int) = this(a, b, c = 1.0) } -class ScalaReflectionSuite extends FunSuite { - import ScalaReflection._ +class ScalaReflectionSuite extends SparkFunSuite { + import org.apache.spark.sql.catalyst.ScalaReflection._ test("primitive data") { val schema = schemaFor[PrimitiveData] @@ -253,14 +251,14 @@ class ScalaReflectionSuite extends FunSuite { } assert(ArrayType(IntegerType) === typeOfObject3(Seq(1, 2, 3))) - assert(ArrayType(ArrayType(IntegerType)) === typeOfObject3(Seq(Seq(1,2,3)))) + assert(ArrayType(ArrayType(IntegerType)) === typeOfObject3(Seq(Seq(1, 2, 3)))) } test("convert PrimitiveData to catalyst") { val data = PrimitiveData(1, 1, 1, 1, 1, 1, true) - val convertedData = Row(1, 1.toLong, 1.toDouble, 1.toFloat, 1.toShort, 1.toByte, true) + val convertedData = InternalRow(1, 1.toLong, 1.toDouble, 1.toFloat, 1.toShort, 1.toByte, true) val dataType = schemaFor[PrimitiveData].dataType - assert(CatalystTypeConverters.convertToCatalyst(data, dataType) === convertedData) + assert(CatalystTypeConverters.createToCatalystConverter(dataType)(data) === convertedData) } test("convert Option[Product] to catalyst") { @@ -268,9 +266,9 @@ class ScalaReflectionSuite extends FunSuite { val data = OptionalData(Some(2), Some(2), Some(2), Some(2), Some(2), Some(2), Some(true), Some(primitiveData)) val dataType = schemaFor[OptionalData].dataType - val convertedData = Row(2, 2.toLong, 2.toDouble, 2.toFloat, 2.toShort, 2.toByte, true, - Row(1, 1, 1, 1, 1, 1, true)) - assert(CatalystTypeConverters.convertToCatalyst(data, dataType) === convertedData) + val convertedData = InternalRow(2, 2.toLong, 2.toDouble, 2.toFloat, 2.toShort, 2.toByte, true, + InternalRow(1, 1, 1, 1, 1, 1, true)) + assert(CatalystTypeConverters.createToCatalystConverter(dataType)(data) === convertedData) } test("infer schema from case class with multiple constructors") { diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/SqlParserSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/SqlParserSuite.scala index 890ea2a84b82..b93a3abc6ebd 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/SqlParserSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/SqlParserSuite.scala @@ -17,10 +17,10 @@ package org.apache.spark.sql.catalyst +import org.apache.spark.SparkFunSuite import org.apache.spark.sql.catalyst.expressions.Attribute import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan import org.apache.spark.sql.catalyst.plans.logical.Command -import org.scalatest.FunSuite private[sql] case class TestCommand(cmd: String) extends LogicalPlan with Command { override def output: Seq[Attribute] = Seq.empty @@ -28,7 +28,7 @@ private[sql] case class TestCommand(cmd: String) extends LogicalPlan with Comman } private[sql] class SuperLongKeywordTestParser extends AbstractSparkSQLParser { - protected val EXECUTE = Keyword("THISISASUPERLONGKEYWORDTEST") + protected val EXECUTE = Keyword("THISISASUPERLONGKEYWORDTEST") override protected lazy val start: Parser[LogicalPlan] = set @@ -39,7 +39,7 @@ private[sql] class SuperLongKeywordTestParser extends AbstractSparkSQLParser { } private[sql] class CaseInsensitiveTestParser extends AbstractSparkSQLParser { - protected val EXECUTE = Keyword("EXECUTE") + protected val EXECUTE = Keyword("EXECUTE") override protected lazy val start: Parser[LogicalPlan] = set @@ -49,7 +49,7 @@ private[sql] class CaseInsensitiveTestParser extends AbstractSparkSQLParser { } } -class SqlParserSuite extends FunSuite { +class SqlParserSuite extends SparkFunSuite { test("test long keyword") { val parser = new SuperLongKeywordTestParser diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala index e1d6ac462fbc..77ca080f366c 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala @@ -17,8 +17,9 @@ package org.apache.spark.sql.catalyst.analysis -import org.scalatest.{BeforeAndAfter, FunSuite} +import org.scalatest.BeforeAndAfter +import org.apache.spark.SparkFunSuite import org.apache.spark.sql.AnalysisException import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.plans.logical._ @@ -27,7 +28,7 @@ import org.apache.spark.sql.catalyst.SimpleCatalystConf import org.apache.spark.sql.catalyst.dsl.expressions._ import org.apache.spark.sql.catalyst.dsl.plans._ -class AnalysisSuite extends FunSuite with BeforeAndAfter { +class AnalysisSuite extends SparkFunSuite with BeforeAndAfter { val caseSensitiveConf = new SimpleCatalystConf(true) val caseInsensitiveConf = new SimpleCatalystConf(false) @@ -155,7 +156,7 @@ class AnalysisSuite extends FunSuite with BeforeAndAfter { caseSensitive: Boolean = true): Unit = { test(name) { val error = intercept[AnalysisException] { - if(caseSensitive) { + if (caseSensitive) { caseSensitiveAnalyze(plan) } else { caseInsensitiveAnalyze(plan) @@ -166,6 +167,19 @@ class AnalysisSuite extends FunSuite with BeforeAndAfter { } } + errorTest( + "unresolved window function", + testRelation2.select( + WindowExpression( + UnresolvedWindowFunction( + "lead", + UnresolvedAttribute("c") :: Nil), + WindowSpecDefinition( + UnresolvedAttribute("a") :: Nil, + SortOrder(UnresolvedAttribute("b"), Ascending) :: Nil, + UnspecifiedFrame)).as('window)), + "lead" :: "window functions currently requires a HiveContext" :: Nil) + errorTest( "too many generators", listRelation.select(Explode('list).as('a), Explode('list).as('b)), @@ -179,7 +193,7 @@ class AnalysisSuite extends FunSuite with BeforeAndAfter { errorTest( "bad casts", testRelation.select(Literal(1).cast(BinaryType).as('badCast)), - "invalid cast" :: Literal(1).dataType.simpleString :: BinaryType.simpleString :: Nil) + "cannot cast" :: Literal(1).dataType.simpleString :: BinaryType.simpleString :: Nil) errorTest( "non-boolean filters", @@ -250,9 +264,9 @@ class AnalysisSuite extends FunSuite with BeforeAndAfter { val plan = Aggregate( Nil, - Alias(Sum(AttributeReference("a", StringType)(exprId = ExprId(1))), "b")() :: Nil, + Alias(Sum(AttributeReference("a", IntegerType)(exprId = ExprId(1))), "b")() :: Nil, LocalRelation( - AttributeReference("a", StringType)(exprId = ExprId(2)))) + AttributeReference("a", IntegerType)(exprId = ExprId(2)))) assert(plan.resolved) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/DecimalPrecisionSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/DecimalPrecisionSuite.scala index 565b1cfe019c..7bac97b7894f 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/DecimalPrecisionSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/DecimalPrecisionSuite.scala @@ -17,14 +17,15 @@ package org.apache.spark.sql.catalyst.analysis -import org.scalatest.{BeforeAndAfter, FunSuite} +import org.scalatest.BeforeAndAfter +import org.apache.spark.SparkFunSuite import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.plans.logical.{Union, Project, LocalRelation} import org.apache.spark.sql.types._ import org.apache.spark.sql.catalyst.SimpleCatalystConf -class DecimalPrecisionSuite extends FunSuite with BeforeAndAfter { +class DecimalPrecisionSuite extends SparkFunSuite with BeforeAndAfter { val conf = new SimpleCatalystConf(true) val catalog = new SimpleCatalog(conf) val analyzer = new Analyzer(catalog, EmptyFunctionRegistry, conf) @@ -91,8 +92,10 @@ class DecimalPrecisionSuite extends FunSuite with BeforeAndAfter { } test("Comparison operations") { - checkComparison(LessThan(i, d1), DecimalType.Unlimited) - checkComparison(LessThanOrEqual(d1, d2), DecimalType.Unlimited) + checkComparison(EqualTo(i, d1), DecimalType(10, 1)) + checkComparison(EqualNullSafe(d2, d1), DecimalType(5, 2)) + checkComparison(LessThan(i, d1), DecimalType(10, 1)) + checkComparison(LessThanOrEqual(d1, d2), DecimalType(5, 2)) checkComparison(GreaterThan(d2, u), DecimalType.Unlimited) checkComparison(GreaterThanOrEqual(d1, f), DoubleType) checkComparison(GreaterThan(d2, d2), DecimalType(5, 2)) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/ExpressionTypeCheckingSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/ExpressionTypeCheckingSuite.scala new file mode 100644 index 000000000000..8e0551b23eea --- /dev/null +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/ExpressionTypeCheckingSuite.scala @@ -0,0 +1,174 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.analysis + +import org.apache.spark.SparkFunSuite +import org.apache.spark.sql.AnalysisException +import org.apache.spark.sql.catalyst.dsl.expressions._ +import org.apache.spark.sql.catalyst.dsl.plans._ +import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.catalyst.plans.logical.LocalRelation +import org.apache.spark.sql.types.StringType + +class ExpressionTypeCheckingSuite extends SparkFunSuite { + + val testRelation = LocalRelation( + 'intField.int, + 'stringField.string, + 'booleanField.boolean, + 'complexField.array(StringType)) + + def assertError(expr: Expression, errorMessage: String): Unit = { + val e = intercept[AnalysisException] { + assertSuccess(expr) + } + assert(e.getMessage.contains( + s"cannot resolve '${expr.prettyString}' due to data type mismatch:")) + assert(e.getMessage.contains(errorMessage)) + } + + def assertSuccess(expr: Expression): Unit = { + val analyzed = testRelation.select(expr.as("c")).analyze + SimpleAnalyzer.checkAnalysis(analyzed) + } + + def assertErrorForDifferingTypes(expr: Expression): Unit = { + assertError(expr, + s"differing types in ${expr.getClass.getSimpleName} (IntegerType and BooleanType).") + } + + test("check types for unary arithmetic") { + assertError(UnaryMinus('stringField), "operator - accepts numeric type") + assertError(Abs('stringField), "function abs accepts numeric type") + assertError(BitwiseNot('stringField), "operator ~ accepts integral type") + } + + test("check types for binary arithmetic") { + // We will cast String to Double for binary arithmetic + assertSuccess(Add('intField, 'stringField)) + assertSuccess(Subtract('intField, 'stringField)) + assertSuccess(Multiply('intField, 'stringField)) + assertSuccess(Divide('intField, 'stringField)) + assertSuccess(Remainder('intField, 'stringField)) + // checkAnalysis(BitwiseAnd('intField, 'stringField)) + + assertErrorForDifferingTypes(Add('intField, 'booleanField)) + assertErrorForDifferingTypes(Subtract('intField, 'booleanField)) + assertErrorForDifferingTypes(Multiply('intField, 'booleanField)) + assertErrorForDifferingTypes(Divide('intField, 'booleanField)) + assertErrorForDifferingTypes(Remainder('intField, 'booleanField)) + assertErrorForDifferingTypes(BitwiseAnd('intField, 'booleanField)) + assertErrorForDifferingTypes(BitwiseOr('intField, 'booleanField)) + assertErrorForDifferingTypes(BitwiseXor('intField, 'booleanField)) + assertErrorForDifferingTypes(MaxOf('intField, 'booleanField)) + assertErrorForDifferingTypes(MinOf('intField, 'booleanField)) + + assertError(Add('booleanField, 'booleanField), "operator + accepts numeric type") + assertError(Subtract('booleanField, 'booleanField), "operator - accepts numeric type") + assertError(Multiply('booleanField, 'booleanField), "operator * accepts numeric type") + assertError(Divide('booleanField, 'booleanField), "operator / accepts numeric type") + assertError(Remainder('booleanField, 'booleanField), "operator % accepts numeric type") + + assertError(BitwiseAnd('booleanField, 'booleanField), "operator & accepts integral type") + assertError(BitwiseOr('booleanField, 'booleanField), "operator | accepts integral type") + assertError(BitwiseXor('booleanField, 'booleanField), "operator ^ accepts integral type") + + assertError(MaxOf('complexField, 'complexField), "function maxOf accepts non-complex type") + assertError(MinOf('complexField, 'complexField), "function minOf accepts non-complex type") + } + + test("check types for predicates") { + // We will cast String to Double for binary comparison + assertSuccess(EqualTo('intField, 'stringField)) + assertSuccess(EqualNullSafe('intField, 'stringField)) + assertSuccess(LessThan('intField, 'stringField)) + assertSuccess(LessThanOrEqual('intField, 'stringField)) + assertSuccess(GreaterThan('intField, 'stringField)) + assertSuccess(GreaterThanOrEqual('intField, 'stringField)) + + // We will transform EqualTo with numeric and boolean types to CaseKeyWhen + assertSuccess(EqualTo('intField, 'booleanField)) + assertSuccess(EqualNullSafe('intField, 'booleanField)) + + assertError(EqualTo('intField, 'complexField), "differing types") + assertError(EqualNullSafe('intField, 'complexField), "differing types") + + assertErrorForDifferingTypes(LessThan('intField, 'booleanField)) + assertErrorForDifferingTypes(LessThanOrEqual('intField, 'booleanField)) + assertErrorForDifferingTypes(GreaterThan('intField, 'booleanField)) + assertErrorForDifferingTypes(GreaterThanOrEqual('intField, 'booleanField)) + + assertError( + LessThan('complexField, 'complexField), "operator < accepts non-complex type") + assertError( + LessThanOrEqual('complexField, 'complexField), "operator <= accepts non-complex type") + assertError( + GreaterThan('complexField, 'complexField), "operator > accepts non-complex type") + assertError( + GreaterThanOrEqual('complexField, 'complexField), "operator >= accepts non-complex type") + + assertError( + If('intField, 'stringField, 'stringField), + "type of predicate expression in If should be boolean") + assertErrorForDifferingTypes(If('booleanField, 'intField, 'booleanField)) + + assertError( + CaseWhen(Seq('booleanField, 'intField, 'booleanField, 'complexField)), + "THEN and ELSE expressions should all be same type or coercible to a common type") + assertError( + CaseKeyWhen('intField, Seq('intField, 'stringField, 'intField, 'complexField)), + "THEN and ELSE expressions should all be same type or coercible to a common type") + assertError( + CaseWhen(Seq('booleanField, 'intField, 'intField, 'intField)), + "WHEN expressions in CaseWhen should all be boolean type") + } + + test("check types for aggregates") { + // We will cast String to Double for sum and average + assertSuccess(Sum('stringField)) + assertSuccess(SumDistinct('stringField)) + assertSuccess(Average('stringField)) + + assertError(Min('complexField), "function min accepts non-complex type") + assertError(Max('complexField), "function max accepts non-complex type") + assertError(Sum('booleanField), "function sum accepts numeric type") + assertError(SumDistinct('booleanField), "function sumDistinct accepts numeric type") + assertError(Average('booleanField), "function average accepts numeric type") + } + + test("check types for others") { + assertError(CreateArray(Seq('intField, 'booleanField)), + "input to function array should all be the same type") + assertError(Coalesce(Seq('intField, 'booleanField)), + "input to function coalesce should all be the same type") + assertError(Coalesce(Nil), "input to function coalesce cannot be empty") + assertError(Explode('intField), + "input to function explode should be array or map type") + } + + test("check types for CreateNamedStruct") { + assertError( + CreateNamedStruct(Seq("a", "b", 2.0)), "even number of arguments") + assertError( + CreateNamedStruct(Seq(1, "a", "b", 2.0)), + "Odd position only allow foldable and not-null StringType expressions") + assertError( + CreateNamedStruct(Seq('a.string.at(0), "a", "b", 2.0)), + "Odd position only allow foldable and not-null StringType expressions") + } +} diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercionSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercionSuite.scala index fcd745f43cfb..b56426617789 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercionSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercionSuite.scala @@ -20,18 +20,92 @@ package org.apache.spark.sql.catalyst.analysis import org.apache.spark.sql.catalyst.plans.PlanTest import org.apache.spark.sql.catalyst.expressions._ -import org.apache.spark.sql.catalyst.plans.logical.{LocalRelation, Project} +import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, LocalRelation, Project} +import org.apache.spark.sql.catalyst.rules.Rule import org.apache.spark.sql.types._ class HiveTypeCoercionSuite extends PlanTest { + test("eligible implicit type cast") { + def shouldCast(from: DataType, to: AbstractDataType, expected: DataType): Unit = { + val got = HiveTypeCoercion.ImplicitTypeCasts.implicitCast(Literal.create(null, from), to) + assert(got.map(_.dataType) == Option(expected), + s"Failed to cast $from to $to") + } + + shouldCast(NullType, NullType, NullType) + shouldCast(NullType, IntegerType, IntegerType) + shouldCast(NullType, DecimalType, DecimalType.Unlimited) + + // TODO: write the entire implicit cast table out for test cases. + shouldCast(ByteType, IntegerType, IntegerType) + shouldCast(IntegerType, IntegerType, IntegerType) + shouldCast(IntegerType, LongType, LongType) + shouldCast(IntegerType, DecimalType, DecimalType.Unlimited) + shouldCast(LongType, IntegerType, IntegerType) + shouldCast(LongType, DecimalType, DecimalType.Unlimited) + + shouldCast(DateType, TimestampType, TimestampType) + shouldCast(TimestampType, DateType, DateType) + + shouldCast(StringType, IntegerType, IntegerType) + shouldCast(StringType, DateType, DateType) + shouldCast(StringType, TimestampType, TimestampType) + shouldCast(IntegerType, StringType, StringType) + shouldCast(DateType, StringType, StringType) + shouldCast(TimestampType, StringType, StringType) + + shouldCast(StringType, BinaryType, BinaryType) + shouldCast(BinaryType, StringType, StringType) + + shouldCast(NullType, TypeCollection(StringType, BinaryType), StringType) + + shouldCast(StringType, TypeCollection(StringType, BinaryType), StringType) + shouldCast(BinaryType, TypeCollection(StringType, BinaryType), BinaryType) + shouldCast(StringType, TypeCollection(BinaryType, StringType), StringType) + + shouldCast(IntegerType, TypeCollection(IntegerType, BinaryType), IntegerType) + shouldCast(IntegerType, TypeCollection(BinaryType, IntegerType), IntegerType) + shouldCast(BinaryType, TypeCollection(BinaryType, IntegerType), BinaryType) + shouldCast(BinaryType, TypeCollection(IntegerType, BinaryType), BinaryType) + + shouldCast(IntegerType, TypeCollection(StringType, BinaryType), StringType) + shouldCast(IntegerType, TypeCollection(BinaryType, StringType), StringType) + + shouldCast( + DecimalType.Unlimited, TypeCollection(IntegerType, DecimalType), DecimalType.Unlimited) + shouldCast(DecimalType(10, 2), TypeCollection(IntegerType, DecimalType), DecimalType(10, 2)) + shouldCast(DecimalType(10, 2), TypeCollection(DecimalType, IntegerType), DecimalType(10, 2)) + shouldCast(IntegerType, TypeCollection(DecimalType(10, 2), StringType), DecimalType(10, 2)) + } + + test("ineligible implicit type cast") { + def shouldNotCast(from: DataType, to: AbstractDataType): Unit = { + val got = HiveTypeCoercion.ImplicitTypeCasts.implicitCast(Literal.create(null, from), to) + assert(got.isEmpty, s"Should not be able to cast $from to $to, but got $got") + } + + shouldNotCast(IntegerType, DateType) + shouldNotCast(IntegerType, TimestampType) + shouldNotCast(LongType, DateType) + shouldNotCast(LongType, TimestampType) + shouldNotCast(DecimalType.Unlimited, DateType) + shouldNotCast(DecimalType.Unlimited, TimestampType) + + shouldNotCast(IntegerType, TypeCollection(DateType, TimestampType)) + + shouldNotCast(IntegerType, ArrayType) + shouldNotCast(IntegerType, MapType) + shouldNotCast(IntegerType, StructType) + } + test("tightest common bound for types") { def widenTest(t1: DataType, t2: DataType, tightestCommon: Option[DataType]) { - var found = HiveTypeCoercion.findTightestCommonType(t1, t2) + var found = HiveTypeCoercion.findTightestCommonTypeOfTwo(t1, t2) assert(found == tightestCommon, s"Expected $tightestCommon as tightest common type for $t1 and $t2, found $found") // Test both directions to make sure the widening is symmetric. - found = HiveTypeCoercion.findTightestCommonType(t2, t1) + found = HiveTypeCoercion.findTightestCommonTypeOfTwo(t2, t1) assert(found == tightestCommon, s"Expected $tightestCommon as tightest common type for $t2 and $t1, found $found") } @@ -104,31 +178,15 @@ class HiveTypeCoercionSuite extends PlanTest { widenTest(ArrayType(IntegerType), StructType(Seq()), None) } - test("boolean casts") { - val booleanCasts = new HiveTypeCoercion { }.BooleanCasts - def ruleTest(initial: Expression, transformed: Expression) { - val testRelation = LocalRelation(AttributeReference("a", IntegerType)()) - comparePlans( - booleanCasts(Project(Seq(Alias(initial, "a")()), testRelation)), - Project(Seq(Alias(transformed, "a")()), testRelation)) - } - // Remove superflous boolean -> boolean casts. - ruleTest(Cast(Literal(true), BooleanType), Literal(true)) - // Stringify boolean when casting to string. - ruleTest( - Cast(Literal(false), StringType), - If(Literal(false), Literal("true"), Literal("false"))) + private def ruleTest(rule: Rule[LogicalPlan], initial: Expression, transformed: Expression) { + val testRelation = LocalRelation(AttributeReference("a", IntegerType)()) + comparePlans( + rule(Project(Seq(Alias(initial, "a")()), testRelation)), + Project(Seq(Alias(transformed, "a")()), testRelation)) } test("coalesce casts") { - val fac = new HiveTypeCoercion { }.FunctionArgumentConversion - def ruleTest(initial: Expression, transformed: Expression) { - val testRelation = LocalRelation(AttributeReference("a", IntegerType)()) - comparePlans( - fac(Project(Seq(Alias(initial, "a")()), testRelation)), - Project(Seq(Alias(transformed, "a")()), testRelation)) - } - ruleTest( + ruleTest(HiveTypeCoercion.FunctionArgumentConversion, Coalesce(Literal(1.0) :: Literal(1) :: Literal.create(1.0, FloatType) @@ -137,7 +195,7 @@ class HiveTypeCoercionSuite extends PlanTest { :: Cast(Literal(1), DoubleType) :: Cast(Literal.create(1.0, FloatType), DoubleType) :: Nil)) - ruleTest( + ruleTest(HiveTypeCoercion.FunctionArgumentConversion, Coalesce(Literal(1L) :: Literal(1) :: Literal(new java.math.BigDecimal("1000000000000000000000")) @@ -147,4 +205,70 @@ class HiveTypeCoercionSuite extends PlanTest { :: Cast(Literal(new java.math.BigDecimal("1000000000000000000000")), DecimalType()) :: Nil)) } + + test("type coercion for If") { + val rule = HiveTypeCoercion.IfCoercion + ruleTest(rule, + If(Literal(true), Literal(1), Literal(1L)), + If(Literal(true), Cast(Literal(1), LongType), Literal(1L)) + ) + + ruleTest(rule, + If(Literal.create(null, NullType), Literal(1), Literal(1)), + If(Literal.create(null, BooleanType), Literal(1), Literal(1)) + ) + } + + test("type coercion for CaseKeyWhen") { + ruleTest(HiveTypeCoercion.CaseWhenCoercion, + CaseKeyWhen(Literal(1.toShort), Seq(Literal(1), Literal("a"))), + CaseKeyWhen(Cast(Literal(1.toShort), IntegerType), Seq(Literal(1), Literal("a"))) + ) + ruleTest(HiveTypeCoercion.CaseWhenCoercion, + CaseKeyWhen(Literal(true), Seq(Literal(1), Literal("a"))), + CaseKeyWhen(Literal(true), Seq(Literal(1), Literal("a"))) + ) + } + + test("type coercion simplification for equal to") { + val be = HiveTypeCoercion.BooleanEquality + + ruleTest(be, + EqualTo(Literal(true), Literal(1)), + Literal(true) + ) + ruleTest(be, + EqualTo(Literal(true), Literal(0)), + Not(Literal(true)) + ) + ruleTest(be, + EqualNullSafe(Literal(true), Literal(1)), + And(IsNotNull(Literal(true)), Literal(true)) + ) + ruleTest(be, + EqualNullSafe(Literal(true), Literal(0)), + And(IsNotNull(Literal(true)), Not(Literal(true))) + ) + + ruleTest(be, + EqualTo(Literal(true), Literal(1L)), + Literal(true) + ) + ruleTest(be, + EqualTo(Literal(new java.math.BigDecimal(1)), Literal(true)), + Literal(true) + ) + ruleTest(be, + EqualTo(Literal(BigDecimal(0)), Literal(true)), + Not(Literal(true)) + ) + ruleTest(be, + EqualTo(Literal(Decimal(1)), Literal(true)), + Literal(true) + ) + ruleTest(be, + EqualTo(Literal.create(Decimal(1), DecimalType(8, 0)), Literal(true)), + Literal(true) + ) + } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ArithmeticExpressionSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ArithmeticExpressionSuite.scala new file mode 100644 index 000000000000..6c93698f8017 --- /dev/null +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ArithmeticExpressionSuite.scala @@ -0,0 +1,161 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.expressions + +import org.apache.spark.SparkFunSuite +import org.apache.spark.sql.catalyst.dsl.expressions._ +import org.apache.spark.sql.types.Decimal + + +class ArithmeticExpressionSuite extends SparkFunSuite with ExpressionEvalHelper { + + /** + * Runs through the testFunc for all numeric data types. + * + * @param testFunc a test function that accepts a conversion function to convert an integer + * into another data type. + */ + private def testNumericDataTypes(testFunc: (Int => Any) => Unit): Unit = { + testFunc(_.toByte) + testFunc(_.toShort) + testFunc(identity) + testFunc(_.toLong) + testFunc(_.toFloat) + testFunc(_.toDouble) + testFunc(Decimal(_)) + } + + test("+ (Add)") { + testNumericDataTypes { convert => + val left = Literal(convert(1)) + val right = Literal(convert(2)) + checkEvaluation(Add(left, right), convert(3)) + checkEvaluation(Add(Literal.create(null, left.dataType), right), null) + checkEvaluation(Add(left, Literal.create(null, right.dataType)), null) + } + } + + test("- (UnaryMinus)") { + testNumericDataTypes { convert => + val input = Literal(convert(1)) + val dataType = input.dataType + checkEvaluation(UnaryMinus(input), convert(-1)) + checkEvaluation(UnaryMinus(Literal.create(null, dataType)), null) + } + } + + test("- (Minus)") { + testNumericDataTypes { convert => + val left = Literal(convert(1)) + val right = Literal(convert(2)) + checkEvaluation(Subtract(left, right), convert(-1)) + checkEvaluation(Subtract(Literal.create(null, left.dataType), right), null) + checkEvaluation(Subtract(left, Literal.create(null, right.dataType)), null) + } + } + + test("* (Multiply)") { + testNumericDataTypes { convert => + val left = Literal(convert(1)) + val right = Literal(convert(2)) + checkEvaluation(Multiply(left, right), convert(2)) + checkEvaluation(Multiply(Literal.create(null, left.dataType), right), null) + checkEvaluation(Multiply(left, Literal.create(null, right.dataType)), null) + } + } + + test("/ (Divide) basic") { + testNumericDataTypes { convert => + val left = Literal(convert(2)) + val right = Literal(convert(1)) + val dataType = left.dataType + checkEvaluation(Divide(left, right), convert(2)) + checkEvaluation(Divide(Literal.create(null, dataType), right), null) + checkEvaluation(Divide(left, Literal.create(null, right.dataType)), null) + checkEvaluation(Divide(left, Literal(convert(0))), null) // divide by zero + } + } + + test("/ (Divide) for integral type") { + checkEvaluation(Divide(Literal(1.toByte), Literal(2.toByte)), 0.toByte) + checkEvaluation(Divide(Literal(1.toShort), Literal(2.toShort)), 0.toShort) + checkEvaluation(Divide(Literal(1), Literal(2)), 0) + checkEvaluation(Divide(Literal(1.toLong), Literal(2.toLong)), 0.toLong) + } + + test("/ (Divide) for floating point") { + checkEvaluation(Divide(Literal(1.0f), Literal(2.0f)), 0.5f) + checkEvaluation(Divide(Literal(1.0), Literal(2.0)), 0.5) + checkEvaluation(Divide(Literal(Decimal(1.0)), Literal(Decimal(2.0))), Decimal(0.5)) + } + + test("% (Remainder)") { + testNumericDataTypes { convert => + val left = Literal(convert(1)) + val right = Literal(convert(2)) + checkEvaluation(Remainder(left, right), convert(1)) + checkEvaluation(Remainder(Literal.create(null, left.dataType), right), null) + checkEvaluation(Remainder(left, Literal.create(null, right.dataType)), null) + checkEvaluation(Remainder(left, Literal(convert(0))), null) // mod by 0 + } + } + + test("Abs") { + testNumericDataTypes { convert => + checkEvaluation(Abs(Literal(convert(0))), convert(0)) + checkEvaluation(Abs(Literal(convert(1))), convert(1)) + checkEvaluation(Abs(Literal(convert(-1))), convert(1)) + } + } + + test("MaxOf basic") { + testNumericDataTypes { convert => + val small = Literal(convert(1)) + val large = Literal(convert(2)) + checkEvaluation(MaxOf(small, large), convert(2)) + checkEvaluation(MaxOf(large, small), convert(2)) + checkEvaluation(MaxOf(Literal.create(null, small.dataType), large), convert(2)) + checkEvaluation(MaxOf(large, Literal.create(null, small.dataType)), convert(2)) + } + } + + test("MaxOf for atomic type") { + checkEvaluation(MaxOf(true, false), true) + checkEvaluation(MaxOf("abc", "bcd"), "bcd") + checkEvaluation(MaxOf(Array(1.toByte, 2.toByte), Array(1.toByte, 3.toByte)), + Array(1.toByte, 3.toByte)) + } + + test("MinOf basic") { + testNumericDataTypes { convert => + val small = Literal(convert(1)) + val large = Literal(convert(2)) + checkEvaluation(MinOf(small, large), convert(1)) + checkEvaluation(MinOf(large, small), convert(1)) + checkEvaluation(MinOf(Literal.create(null, small.dataType), large), convert(2)) + checkEvaluation(MinOf(small, Literal.create(null, small.dataType)), convert(1)) + } + } + + test("MinOf for atomic type") { + checkEvaluation(MinOf(true, false), false) + checkEvaluation(MinOf("abc", "bcd"), "abc") + checkEvaluation(MinOf(Array(1.toByte, 2.toByte), Array(1.toByte, 3.toByte)), + Array(1.toByte, 2.toByte)) + } +} diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/AttributeSetSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/AttributeSetSuite.scala index f2f3a84d1938..97cfb5f06dd7 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/AttributeSetSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/AttributeSetSuite.scala @@ -17,11 +17,10 @@ package org.apache.spark.sql.catalyst.expressions -import org.scalatest.FunSuite - +import org.apache.spark.SparkFunSuite import org.apache.spark.sql.types.IntegerType -class AttributeSetSuite extends FunSuite { +class AttributeSetSuite extends SparkFunSuite { val aUpper = AttributeReference("A", IntegerType)(exprId = ExprId(1)) val aLower = AttributeReference("a", IntegerType)(exprId = ExprId(1)) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/BitwiseFunctionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/BitwiseFunctionsSuite.scala new file mode 100644 index 000000000000..c9bbc7a8b8c1 --- /dev/null +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/BitwiseFunctionsSuite.scala @@ -0,0 +1,80 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.expressions + +import org.apache.spark.SparkFunSuite +import org.apache.spark.sql.catalyst.dsl.expressions._ +import org.apache.spark.sql.types._ + + +class BitwiseFunctionsSuite extends SparkFunSuite with ExpressionEvalHelper { + + test("Bitwise operations") { + val row = create_row(1, 2, 3, null) + val c1 = 'a.int.at(0) + val c2 = 'a.int.at(1) + val c3 = 'a.int.at(2) + val c4 = 'a.int.at(3) + + checkEvaluation(BitwiseAnd(c1, c4), null, row) + checkEvaluation(BitwiseAnd(c1, c2), 0, row) + checkEvaluation(BitwiseAnd(c1, Literal.create(null, IntegerType)), null, row) + checkEvaluation( + BitwiseAnd(Literal.create(null, IntegerType), Literal.create(null, IntegerType)), null, row) + + checkEvaluation(BitwiseOr(c1, c4), null, row) + checkEvaluation(BitwiseOr(c1, c2), 3, row) + checkEvaluation(BitwiseOr(c1, Literal.create(null, IntegerType)), null, row) + checkEvaluation( + BitwiseOr(Literal.create(null, IntegerType), Literal.create(null, IntegerType)), null, row) + + checkEvaluation(BitwiseXor(c1, c4), null, row) + checkEvaluation(BitwiseXor(c1, c2), 3, row) + checkEvaluation(BitwiseXor(c1, Literal.create(null, IntegerType)), null, row) + checkEvaluation( + BitwiseXor(Literal.create(null, IntegerType), Literal.create(null, IntegerType)), null, row) + + checkEvaluation(BitwiseNot(c4), null, row) + checkEvaluation(BitwiseNot(c1), -2, row) + checkEvaluation(BitwiseNot(Literal.create(null, IntegerType)), null, row) + + checkEvaluation(c1 & c2, 0, row) + checkEvaluation(c1 | c2, 3, row) + checkEvaluation(c1 ^ c2, 3, row) + checkEvaluation(~c1, -2, row) + } + + test("unary BitwiseNOT") { + checkEvaluation(BitwiseNot(1), -2) + assert(BitwiseNot(1).dataType === IntegerType) + assert(BitwiseNot(1).eval(EmptyRow).isInstanceOf[Int]) + + checkEvaluation(BitwiseNot(1.toLong), -2.toLong) + assert(BitwiseNot(1.toLong).dataType === LongType) + assert(BitwiseNot(1.toLong).eval(EmptyRow).isInstanceOf[Long]) + + checkEvaluation(BitwiseNot(1.toShort), -2.toShort) + assert(BitwiseNot(1.toShort).dataType === ShortType) + assert(BitwiseNot(1.toShort).eval(EmptyRow).isInstanceOf[Short]) + + checkEvaluation(BitwiseNot(1.toByte), -2.toByte) + assert(BitwiseNot(1.toByte).dataType === ByteType) + assert(BitwiseNot(1.toByte).eval(EmptyRow).isInstanceOf[Byte]) + } + +} diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CastSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CastSuite.scala new file mode 100644 index 000000000000..f3809be722a8 --- /dev/null +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CastSuite.scala @@ -0,0 +1,565 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.expressions + +import java.sql.{Timestamp, Date} + +import org.apache.spark.SparkFunSuite +import org.apache.spark.sql.catalyst.util.DateTimeUtils +import org.apache.spark.sql.types._ + +/** + * Test suite for data type casting expression [[Cast]]. + */ +class CastSuite extends SparkFunSuite with ExpressionEvalHelper { + + private def cast(v: Any, targetType: DataType): Cast = { + v match { + case lit: Expression => Cast(lit, targetType) + case _ => Cast(Literal(v), targetType) + } + } + + // expected cannot be null + private def checkCast(v: Any, expected: Any): Unit = { + checkEvaluation(cast(v, Literal(expected).dataType), expected) + } + + test("cast from int") { + checkCast(0, false) + checkCast(1, true) + checkCast(-5, true) + checkCast(1, 1.toByte) + checkCast(1, 1.toShort) + checkCast(1, 1) + checkCast(1, 1.toLong) + checkCast(1, 1.0f) + checkCast(1, 1.0) + checkCast(123, "123") + + checkEvaluation(cast(123, DecimalType.Unlimited), Decimal(123)) + checkEvaluation(cast(123, DecimalType(3, 0)), Decimal(123)) + checkEvaluation(cast(123, DecimalType(3, 1)), null) + checkEvaluation(cast(123, DecimalType(2, 0)), null) + } + + test("cast from long") { + checkCast(0L, false) + checkCast(1L, true) + checkCast(-5L, true) + checkCast(1L, 1.toByte) + checkCast(1L, 1.toShort) + checkCast(1L, 1) + checkCast(1L, 1.toLong) + checkCast(1L, 1.0f) + checkCast(1L, 1.0) + checkCast(123L, "123") + + checkEvaluation(cast(123L, DecimalType.Unlimited), Decimal(123)) + checkEvaluation(cast(123L, DecimalType(3, 0)), Decimal(123)) + checkEvaluation(cast(123L, DecimalType(3, 1)), Decimal(123.0)) + + // TODO: Fix the following bug and re-enable it. + // checkEvaluation(cast(123L, DecimalType(2, 0)), null) + } + + test("cast from boolean") { + checkEvaluation(cast(true, IntegerType), 1) + checkEvaluation(cast(false, IntegerType), 0) + checkEvaluation(cast(true, StringType), "true") + checkEvaluation(cast(false, StringType), "false") + checkEvaluation(cast(cast(1, BooleanType), IntegerType), 1) + checkEvaluation(cast(cast(0, BooleanType), IntegerType), 0) + } + + test("cast from int 2") { + checkEvaluation(cast(1, LongType), 1.toLong) + checkEvaluation(cast(cast(1000, TimestampType), LongType), 1.toLong) + checkEvaluation(cast(cast(-1200, TimestampType), LongType), -2.toLong) + + checkEvaluation(cast(123, DecimalType.Unlimited), Decimal(123)) + checkEvaluation(cast(123, DecimalType(3, 0)), Decimal(123)) + checkEvaluation(cast(123, DecimalType(3, 1)), null) + checkEvaluation(cast(123, DecimalType(2, 0)), null) + } + + test("cast from float") { + checkCast(0.0f, false) + checkCast(0.5f, true) + checkCast(-5.0f, true) + checkCast(1.5f, 1.toByte) + checkCast(1.5f, 1.toShort) + checkCast(1.5f, 1) + checkCast(1.5f, 1.toLong) + checkCast(1.5f, 1.5) + checkCast(1.5f, "1.5") + } + + test("cast from double") { + checkCast(0.0, false) + checkCast(0.5, true) + checkCast(-5.0, true) + checkCast(1.5, 1.toByte) + checkCast(1.5, 1.toShort) + checkCast(1.5, 1) + checkCast(1.5, 1.toLong) + checkCast(1.5, 1.5f) + checkCast(1.5, "1.5") + + checkEvaluation(cast(cast(1.toDouble, TimestampType), DoubleType), 1.toDouble) + checkEvaluation(cast(cast(1.toDouble, TimestampType), DoubleType), 1.toDouble) + } + + test("cast from string") { + assert(cast("abcdef", StringType).nullable === false) + assert(cast("abcdef", BinaryType).nullable === false) + assert(cast("abcdef", BooleanType).nullable === false) + assert(cast("abcdef", TimestampType).nullable === true) + assert(cast("abcdef", LongType).nullable === true) + assert(cast("abcdef", IntegerType).nullable === true) + assert(cast("abcdef", ShortType).nullable === true) + assert(cast("abcdef", ByteType).nullable === true) + assert(cast("abcdef", DecimalType.Unlimited).nullable === true) + assert(cast("abcdef", DecimalType(4, 2)).nullable === true) + assert(cast("abcdef", DoubleType).nullable === true) + assert(cast("abcdef", FloatType).nullable === true) + } + + test("data type casting") { + val sd = "1970-01-01" + val d = Date.valueOf(sd) + val zts = sd + " 00:00:00" + val sts = sd + " 00:00:02" + val nts = sts + ".1" + val ts = Timestamp.valueOf(nts) + + checkEvaluation(cast("abdef", StringType), "abdef") + checkEvaluation(cast("abdef", DecimalType.Unlimited), null) + checkEvaluation(cast("abdef", TimestampType), null) + checkEvaluation(cast("12.65", DecimalType.Unlimited), Decimal(12.65)) + + checkEvaluation(cast(cast(sd, DateType), StringType), sd) + checkEvaluation(cast(cast(d, StringType), DateType), 0) + checkEvaluation(cast(cast(nts, TimestampType), StringType), nts) + checkEvaluation(cast(cast(ts, StringType), TimestampType), DateTimeUtils.fromJavaTimestamp(ts)) + + // all convert to string type to check + checkEvaluation(cast(cast(cast(nts, TimestampType), DateType), StringType), sd) + checkEvaluation(cast(cast(cast(ts, DateType), TimestampType), StringType), zts) + + checkEvaluation(cast(cast("abdef", BinaryType), StringType), "abdef") + + checkEvaluation(cast(cast(cast(cast( + cast(cast("5", ByteType), ShortType), IntegerType), FloatType), DoubleType), LongType), + 5.toLong) + checkEvaluation( + cast(cast(cast(cast(cast(cast("5", ByteType), TimestampType), + DecimalType.Unlimited), LongType), StringType), ShortType), + 0.toShort) + checkEvaluation( + cast(cast(cast(cast(cast(cast("5", TimestampType), ByteType), + DecimalType.Unlimited), LongType), StringType), ShortType), + null) + checkEvaluation(cast(cast(cast(cast(cast(cast("5", DecimalType.Unlimited), + ByteType), TimestampType), LongType), StringType), ShortType), + 0.toShort) + + checkEvaluation(cast("23", DoubleType), 23d) + checkEvaluation(cast("23", IntegerType), 23) + checkEvaluation(cast("23", FloatType), 23f) + checkEvaluation(cast("23", DecimalType.Unlimited), Decimal(23)) + checkEvaluation(cast("23", ByteType), 23.toByte) + checkEvaluation(cast("23", ShortType), 23.toShort) + checkEvaluation(cast("2012-12-11", DoubleType), null) + checkEvaluation(cast(123, IntegerType), 123) + + + checkEvaluation(cast(Literal.create(null, IntegerType), ShortType), null) + } + + test("cast and add") { + checkEvaluation(Add(Literal(23d), cast(true, DoubleType)), 24d) + checkEvaluation(Add(Literal(23), cast(true, IntegerType)), 24) + checkEvaluation(Add(Literal(23f), cast(true, FloatType)), 24f) + checkEvaluation(Add(Literal(Decimal(23)), cast(true, DecimalType.Unlimited)), Decimal(24)) + checkEvaluation(Add(Literal(23.toByte), cast(true, ByteType)), 24.toByte) + checkEvaluation(Add(Literal(23.toShort), cast(true, ShortType)), 24.toShort) + } + + test("from decimal") { + checkCast(Decimal(0.0), false) + checkCast(Decimal(0.5), true) + checkCast(Decimal(-5.0), true) + checkCast(Decimal(1.5), 1.toByte) + checkCast(Decimal(1.5), 1.toShort) + checkCast(Decimal(1.5), 1) + checkCast(Decimal(1.5), 1.toLong) + checkCast(Decimal(1.5), 1.5f) + checkCast(Decimal(1.5), 1.5) + checkCast(Decimal(1.5), "1.5") + } + + test("casting to fixed-precision decimals") { + // Overflow and rounding for casting to fixed-precision decimals: + // - Values should round with HALF_UP mode by default when you lower scale + // - Values that would overflow the target precision should turn into null + // - Because of this, casts to fixed-precision decimals should be nullable + + assert(cast(123, DecimalType.Unlimited).nullable === false) + assert(cast(10.03f, DecimalType.Unlimited).nullable === true) + assert(cast(10.03, DecimalType.Unlimited).nullable === true) + assert(cast(Decimal(10.03), DecimalType.Unlimited).nullable === false) + + assert(cast(123, DecimalType(2, 1)).nullable === true) + assert(cast(10.03f, DecimalType(2, 1)).nullable === true) + assert(cast(10.03, DecimalType(2, 1)).nullable === true) + assert(cast(Decimal(10.03), DecimalType(2, 1)).nullable === true) + + + checkEvaluation(cast(10.03, DecimalType.Unlimited), Decimal(10.03)) + checkEvaluation(cast(10.03, DecimalType(4, 2)), Decimal(10.03)) + checkEvaluation(cast(10.03, DecimalType(3, 1)), Decimal(10.0)) + checkEvaluation(cast(10.03, DecimalType(2, 0)), Decimal(10)) + checkEvaluation(cast(10.03, DecimalType(1, 0)), null) + checkEvaluation(cast(10.03, DecimalType(2, 1)), null) + checkEvaluation(cast(10.03, DecimalType(3, 2)), null) + checkEvaluation(cast(Decimal(10.03), DecimalType(3, 1)), Decimal(10.0)) + checkEvaluation(cast(Decimal(10.03), DecimalType(3, 2)), null) + + checkEvaluation(cast(10.05, DecimalType.Unlimited), Decimal(10.05)) + checkEvaluation(cast(10.05, DecimalType(4, 2)), Decimal(10.05)) + checkEvaluation(cast(10.05, DecimalType(3, 1)), Decimal(10.1)) + checkEvaluation(cast(10.05, DecimalType(2, 0)), Decimal(10)) + checkEvaluation(cast(10.05, DecimalType(1, 0)), null) + checkEvaluation(cast(10.05, DecimalType(2, 1)), null) + checkEvaluation(cast(10.05, DecimalType(3, 2)), null) + checkEvaluation(cast(Decimal(10.05), DecimalType(3, 1)), Decimal(10.1)) + checkEvaluation(cast(Decimal(10.05), DecimalType(3, 2)), null) + + checkEvaluation(cast(9.95, DecimalType(3, 2)), Decimal(9.95)) + checkEvaluation(cast(9.95, DecimalType(3, 1)), Decimal(10.0)) + checkEvaluation(cast(9.95, DecimalType(2, 0)), Decimal(10)) + checkEvaluation(cast(9.95, DecimalType(2, 1)), null) + checkEvaluation(cast(9.95, DecimalType(1, 0)), null) + checkEvaluation(cast(Decimal(9.95), DecimalType(3, 1)), Decimal(10.0)) + checkEvaluation(cast(Decimal(9.95), DecimalType(1, 0)), null) + + checkEvaluation(cast(-9.95, DecimalType(3, 2)), Decimal(-9.95)) + checkEvaluation(cast(-9.95, DecimalType(3, 1)), Decimal(-10.0)) + checkEvaluation(cast(-9.95, DecimalType(2, 0)), Decimal(-10)) + checkEvaluation(cast(-9.95, DecimalType(2, 1)), null) + checkEvaluation(cast(-9.95, DecimalType(1, 0)), null) + checkEvaluation(cast(Decimal(-9.95), DecimalType(3, 1)), Decimal(-10.0)) + checkEvaluation(cast(Decimal(-9.95), DecimalType(1, 0)), null) + + checkEvaluation(cast(Double.NaN, DecimalType.Unlimited), null) + checkEvaluation(cast(1.0 / 0.0, DecimalType.Unlimited), null) + checkEvaluation(cast(Float.NaN, DecimalType.Unlimited), null) + checkEvaluation(cast(1.0f / 0.0f, DecimalType.Unlimited), null) + + checkEvaluation(cast(Double.NaN, DecimalType(2, 1)), null) + checkEvaluation(cast(1.0 / 0.0, DecimalType(2, 1)), null) + checkEvaluation(cast(Float.NaN, DecimalType(2, 1)), null) + checkEvaluation(cast(1.0f / 0.0f, DecimalType(2, 1)), null) + } + + test("cast from date") { + val d = Date.valueOf("1970-01-01") + checkEvaluation(cast(d, ShortType), null) + checkEvaluation(cast(d, IntegerType), null) + checkEvaluation(cast(d, LongType), null) + checkEvaluation(cast(d, FloatType), null) + checkEvaluation(cast(d, DoubleType), null) + checkEvaluation(cast(d, DecimalType.Unlimited), null) + checkEvaluation(cast(d, DecimalType(10, 2)), null) + checkEvaluation(cast(d, StringType), "1970-01-01") + checkEvaluation(cast(cast(d, TimestampType), StringType), "1970-01-01 00:00:00") + } + + test("cast from timestamp") { + val millis = 15 * 1000 + 2 + val seconds = millis * 1000 + 2 + val ts = new Timestamp(millis) + val tss = new Timestamp(seconds) + checkEvaluation(cast(ts, ShortType), 15.toShort) + checkEvaluation(cast(ts, IntegerType), 15) + checkEvaluation(cast(ts, LongType), 15.toLong) + checkEvaluation(cast(ts, FloatType), 15.002f) + checkEvaluation(cast(ts, DoubleType), 15.002) + checkEvaluation(cast(cast(tss, ShortType), TimestampType), DateTimeUtils.fromJavaTimestamp(ts)) + checkEvaluation(cast(cast(tss, IntegerType), TimestampType), + DateTimeUtils.fromJavaTimestamp(ts)) + checkEvaluation(cast(cast(tss, LongType), TimestampType), DateTimeUtils.fromJavaTimestamp(ts)) + checkEvaluation( + cast(cast(millis.toFloat / 1000, TimestampType), FloatType), + millis.toFloat / 1000) + checkEvaluation( + cast(cast(millis.toDouble / 1000, TimestampType), DoubleType), + millis.toDouble / 1000) + checkEvaluation( + cast(cast(Decimal(1), TimestampType), DecimalType.Unlimited), + Decimal(1)) + + // A test for higher precision than millis + checkEvaluation(cast(cast(0.0000001, TimestampType), DoubleType), 0.0000001) + + checkEvaluation(cast(Double.NaN, TimestampType), null) + checkEvaluation(cast(1.0 / 0.0, TimestampType), null) + checkEvaluation(cast(Float.NaN, TimestampType), null) + checkEvaluation(cast(1.0f / 0.0f, TimestampType), null) + } + + test("cast from array") { + val array = Literal.create(Seq("123", "abc", "", null), + ArrayType(StringType, containsNull = true)) + val array_notNull = Literal.create(Seq("123", "abc", ""), + ArrayType(StringType, containsNull = false)) + + { + val ret = cast(array, ArrayType(IntegerType, containsNull = true)) + assert(ret.resolved === true) + checkEvaluation(ret, Seq(123, null, null, null)) + } + { + val ret = cast(array, ArrayType(IntegerType, containsNull = false)) + assert(ret.resolved === false) + } + { + val ret = cast(array, ArrayType(BooleanType, containsNull = true)) + assert(ret.resolved === true) + checkEvaluation(ret, Seq(true, true, false, null)) + } + { + val ret = cast(array, ArrayType(BooleanType, containsNull = false)) + assert(ret.resolved === false) + } + + { + val ret = cast(array_notNull, ArrayType(IntegerType, containsNull = true)) + assert(ret.resolved === true) + checkEvaluation(ret, Seq(123, null, null)) + } + { + val ret = cast(array_notNull, ArrayType(IntegerType, containsNull = false)) + assert(ret.resolved === false) + } + { + val ret = cast(array_notNull, ArrayType(BooleanType, containsNull = true)) + assert(ret.resolved === true) + checkEvaluation(ret, Seq(true, true, false)) + } + { + val ret = cast(array_notNull, ArrayType(BooleanType, containsNull = false)) + assert(ret.resolved === true) + checkEvaluation(ret, Seq(true, true, false)) + } + + { + val ret = cast(array, IntegerType) + assert(ret.resolved === false) + } + } + + test("cast from map") { + val map = Literal.create( + Map("a" -> "123", "b" -> "abc", "c" -> "", "d" -> null), + MapType(StringType, StringType, valueContainsNull = true)) + val map_notNull = Literal.create( + Map("a" -> "123", "b" -> "abc", "c" -> ""), + MapType(StringType, StringType, valueContainsNull = false)) + + { + val ret = cast(map, MapType(StringType, IntegerType, valueContainsNull = true)) + assert(ret.resolved === true) + checkEvaluation(ret, Map("a" -> 123, "b" -> null, "c" -> null, "d" -> null)) + } + { + val ret = cast(map, MapType(StringType, IntegerType, valueContainsNull = false)) + assert(ret.resolved === false) + } + { + val ret = cast(map, MapType(StringType, BooleanType, valueContainsNull = true)) + assert(ret.resolved === true) + checkEvaluation(ret, Map("a" -> true, "b" -> true, "c" -> false, "d" -> null)) + } + { + val ret = cast(map, MapType(StringType, BooleanType, valueContainsNull = false)) + assert(ret.resolved === false) + } + { + val ret = cast(map, MapType(IntegerType, StringType, valueContainsNull = true)) + assert(ret.resolved === false) + } + + { + val ret = cast(map_notNull, MapType(StringType, IntegerType, valueContainsNull = true)) + assert(ret.resolved === true) + checkEvaluation(ret, Map("a" -> 123, "b" -> null, "c" -> null)) + } + { + val ret = cast(map_notNull, MapType(StringType, IntegerType, valueContainsNull = false)) + assert(ret.resolved === false) + } + { + val ret = cast(map_notNull, MapType(StringType, BooleanType, valueContainsNull = true)) + assert(ret.resolved === true) + checkEvaluation(ret, Map("a" -> true, "b" -> true, "c" -> false)) + } + { + val ret = cast(map_notNull, MapType(StringType, BooleanType, valueContainsNull = false)) + assert(ret.resolved === true) + checkEvaluation(ret, Map("a" -> true, "b" -> true, "c" -> false)) + } + { + val ret = cast(map_notNull, MapType(IntegerType, StringType, valueContainsNull = true)) + assert(ret.resolved === false) + } + + { + val ret = cast(map, IntegerType) + assert(ret.resolved === false) + } + } + + test("cast from struct") { + val struct = Literal.create( + InternalRow("123", "abc", "", null), + StructType(Seq( + StructField("a", StringType, nullable = true), + StructField("b", StringType, nullable = true), + StructField("c", StringType, nullable = true), + StructField("d", StringType, nullable = true)))) + val struct_notNull = Literal.create( + InternalRow("123", "abc", ""), + StructType(Seq( + StructField("a", StringType, nullable = false), + StructField("b", StringType, nullable = false), + StructField("c", StringType, nullable = false)))) + + { + val ret = cast(struct, StructType(Seq( + StructField("a", IntegerType, nullable = true), + StructField("b", IntegerType, nullable = true), + StructField("c", IntegerType, nullable = true), + StructField("d", IntegerType, nullable = true)))) + assert(ret.resolved === true) + checkEvaluation(ret, InternalRow(123, null, null, null)) + } + { + val ret = cast(struct, StructType(Seq( + StructField("a", IntegerType, nullable = true), + StructField("b", IntegerType, nullable = true), + StructField("c", IntegerType, nullable = false), + StructField("d", IntegerType, nullable = true)))) + assert(ret.resolved === false) + } + { + val ret = cast(struct, StructType(Seq( + StructField("a", BooleanType, nullable = true), + StructField("b", BooleanType, nullable = true), + StructField("c", BooleanType, nullable = true), + StructField("d", BooleanType, nullable = true)))) + assert(ret.resolved === true) + checkEvaluation(ret, InternalRow(true, true, false, null)) + } + { + val ret = cast(struct, StructType(Seq( + StructField("a", BooleanType, nullable = true), + StructField("b", BooleanType, nullable = true), + StructField("c", BooleanType, nullable = false), + StructField("d", BooleanType, nullable = true)))) + assert(ret.resolved === false) + } + + { + val ret = cast(struct_notNull, StructType(Seq( + StructField("a", IntegerType, nullable = true), + StructField("b", IntegerType, nullable = true), + StructField("c", IntegerType, nullable = true)))) + assert(ret.resolved === true) + checkEvaluation(ret, InternalRow(123, null, null)) + } + { + val ret = cast(struct_notNull, StructType(Seq( + StructField("a", IntegerType, nullable = true), + StructField("b", IntegerType, nullable = true), + StructField("c", IntegerType, nullable = false)))) + assert(ret.resolved === false) + } + { + val ret = cast(struct_notNull, StructType(Seq( + StructField("a", BooleanType, nullable = true), + StructField("b", BooleanType, nullable = true), + StructField("c", BooleanType, nullable = true)))) + assert(ret.resolved === true) + checkEvaluation(ret, InternalRow(true, true, false)) + } + { + val ret = cast(struct_notNull, StructType(Seq( + StructField("a", BooleanType, nullable = true), + StructField("b", BooleanType, nullable = true), + StructField("c", BooleanType, nullable = false)))) + assert(ret.resolved === true) + checkEvaluation(ret, InternalRow(true, true, false)) + } + + { + val ret = cast(struct, StructType(Seq( + StructField("a", StringType, nullable = true), + StructField("b", StringType, nullable = true), + StructField("c", StringType, nullable = true)))) + assert(ret.resolved === false) + } + { + val ret = cast(struct, IntegerType) + assert(ret.resolved === false) + } + } + + test("complex casting") { + val complex = Literal.create( + InternalRow( + Seq("123", "abc", ""), + Map("a" -> "123", "b" -> "abc", "c" -> ""), + InternalRow(0)), + StructType(Seq( + StructField("a", + ArrayType(StringType, containsNull = false), nullable = true), + StructField("m", + MapType(StringType, StringType, valueContainsNull = false), nullable = true), + StructField("s", + StructType(Seq( + StructField("i", IntegerType, nullable = true))))))) + + val ret = cast(complex, StructType(Seq( + StructField("a", + ArrayType(IntegerType, containsNull = true), nullable = true), + StructField("m", + MapType(StringType, BooleanType, valueContainsNull = false), nullable = true), + StructField("s", + StructType(Seq( + StructField("l", LongType, nullable = true))))))) + + assert(ret.resolved === true) + checkEvaluation(ret, InternalRow( + Seq(123, null, null), + Map("a" -> true, "b" -> true, "c" -> false), + InternalRow(0L))) + } + +} diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/GeneratedEvaluationSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CodeGenerationSuite.scala similarity index 62% rename from sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/GeneratedEvaluationSuite.scala rename to sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CodeGenerationSuite.scala index b5ebe4b38e33..481b335d15df 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/GeneratedEvaluationSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CodeGenerationSuite.scala @@ -17,37 +17,14 @@ package org.apache.spark.sql.catalyst.expressions +import org.apache.spark.SparkFunSuite import org.apache.spark.sql.catalyst.dsl.expressions._ import org.apache.spark.sql.catalyst.expressions.codegen._ /** - * Overrides our expression evaluation tests to use code generation for evaluation. + * Additional tests for code generation. */ -class GeneratedEvaluationSuite extends ExpressionEvaluationSuite { - override def checkEvaluation( - expression: Expression, - expected: Any, - inputRow: Row = EmptyRow): Unit = { - val plan = try { - GenerateMutableProjection.generate(Alias(expression, s"Optimized($expression)")() :: Nil)() - } catch { - case e: Throwable => - val evaluated = GenerateProjection.expressionEvaluator(expression) - fail( - s""" - |Code generation of $expression failed: - |${evaluated.code.mkString("\n")} - |$e - """.stripMargin) - } - - val actual = plan(inputRow).apply(0) - if(actual != expected) { - val input = if(inputRow == EmptyRow) "" else s", input: $inputRow" - fail(s"Incorrect Evaluation: $expression, actual: $actual, expected: $expected$input") - } - } - +class CodeGenerationSuite extends SparkFunSuite { test("multithreaded eval") { import scala.concurrent._ diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ComplexTypeSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ComplexTypeSuite.scala new file mode 100644 index 000000000000..a09014e1ffc1 --- /dev/null +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ComplexTypeSuite.scala @@ -0,0 +1,184 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.expressions + +import org.scalatest.exceptions.TestFailedException + +import org.apache.spark.SparkFunSuite +import org.apache.spark.sql.catalyst.analysis.UnresolvedExtractValue +import org.apache.spark.sql.catalyst.dsl.expressions._ +import org.apache.spark.sql.types._ +import org.apache.spark.unsafe.types.UTF8String + + +class ComplexTypeSuite extends SparkFunSuite with ExpressionEvalHelper { + + /** + * Runs through the testFunc for all integral data types. + * + * @param testFunc a test function that accepts a conversion function to convert an integer + * into another data type. + */ + private def testIntegralDataTypes(testFunc: (Int => Any) => Unit): Unit = { + testFunc(_.toByte) + testFunc(_.toShort) + testFunc(identity) + testFunc(_.toLong) + } + + test("GetArrayItem") { + val typeA = ArrayType(StringType) + val array = Literal.create(Seq("a", "b"), typeA) + testIntegralDataTypes { convert => + checkEvaluation(GetArrayItem(array, Literal(convert(1))), "b") + } + val nullArray = Literal.create(null, typeA) + val nullInt = Literal.create(null, IntegerType) + checkEvaluation(GetArrayItem(nullArray, Literal(1)), null) + checkEvaluation(GetArrayItem(array, nullInt), null) + checkEvaluation(GetArrayItem(nullArray, nullInt), null) + + val nestedArray = Literal.create(Seq(Seq(1)), ArrayType(ArrayType(IntegerType))) + checkEvaluation(GetArrayItem(nestedArray, Literal(0)), Seq(1)) + } + + test("GetMapValue") { + val typeM = MapType(StringType, StringType) + val map = Literal.create(Map("a" -> "b"), typeM) + val nullMap = Literal.create(null, typeM) + val nullString = Literal.create(null, StringType) + + checkEvaluation(GetMapValue(map, Literal("a")), "b") + checkEvaluation(GetMapValue(map, nullString), null) + checkEvaluation(GetMapValue(nullMap, nullString), null) + checkEvaluation(GetMapValue(map, nullString), null) + + val nestedMap = Literal.create(Map("a" -> Map("b" -> "c")), MapType(StringType, typeM)) + checkEvaluation(GetMapValue(nestedMap, Literal("a")), Map("b" -> "c")) + } + + test("GetStructField") { + val typeS = StructType(StructField("a", IntegerType) :: Nil) + val struct = Literal.create(create_row(1), typeS) + val nullStruct = Literal.create(null, typeS) + + def getStructField(expr: Expression, fieldName: String): GetStructField = { + expr.dataType match { + case StructType(fields) => + val field = fields.find(_.name == fieldName).get + GetStructField(expr, field, fields.indexOf(field)) + } + } + + checkEvaluation(getStructField(struct, "a"), 1) + checkEvaluation(getStructField(nullStruct, "a"), null) + + val nestedStruct = Literal.create(create_row(create_row(1)), + StructType(StructField("a", typeS) :: Nil)) + checkEvaluation(getStructField(nestedStruct, "a"), create_row(1)) + + val typeS_fieldNotNullable = StructType(StructField("a", IntegerType, false) :: Nil) + val struct_fieldNotNullable = Literal.create(create_row(1), typeS_fieldNotNullable) + val nullStruct_fieldNotNullable = Literal.create(null, typeS_fieldNotNullable) + + assert(getStructField(struct_fieldNotNullable, "a").nullable === false) + assert(getStructField(struct, "a").nullable === true) + assert(getStructField(nullStruct_fieldNotNullable, "a").nullable === true) + assert(getStructField(nullStruct, "a").nullable === true) + } + + test("GetArrayStructFields") { + val typeAS = ArrayType(StructType(StructField("a", IntegerType) :: Nil)) + val arrayStruct = Literal.create(Seq(create_row(1)), typeAS) + val nullArrayStruct = Literal.create(null, typeAS) + + def getArrayStructFields(expr: Expression, fieldName: String): GetArrayStructFields = { + expr.dataType match { + case ArrayType(StructType(fields), containsNull) => + val field = fields.find(_.name == fieldName).get + GetArrayStructFields(expr, field, fields.indexOf(field), containsNull) + } + } + + checkEvaluation(getArrayStructFields(arrayStruct, "a"), Seq(1)) + checkEvaluation(getArrayStructFields(nullArrayStruct, "a"), null) + } + + test("CreateStruct") { + val row = create_row(1, 2, 3) + val c1 = 'a.int.at(0) + val c3 = 'c.int.at(2) + checkEvaluation(CreateStruct(Seq(c1, c3)), create_row(1, 3), row) + } + + test("CreateNamedStruct") { + val row = InternalRow(1, 2, 3) + val c1 = 'a.int.at(0) + val c3 = 'c.int.at(2) + checkEvaluation(CreateNamedStruct(Seq("a", c1, "b", c3)), InternalRow(1, 3), row) + } + + test("CreateNamedStruct with literal field") { + val row = InternalRow(1, 2, 3) + val c1 = 'a.int.at(0) + checkEvaluation(CreateNamedStruct(Seq("a", c1, "b", "y")), InternalRow(1, "y"), row) + } + + test("CreateNamedStruct from all literal fields") { + checkEvaluation( + CreateNamedStruct(Seq("a", "x", "b", 2.0)), InternalRow("x", 2.0), InternalRow.empty) + } + + test("test dsl for complex type") { + def quickResolve(u: UnresolvedExtractValue): Expression = { + ExtractValue(u.child, u.extraction, _ == _) + } + + checkEvaluation(quickResolve('c.map(MapType(StringType, StringType)).at(0).getItem("a")), + "b", create_row(Map("a" -> "b"))) + checkEvaluation(quickResolve('c.array(StringType).at(0).getItem(1)), + "b", create_row(Seq("a", "b"))) + checkEvaluation(quickResolve('c.struct(StructField("a", IntegerType)).at(0).getField("a")), + 1, create_row(create_row(1))) + } + + test("error message of ExtractValue") { + val structType = StructType(StructField("a", StringType, true) :: Nil) + val arrayStructType = ArrayType(structType) + val arrayType = ArrayType(StringType) + val otherType = StringType + + def checkErrorMessage( + childDataType: DataType, + fieldDataType: DataType, + errorMesage: String): Unit = { + val e = intercept[org.apache.spark.sql.AnalysisException] { + ExtractValue( + Literal.create(null, childDataType), + Literal.create(null, fieldDataType), + _ == _) + } + assert(e.getMessage().contains(errorMesage)) + } + + checkErrorMessage(structType, IntegerType, "Field name should be String Literal") + checkErrorMessage(arrayStructType, BooleanType, "Field name should be String Literal") + checkErrorMessage(arrayType, StringType, "Array index should be integral type") + checkErrorMessage(otherType, StringType, "Can't extract value from") + } +} diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ConditionalExpressionSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ConditionalExpressionSuite.scala new file mode 100644 index 000000000000..372848ea9a59 --- /dev/null +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ConditionalExpressionSuite.scala @@ -0,0 +1,137 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.expressions + +import org.apache.spark.SparkFunSuite +import org.apache.spark.sql.catalyst.dsl.expressions._ +import org.apache.spark.sql.types._ + + +class ConditionalExpressionSuite extends SparkFunSuite with ExpressionEvalHelper { + + test("if") { + val testcases = Seq[(java.lang.Boolean, Integer, Integer, Integer)]( + (true, 1, 2, 1), + (false, 1, 2, 2), + (null, 1, 2, 2), + (true, null, 2, null), + (false, 1, null, null), + (null, null, 2, 2), + (null, 1, null, null) + ) + + // dataType must match T. + def testIf(convert: (Integer => Any), dataType: DataType): Unit = { + for ((predicate, trueValue, falseValue, expected) <- testcases) { + val trueValueConverted = if (trueValue == null) null else convert(trueValue) + val falseValueConverted = if (falseValue == null) null else convert(falseValue) + val expectedConverted = if (expected == null) null else convert(expected) + + checkEvaluation( + If(Literal.create(predicate, BooleanType), + Literal.create(trueValueConverted, dataType), + Literal.create(falseValueConverted, dataType)), + expectedConverted) + } + } + + testIf(_ == 1, BooleanType) + testIf(_.toShort, ShortType) + testIf(identity, IntegerType) + testIf(_.toLong, LongType) + + testIf(_.toFloat, FloatType) + testIf(_.toDouble, DoubleType) + testIf(Decimal(_), DecimalType.Unlimited) + + testIf(identity, DateType) + testIf(_.toLong, TimestampType) + + testIf(_.toString, StringType) + } + + test("case when") { + val row = create_row(null, false, true, "a", "b", "c") + val c1 = 'a.boolean.at(0) + val c2 = 'a.boolean.at(1) + val c3 = 'a.boolean.at(2) + val c4 = 'a.string.at(3) + val c5 = 'a.string.at(4) + val c6 = 'a.string.at(5) + + checkEvaluation(CaseWhen(Seq(c1, c4, c6)), "c", row) + checkEvaluation(CaseWhen(Seq(c2, c4, c6)), "c", row) + checkEvaluation(CaseWhen(Seq(c3, c4, c6)), "a", row) + checkEvaluation(CaseWhen(Seq(Literal.create(null, BooleanType), c4, c6)), "c", row) + checkEvaluation(CaseWhen(Seq(Literal.create(false, BooleanType), c4, c6)), "c", row) + checkEvaluation(CaseWhen(Seq(Literal.create(true, BooleanType), c4, c6)), "a", row) + + checkEvaluation(CaseWhen(Seq(c3, c4, c2, c5, c6)), "a", row) + checkEvaluation(CaseWhen(Seq(c2, c4, c3, c5, c6)), "b", row) + checkEvaluation(CaseWhen(Seq(c1, c4, c2, c5, c6)), "c", row) + checkEvaluation(CaseWhen(Seq(c1, c4, c2, c5)), null, row) + + assert(CaseWhen(Seq(c2, c4, c6)).nullable === true) + assert(CaseWhen(Seq(c2, c4, c3, c5, c6)).nullable === true) + assert(CaseWhen(Seq(c2, c4, c3, c5)).nullable === true) + + val c4_notNull = 'a.boolean.notNull.at(3) + val c5_notNull = 'a.boolean.notNull.at(4) + val c6_notNull = 'a.boolean.notNull.at(5) + + assert(CaseWhen(Seq(c2, c4_notNull, c6_notNull)).nullable === false) + assert(CaseWhen(Seq(c2, c4, c6_notNull)).nullable === true) + assert(CaseWhen(Seq(c2, c4_notNull, c6)).nullable === true) + + assert(CaseWhen(Seq(c2, c4_notNull, c3, c5_notNull, c6_notNull)).nullable === false) + assert(CaseWhen(Seq(c2, c4, c3, c5_notNull, c6_notNull)).nullable === true) + assert(CaseWhen(Seq(c2, c4_notNull, c3, c5, c6_notNull)).nullable === true) + assert(CaseWhen(Seq(c2, c4_notNull, c3, c5_notNull, c6)).nullable === true) + + assert(CaseWhen(Seq(c2, c4_notNull, c3, c5_notNull)).nullable === true) + assert(CaseWhen(Seq(c2, c4, c3, c5_notNull)).nullable === true) + assert(CaseWhen(Seq(c2, c4_notNull, c3, c5)).nullable === true) + } + + test("case key when") { + val row = create_row(null, 1, 2, "a", "b", "c") + val c1 = 'a.int.at(0) + val c2 = 'a.int.at(1) + val c3 = 'a.int.at(2) + val c4 = 'a.string.at(3) + val c5 = 'a.string.at(4) + val c6 = 'a.string.at(5) + + val literalNull = Literal.create(null, IntegerType) + val literalInt = Literal(1) + val literalString = Literal("a") + + checkEvaluation(CaseKeyWhen(c1, Seq(c2, c4, c5)), "b", row) + checkEvaluation(CaseKeyWhen(c1, Seq(c2, c4, literalNull, c5, c6)), "b", row) + checkEvaluation(CaseKeyWhen(c2, Seq(literalInt, c4, c5)), "a", row) + checkEvaluation(CaseKeyWhen(c2, Seq(c1, c4, c5)), "b", row) + checkEvaluation(CaseKeyWhen(c4, Seq(literalString, c2, c3)), 1, row) + checkEvaluation(CaseKeyWhen(c4, Seq(c6, c3, c5, c2, Literal(3))), 3, row) + + checkEvaluation(CaseKeyWhen(literalInt, Seq(c2, c4, c5)), "a", row) + checkEvaluation(CaseKeyWhen(literalString, Seq(c5, c2, c4, c3)), 2, row) + checkEvaluation(CaseKeyWhen(c6, Seq(c5, c2, c4, c3)), null, row) + checkEvaluation(CaseKeyWhen(literalNull, Seq(c2, c5, c1, c6)), "c", row) + } + +} diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/DatetimeFunctionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/DatetimeFunctionsSuite.scala new file mode 100644 index 000000000000..1618c24871c6 --- /dev/null +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/DatetimeFunctionsSuite.scala @@ -0,0 +1,37 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.expressions + +import org.apache.spark.SparkFunSuite +import org.apache.spark.sql.catalyst.util.DateTimeUtils + +class DatetimeFunctionsSuite extends SparkFunSuite with ExpressionEvalHelper { + test("datetime function current_date") { + val d0 = DateTimeUtils.millisToDays(System.currentTimeMillis()) + val cd = CurrentDate().eval(EmptyRow).asInstanceOf[Int] + val d1 = DateTimeUtils.millisToDays(System.currentTimeMillis()) + assert(d0 <= cd && cd <= d1 && d1 - d0 <= 1) + } + + test("datetime function current_timestamp") { + val ct = DateTimeUtils.toJavaTimestamp(CurrentTimestamp().eval(EmptyRow).asInstanceOf[Long]) + val t1 = System.currentTimeMillis() + assert(math.abs(t1 - ct.getTime) < 5000) + } + +} diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvalHelper.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvalHelper.scala new file mode 100644 index 000000000000..3171caf6ad77 --- /dev/null +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvalHelper.scala @@ -0,0 +1,162 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.expressions + +import org.scalactic.TripleEqualsSupport.Spread +import org.scalatest.Matchers._ + +import org.apache.spark.SparkFunSuite +import org.apache.spark.sql.catalyst.CatalystTypeConverters +import org.apache.spark.sql.catalyst.expressions.codegen.{GenerateProjection, GenerateMutableProjection} +import org.apache.spark.sql.catalyst.optimizer.DefaultOptimizer +import org.apache.spark.sql.catalyst.plans.logical.{OneRowRelation, Project} + +/** + * A few helper functions for expression evaluation testing. Mixin this trait to use them. + */ +trait ExpressionEvalHelper { + self: SparkFunSuite => + + protected def create_row(values: Any*): InternalRow = { + InternalRow.fromSeq(values.map(CatalystTypeConverters.convertToCatalyst)) + } + + protected def checkEvaluation( + expression: Expression, expected: Any, inputRow: InternalRow = EmptyRow): Unit = { + val catalystValue = CatalystTypeConverters.convertToCatalyst(expected) + checkEvaluationWithoutCodegen(expression, catalystValue, inputRow) + checkEvaluationWithGeneratedMutableProjection(expression, catalystValue, inputRow) + checkEvaluationWithGeneratedProjection(expression, catalystValue, inputRow) + checkEvaluationWithOptimization(expression, catalystValue, inputRow) + } + + /** + * Check the equality between result of expression and expected value, it will handle + * Array[Byte]. + */ + protected def checkResult(result: Any, expected: Any): Boolean = { + (result, expected) match { + case (result: Array[Byte], expected: Array[Byte]) => + java.util.Arrays.equals(result, expected) + case _ => result == expected + } + } + + protected def evaluate(expression: Expression, inputRow: InternalRow = EmptyRow): Any = { + expression.eval(inputRow) + } + + protected def checkEvaluationWithoutCodegen( + expression: Expression, + expected: Any, + inputRow: InternalRow = EmptyRow): Unit = { + val actual = try evaluate(expression, inputRow) catch { + case e: Exception => fail(s"Exception evaluating $expression", e) + } + if (!checkResult(actual, expected)) { + val input = if (inputRow == EmptyRow) "" else s", input: $inputRow" + fail(s"Incorrect evaluation (codegen off): $expression, " + + s"actual: $actual, " + + s"expected: $expected$input") + } + } + + protected def checkEvaluationWithGeneratedMutableProjection( + expression: Expression, + expected: Any, + inputRow: InternalRow = EmptyRow): Unit = { + + val plan = try { + GenerateMutableProjection.generate(Alias(expression, s"Optimized($expression)")() :: Nil)() + } catch { + case e: Throwable => + val ctx = GenerateProjection.newCodeGenContext() + val evaluated = expression.gen(ctx) + fail( + s""" + |Code generation of $expression failed: + |${evaluated.code} + |$e + """.stripMargin) + } + + val actual = plan(inputRow).apply(0) + if (!checkResult(actual, expected)) { + val input = if (inputRow == EmptyRow) "" else s", input: $inputRow" + fail(s"Incorrect Evaluation: $expression, actual: $actual, expected: $expected$input") + } + } + + protected def checkEvaluationWithGeneratedProjection( + expression: Expression, + expected: Any, + inputRow: InternalRow = EmptyRow): Unit = { + val ctx = GenerateProjection.newCodeGenContext() + lazy val evaluated = expression.gen(ctx) + + val plan = try { + GenerateProjection.generate(Alias(expression, s"Optimized($expression)")() :: Nil) + } catch { + case e: Throwable => + fail( + s""" + |Code generation of $expression failed: + |${evaluated.code} + |$e + """.stripMargin) + } + + val actual = plan(inputRow) + val expectedRow = InternalRow(expected) + if (actual.hashCode() != expectedRow.hashCode()) { + fail( + s""" + |Mismatched hashCodes for values: $actual, $expectedRow + |Hash Codes: ${actual.hashCode()} != ${expectedRow.hashCode()} + |Expressions: $expression + |Code: $evaluated + """.stripMargin) + } + if (actual != expectedRow) { + val input = if (inputRow == EmptyRow) "" else s", input: $inputRow" + fail(s"Incorrect Evaluation: $expression, actual: $actual, expected: $expected$input") + } + if (actual.copy() != expectedRow) { + fail(s"Copy of generated Row is wrong: actual: ${actual.copy()}, expected: $expectedRow") + } + } + + protected def checkEvaluationWithOptimization( + expression: Expression, + expected: Any, + inputRow: InternalRow = EmptyRow): Unit = { + val plan = Project(Alias(expression, s"Optimized($expression)")() :: Nil, OneRowRelation) + val optimizedPlan = DefaultOptimizer.execute(plan) + checkEvaluationWithoutCodegen(optimizedPlan.expressions.head, expected, inputRow) + } + + protected def checkDoubleEvaluation( + expression: Expression, + expected: Spread[Double], + inputRow: InternalRow = EmptyRow): Unit = { + val actual = try evaluate(expression, inputRow) catch { + case e: Exception => fail(s"Exception evaluating $expression", e) + } + actual.asInstanceOf[Double] shouldBe expected + } +} diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvaluationSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvaluationSuite.scala deleted file mode 100644 index 5c4a1527c27c..000000000000 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvaluationSuite.scala +++ /dev/null @@ -1,1371 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.sql.catalyst.expressions - -import java.sql.{Date, Timestamp} - -import scala.collection.immutable.HashSet - -import org.scalactic.TripleEqualsSupport.Spread -import org.scalatest.FunSuite -import org.scalatest.Matchers._ - -import org.apache.spark.sql.catalyst.CatalystTypeConverters -import org.apache.spark.sql.catalyst.analysis.UnresolvedExtractValue -import org.apache.spark.sql.catalyst.dsl.expressions._ -import org.apache.spark.sql.catalyst.expressions.mathfuncs._ -import org.apache.spark.sql.catalyst.util.DateUtils -import org.apache.spark.sql.types._ - - -class ExpressionEvaluationBaseSuite extends FunSuite { - - def evaluate(expression: Expression, inputRow: Row = EmptyRow): Any = { - expression.eval(inputRow) - } - - def checkEvaluation(expression: Expression, expected: Any, inputRow: Row = EmptyRow): Unit = { - val actual = try evaluate(expression, inputRow) catch { - case e: Exception => fail(s"Exception evaluating $expression", e) - } - if(actual != expected) { - val input = if(inputRow == EmptyRow) "" else s", input: $inputRow" - fail(s"Incorrect Evaluation: $expression, actual: $actual, expected: $expected$input") - } - } - - def checkDoubleEvaluation( - expression: Expression, - expected: Spread[Double], - inputRow: Row = EmptyRow): Unit = { - val actual = try evaluate(expression, inputRow) catch { - case e: Exception => fail(s"Exception evaluating $expression", e) - } - actual.asInstanceOf[Double] shouldBe expected - } -} - -class ExpressionEvaluationSuite extends ExpressionEvaluationBaseSuite { - - def create_row(values: Any*): Row = { - new GenericRow(values.map(CatalystTypeConverters.convertToCatalyst).toArray) - } - - test("literals") { - checkEvaluation(Literal(1), 1) - checkEvaluation(Literal(true), true) - checkEvaluation(Literal(0L), 0L) - checkEvaluation(Literal("test"), "test") - checkEvaluation(Literal(1) + Literal(1), 2) - } - - test("unary BitwiseNOT") { - checkEvaluation(BitwiseNot(1), -2) - assert(BitwiseNot(1).dataType === IntegerType) - assert(BitwiseNot(1).eval(EmptyRow).isInstanceOf[Int]) - checkEvaluation(BitwiseNot(1.toLong), -2.toLong) - assert(BitwiseNot(1.toLong).dataType === LongType) - assert(BitwiseNot(1.toLong).eval(EmptyRow).isInstanceOf[Long]) - checkEvaluation(BitwiseNot(1.toShort), -2.toShort) - assert(BitwiseNot(1.toShort).dataType === ShortType) - assert(BitwiseNot(1.toShort).eval(EmptyRow).isInstanceOf[Short]) - checkEvaluation(BitwiseNot(1.toByte), -2.toByte) - assert(BitwiseNot(1.toByte).dataType === ByteType) - assert(BitwiseNot(1.toByte).eval(EmptyRow).isInstanceOf[Byte]) - } - - // scalastyle:off - /** - * Checks for three-valued-logic. Based on: - * http://en.wikipedia.org/wiki/Null_(SQL)#Comparisons_with_NULL_and_the_three-valued_logic_.283VL.29 - * I.e. in flat cpo "False -> Unknown -> True", - * OR is lowest upper bound, - * AND is greatest lower bound. - * p q p OR q p AND q p = q - * True True True True True - * True False True False False - * True Unknown True Unknown Unknown - * False True True False False - * False False False False True - * False Unknown Unknown False Unknown - * Unknown True True Unknown Unknown - * Unknown False Unknown False Unknown - * Unknown Unknown Unknown Unknown Unknown - * - * p NOT p - * True False - * False True - * Unknown Unknown - */ - // scalastyle:on - val notTrueTable = - (true, false) :: - (false, true) :: - (null, null) :: Nil - - test("3VL Not") { - notTrueTable.foreach { - case (v, answer) => - checkEvaluation(!Literal.create(v, BooleanType), answer) - } - } - - booleanLogicTest("AND", _ && _, - (true, true, true) :: - (true, false, false) :: - (true, null, null) :: - (false, true, false) :: - (false, false, false) :: - (false, null, false) :: - (null, true, null) :: - (null, false, false) :: - (null, null, null) :: Nil) - - booleanLogicTest("OR", _ || _, - (true, true, true) :: - (true, false, true) :: - (true, null, true) :: - (false, true, true) :: - (false, false, false) :: - (false, null, null) :: - (null, true, true) :: - (null, false, null) :: - (null, null, null) :: Nil) - - booleanLogicTest("=", _ === _, - (true, true, true) :: - (true, false, false) :: - (true, null, null) :: - (false, true, false) :: - (false, false, true) :: - (false, null, null) :: - (null, true, null) :: - (null, false, null) :: - (null, null, null) :: Nil) - - def booleanLogicTest( - name: String, - op: (Expression, Expression) => Expression, - truthTable: Seq[(Any, Any, Any)]) { - test(s"3VL $name") { - truthTable.foreach { - case (l,r,answer) => - val expr = op(Literal.create(l, BooleanType), Literal.create(r, BooleanType)) - checkEvaluation(expr, answer) - } - } - } - - test("IN") { - checkEvaluation(In(Literal(1), Seq(Literal(1), Literal(2))), true) - checkEvaluation(In(Literal(2), Seq(Literal(1), Literal(2))), true) - checkEvaluation(In(Literal(3), Seq(Literal(1), Literal(2))), false) - checkEvaluation( - In(Literal(1), Seq(Literal(1), Literal(2))) && In(Literal(2), Seq(Literal(1), Literal(2))), - true) - } - - test("Divide") { - checkEvaluation(Divide(Literal(2), Literal(1)), 2) - checkEvaluation(Divide(Literal(1.0), Literal(2.0)), 0.5) - checkEvaluation(Divide(Literal(1), Literal(2)), 0) - checkEvaluation(Divide(Literal(1), Literal(0)), null) - checkEvaluation(Divide(Literal(1.0), Literal(0.0)), null) - checkEvaluation(Divide(Literal(0.0), Literal(0.0)), null) - checkEvaluation(Divide(Literal(0), Literal.create(null, IntegerType)), null) - checkEvaluation(Divide(Literal(1), Literal.create(null, IntegerType)), null) - checkEvaluation(Divide(Literal.create(null, IntegerType), Literal(0)), null) - checkEvaluation(Divide(Literal.create(null, DoubleType), Literal(0.0)), null) - checkEvaluation(Divide(Literal.create(null, IntegerType), Literal(1)), null) - checkEvaluation(Divide(Literal.create(null, IntegerType), Literal.create(null, IntegerType)), - null) - } - - test("Remainder") { - checkEvaluation(Remainder(Literal(2), Literal(1)), 0) - checkEvaluation(Remainder(Literal(1.0), Literal(2.0)), 1.0) - checkEvaluation(Remainder(Literal(1), Literal(2)), 1) - checkEvaluation(Remainder(Literal(1), Literal(0)), null) - checkEvaluation(Remainder(Literal(1.0), Literal(0.0)), null) - checkEvaluation(Remainder(Literal(0.0), Literal(0.0)), null) - checkEvaluation(Remainder(Literal(0), Literal.create(null, IntegerType)), null) - checkEvaluation(Remainder(Literal(1), Literal.create(null, IntegerType)), null) - checkEvaluation(Remainder(Literal.create(null, IntegerType), Literal(0)), null) - checkEvaluation(Remainder(Literal.create(null, DoubleType), Literal(0.0)), null) - checkEvaluation(Remainder(Literal.create(null, IntegerType), Literal(1)), null) - checkEvaluation(Remainder(Literal.create(null, IntegerType), Literal.create(null, IntegerType)), - null) - } - - test("INSET") { - val hS = HashSet[Any]() + 1 + 2 - val nS = HashSet[Any]() + 1 + 2 + null - val one = Literal(1) - val two = Literal(2) - val three = Literal(3) - val nl = Literal(null) - val s = Seq(one, two) - val nullS = Seq(one, two, null) - checkEvaluation(InSet(one, hS), true) - checkEvaluation(InSet(two, hS), true) - checkEvaluation(InSet(two, nS), true) - checkEvaluation(InSet(nl, nS), true) - checkEvaluation(InSet(three, hS), false) - checkEvaluation(InSet(three, nS), false) - checkEvaluation(InSet(one, hS) && InSet(two, hS), true) - } - - test("MaxOf") { - checkEvaluation(MaxOf(1, 2), 2) - checkEvaluation(MaxOf(2, 1), 2) - checkEvaluation(MaxOf(1L, 2L), 2L) - checkEvaluation(MaxOf(2L, 1L), 2L) - - checkEvaluation(MaxOf(Literal.create(null, IntegerType), 2), 2) - checkEvaluation(MaxOf(2, Literal.create(null, IntegerType)), 2) - } - - test("MinOf") { - checkEvaluation(MinOf(1, 2), 1) - checkEvaluation(MinOf(2, 1), 1) - checkEvaluation(MinOf(1L, 2L), 1L) - checkEvaluation(MinOf(2L, 1L), 1L) - - checkEvaluation(MinOf(Literal.create(null, IntegerType), 1), 1) - checkEvaluation(MinOf(1, Literal.create(null, IntegerType)), 1) - } - - test("LIKE literal Regular Expression") { - checkEvaluation(Literal.create(null, StringType).like("a"), null) - checkEvaluation(Literal.create("a", StringType).like(Literal.create(null, StringType)), null) - checkEvaluation(Literal.create(null, StringType).like(Literal.create(null, StringType)), null) - checkEvaluation("abdef" like "abdef", true) - checkEvaluation("a_%b" like "a\\__b", true) - checkEvaluation("addb" like "a_%b", true) - checkEvaluation("addb" like "a\\__b", false) - checkEvaluation("addb" like "a%\\%b", false) - checkEvaluation("a_%b" like "a%\\%b", true) - checkEvaluation("addb" like "a%", true) - checkEvaluation("addb" like "**", false) - checkEvaluation("abc" like "a%", true) - checkEvaluation("abc" like "b%", false) - checkEvaluation("abc" like "bc%", false) - checkEvaluation("a\nb" like "a_b", true) - checkEvaluation("ab" like "a%b", true) - checkEvaluation("a\nb" like "a%b", true) - } - - test("LIKE Non-literal Regular Expression") { - val regEx = 'a.string.at(0) - checkEvaluation("abcd" like regEx, null, create_row(null)) - checkEvaluation("abdef" like regEx, true, create_row("abdef")) - checkEvaluation("a_%b" like regEx, true, create_row("a\\__b")) - checkEvaluation("addb" like regEx, true, create_row("a_%b")) - checkEvaluation("addb" like regEx, false, create_row("a\\__b")) - checkEvaluation("addb" like regEx, false, create_row("a%\\%b")) - checkEvaluation("a_%b" like regEx, true, create_row("a%\\%b")) - checkEvaluation("addb" like regEx, true, create_row("a%")) - checkEvaluation("addb" like regEx, false, create_row("**")) - checkEvaluation("abc" like regEx, true, create_row("a%")) - checkEvaluation("abc" like regEx, false, create_row("b%")) - checkEvaluation("abc" like regEx, false, create_row("bc%")) - checkEvaluation("a\nb" like regEx, true, create_row("a_b")) - checkEvaluation("ab" like regEx, true, create_row("a%b")) - checkEvaluation("a\nb" like regEx, true, create_row("a%b")) - - checkEvaluation(Literal.create(null, StringType) like regEx, null, create_row("bc%")) - } - - test("RLIKE literal Regular Expression") { - checkEvaluation(Literal.create(null, StringType) rlike "abdef", null) - checkEvaluation("abdef" rlike Literal.create(null, StringType), null) - checkEvaluation(Literal.create(null, StringType) rlike Literal.create(null, StringType), null) - checkEvaluation("abdef" rlike "abdef", true) - checkEvaluation("abbbbc" rlike "a.*c", true) - - checkEvaluation("fofo" rlike "^fo", true) - checkEvaluation("fo\no" rlike "^fo\no$", true) - checkEvaluation("Bn" rlike "^Ba*n", true) - checkEvaluation("afofo" rlike "fo", true) - checkEvaluation("afofo" rlike "^fo", false) - checkEvaluation("Baan" rlike "^Ba?n", false) - checkEvaluation("axe" rlike "pi|apa", false) - checkEvaluation("pip" rlike "^(pi)*$", false) - - checkEvaluation("abc" rlike "^ab", true) - checkEvaluation("abc" rlike "^bc", false) - checkEvaluation("abc" rlike "^ab", true) - checkEvaluation("abc" rlike "^bc", false) - - intercept[java.util.regex.PatternSyntaxException] { - evaluate("abbbbc" rlike "**") - } - } - - test("RLIKE Non-literal Regular Expression") { - val regEx = 'a.string.at(0) - checkEvaluation("abdef" rlike regEx, true, create_row("abdef")) - checkEvaluation("abbbbc" rlike regEx, true, create_row("a.*c")) - checkEvaluation("fofo" rlike regEx, true, create_row("^fo")) - checkEvaluation("fo\no" rlike regEx, true, create_row("^fo\no$")) - checkEvaluation("Bn" rlike regEx, true, create_row("^Ba*n")) - - intercept[java.util.regex.PatternSyntaxException] { - evaluate("abbbbc" rlike regEx, create_row("**")) - } - } - - test("data type casting") { - - val sd = "1970-01-01" - val d = Date.valueOf(sd) - val zts = sd + " 00:00:00" - val sts = sd + " 00:00:02" - val nts = sts + ".1" - val ts = Timestamp.valueOf(nts) - - checkEvaluation("abdef" cast StringType, "abdef") - checkEvaluation("abdef" cast DecimalType.Unlimited, null) - checkEvaluation("abdef" cast TimestampType, null) - checkEvaluation("12.65" cast DecimalType.Unlimited, Decimal(12.65)) - - checkEvaluation(Literal(1) cast LongType, 1) - checkEvaluation(Cast(Literal(1000) cast TimestampType, LongType), 1.toLong) - checkEvaluation(Cast(Literal(-1200) cast TimestampType, LongType), -2.toLong) - checkEvaluation(Cast(Literal(1.toDouble) cast TimestampType, DoubleType), 1.toDouble) - checkEvaluation(Cast(Literal(1.toDouble) cast TimestampType, DoubleType), 1.toDouble) - - checkEvaluation(Cast(Literal(sd) cast DateType, StringType), sd) - checkEvaluation(Cast(Literal(d) cast StringType, DateType), 0) - checkEvaluation(Cast(Literal(nts) cast TimestampType, StringType), nts) - checkEvaluation(Cast(Literal(ts) cast StringType, TimestampType), ts) - // all convert to string type to check - checkEvaluation( - Cast(Cast(Literal(nts) cast TimestampType, DateType), StringType), sd) - checkEvaluation( - Cast(Cast(Literal(ts) cast DateType, TimestampType), StringType), zts) - - checkEvaluation(Cast("abdef" cast BinaryType, StringType), "abdef") - - checkEvaluation(Cast(Cast(Cast(Cast( - Cast("5" cast ByteType, ShortType), IntegerType), FloatType), DoubleType), LongType), 5) - checkEvaluation(Cast(Cast(Cast(Cast(Cast("5" cast - ByteType, TimestampType), DecimalType.Unlimited), LongType), StringType), ShortType), 0) - checkEvaluation(Cast(Cast(Cast(Cast(Cast("5" cast - TimestampType, ByteType), DecimalType.Unlimited), LongType), StringType), ShortType), null) - checkEvaluation(Cast(Cast(Cast(Cast(Cast("5" cast - DecimalType.Unlimited, ByteType), TimestampType), LongType), StringType), ShortType), 0) - checkEvaluation(Literal(true) cast IntegerType, 1) - checkEvaluation(Literal(false) cast IntegerType, 0) - checkEvaluation(Cast(Literal(1) cast BooleanType, IntegerType), 1) - checkEvaluation(Cast(Literal(0) cast BooleanType, IntegerType), 0) - checkEvaluation("23" cast DoubleType, 23d) - checkEvaluation("23" cast IntegerType, 23) - checkEvaluation("23" cast FloatType, 23f) - checkEvaluation("23" cast DecimalType.Unlimited, Decimal(23)) - checkEvaluation("23" cast ByteType, 23.toByte) - checkEvaluation("23" cast ShortType, 23.toShort) - checkEvaluation("2012-12-11" cast DoubleType, null) - checkEvaluation(Literal(123) cast IntegerType, 123) - - checkEvaluation(Literal(23d) + Cast(true, DoubleType), 24d) - checkEvaluation(Literal(23) + Cast(true, IntegerType), 24) - checkEvaluation(Literal(23f) + Cast(true, FloatType), 24f) - checkEvaluation(Literal(Decimal(23)) + Cast(true, DecimalType.Unlimited), Decimal(24)) - checkEvaluation(Literal(23.toByte) + Cast(true, ByteType), 24.toByte) - checkEvaluation(Literal(23.toShort) + Cast(true, ShortType), 24.toShort) - - intercept[Exception] {evaluate(Literal(1) cast BinaryType, null)} - - assert(("abcdef" cast StringType).nullable === false) - assert(("abcdef" cast BinaryType).nullable === false) - assert(("abcdef" cast BooleanType).nullable === false) - assert(("abcdef" cast TimestampType).nullable === true) - assert(("abcdef" cast LongType).nullable === true) - assert(("abcdef" cast IntegerType).nullable === true) - assert(("abcdef" cast ShortType).nullable === true) - assert(("abcdef" cast ByteType).nullable === true) - assert(("abcdef" cast DecimalType.Unlimited).nullable === true) - assert(("abcdef" cast DecimalType(4, 2)).nullable === true) - assert(("abcdef" cast DoubleType).nullable === true) - assert(("abcdef" cast FloatType).nullable === true) - - checkEvaluation(Cast(Literal.create(null, IntegerType), ShortType), null) - } - - test("date") { - val d1 = DateUtils.fromJavaDate(Date.valueOf("1970-01-01")) - val d2 = DateUtils.fromJavaDate(Date.valueOf("1970-01-02")) - checkEvaluation(Literal(d1) < Literal(d2), true) - } - - test("casting to fixed-precision decimals") { - // Overflow and rounding for casting to fixed-precision decimals: - // - Values should round with HALF_UP mode by default when you lower scale - // - Values that would overflow the target precision should turn into null - // - Because of this, casts to fixed-precision decimals should be nullable - - assert(Cast(Literal(123), DecimalType.Unlimited).nullable === false) - assert(Cast(Literal(10.03f), DecimalType.Unlimited).nullable === true) - assert(Cast(Literal(10.03), DecimalType.Unlimited).nullable === true) - assert(Cast(Literal(Decimal(10.03)), DecimalType.Unlimited).nullable === false) - - assert(Cast(Literal(123), DecimalType(2, 1)).nullable === true) - assert(Cast(Literal(10.03f), DecimalType(2, 1)).nullable === true) - assert(Cast(Literal(10.03), DecimalType(2, 1)).nullable === true) - assert(Cast(Literal(Decimal(10.03)), DecimalType(2, 1)).nullable === true) - - checkEvaluation(Cast(Literal(123), DecimalType.Unlimited), Decimal(123)) - checkEvaluation(Cast(Literal(123), DecimalType(3, 0)), Decimal(123)) - checkEvaluation(Cast(Literal(123), DecimalType(3, 1)), null) - checkEvaluation(Cast(Literal(123), DecimalType(2, 0)), null) - - checkEvaluation(Cast(Literal(10.03), DecimalType.Unlimited), Decimal(10.03)) - checkEvaluation(Cast(Literal(10.03), DecimalType(4, 2)), Decimal(10.03)) - checkEvaluation(Cast(Literal(10.03), DecimalType(3, 1)), Decimal(10.0)) - checkEvaluation(Cast(Literal(10.03), DecimalType(2, 0)), Decimal(10)) - checkEvaluation(Cast(Literal(10.03), DecimalType(1, 0)), null) - checkEvaluation(Cast(Literal(10.03), DecimalType(2, 1)), null) - checkEvaluation(Cast(Literal(10.03), DecimalType(3, 2)), null) - checkEvaluation(Cast(Literal(Decimal(10.03)), DecimalType(3, 1)), Decimal(10.0)) - checkEvaluation(Cast(Literal(Decimal(10.03)), DecimalType(3, 2)), null) - - checkEvaluation(Cast(Literal(10.05), DecimalType.Unlimited), Decimal(10.05)) - checkEvaluation(Cast(Literal(10.05), DecimalType(4, 2)), Decimal(10.05)) - checkEvaluation(Cast(Literal(10.05), DecimalType(3, 1)), Decimal(10.1)) - checkEvaluation(Cast(Literal(10.05), DecimalType(2, 0)), Decimal(10)) - checkEvaluation(Cast(Literal(10.05), DecimalType(1, 0)), null) - checkEvaluation(Cast(Literal(10.05), DecimalType(2, 1)), null) - checkEvaluation(Cast(Literal(10.05), DecimalType(3, 2)), null) - checkEvaluation(Cast(Literal(Decimal(10.05)), DecimalType(3, 1)), Decimal(10.1)) - checkEvaluation(Cast(Literal(Decimal(10.05)), DecimalType(3, 2)), null) - - checkEvaluation(Cast(Literal(9.95), DecimalType(3, 2)), Decimal(9.95)) - checkEvaluation(Cast(Literal(9.95), DecimalType(3, 1)), Decimal(10.0)) - checkEvaluation(Cast(Literal(9.95), DecimalType(2, 0)), Decimal(10)) - checkEvaluation(Cast(Literal(9.95), DecimalType(2, 1)), null) - checkEvaluation(Cast(Literal(9.95), DecimalType(1, 0)), null) - checkEvaluation(Cast(Literal(Decimal(9.95)), DecimalType(3, 1)), Decimal(10.0)) - checkEvaluation(Cast(Literal(Decimal(9.95)), DecimalType(1, 0)), null) - - checkEvaluation(Cast(Literal(-9.95), DecimalType(3, 2)), Decimal(-9.95)) - checkEvaluation(Cast(Literal(-9.95), DecimalType(3, 1)), Decimal(-10.0)) - checkEvaluation(Cast(Literal(-9.95), DecimalType(2, 0)), Decimal(-10)) - checkEvaluation(Cast(Literal(-9.95), DecimalType(2, 1)), null) - checkEvaluation(Cast(Literal(-9.95), DecimalType(1, 0)), null) - checkEvaluation(Cast(Literal(Decimal(-9.95)), DecimalType(3, 1)), Decimal(-10.0)) - checkEvaluation(Cast(Literal(Decimal(-9.95)), DecimalType(1, 0)), null) - - checkEvaluation(Cast(Literal(Double.NaN), DecimalType.Unlimited), null) - checkEvaluation(Cast(Literal(1.0 / 0.0), DecimalType.Unlimited), null) - checkEvaluation(Cast(Literal(Float.NaN), DecimalType.Unlimited), null) - checkEvaluation(Cast(Literal(1.0f / 0.0f), DecimalType.Unlimited), null) - - checkEvaluation(Cast(Literal(Double.NaN), DecimalType(2, 1)), null) - checkEvaluation(Cast(Literal(1.0 / 0.0), DecimalType(2, 1)), null) - checkEvaluation(Cast(Literal(Float.NaN), DecimalType(2, 1)), null) - checkEvaluation(Cast(Literal(1.0f / 0.0f), DecimalType(2, 1)), null) - } - - test("timestamp") { - val ts1 = new Timestamp(12) - val ts2 = new Timestamp(123) - checkEvaluation(Literal("ab") < Literal("abc"), true) - checkEvaluation(Literal(ts1) < Literal(ts2), true) - } - - test("date casting") { - val d = Date.valueOf("1970-01-01") - checkEvaluation(Cast(Literal(d), ShortType), null) - checkEvaluation(Cast(Literal(d), IntegerType), null) - checkEvaluation(Cast(Literal(d), LongType), null) - checkEvaluation(Cast(Literal(d), FloatType), null) - checkEvaluation(Cast(Literal(d), DoubleType), null) - checkEvaluation(Cast(Literal(d), DecimalType.Unlimited), null) - checkEvaluation(Cast(Literal(d), DecimalType(10, 2)), null) - checkEvaluation(Cast(Literal(d), StringType), "1970-01-01") - checkEvaluation(Cast(Cast(Literal(d), TimestampType), StringType), "1970-01-01 00:00:00") - } - - test("timestamp casting") { - val millis = 15 * 1000 + 2 - val seconds = millis * 1000 + 2 - val ts = new Timestamp(millis) - val tss = new Timestamp(seconds) - checkEvaluation(Cast(ts, ShortType), 15) - checkEvaluation(Cast(ts, IntegerType), 15) - checkEvaluation(Cast(ts, LongType), 15) - checkEvaluation(Cast(ts, FloatType), 15.002f) - checkEvaluation(Cast(ts, DoubleType), 15.002) - checkEvaluation(Cast(Cast(tss, ShortType), TimestampType), ts) - checkEvaluation(Cast(Cast(tss, IntegerType), TimestampType), ts) - checkEvaluation(Cast(Cast(tss, LongType), TimestampType), ts) - checkEvaluation(Cast(Cast(millis.toFloat / 1000, TimestampType), FloatType), - millis.toFloat / 1000) - checkEvaluation(Cast(Cast(millis.toDouble / 1000, TimestampType), DoubleType), - millis.toDouble / 1000) - checkEvaluation(Cast(Literal(Decimal(1)) cast TimestampType, DecimalType.Unlimited), Decimal(1)) - - // A test for higher precision than millis - checkEvaluation(Cast(Cast(0.00000001, TimestampType), DoubleType), 0.00000001) - - checkEvaluation(Cast(Literal(Double.NaN), TimestampType), null) - checkEvaluation(Cast(Literal(1.0 / 0.0), TimestampType), null) - checkEvaluation(Cast(Literal(Float.NaN), TimestampType), null) - checkEvaluation(Cast(Literal(1.0f / 0.0f), TimestampType), null) - } - - test("array casting") { - val array = Literal.create(Seq("123", "abc", "", null), - ArrayType(StringType, containsNull = true)) - val array_notNull = Literal.create(Seq("123", "abc", ""), - ArrayType(StringType, containsNull = false)) - - { - val cast = Cast(array, ArrayType(IntegerType, containsNull = true)) - assert(cast.resolved === true) - checkEvaluation(cast, Seq(123, null, null, null)) - } - { - val cast = Cast(array, ArrayType(IntegerType, containsNull = false)) - assert(cast.resolved === false) - } - { - val cast = Cast(array, ArrayType(BooleanType, containsNull = true)) - assert(cast.resolved === true) - checkEvaluation(cast, Seq(true, true, false, null)) - } - { - val cast = Cast(array, ArrayType(BooleanType, containsNull = false)) - assert(cast.resolved === false) - } - - { - val cast = Cast(array_notNull, ArrayType(IntegerType, containsNull = true)) - assert(cast.resolved === true) - checkEvaluation(cast, Seq(123, null, null)) - } - { - val cast = Cast(array_notNull, ArrayType(IntegerType, containsNull = false)) - assert(cast.resolved === false) - } - { - val cast = Cast(array_notNull, ArrayType(BooleanType, containsNull = true)) - assert(cast.resolved === true) - checkEvaluation(cast, Seq(true, true, false)) - } - { - val cast = Cast(array_notNull, ArrayType(BooleanType, containsNull = false)) - assert(cast.resolved === true) - checkEvaluation(cast, Seq(true, true, false)) - } - - { - val cast = Cast(array, IntegerType) - assert(cast.resolved === false) - } - } - - test("map casting") { - val map = Literal.create( - Map("a" -> "123", "b" -> "abc", "c" -> "", "d" -> null), - MapType(StringType, StringType, valueContainsNull = true)) - val map_notNull = Literal.create( - Map("a" -> "123", "b" -> "abc", "c" -> ""), - MapType(StringType, StringType, valueContainsNull = false)) - - { - val cast = Cast(map, MapType(StringType, IntegerType, valueContainsNull = true)) - assert(cast.resolved === true) - checkEvaluation(cast, Map("a" -> 123, "b" -> null, "c" -> null, "d" -> null)) - } - { - val cast = Cast(map, MapType(StringType, IntegerType, valueContainsNull = false)) - assert(cast.resolved === false) - } - { - val cast = Cast(map, MapType(StringType, BooleanType, valueContainsNull = true)) - assert(cast.resolved === true) - checkEvaluation(cast, Map("a" -> true, "b" -> true, "c" -> false, "d" -> null)) - } - { - val cast = Cast(map, MapType(StringType, BooleanType, valueContainsNull = false)) - assert(cast.resolved === false) - } - { - val cast = Cast(map, MapType(IntegerType, StringType, valueContainsNull = true)) - assert(cast.resolved === false) - } - - { - val cast = Cast(map_notNull, MapType(StringType, IntegerType, valueContainsNull = true)) - assert(cast.resolved === true) - checkEvaluation(cast, Map("a" -> 123, "b" -> null, "c" -> null)) - } - { - val cast = Cast(map_notNull, MapType(StringType, IntegerType, valueContainsNull = false)) - assert(cast.resolved === false) - } - { - val cast = Cast(map_notNull, MapType(StringType, BooleanType, valueContainsNull = true)) - assert(cast.resolved === true) - checkEvaluation(cast, Map("a" -> true, "b" -> true, "c" -> false)) - } - { - val cast = Cast(map_notNull, MapType(StringType, BooleanType, valueContainsNull = false)) - assert(cast.resolved === true) - checkEvaluation(cast, Map("a" -> true, "b" -> true, "c" -> false)) - } - { - val cast = Cast(map_notNull, MapType(IntegerType, StringType, valueContainsNull = true)) - assert(cast.resolved === false) - } - - { - val cast = Cast(map, IntegerType) - assert(cast.resolved === false) - } - } - - test("struct casting") { - val struct = Literal.create( - Row("123", "abc", "", null), - StructType(Seq( - StructField("a", StringType, nullable = true), - StructField("b", StringType, nullable = true), - StructField("c", StringType, nullable = true), - StructField("d", StringType, nullable = true)))) - val struct_notNull = Literal.create( - Row("123", "abc", ""), - StructType(Seq( - StructField("a", StringType, nullable = false), - StructField("b", StringType, nullable = false), - StructField("c", StringType, nullable = false)))) - - { - val cast = Cast(struct, StructType(Seq( - StructField("a", IntegerType, nullable = true), - StructField("b", IntegerType, nullable = true), - StructField("c", IntegerType, nullable = true), - StructField("d", IntegerType, nullable = true)))) - assert(cast.resolved === true) - checkEvaluation(cast, Row(123, null, null, null)) - } - { - val cast = Cast(struct, StructType(Seq( - StructField("a", IntegerType, nullable = true), - StructField("b", IntegerType, nullable = true), - StructField("c", IntegerType, nullable = false), - StructField("d", IntegerType, nullable = true)))) - assert(cast.resolved === false) - } - { - val cast = Cast(struct, StructType(Seq( - StructField("a", BooleanType, nullable = true), - StructField("b", BooleanType, nullable = true), - StructField("c", BooleanType, nullable = true), - StructField("d", BooleanType, nullable = true)))) - assert(cast.resolved === true) - checkEvaluation(cast, Row(true, true, false, null)) - } - { - val cast = Cast(struct, StructType(Seq( - StructField("a", BooleanType, nullable = true), - StructField("b", BooleanType, nullable = true), - StructField("c", BooleanType, nullable = false), - StructField("d", BooleanType, nullable = true)))) - assert(cast.resolved === false) - } - - { - val cast = Cast(struct_notNull, StructType(Seq( - StructField("a", IntegerType, nullable = true), - StructField("b", IntegerType, nullable = true), - StructField("c", IntegerType, nullable = true)))) - assert(cast.resolved === true) - checkEvaluation(cast, Row(123, null, null)) - } - { - val cast = Cast(struct_notNull, StructType(Seq( - StructField("a", IntegerType, nullable = true), - StructField("b", IntegerType, nullable = true), - StructField("c", IntegerType, nullable = false)))) - assert(cast.resolved === false) - } - { - val cast = Cast(struct_notNull, StructType(Seq( - StructField("a", BooleanType, nullable = true), - StructField("b", BooleanType, nullable = true), - StructField("c", BooleanType, nullable = true)))) - assert(cast.resolved === true) - checkEvaluation(cast, Row(true, true, false)) - } - { - val cast = Cast(struct_notNull, StructType(Seq( - StructField("a", BooleanType, nullable = true), - StructField("b", BooleanType, nullable = true), - StructField("c", BooleanType, nullable = false)))) - assert(cast.resolved === true) - checkEvaluation(cast, Row(true, true, false)) - } - - { - val cast = Cast(struct, StructType(Seq( - StructField("a", StringType, nullable = true), - StructField("b", StringType, nullable = true), - StructField("c", StringType, nullable = true)))) - assert(cast.resolved === false) - } - { - val cast = Cast(struct, IntegerType) - assert(cast.resolved === false) - } - } - - test("complex casting") { - val complex = Literal.create( - Row( - Seq("123", "abc", ""), - Map("a" -> "123", "b" -> "abc", "c" -> ""), - Row(0)), - StructType(Seq( - StructField("a", - ArrayType(StringType, containsNull = false), nullable = true), - StructField("m", - MapType(StringType, StringType, valueContainsNull = false), nullable = true), - StructField("s", - StructType(Seq( - StructField("i", IntegerType, nullable = true))))))) - - val cast = Cast(complex, StructType(Seq( - StructField("a", - ArrayType(IntegerType, containsNull = true), nullable = true), - StructField("m", - MapType(StringType, BooleanType, valueContainsNull = false), nullable = true), - StructField("s", - StructType(Seq( - StructField("l", LongType, nullable = true))))))) - - assert(cast.resolved === true) - checkEvaluation(cast, Row( - Seq(123, null, null), - Map("a" -> true, "b" -> true, "c" -> false), - Row(0L))) - } - - test("null checking") { - val row = create_row("^Ba*n", null, true, null) - val c1 = 'a.string.at(0) - val c2 = 'a.string.at(1) - val c3 = 'a.boolean.at(2) - val c4 = 'a.boolean.at(3) - - checkEvaluation(c1.isNull, false, row) - checkEvaluation(c1.isNotNull, true, row) - - checkEvaluation(c2.isNull, true, row) - checkEvaluation(c2.isNotNull, false, row) - - checkEvaluation(Literal.create(1, ShortType).isNull, false) - checkEvaluation(Literal.create(1, ShortType).isNotNull, true) - - checkEvaluation(Literal.create(null, ShortType).isNull, true) - checkEvaluation(Literal.create(null, ShortType).isNotNull, false) - - checkEvaluation(Coalesce(c1 :: c2 :: Nil), "^Ba*n", row) - checkEvaluation(Coalesce(Literal.create(null, StringType) :: Nil), null, row) - checkEvaluation(Coalesce(Literal.create(null, StringType) :: c1 :: c2 :: Nil), "^Ba*n", row) - - checkEvaluation( - If(c3, Literal.create("a", StringType), Literal.create("b", StringType)), "a", row) - checkEvaluation(If(c3, c1, c2), "^Ba*n", row) - checkEvaluation(If(c4, c2, c1), "^Ba*n", row) - checkEvaluation(If(Literal.create(null, BooleanType), c2, c1), "^Ba*n", row) - checkEvaluation(If(Literal.create(true, BooleanType), c1, c2), "^Ba*n", row) - checkEvaluation(If(Literal.create(false, BooleanType), c2, c1), "^Ba*n", row) - checkEvaluation(If(Literal.create(false, BooleanType), - Literal.create("a", StringType), Literal.create("b", StringType)), "b", row) - - checkEvaluation(c1 in (c1, c2), true, row) - checkEvaluation( - Literal.create("^Ba*n", StringType) in (Literal.create("^Ba*n", StringType)), true, row) - checkEvaluation( - Literal.create("^Ba*n", StringType) in (Literal.create("^Ba*n", StringType), c2), true, row) - } - - test("case when") { - val row = create_row(null, false, true, "a", "b", "c") - val c1 = 'a.boolean.at(0) - val c2 = 'a.boolean.at(1) - val c3 = 'a.boolean.at(2) - val c4 = 'a.string.at(3) - val c5 = 'a.string.at(4) - val c6 = 'a.string.at(5) - - checkEvaluation(CaseWhen(Seq(c1, c4, c6)), "c", row) - checkEvaluation(CaseWhen(Seq(c2, c4, c6)), "c", row) - checkEvaluation(CaseWhen(Seq(c3, c4, c6)), "a", row) - checkEvaluation(CaseWhen(Seq(Literal.create(null, BooleanType), c4, c6)), "c", row) - checkEvaluation(CaseWhen(Seq(Literal.create(false, BooleanType), c4, c6)), "c", row) - checkEvaluation(CaseWhen(Seq(Literal.create(true, BooleanType), c4, c6)), "a", row) - - checkEvaluation(CaseWhen(Seq(c3, c4, c2, c5, c6)), "a", row) - checkEvaluation(CaseWhen(Seq(c2, c4, c3, c5, c6)), "b", row) - checkEvaluation(CaseWhen(Seq(c1, c4, c2, c5, c6)), "c", row) - checkEvaluation(CaseWhen(Seq(c1, c4, c2, c5)), null, row) - - assert(CaseWhen(Seq(c2, c4, c6)).nullable === true) - assert(CaseWhen(Seq(c2, c4, c3, c5, c6)).nullable === true) - assert(CaseWhen(Seq(c2, c4, c3, c5)).nullable === true) - - val c4_notNull = 'a.boolean.notNull.at(3) - val c5_notNull = 'a.boolean.notNull.at(4) - val c6_notNull = 'a.boolean.notNull.at(5) - - assert(CaseWhen(Seq(c2, c4_notNull, c6_notNull)).nullable === false) - assert(CaseWhen(Seq(c2, c4, c6_notNull)).nullable === true) - assert(CaseWhen(Seq(c2, c4_notNull, c6)).nullable === true) - - assert(CaseWhen(Seq(c2, c4_notNull, c3, c5_notNull, c6_notNull)).nullable === false) - assert(CaseWhen(Seq(c2, c4, c3, c5_notNull, c6_notNull)).nullable === true) - assert(CaseWhen(Seq(c2, c4_notNull, c3, c5, c6_notNull)).nullable === true) - assert(CaseWhen(Seq(c2, c4_notNull, c3, c5_notNull, c6)).nullable === true) - - assert(CaseWhen(Seq(c2, c4_notNull, c3, c5_notNull)).nullable === true) - assert(CaseWhen(Seq(c2, c4, c3, c5_notNull)).nullable === true) - assert(CaseWhen(Seq(c2, c4_notNull, c3, c5)).nullable === true) - } - - test("case key when") { - val row = create_row(null, 1, 2, "a", "b", "c") - val c1 = 'a.int.at(0) - val c2 = 'a.int.at(1) - val c3 = 'a.int.at(2) - val c4 = 'a.string.at(3) - val c5 = 'a.string.at(4) - val c6 = 'a.string.at(5) - - val literalNull = Literal.create(null, BooleanType) - val literalInt = Literal(1) - val literalString = Literal("a") - - checkEvaluation(CaseKeyWhen(c1, Seq(c2, c4, c5)), "b", row) - checkEvaluation(CaseKeyWhen(c1, Seq(c2, c4, literalNull, c5, c6)), "b", row) - checkEvaluation(CaseKeyWhen(c2, Seq(literalInt, c4, c5)), "a", row) - checkEvaluation(CaseKeyWhen(c2, Seq(c1, c4, c5)), "b", row) - checkEvaluation(CaseKeyWhen(c4, Seq(literalString, c2, c3)), 1, row) - checkEvaluation(CaseKeyWhen(c4, Seq(c1, c3, c5, c2, Literal(3))), 3, row) - - checkEvaluation(CaseKeyWhen(literalInt, Seq(c2, c4, c5)), "a", row) - checkEvaluation(CaseKeyWhen(literalString, Seq(c5, c2, c4, c3)), 2, row) - checkEvaluation(CaseKeyWhen(literalInt, Seq(c5, c2, c4, c3)), null, row) - checkEvaluation(CaseKeyWhen(literalNull, Seq(c5, c2, c1, c3)), 2, row) - } - - test("complex type") { - val row = create_row( - "^Ba*n", // 0 - null.asInstanceOf[UTF8String], // 1 - create_row("aa", "bb"), // 2 - Map("aa"->"bb"), // 3 - Seq("aa", "bb") // 4 - ) - - val typeS = StructType( - StructField("a", StringType, true) :: StructField("b", StringType, true) :: Nil - ) - val typeMap = MapType(StringType, StringType) - val typeArray = ArrayType(StringType) - - checkEvaluation(GetMapValue(BoundReference(3, typeMap, true), - Literal("aa")), "bb", row) - checkEvaluation(GetMapValue(Literal.create(null, typeMap), Literal("aa")), null, row) - checkEvaluation( - GetMapValue(Literal.create(null, typeMap), Literal.create(null, StringType)), null, row) - checkEvaluation(GetMapValue(BoundReference(3, typeMap, true), - Literal.create(null, StringType)), null, row) - - checkEvaluation(GetArrayItem(BoundReference(4, typeArray, true), - Literal(1)), "bb", row) - checkEvaluation(GetArrayItem(Literal.create(null, typeArray), Literal(1)), null, row) - checkEvaluation( - GetArrayItem(Literal.create(null, typeArray), Literal.create(null, IntegerType)), null, row) - checkEvaluation(GetArrayItem(BoundReference(4, typeArray, true), - Literal.create(null, IntegerType)), null, row) - - def getStructField(expr: Expression, fieldName: String): ExtractValue = { - expr.dataType match { - case StructType(fields) => - val field = fields.find(_.name == fieldName).get - GetStructField(expr, field, fields.indexOf(field)) - } - } - - def quickResolve(u: UnresolvedExtractValue): ExtractValue = { - ExtractValue(u.child, u.extraction, _ == _) - } - - checkEvaluation(getStructField(BoundReference(2, typeS, nullable = true), "a"), "aa", row) - checkEvaluation(getStructField(Literal.create(null, typeS), "a"), null, row) - - val typeS_notNullable = StructType( - StructField("a", StringType, nullable = false) - :: StructField("b", StringType, nullable = false) :: Nil - ) - - assert(getStructField(BoundReference(2,typeS, nullable = true), "a").nullable === true) - assert(getStructField(BoundReference(2, typeS_notNullable, nullable = false), "a").nullable - === false) - - assert(getStructField(Literal.create(null, typeS), "a").nullable === true) - assert(getStructField(Literal.create(null, typeS_notNullable), "a").nullable === true) - - checkEvaluation(quickResolve('c.map(typeMap).at(3).getItem("aa")), "bb", row) - checkEvaluation(quickResolve('c.array(typeArray.elementType).at(4).getItem(1)), "bb", row) - checkEvaluation(quickResolve('c.struct(typeS).at(2).getField("a")), "aa", row) - } - - test("error message of ExtractValue") { - val structType = StructType(StructField("a", StringType, true) :: Nil) - val arrayStructType = ArrayType(structType) - val arrayType = ArrayType(StringType) - val otherType = StringType - - def checkErrorMessage( - childDataType: DataType, - fieldDataType: DataType, - errorMesage: String): Unit = { - val e = intercept[org.apache.spark.sql.AnalysisException] { - ExtractValue( - Literal.create(null, childDataType), - Literal.create(null, fieldDataType), - _ == _) - } - assert(e.getMessage().contains(errorMesage)) - } - - checkErrorMessage(structType, IntegerType, "Field name should be String Literal") - checkErrorMessage(arrayStructType, BooleanType, "Field name should be String Literal") - checkErrorMessage(arrayType, StringType, "Array index should be integral type") - checkErrorMessage(otherType, StringType, "Can't extract value from") - } - - test("arithmetic") { - val row = create_row(1, 2, 3, null) - val c1 = 'a.int.at(0) - val c2 = 'a.int.at(1) - val c3 = 'a.int.at(2) - val c4 = 'a.int.at(3) - - checkEvaluation(UnaryMinus(c1), -1, row) - checkEvaluation(UnaryMinus(Literal.create(100, IntegerType)), -100) - - checkEvaluation(Add(c1, c4), null, row) - checkEvaluation(Add(c1, c2), 3, row) - checkEvaluation(Add(c1, Literal.create(null, IntegerType)), null, row) - checkEvaluation(Add(Literal.create(null, IntegerType), c2), null, row) - checkEvaluation( - Add(Literal.create(null, IntegerType), Literal.create(null, IntegerType)), null, row) - - checkEvaluation(-c1, -1, row) - checkEvaluation(c1 + c2, 3, row) - checkEvaluation(c1 - c2, -1, row) - checkEvaluation(c1 * c2, 2, row) - checkEvaluation(c1 / c2, 0, row) - checkEvaluation(c1 % c2, 1, row) - } - - test("fractional arithmetic") { - val row = create_row(1.1, 2.0, 3.1, null) - val c1 = 'a.double.at(0) - val c2 = 'a.double.at(1) - val c3 = 'a.double.at(2) - val c4 = 'a.double.at(3) - - checkEvaluation(UnaryMinus(c1), -1.1, row) - checkEvaluation(UnaryMinus(Literal.create(100.0, DoubleType)), -100.0) - checkEvaluation(Add(c1, c4), null, row) - checkEvaluation(Add(c1, c2), 3.1, row) - checkEvaluation(Add(c1, Literal.create(null, DoubleType)), null, row) - checkEvaluation(Add(Literal.create(null, DoubleType), c2), null, row) - checkEvaluation( - Add(Literal.create(null, DoubleType), Literal.create(null, DoubleType)), null, row) - - checkEvaluation(-c1, -1.1, row) - checkEvaluation(c1 + c2, 3.1, row) - checkDoubleEvaluation(c1 - c2, (-0.9 +- 0.001), row) - checkDoubleEvaluation(c1 * c2, (2.2 +- 0.001), row) - checkDoubleEvaluation(c1 / c2, (0.55 +- 0.001), row) - checkDoubleEvaluation(c3 % c2, (1.1 +- 0.001), row) - } - - test("BinaryComparison") { - val row = create_row(1, 2, 3, null, 3, null) - val c1 = 'a.int.at(0) - val c2 = 'a.int.at(1) - val c3 = 'a.int.at(2) - val c4 = 'a.int.at(3) - val c5 = 'a.int.at(4) - val c6 = 'a.int.at(5) - - checkEvaluation(LessThan(c1, c4), null, row) - checkEvaluation(LessThan(c1, c2), true, row) - checkEvaluation(LessThan(c1, Literal.create(null, IntegerType)), null, row) - checkEvaluation(LessThan(Literal.create(null, IntegerType), c2), null, row) - checkEvaluation( - LessThan(Literal.create(null, IntegerType), Literal.create(null, IntegerType)), null, row) - - checkEvaluation(c1 < c2, true, row) - checkEvaluation(c1 <= c2, true, row) - checkEvaluation(c1 > c2, false, row) - checkEvaluation(c1 >= c2, false, row) - checkEvaluation(c1 === c2, false, row) - checkEvaluation(c1 !== c2, true, row) - checkEvaluation(c4 <=> c1, false, row) - checkEvaluation(c1 <=> c4, false, row) - checkEvaluation(c4 <=> c6, true, row) - checkEvaluation(c3 <=> c5, true, row) - checkEvaluation(Literal(true) <=> Literal.create(null, BooleanType), false, row) - checkEvaluation(Literal.create(null, BooleanType) <=> Literal(true), false, row) - } - - test("StringComparison") { - val row = create_row("abc", null) - val c1 = 'a.string.at(0) - val c2 = 'a.string.at(1) - - checkEvaluation(c1 contains "b", true, row) - checkEvaluation(c1 contains "x", false, row) - checkEvaluation(c2 contains "b", null, row) - checkEvaluation(c1 contains Literal.create(null, StringType), null, row) - - checkEvaluation(c1 startsWith "a", true, row) - checkEvaluation(c1 startsWith "b", false, row) - checkEvaluation(c2 startsWith "a", null, row) - checkEvaluation(c1 startsWith Literal.create(null, StringType), null, row) - - checkEvaluation(c1 endsWith "c", true, row) - checkEvaluation(c1 endsWith "b", false, row) - checkEvaluation(c2 endsWith "b", null, row) - checkEvaluation(c1 endsWith Literal.create(null, StringType), null, row) - } - - test("Substring") { - val row = create_row("example", "example".toArray.map(_.toByte)) - - val s = 'a.string.at(0) - - // substring from zero position with less-than-full length - checkEvaluation( - Substring(s, Literal.create(0, IntegerType), Literal.create(2, IntegerType)), "ex", row) - checkEvaluation( - Substring(s, Literal.create(1, IntegerType), Literal.create(2, IntegerType)), "ex", row) - - // substring from zero position with full length - checkEvaluation( - Substring(s, Literal.create(0, IntegerType), Literal.create(7, IntegerType)), "example", row) - checkEvaluation( - Substring(s, Literal.create(1, IntegerType), Literal.create(7, IntegerType)), "example", row) - - // substring from zero position with greater-than-full length - checkEvaluation(Substring(s, Literal.create(0, IntegerType), Literal.create(100, IntegerType)), - "example", row) - checkEvaluation(Substring(s, Literal.create(1, IntegerType), Literal.create(100, IntegerType)), - "example", row) - - // substring from nonzero position with less-than-full length - checkEvaluation(Substring(s, Literal.create(2, IntegerType), Literal.create(2, IntegerType)), - "xa", row) - - // substring from nonzero position with full length - checkEvaluation(Substring(s, Literal.create(2, IntegerType), Literal.create(6, IntegerType)), - "xample", row) - - // substring from nonzero position with greater-than-full length - checkEvaluation(Substring(s, Literal.create(2, IntegerType), Literal.create(100, IntegerType)), - "xample", row) - - // zero-length substring (within string bounds) - checkEvaluation(Substring(s, Literal.create(0, IntegerType), Literal.create(0, IntegerType)), - "", row) - - // zero-length substring (beyond string bounds) - checkEvaluation(Substring(s, Literal.create(100, IntegerType), Literal.create(4, IntegerType)), - "", row) - - // substring(null, _, _) -> null - checkEvaluation(Substring(s, Literal.create(100, IntegerType), Literal.create(4, IntegerType)), - null, create_row(null)) - - // substring(_, null, _) -> null - checkEvaluation(Substring(s, Literal.create(null, IntegerType), Literal.create(4, IntegerType)), - null, row) - - // substring(_, _, null) -> null - checkEvaluation( - Substring(s, Literal.create(100, IntegerType), Literal.create(null, IntegerType)), - null, - row) - - // 2-arg substring from zero position - checkEvaluation( - Substring(s, Literal.create(0, IntegerType), Literal.create(Integer.MAX_VALUE, IntegerType)), - "example", - row) - checkEvaluation( - Substring(s, Literal.create(1, IntegerType), Literal.create(Integer.MAX_VALUE, IntegerType)), - "example", - row) - - // 2-arg substring from nonzero position - checkEvaluation( - Substring(s, Literal.create(2, IntegerType), Literal.create(Integer.MAX_VALUE, IntegerType)), - "xample", - row) - - val s_notNull = 'a.string.notNull.at(0) - - assert(Substring(s, Literal.create(0, IntegerType), Literal.create(2, IntegerType)).nullable - === true) - assert( - Substring(s_notNull, Literal.create(0, IntegerType), Literal.create(2, IntegerType)).nullable - === false) - assert(Substring(s_notNull, - Literal.create(null, IntegerType), Literal.create(2, IntegerType)).nullable === true) - assert(Substring(s_notNull, - Literal.create(0, IntegerType), Literal.create(null, IntegerType)).nullable === true) - - checkEvaluation(s.substr(0, 2), "ex", row) - checkEvaluation(s.substr(0), "example", row) - checkEvaluation(s.substring(0, 2), "ex", row) - checkEvaluation(s.substring(0), "example", row) - } - - test("SQRT") { - val inputSequence = (1 to (1<<24) by 511).map(_ * (1L<<24)) - val expectedResults = inputSequence.map(l => math.sqrt(l.toDouble)) - val rowSequence = inputSequence.map(l => create_row(l.toDouble)) - val d = 'a.double.at(0) - - for ((row, expected) <- rowSequence zip expectedResults) { - checkEvaluation(Sqrt(d), expected, row) - } - - checkEvaluation(Sqrt(Literal.create(null, DoubleType)), null, create_row(null)) - checkEvaluation(Sqrt(-1), null, EmptyRow) - checkEvaluation(Sqrt(-1.5), null, EmptyRow) - } - - test("Bitwise operations") { - val row = create_row(1, 2, 3, null) - val c1 = 'a.int.at(0) - val c2 = 'a.int.at(1) - val c3 = 'a.int.at(2) - val c4 = 'a.int.at(3) - - checkEvaluation(BitwiseAnd(c1, c4), null, row) - checkEvaluation(BitwiseAnd(c1, c2), 0, row) - checkEvaluation(BitwiseAnd(c1, Literal.create(null, IntegerType)), null, row) - checkEvaluation( - BitwiseAnd(Literal.create(null, IntegerType), Literal.create(null, IntegerType)), null, row) - - checkEvaluation(BitwiseOr(c1, c4), null, row) - checkEvaluation(BitwiseOr(c1, c2), 3, row) - checkEvaluation(BitwiseOr(c1, Literal.create(null, IntegerType)), null, row) - checkEvaluation( - BitwiseOr(Literal.create(null, IntegerType), Literal.create(null, IntegerType)), null, row) - - checkEvaluation(BitwiseXor(c1, c4), null, row) - checkEvaluation(BitwiseXor(c1, c2), 3, row) - checkEvaluation(BitwiseXor(c1, Literal.create(null, IntegerType)), null, row) - checkEvaluation( - BitwiseXor(Literal.create(null, IntegerType), Literal.create(null, IntegerType)), null, row) - - checkEvaluation(BitwiseNot(c4), null, row) - checkEvaluation(BitwiseNot(c1), -2, row) - checkEvaluation(BitwiseNot(Literal.create(null, IntegerType)), null, row) - - checkEvaluation(c1 & c2, 0, row) - checkEvaluation(c1 | c2, 3, row) - checkEvaluation(c1 ^ c2, 3, row) - checkEvaluation(~c1, -2, row) - } - - /** - * Used for testing math functions for DataFrames. - * @param c The DataFrame function - * @param f The functions in scala.math - * @param domain The set of values to run the function with - * @param expectNull Whether the given values should return null or not - * @tparam T Generic type for primitives - */ - def unaryMathFunctionEvaluation[@specialized(Int, Double, Float, Long) T]( - c: Expression => Expression, - f: T => T, - domain: Iterable[T] = (-20 to 20).map(_ * 0.1), - expectNull: Boolean = false): Unit = { - if (expectNull) { - domain.foreach { value => - checkEvaluation(c(Literal(value)), null, EmptyRow) - } - } else { - domain.foreach { value => - checkEvaluation(c(Literal(value)), f(value), EmptyRow) - } - } - checkEvaluation(c(Literal.create(null, DoubleType)), null, create_row(null)) - } - - test("sin") { - unaryMathFunctionEvaluation(Sin, math.sin) - } - - test("asin") { - unaryMathFunctionEvaluation(Asin, math.asin, (-10 to 10).map(_ * 0.1)) - unaryMathFunctionEvaluation(Asin, math.asin, (11 to 20).map(_ * 0.1), true) - } - - test("sinh") { - unaryMathFunctionEvaluation(Sinh, math.sinh) - } - - test("cos") { - unaryMathFunctionEvaluation(Cos, math.cos) - } - - test("acos") { - unaryMathFunctionEvaluation(Acos, math.acos, (-10 to 10).map(_ * 0.1)) - unaryMathFunctionEvaluation(Acos, math.acos, (11 to 20).map(_ * 0.1), true) - } - - test("cosh") { - unaryMathFunctionEvaluation(Cosh, math.cosh) - } - - test("tan") { - unaryMathFunctionEvaluation(Tan, math.tan) - } - - test("atan") { - unaryMathFunctionEvaluation(Atan, math.atan) - } - - test("tanh") { - unaryMathFunctionEvaluation(Tanh, math.tanh) - } - - test("toDegrees") { - unaryMathFunctionEvaluation(ToDegrees, math.toDegrees) - } - - test("toRadians") { - unaryMathFunctionEvaluation(ToRadians, math.toRadians) - } - - test("cbrt") { - unaryMathFunctionEvaluation(Cbrt, math.cbrt) - } - - test("ceil") { - unaryMathFunctionEvaluation(Ceil, math.ceil) - } - - test("floor") { - unaryMathFunctionEvaluation(Floor, math.floor) - } - - test("rint") { - unaryMathFunctionEvaluation(Rint, math.rint) - } - - test("exp") { - unaryMathFunctionEvaluation(Exp, math.exp) - } - - test("expm1") { - unaryMathFunctionEvaluation(Expm1, math.expm1) - } - - test("signum") { - unaryMathFunctionEvaluation[Double](Signum, math.signum) - } - - test("log") { - unaryMathFunctionEvaluation(Log, math.log, (0 to 20).map(_ * 0.1)) - unaryMathFunctionEvaluation(Log, math.log, (-5 to -1).map(_ * 0.1), true) - } - - test("log10") { - unaryMathFunctionEvaluation(Log10, math.log10, (0 to 20).map(_ * 0.1)) - unaryMathFunctionEvaluation(Log10, math.log10, (-5 to -1).map(_ * 0.1), true) - } - - test("log1p") { - unaryMathFunctionEvaluation(Log1p, math.log1p, (-1 to 20).map(_ * 0.1)) - unaryMathFunctionEvaluation(Log1p, math.log1p, (-10 to -2).map(_ * 1.0), true) - } - - /** - * Used for testing math functions for DataFrames. - * @param c The DataFrame function - * @param f The functions in scala.math - * @param domain The set of values to run the function with - */ - def binaryMathFunctionEvaluation( - c: (Expression, Expression) => Expression, - f: (Double, Double) => Double, - domain: Iterable[(Double, Double)] = (-20 to 20).map(v => (v * 0.1, v * -0.1)), - expectNull: Boolean = false): Unit = { - if (expectNull) { - domain.foreach { case (v1, v2) => - checkEvaluation(c(v1, v2), null, create_row(null)) - } - } else { - domain.foreach { case (v1, v2) => - checkEvaluation(c(v1, v2), f(v1 + 0.0, v2 + 0.0), EmptyRow) - checkEvaluation(c(v2, v1), f(v2 + 0.0, v1 + 0.0), EmptyRow) - } - } - checkEvaluation(c(Literal.create(null, DoubleType), 1.0), null, create_row(null)) - checkEvaluation(c(1.0, Literal.create(null, DoubleType)), null, create_row(null)) - } - - test("pow") { - binaryMathFunctionEvaluation(Pow, math.pow, (-5 to 5).map(v => (v * 1.0, v * 1.0))) - binaryMathFunctionEvaluation(Pow, math.pow, Seq((-1.0, 0.9), (-2.2, 1.7), (-2.2, -1.7)), true) - } - - test("hypot") { - binaryMathFunctionEvaluation(Hypot, math.hypot) - } - - test("atan2") { - binaryMathFunctionEvaluation(Atan2, math.atan2) - } -} - -// TODO: Make the tests work with codegen. -class ExpressionEvaluationWithoutCodeGenSuite extends ExpressionEvaluationBaseSuite { - - test("CreateStruct") { - val row = Row(1, 2, 3) - val c1 = 'a.int.at(0).as("a") - val c3 = 'c.int.at(2).as("c") - checkEvaluation(CreateStruct(Seq(c1, c3)), Row(1, 3), row) - } -} diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/GeneratedMutableEvaluationSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/GeneratedMutableEvaluationSuite.scala deleted file mode 100644 index 97af2e0fd050..000000000000 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/GeneratedMutableEvaluationSuite.scala +++ /dev/null @@ -1,60 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.sql.catalyst.expressions - -import org.apache.spark.sql.catalyst.CatalystTypeConverters -import org.apache.spark.sql.catalyst.expressions.codegen._ - -/** - * Overrides our expression evaluation tests to use generated code on mutable rows. - */ -class GeneratedMutableEvaluationSuite extends ExpressionEvaluationSuite { - override def checkEvaluation( - expression: Expression, - expected: Any, - inputRow: Row = EmptyRow): Unit = { - lazy val evaluated = GenerateProjection.expressionEvaluator(expression) - - val plan = try { - GenerateProjection.generate(Alias(expression, s"Optimized($expression)")() :: Nil) - } catch { - case e: Throwable => - fail( - s""" - |Code generation of $expression failed: - |${evaluated.code.mkString("\n")} - |$e - """.stripMargin) - } - - val actual = plan(inputRow) - val expectedRow = new GenericRow(Array[Any](CatalystTypeConverters.convertToCatalyst(expected))) - if (actual.hashCode() != expectedRow.hashCode()) { - fail( - s""" - |Mismatched hashCodes for values: $actual, $expectedRow - |Hash Codes: ${actual.hashCode()} != ${expectedRow.hashCode()} - |${evaluated.code.mkString("\n")} - """.stripMargin) - } - if (actual != expectedRow) { - val input = if(inputRow == EmptyRow) "" else s", input: $inputRow" - fail(s"Incorrect Evaluation: $expression, actual: $actual, expected: $expected$input") - } - } -} diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/LiteralExpressionSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/LiteralExpressionSuite.scala new file mode 100644 index 000000000000..d924ff7a102f --- /dev/null +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/LiteralExpressionSuite.scala @@ -0,0 +1,96 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.expressions + +import org.apache.spark.SparkFunSuite +import org.apache.spark.sql.types._ + + +class LiteralExpressionSuite extends SparkFunSuite with ExpressionEvalHelper { + + test("null") { + checkEvaluation(Literal.create(null, BooleanType), null) + checkEvaluation(Literal.create(null, ByteType), null) + checkEvaluation(Literal.create(null, ShortType), null) + checkEvaluation(Literal.create(null, IntegerType), null) + checkEvaluation(Literal.create(null, LongType), null) + checkEvaluation(Literal.create(null, FloatType), null) + checkEvaluation(Literal.create(null, LongType), null) + checkEvaluation(Literal.create(null, StringType), null) + checkEvaluation(Literal.create(null, BinaryType), null) + checkEvaluation(Literal.create(null, DecimalType()), null) + checkEvaluation(Literal.create(null, ArrayType(ByteType, true)), null) + checkEvaluation(Literal.create(null, MapType(StringType, IntegerType)), null) + checkEvaluation(Literal.create(null, StructType(Seq.empty)), null) + } + + test("boolean literals") { + checkEvaluation(Literal(true), true) + checkEvaluation(Literal(false), false) + } + + test("int literals") { + List(0, 1, Int.MinValue, Int.MaxValue).foreach { d => + checkEvaluation(Literal(d), d) + checkEvaluation(Literal(d.toLong), d.toLong) + checkEvaluation(Literal(d.toShort), d.toShort) + checkEvaluation(Literal(d.toByte), d.toByte) + } + checkEvaluation(Literal(Long.MinValue), Long.MinValue) + checkEvaluation(Literal(Long.MaxValue), Long.MaxValue) + } + + test("double literals") { + List(0.0, -0.0, Double.NegativeInfinity, Double.PositiveInfinity).foreach { d => + checkEvaluation(Literal(d), d) + checkEvaluation(Literal(d.toFloat), d.toFloat) + } + checkEvaluation(Literal(Double.MinValue), Double.MinValue) + checkEvaluation(Literal(Double.MaxValue), Double.MaxValue) + checkEvaluation(Literal(Float.MinValue), Float.MinValue) + checkEvaluation(Literal(Float.MaxValue), Float.MaxValue) + + } + + test("string literals") { + checkEvaluation(Literal(""), "") + checkEvaluation(Literal("test"), "test") + checkEvaluation(Literal("\0"), "\0") + } + + test("sum two literals") { + checkEvaluation(Add(Literal(1), Literal(1)), 2) + } + + test("binary literals") { + checkEvaluation(Literal.create(new Array[Byte](0), BinaryType), new Array[Byte](0)) + checkEvaluation(Literal.create(new Array[Byte](2), BinaryType), new Array[Byte](2)) + } + + test("decimal") { + List(0.0, 1.2, 1.1111, 5).foreach { d => + checkEvaluation(Literal(Decimal(d)), Decimal(d)) + checkEvaluation(Literal(Decimal(d.toInt)), Decimal(d.toInt)) + checkEvaluation(Literal(Decimal(d.toLong)), Decimal(d.toLong)) + checkEvaluation(Literal(Decimal((d * 1000L).toLong, 10, 1)), + Decimal((d * 1000L).toLong, 10, 1)) + } + } + + // TODO(davies): add tests for ArrayType, MapType and StructType +} diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/MathFunctionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/MathFunctionsSuite.scala new file mode 100644 index 000000000000..7ca9e30b2bcd --- /dev/null +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/MathFunctionsSuite.scala @@ -0,0 +1,339 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.expressions + +import com.google.common.math.LongMath + +import org.apache.spark.SparkFunSuite +import org.apache.spark.sql.catalyst.dsl.expressions._ +import org.apache.spark.sql.types._ + +class MathFunctionsSuite extends SparkFunSuite with ExpressionEvalHelper { + + /** + * Used for testing leaf math expressions. + * + * @param e expression + * @param c The constants in scala.math + * @tparam T Generic type for primitives + */ + private def testLeaf[T]( + e: () => Expression, + c: T): Unit = { + checkEvaluation(e(), c, EmptyRow) + checkEvaluation(e(), c, create_row(null)) + } + + /** + * Used for testing unary math expressions. + * + * @param c expression + * @param f The functions in scala.math or elsewhere used to generate expected results + * @param domain The set of values to run the function with + * @param expectNull Whether the given values should return null or not + * @tparam T Generic type for primitives + * @tparam U Generic type for the output of the given function `f` + */ + private def testUnary[T, U]( + c: Expression => Expression, + f: T => U, + domain: Iterable[T] = (-20 to 20).map(_ * 0.1), + expectNull: Boolean = false, + evalType: DataType = DoubleType): Unit = { + if (expectNull) { + domain.foreach { value => + checkEvaluation(c(Literal(value)), null, EmptyRow) + } + } else { + domain.foreach { value => + checkEvaluation(c(Literal(value)), f(value), EmptyRow) + } + } + checkEvaluation(c(Literal.create(null, evalType)), null, create_row(null)) + } + + /** + * Used for testing binary math expressions. + * + * @param c The DataFrame function + * @param f The functions in scala.math + * @param domain The set of values to run the function with + */ + private def testBinary( + c: (Expression, Expression) => Expression, + f: (Double, Double) => Double, + domain: Iterable[(Double, Double)] = (-20 to 20).map(v => (v * 0.1, v * -0.1)), + expectNull: Boolean = false): Unit = { + if (expectNull) { + domain.foreach { case (v1, v2) => + checkEvaluation(c(Literal(v1), Literal(v2)), null, create_row(null)) + } + } else { + domain.foreach { case (v1, v2) => + checkEvaluation(c(Literal(v1), Literal(v2)), f(v1 + 0.0, v2 + 0.0), EmptyRow) + checkEvaluation(c(Literal(v2), Literal(v1)), f(v2 + 0.0, v1 + 0.0), EmptyRow) + } + } + checkEvaluation(c(Literal.create(null, DoubleType), Literal(1.0)), null, create_row(null)) + checkEvaluation(c(Literal(1.0), Literal.create(null, DoubleType)), null, create_row(null)) + } + + test("e") { + testLeaf(EulerNumber, math.E) + } + + test("pi") { + testLeaf(Pi, math.Pi) + } + + test("sin") { + testUnary(Sin, math.sin) + } + + test("asin") { + testUnary(Asin, math.asin, (-10 to 10).map(_ * 0.1)) + testUnary(Asin, math.asin, (11 to 20).map(_ * 0.1), expectNull = true) + } + + test("sinh") { + testUnary(Sinh, math.sinh) + } + + test("cos") { + testUnary(Cos, math.cos) + } + + test("acos") { + testUnary(Acos, math.acos, (-10 to 10).map(_ * 0.1)) + testUnary(Acos, math.acos, (11 to 20).map(_ * 0.1), expectNull = true) + } + + test("cosh") { + testUnary(Cosh, math.cosh) + } + + test("tan") { + testUnary(Tan, math.tan) + } + + test("atan") { + testUnary(Atan, math.atan) + } + + test("tanh") { + testUnary(Tanh, math.tanh) + } + + test("toDegrees") { + testUnary(ToDegrees, math.toDegrees) + } + + test("toRadians") { + testUnary(ToRadians, math.toRadians) + } + + test("cbrt") { + testUnary(Cbrt, math.cbrt) + } + + test("ceil") { + testUnary(Ceil, math.ceil) + } + + test("floor") { + testUnary(Floor, math.floor) + } + + test("factorial") { + (0 to 20).foreach { value => + checkEvaluation(Factorial(Literal(value)), LongMath.factorial(value), EmptyRow) + } + checkEvaluation(Literal.create(null, IntegerType), null, create_row(null)) + checkEvaluation(Factorial(Literal(20)), 2432902008176640000L, EmptyRow) + checkEvaluation(Factorial(Literal(21)), null, EmptyRow) + } + + test("rint") { + testUnary(Rint, math.rint) + } + + test("exp") { + testUnary(Exp, math.exp) + } + + test("expm1") { + testUnary(Expm1, math.expm1) + } + + test("signum") { + testUnary[Double, Double](Signum, math.signum) + } + + test("log") { + testUnary(Log, math.log, (0 to 20).map(_ * 0.1)) + testUnary(Log, math.log, (-5 to -1).map(_ * 0.1), expectNull = true) + } + + test("log10") { + testUnary(Log10, math.log10, (0 to 20).map(_ * 0.1)) + testUnary(Log10, math.log10, (-5 to -1).map(_ * 0.1), expectNull = true) + } + + test("log1p") { + testUnary(Log1p, math.log1p, (-1 to 20).map(_ * 0.1)) + testUnary(Log1p, math.log1p, (-10 to -2).map(_ * 1.0), expectNull = true) + } + + test("bin") { + testUnary(Bin, java.lang.Long.toBinaryString, (-20 to 20).map(_.toLong), evalType = LongType) + + val row = create_row(null, 12L, 123L, 1234L, -123L) + val l1 = 'a.long.at(0) + val l2 = 'a.long.at(1) + val l3 = 'a.long.at(2) + val l4 = 'a.long.at(3) + val l5 = 'a.long.at(4) + + checkEvaluation(Bin(l1), null, row) + checkEvaluation(Bin(l2), java.lang.Long.toBinaryString(12), row) + checkEvaluation(Bin(l3), java.lang.Long.toBinaryString(123), row) + checkEvaluation(Bin(l4), java.lang.Long.toBinaryString(1234), row) + checkEvaluation(Bin(l5), java.lang.Long.toBinaryString(-123), row) + } + + test("log2") { + def f: (Double) => Double = (x: Double) => math.log(x) / math.log(2) + testUnary(Log2, f, (0 to 20).map(_ * 0.1)) + testUnary(Log2, f, (-5 to -1).map(_ * 1.0), expectNull = true) + } + + test("sqrt") { + testUnary(Sqrt, math.sqrt, (0 to 20).map(_ * 0.1)) + testUnary(Sqrt, math.sqrt, (-5 to -1).map(_ * 1.0), expectNull = true) + + checkEvaluation(Sqrt(Literal.create(null, DoubleType)), null, create_row(null)) + checkEvaluation(Sqrt(Literal(-1.0)), null, EmptyRow) + checkEvaluation(Sqrt(Literal(-1.5)), null, EmptyRow) + } + + test("pow") { + testBinary(Pow, math.pow, (-5 to 5).map(v => (v * 1.0, v * 1.0))) + testBinary(Pow, math.pow, Seq((-1.0, 0.9), (-2.2, 1.7), (-2.2, -1.7)), expectNull = true) + } + + test("shift left") { + checkEvaluation(ShiftLeft(Literal.create(null, IntegerType), Literal(1)), null) + checkEvaluation(ShiftLeft(Literal(21), Literal.create(null, IntegerType)), null) + checkEvaluation( + ShiftLeft(Literal.create(null, IntegerType), Literal.create(null, IntegerType)), null) + checkEvaluation(ShiftLeft(Literal(21), Literal(1)), 42) + + checkEvaluation(ShiftLeft(Literal(21.toLong), Literal(1)), 42.toLong) + checkEvaluation(ShiftLeft(Literal(-21.toLong), Literal(1)), -42.toLong) + } + + test("shift right") { + checkEvaluation(ShiftRight(Literal.create(null, IntegerType), Literal(1)), null) + checkEvaluation(ShiftRight(Literal(42), Literal.create(null, IntegerType)), null) + checkEvaluation( + ShiftRight(Literal.create(null, IntegerType), Literal.create(null, IntegerType)), null) + checkEvaluation(ShiftRight(Literal(42), Literal(1)), 21) + + checkEvaluation(ShiftRight(Literal(42.toLong), Literal(1)), 21.toLong) + checkEvaluation(ShiftRight(Literal(-42.toLong), Literal(1)), -21.toLong) + } + + test("shift right unsigned") { + checkEvaluation(ShiftRightUnsigned(Literal.create(null, IntegerType), Literal(1)), null) + checkEvaluation(ShiftRightUnsigned(Literal(42), Literal.create(null, IntegerType)), null) + checkEvaluation( + ShiftRight(Literal.create(null, IntegerType), Literal.create(null, IntegerType)), null) + checkEvaluation(ShiftRightUnsigned(Literal(42), Literal(1)), 21) + + checkEvaluation(ShiftRightUnsigned(Literal(42.toLong), Literal(1)), 21.toLong) + checkEvaluation(ShiftRightUnsigned(Literal(-42.toLong), Literal(1)), 9223372036854775787L) + } + + test("hex") { + checkEvaluation(Hex(Literal.create(null, LongType)), null) + checkEvaluation(Hex(Literal(28L)), "1C") + checkEvaluation(Hex(Literal(-28L)), "FFFFFFFFFFFFFFE4") + checkEvaluation(Hex(Literal(100800200404L)), "177828FED4") + checkEvaluation(Hex(Literal(-100800200404L)), "FFFFFFE887D7012C") + checkEvaluation(Hex(Literal.create(null, BinaryType)), null) + checkEvaluation(Hex(Literal("helloHex".getBytes())), "68656C6C6F486578") + // scalastyle:off + // Turn off scala style for non-ascii chars + checkEvaluation(Hex(Literal("三重的".getBytes("UTF8"))), "E4B889E9878DE79A84") + // scalastyle:on + } + + test("unhex") { + checkEvaluation(Unhex(Literal.create(null, StringType)), null) + checkEvaluation(Unhex(Literal("737472696E67")), "string".getBytes) + checkEvaluation(Unhex(Literal("")), new Array[Byte](0)) + checkEvaluation(Unhex(Literal("F")), Array[Byte](15)) + checkEvaluation(Unhex(Literal("ff")), Array[Byte](-1)) + checkEvaluation(Unhex(Literal("GG")), null) + // scalastyle:off + // Turn off scala style for non-ascii chars + checkEvaluation(Unhex(Literal("E4B889E9878DE79A84")), "三重的".getBytes("UTF-8")) + checkEvaluation(Unhex(Literal("三重的")), null) + + // scalastyle:on + } + + test("hypot") { + testBinary(Hypot, math.hypot) + } + + test("atan2") { + testBinary(Atan2, math.atan2) + } + + test("binary log") { + val f = (c1: Double, c2: Double) => math.log(c2) / math.log(c1) + val domain = (1 to 20).map(v => (v * 0.1, v * 0.2)) + + domain.foreach { case (v1, v2) => + checkEvaluation(Logarithm(Literal(v1), Literal(v2)), f(v1 + 0.0, v2 + 0.0), EmptyRow) + checkEvaluation(Logarithm(Literal(v2), Literal(v1)), f(v2 + 0.0, v1 + 0.0), EmptyRow) + checkEvaluation(new Logarithm(Literal(v1)), f(math.E, v1 + 0.0), EmptyRow) + } + + // null input should yield null output + checkEvaluation( + Logarithm(Literal.create(null, DoubleType), Literal(1.0)), + null, + create_row(null)) + checkEvaluation( + Logarithm(Literal(1.0), Literal.create(null, DoubleType)), + null, + create_row(null)) + + // negative input should yield null output + checkEvaluation( + Logarithm(Literal(-1.0), Literal(1.0)), + null, + create_row(null)) + checkEvaluation( + Logarithm(Literal(1.0), Literal(-1.0)), + null, + create_row(null)) + } +} diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/MiscFunctionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/MiscFunctionsSuite.scala new file mode 100644 index 000000000000..b524d0af14a6 --- /dev/null +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/MiscFunctionsSuite.scala @@ -0,0 +1,60 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.expressions + +import org.apache.commons.codec.digest.DigestUtils + +import org.apache.spark.SparkFunSuite +import org.apache.spark.sql.types.{IntegerType, StringType, BinaryType} + +class MiscFunctionsSuite extends SparkFunSuite with ExpressionEvalHelper { + + test("md5") { + checkEvaluation(Md5(Literal("ABC".getBytes)), "902fbdd2b1df0c4f70b4a5d23525e932") + checkEvaluation(Md5(Literal.create(Array[Byte](1, 2, 3, 4, 5, 6), BinaryType)), + "6ac1e56bc78f031059be7be854522c4c") + checkEvaluation(Md5(Literal.create(null, BinaryType)), null) + } + + test("sha1") { + checkEvaluation(Sha1(Literal("ABC".getBytes)), "3c01bdbb26f358bab27f267924aa2c9a03fcfdb8") + checkEvaluation(Sha1(Literal.create(Array[Byte](1, 2, 3, 4, 5, 6), BinaryType)), + "5d211bad8f4ee70e16c7d343a838fc344a1ed961") + checkEvaluation(Sha1(Literal.create(null, BinaryType)), null) + checkEvaluation(Sha1(Literal("".getBytes)), "da39a3ee5e6b4b0d3255bfef95601890afd80709") + } + + test("sha2") { + checkEvaluation(Sha2(Literal("ABC".getBytes), Literal(256)), DigestUtils.sha256Hex("ABC")) + checkEvaluation(Sha2(Literal.create(Array[Byte](1, 2, 3, 4, 5, 6), BinaryType), Literal(384)), + DigestUtils.sha384Hex(Array[Byte](1, 2, 3, 4, 5, 6))) + // unsupported bit length + checkEvaluation(Sha2(Literal.create(null, BinaryType), Literal(1024)), null) + checkEvaluation(Sha2(Literal.create(null, BinaryType), Literal(512)), null) + checkEvaluation(Sha2(Literal("ABC".getBytes), Literal.create(null, IntegerType)), null) + checkEvaluation(Sha2(Literal.create(null, BinaryType), Literal.create(null, IntegerType)), null) + } + + test("crc32") { + checkEvaluation(Crc32(Literal("ABC".getBytes)), 2743272264L) + checkEvaluation(Crc32(Literal.create(Array[Byte](1, 2, 3, 4, 5, 6), BinaryType)), + 2180413220L) + checkEvaluation(Crc32(Literal.create(null, BinaryType)), null) + } + +} diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/NullFunctionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/NullFunctionsSuite.scala new file mode 100644 index 000000000000..ccdada8b56f8 --- /dev/null +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/NullFunctionsSuite.scala @@ -0,0 +1,65 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.expressions + +import org.apache.spark.SparkFunSuite +import org.apache.spark.sql.catalyst.dsl.expressions._ +import org.apache.spark.sql.types.{BooleanType, StringType, ShortType} + +class NullFunctionsSuite extends SparkFunSuite with ExpressionEvalHelper { + + test("null checking") { + val row = create_row("^Ba*n", null, true, null) + val c1 = 'a.string.at(0) + val c2 = 'a.string.at(1) + val c3 = 'a.boolean.at(2) + val c4 = 'a.boolean.at(3) + + checkEvaluation(c1.isNull, false, row) + checkEvaluation(c1.isNotNull, true, row) + + checkEvaluation(c2.isNull, true, row) + checkEvaluation(c2.isNotNull, false, row) + + checkEvaluation(Literal.create(1, ShortType).isNull, false) + checkEvaluation(Literal.create(1, ShortType).isNotNull, true) + + checkEvaluation(Literal.create(null, ShortType).isNull, true) + checkEvaluation(Literal.create(null, ShortType).isNotNull, false) + + checkEvaluation(Coalesce(c1 :: c2 :: Nil), "^Ba*n", row) + checkEvaluation(Coalesce(Literal.create(null, StringType) :: Nil), null, row) + checkEvaluation(Coalesce(Literal.create(null, StringType) :: c1 :: c2 :: Nil), "^Ba*n", row) + + checkEvaluation( + If(c3, Literal.create("a", StringType), Literal.create("b", StringType)), "a", row) + checkEvaluation(If(c3, c1, c2), "^Ba*n", row) + checkEvaluation(If(c4, c2, c1), "^Ba*n", row) + checkEvaluation(If(Literal.create(null, BooleanType), c2, c1), "^Ba*n", row) + checkEvaluation(If(Literal.create(true, BooleanType), c1, c2), "^Ba*n", row) + checkEvaluation(If(Literal.create(false, BooleanType), c2, c1), "^Ba*n", row) + checkEvaluation(If(Literal.create(false, BooleanType), + Literal.create("a", StringType), Literal.create("b", StringType)), "b", row) + + checkEvaluation(c1 in (c1, c2), true, row) + checkEvaluation( + Literal.create("^Ba*n", StringType) in (Literal.create("^Ba*n", StringType)), true, row) + checkEvaluation( + Literal.create("^Ba*n", StringType) in (Literal.create("^Ba*n", StringType), c2), true, row) + } +} diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/PredicateSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/PredicateSuite.scala new file mode 100644 index 000000000000..188ecef9e767 --- /dev/null +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/PredicateSuite.scala @@ -0,0 +1,209 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.expressions + +import scala.collection.immutable.HashSet + +import org.apache.spark.SparkFunSuite +import org.apache.spark.sql.catalyst.dsl.expressions._ +import org.apache.spark.sql.types.{Decimal, IntegerType, BooleanType} + + +class PredicateSuite extends SparkFunSuite with ExpressionEvalHelper { + + private def booleanLogicTest( + name: String, + op: (Expression, Expression) => Expression, + truthTable: Seq[(Any, Any, Any)]) { + test(s"3VL $name") { + truthTable.foreach { + case (l, r, answer) => + val expr = op(Literal.create(l, BooleanType), Literal.create(r, BooleanType)) + checkEvaluation(expr, answer) + } + } + } + + // scalastyle:off + /** + * Checks for three-valued-logic. Based on: + * http://en.wikipedia.org/wiki/Null_(SQL)#Comparisons_with_NULL_and_the_three-valued_logic_.283VL.29 + * I.e. in flat cpo "False -> Unknown -> True", + * OR is lowest upper bound, + * AND is greatest lower bound. + * p q p OR q p AND q p = q + * True True True True True + * True False True False False + * True Unknown True Unknown Unknown + * False True True False False + * False False False False True + * False Unknown Unknown False Unknown + * Unknown True True Unknown Unknown + * Unknown False Unknown False Unknown + * Unknown Unknown Unknown Unknown Unknown + * + * p NOT p + * True False + * False True + * Unknown Unknown + */ + // scalastyle:on + + test("3VL Not") { + val notTrueTable = + (true, false) :: + (false, true) :: + (null, null) :: Nil + notTrueTable.foreach { case (v, answer) => + checkEvaluation(Not(Literal.create(v, BooleanType)), answer) + } + } + + booleanLogicTest("AND", And, + (true, true, true) :: + (true, false, false) :: + (true, null, null) :: + (false, true, false) :: + (false, false, false) :: + (false, null, false) :: + (null, true, null) :: + (null, false, false) :: + (null, null, null) :: Nil) + + booleanLogicTest("OR", Or, + (true, true, true) :: + (true, false, true) :: + (true, null, true) :: + (false, true, true) :: + (false, false, false) :: + (false, null, null) :: + (null, true, true) :: + (null, false, null) :: + (null, null, null) :: Nil) + + booleanLogicTest("=", EqualTo, + (true, true, true) :: + (true, false, false) :: + (true, null, null) :: + (false, true, false) :: + (false, false, true) :: + (false, null, null) :: + (null, true, null) :: + (null, false, null) :: + (null, null, null) :: Nil) + + test("IN") { + checkEvaluation(In(Literal(1), Seq(Literal(1), Literal(2))), true) + checkEvaluation(In(Literal(2), Seq(Literal(1), Literal(2))), true) + checkEvaluation(In(Literal(3), Seq(Literal(1), Literal(2))), false) + checkEvaluation( + And(In(Literal(1), Seq(Literal(1), Literal(2))), In(Literal(2), Seq(Literal(1), Literal(2)))), + true) + } + + test("INSET") { + val hS = HashSet[Any]() + 1 + 2 + val nS = HashSet[Any]() + 1 + 2 + null + val one = Literal(1) + val two = Literal(2) + val three = Literal(3) + val nl = Literal(null) + checkEvaluation(InSet(one, hS), true) + checkEvaluation(InSet(two, hS), true) + checkEvaluation(InSet(two, nS), true) + checkEvaluation(InSet(nl, nS), true) + checkEvaluation(InSet(three, hS), false) + checkEvaluation(InSet(three, nS), false) + checkEvaluation(And(InSet(one, hS), InSet(two, hS)), true) + } + + private val smallValues = Seq(1, Decimal(1), Array(1.toByte), "a").map(Literal(_)) + private val largeValues = Seq(2, Decimal(2), Array(2.toByte), "b").map(Literal(_)) + + private val equalValues1 = smallValues + private val equalValues2 = Seq(1, Decimal(1), Array(1.toByte), "a").map(Literal(_)) + + test("BinaryComparison: <") { + for (i <- 0 until smallValues.length) { + checkEvaluation(smallValues(i) < largeValues(i), true) + checkEvaluation(equalValues1(i) < equalValues2(i), false) + checkEvaluation(largeValues(i) < smallValues(i), false) + } + } + + test("BinaryComparison: <=") { + for (i <- 0 until smallValues.length) { + checkEvaluation(smallValues(i) <= largeValues(i), true) + checkEvaluation(equalValues1(i) <= equalValues2(i), true) + checkEvaluation(largeValues(i) <= smallValues(i), false) + } + } + + test("BinaryComparison: >") { + for (i <- 0 until smallValues.length) { + checkEvaluation(smallValues(i) > largeValues(i), false) + checkEvaluation(equalValues1(i) > equalValues2(i), false) + checkEvaluation(largeValues(i) > smallValues(i), true) + } + } + + test("BinaryComparison: >=") { + for (i <- 0 until smallValues.length) { + checkEvaluation(smallValues(i) >= largeValues(i), false) + checkEvaluation(equalValues1(i) >= equalValues2(i), true) + checkEvaluation(largeValues(i) >= smallValues(i), true) + } + } + + test("BinaryComparison: ===") { + for (i <- 0 until smallValues.length) { + checkEvaluation(smallValues(i) === largeValues(i), false) + checkEvaluation(equalValues1(i) === equalValues2(i), true) + checkEvaluation(largeValues(i) === smallValues(i), false) + } + } + + test("BinaryComparison: <=>") { + for (i <- 0 until smallValues.length) { + checkEvaluation(smallValues(i) <=> largeValues(i), false) + checkEvaluation(equalValues1(i) <=> equalValues2(i), true) + checkEvaluation(largeValues(i) <=> smallValues(i), false) + } + } + + test("BinaryComparison: null test") { + val normalInt = Literal(1) + val nullInt = Literal.create(null, IntegerType) + + def nullTest(op: (Expression, Expression) => Expression): Unit = { + checkEvaluation(op(normalInt, nullInt), null) + checkEvaluation(op(nullInt, normalInt), null) + checkEvaluation(op(nullInt, nullInt), null) + } + + nullTest(LessThan) + nullTest(LessThanOrEqual) + nullTest(GreaterThan) + nullTest(GreaterThanOrEqual) + nullTest(EqualTo) + + checkEvaluation(normalInt <=> nullInt, false) + checkEvaluation(nullInt <=> normalInt, false) + checkEvaluation(nullInt <=> nullInt, true) + } +} diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/RandomSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/RandomSuite.scala new file mode 100644 index 000000000000..9be2b23a53f2 --- /dev/null +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/RandomSuite.scala @@ -0,0 +1,33 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.expressions + +import org.scalatest.Matchers._ + +import org.apache.spark.SparkFunSuite +import org.apache.spark.sql.catalyst.dsl.expressions._ +import org.apache.spark.sql.types.{DoubleType, IntegerType} + + +class RandomSuite extends SparkFunSuite with ExpressionEvalHelper { + + test("random") { + val row = create_row(1.1, 2.0, 3.1, null) + checkDoubleEvaluation(Rand(30), (0.7363714192755834 +- 0.001), row) + } +} diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/StringFunctionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/StringFunctionsSuite.scala new file mode 100644 index 000000000000..1efbe1a245e8 --- /dev/null +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/StringFunctionsSuite.scala @@ -0,0 +1,286 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.expressions + +import org.apache.spark.SparkFunSuite +import org.apache.spark.sql.catalyst.dsl.expressions._ +import org.apache.spark.sql.types.{BinaryType, IntegerType, StringType} + + +class StringFunctionsSuite extends SparkFunSuite with ExpressionEvalHelper { + + test("StringComparison") { + val row = create_row("abc", null) + val c1 = 'a.string.at(0) + val c2 = 'a.string.at(1) + + checkEvaluation(c1 contains "b", true, row) + checkEvaluation(c1 contains "x", false, row) + checkEvaluation(c2 contains "b", null, row) + checkEvaluation(c1 contains Literal.create(null, StringType), null, row) + + checkEvaluation(c1 startsWith "a", true, row) + checkEvaluation(c1 startsWith "b", false, row) + checkEvaluation(c2 startsWith "a", null, row) + checkEvaluation(c1 startsWith Literal.create(null, StringType), null, row) + + checkEvaluation(c1 endsWith "c", true, row) + checkEvaluation(c1 endsWith "b", false, row) + checkEvaluation(c2 endsWith "b", null, row) + checkEvaluation(c1 endsWith Literal.create(null, StringType), null, row) + } + + test("Substring") { + val row = create_row("example", "example".toArray.map(_.toByte)) + + val s = 'a.string.at(0) + + // substring from zero position with less-than-full length + checkEvaluation( + Substring(s, Literal.create(0, IntegerType), Literal.create(2, IntegerType)), "ex", row) + checkEvaluation( + Substring(s, Literal.create(1, IntegerType), Literal.create(2, IntegerType)), "ex", row) + + // substring from zero position with full length + checkEvaluation( + Substring(s, Literal.create(0, IntegerType), Literal.create(7, IntegerType)), "example", row) + checkEvaluation( + Substring(s, Literal.create(1, IntegerType), Literal.create(7, IntegerType)), "example", row) + + // substring from zero position with greater-than-full length + checkEvaluation(Substring(s, Literal.create(0, IntegerType), Literal.create(100, IntegerType)), + "example", row) + checkEvaluation(Substring(s, Literal.create(1, IntegerType), Literal.create(100, IntegerType)), + "example", row) + + // substring from nonzero position with less-than-full length + checkEvaluation(Substring(s, Literal.create(2, IntegerType), Literal.create(2, IntegerType)), + "xa", row) + + // substring from nonzero position with full length + checkEvaluation(Substring(s, Literal.create(2, IntegerType), Literal.create(6, IntegerType)), + "xample", row) + + // substring from nonzero position with greater-than-full length + checkEvaluation(Substring(s, Literal.create(2, IntegerType), Literal.create(100, IntegerType)), + "xample", row) + + // zero-length substring (within string bounds) + checkEvaluation(Substring(s, Literal.create(0, IntegerType), Literal.create(0, IntegerType)), + "", row) + + // zero-length substring (beyond string bounds) + checkEvaluation(Substring(s, Literal.create(100, IntegerType), Literal.create(4, IntegerType)), + "", row) + + // substring(null, _, _) -> null + checkEvaluation(Substring(s, Literal.create(100, IntegerType), Literal.create(4, IntegerType)), + null, create_row(null)) + + // substring(_, null, _) -> null + checkEvaluation(Substring(s, Literal.create(null, IntegerType), Literal.create(4, IntegerType)), + null, row) + + // substring(_, _, null) -> null + checkEvaluation( + Substring(s, Literal.create(100, IntegerType), Literal.create(null, IntegerType)), + null, + row) + + // 2-arg substring from zero position + checkEvaluation( + Substring(s, Literal.create(0, IntegerType), Literal.create(Integer.MAX_VALUE, IntegerType)), + "example", + row) + checkEvaluation( + Substring(s, Literal.create(1, IntegerType), Literal.create(Integer.MAX_VALUE, IntegerType)), + "example", + row) + + // 2-arg substring from nonzero position + checkEvaluation( + Substring(s, Literal.create(2, IntegerType), Literal.create(Integer.MAX_VALUE, IntegerType)), + "xample", + row) + + val s_notNull = 'a.string.notNull.at(0) + + assert(Substring(s, Literal.create(0, IntegerType), Literal.create(2, IntegerType)).nullable + === true) + assert( + Substring(s_notNull, Literal.create(0, IntegerType), Literal.create(2, IntegerType)).nullable + === false) + assert(Substring(s_notNull, + Literal.create(null, IntegerType), Literal.create(2, IntegerType)).nullable === true) + assert(Substring(s_notNull, + Literal.create(0, IntegerType), Literal.create(null, IntegerType)).nullable === true) + + checkEvaluation(s.substr(0, 2), "ex", row) + checkEvaluation(s.substr(0), "example", row) + checkEvaluation(s.substring(0, 2), "ex", row) + checkEvaluation(s.substring(0), "example", row) + } + + test("LIKE literal Regular Expression") { + checkEvaluation(Literal.create(null, StringType).like("a"), null) + checkEvaluation(Literal.create("a", StringType).like(Literal.create(null, StringType)), null) + checkEvaluation(Literal.create(null, StringType).like(Literal.create(null, StringType)), null) + checkEvaluation("abdef" like "abdef", true) + checkEvaluation("a_%b" like "a\\__b", true) + checkEvaluation("addb" like "a_%b", true) + checkEvaluation("addb" like "a\\__b", false) + checkEvaluation("addb" like "a%\\%b", false) + checkEvaluation("a_%b" like "a%\\%b", true) + checkEvaluation("addb" like "a%", true) + checkEvaluation("addb" like "**", false) + checkEvaluation("abc" like "a%", true) + checkEvaluation("abc" like "b%", false) + checkEvaluation("abc" like "bc%", false) + checkEvaluation("a\nb" like "a_b", true) + checkEvaluation("ab" like "a%b", true) + checkEvaluation("a\nb" like "a%b", true) + } + + test("LIKE Non-literal Regular Expression") { + val regEx = 'a.string.at(0) + checkEvaluation("abcd" like regEx, null, create_row(null)) + checkEvaluation("abdef" like regEx, true, create_row("abdef")) + checkEvaluation("a_%b" like regEx, true, create_row("a\\__b")) + checkEvaluation("addb" like regEx, true, create_row("a_%b")) + checkEvaluation("addb" like regEx, false, create_row("a\\__b")) + checkEvaluation("addb" like regEx, false, create_row("a%\\%b")) + checkEvaluation("a_%b" like regEx, true, create_row("a%\\%b")) + checkEvaluation("addb" like regEx, true, create_row("a%")) + checkEvaluation("addb" like regEx, false, create_row("**")) + checkEvaluation("abc" like regEx, true, create_row("a%")) + checkEvaluation("abc" like regEx, false, create_row("b%")) + checkEvaluation("abc" like regEx, false, create_row("bc%")) + checkEvaluation("a\nb" like regEx, true, create_row("a_b")) + checkEvaluation("ab" like regEx, true, create_row("a%b")) + checkEvaluation("a\nb" like regEx, true, create_row("a%b")) + + checkEvaluation(Literal.create(null, StringType) like regEx, null, create_row("bc%")) + } + + test("RLIKE literal Regular Expression") { + checkEvaluation(Literal.create(null, StringType) rlike "abdef", null) + checkEvaluation("abdef" rlike Literal.create(null, StringType), null) + checkEvaluation(Literal.create(null, StringType) rlike Literal.create(null, StringType), null) + checkEvaluation("abdef" rlike "abdef", true) + checkEvaluation("abbbbc" rlike "a.*c", true) + + checkEvaluation("fofo" rlike "^fo", true) + checkEvaluation("fo\no" rlike "^fo\no$", true) + checkEvaluation("Bn" rlike "^Ba*n", true) + checkEvaluation("afofo" rlike "fo", true) + checkEvaluation("afofo" rlike "^fo", false) + checkEvaluation("Baan" rlike "^Ba?n", false) + checkEvaluation("axe" rlike "pi|apa", false) + checkEvaluation("pip" rlike "^(pi)*$", false) + + checkEvaluation("abc" rlike "^ab", true) + checkEvaluation("abc" rlike "^bc", false) + checkEvaluation("abc" rlike "^ab", true) + checkEvaluation("abc" rlike "^bc", false) + + intercept[java.util.regex.PatternSyntaxException] { + evaluate("abbbbc" rlike "**") + } + } + + test("RLIKE Non-literal Regular Expression") { + val regEx = 'a.string.at(0) + checkEvaluation("abdef" rlike regEx, true, create_row("abdef")) + checkEvaluation("abbbbc" rlike regEx, true, create_row("a.*c")) + checkEvaluation("fofo" rlike regEx, true, create_row("^fo")) + checkEvaluation("fo\no" rlike regEx, true, create_row("^fo\no$")) + checkEvaluation("Bn" rlike regEx, true, create_row("^Ba*n")) + + intercept[java.util.regex.PatternSyntaxException] { + evaluate("abbbbc" rlike regEx, create_row("**")) + } + } + + test("length for string") { + val a = 'a.string.at(0) + checkEvaluation(StringLength(Literal("abc")), 3, create_row("abdef")) + checkEvaluation(StringLength(a), 5, create_row("abdef")) + checkEvaluation(StringLength(a), 0, create_row("")) + checkEvaluation(StringLength(a), null, create_row(null)) + checkEvaluation(StringLength(Literal.create(null, StringType)), null, create_row("abdef")) + } + + test("ascii for string") { + val a = 'a.string.at(0) + checkEvaluation(Ascii(Literal("efg")), 101, create_row("abdef")) + checkEvaluation(Ascii(a), 97, create_row("abdef")) + checkEvaluation(Ascii(a), 0, create_row("")) + checkEvaluation(Ascii(a), null, create_row(null)) + checkEvaluation(Ascii(Literal.create(null, StringType)), null, create_row("abdef")) + } + + test("base64/unbase64 for string") { + val a = 'a.string.at(0) + val b = 'b.binary.at(0) + val bytes = Array[Byte](1, 2, 3, 4) + + checkEvaluation(Base64(Literal(bytes)), "AQIDBA==", create_row("abdef")) + checkEvaluation(Base64(UnBase64(Literal("AQIDBA=="))), "AQIDBA==", create_row("abdef")) + checkEvaluation(Base64(UnBase64(Literal(""))), "", create_row("abdef")) + checkEvaluation(Base64(UnBase64(Literal.create(null, StringType))), null, create_row("abdef")) + checkEvaluation(Base64(UnBase64(a)), "AQIDBA==", create_row("AQIDBA==")) + + checkEvaluation(Base64(b), "AQIDBA==", create_row(bytes)) + checkEvaluation(Base64(b), "", create_row(Array[Byte]())) + checkEvaluation(Base64(b), null, create_row(null)) + checkEvaluation(Base64(Literal.create(null, StringType)), null, create_row("abdef")) + + checkEvaluation(UnBase64(a), null, create_row(null)) + checkEvaluation(UnBase64(Literal.create(null, StringType)), null, create_row("abdef")) + } + + test("encode/decode for string") { + val a = 'a.string.at(0) + val b = 'b.binary.at(0) + // scalastyle:off + // non ascii characters are not allowed in the code, so we disable the scalastyle here. + checkEvaluation( + Decode(Encode(Literal("大千世界"), Literal("UTF-16LE")), Literal("UTF-16LE")), "大千世界") + checkEvaluation( + Decode(Encode(a, Literal("utf-8")), Literal("utf-8")), "大千世界", create_row("大千世界")) + checkEvaluation( + Decode(Encode(a, Literal("utf-8")), Literal("utf-8")), "", create_row("")) + // scalastyle:on + checkEvaluation(Encode(a, Literal("utf-8")), null, create_row(null)) + checkEvaluation(Encode(Literal.create(null, StringType), Literal("utf-8")), null) + checkEvaluation(Encode(a, Literal.create(null, StringType)), null, create_row("")) + + checkEvaluation(Decode(b, Literal("utf-8")), null, create_row(null)) + checkEvaluation(Decode(Literal.create(null, BinaryType), Literal("utf-8")), null) + checkEvaluation(Decode(b, Literal.create(null, StringType)), null, create_row(null)) + } + + test("Levenshtein distance") { + checkEvaluation(Levenshtein(Literal.create(null, StringType), Literal("")), null) + checkEvaluation(Levenshtein(Literal(""), Literal.create(null, StringType)), null) + checkEvaluation(Levenshtein(Literal(""), Literal("")), 0) + checkEvaluation(Levenshtein(Literal("abc"), Literal("abc")), 0) + checkEvaluation(Levenshtein(Literal("kitten"), Literal("sitting")), 3) + checkEvaluation(Levenshtein(Literal("frog"), Literal("fog")), 1) + } +} diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/UnsafeFixedWidthAggregationMapSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/UnsafeFixedWidthAggregationMapSuite.scala index 7a19e511eb8b..6fafc2f86684 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/UnsafeFixedWidthAggregationMapSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/UnsafeFixedWidthAggregationMapSuite.scala @@ -20,18 +20,25 @@ package org.apache.spark.sql.catalyst.expressions import scala.collection.JavaConverters._ import scala.util.Random -import org.apache.spark.unsafe.memory.{ExecutorMemoryManager, TaskMemoryManager, MemoryAllocator} -import org.scalatest.{BeforeAndAfterEach, FunSuite, Matchers} +import org.scalatest.{BeforeAndAfterEach, Matchers} +import org.apache.spark.SparkFunSuite +import org.apache.spark.sql.catalyst.expressions.codegen.GenerateProjection import org.apache.spark.sql.types._ +import org.apache.spark.unsafe.memory.{ExecutorMemoryManager, MemoryAllocator, TaskMemoryManager} +import org.apache.spark.unsafe.types.UTF8String -class UnsafeFixedWidthAggregationMapSuite extends FunSuite with Matchers with BeforeAndAfterEach { - import UnsafeFixedWidthAggregationMap._ +class UnsafeFixedWidthAggregationMapSuite + extends SparkFunSuite + with Matchers + with BeforeAndAfterEach { private val groupKeySchema = StructType(StructField("product", StringType) :: Nil) private val aggBufferSchema = StructType(StructField("salePrice", IntegerType) :: Nil) - private def emptyAggregationBuffer: Row = new GenericRow(Array[Any](0)) + private def emptyProjection: Projection = + GenerateProjection.generate(Seq(Literal(0)), Seq(AttributeReference("price", IntegerType)())) + private def emptyAggregationBuffer: InternalRow = InternalRow(0) private var memoryManager: TaskMemoryManager = null @@ -46,21 +53,11 @@ class UnsafeFixedWidthAggregationMapSuite extends FunSuite with Matchers with Be } } - test("supported schemas") { - assert(!supportsAggregationBufferSchema(StructType(StructField("x", StringType) :: Nil))) - assert(supportsGroupKeySchema(StructType(StructField("x", StringType) :: Nil))) - - assert( - !supportsAggregationBufferSchema(StructType(StructField("x", ArrayType(IntegerType)) :: Nil))) - assert( - !supportsGroupKeySchema(StructType(StructField("x", ArrayType(IntegerType)) :: Nil))) - } - test("empty map") { val map = new UnsafeFixedWidthAggregationMap( - emptyAggregationBuffer, - aggBufferSchema, - groupKeySchema, + emptyProjection, + new UnsafeRowConverter(groupKeySchema), + new UnsafeRowConverter(aggBufferSchema), memoryManager, 1024, // initial capacity false // disable perf metrics @@ -71,14 +68,14 @@ class UnsafeFixedWidthAggregationMapSuite extends FunSuite with Matchers with Be test("updating values for a single key") { val map = new UnsafeFixedWidthAggregationMap( - emptyAggregationBuffer, - aggBufferSchema, - groupKeySchema, + emptyProjection, + new UnsafeRowConverter(groupKeySchema), + new UnsafeRowConverter(aggBufferSchema), memoryManager, 1024, // initial capacity false // disable perf metrics ) - val groupKey = new GenericRow(Array[Any](UTF8String("cats"))) + val groupKey = InternalRow(UTF8String.fromString("cats")) // Looking up a key stores a zero-entry in the map (like Python Counters or DefaultDicts) map.getAggregationBuffer(groupKey) @@ -97,9 +94,9 @@ class UnsafeFixedWidthAggregationMapSuite extends FunSuite with Matchers with Be test("inserting large random keys") { val map = new UnsafeFixedWidthAggregationMap( - emptyAggregationBuffer, - aggBufferSchema, - groupKeySchema, + emptyProjection, + new UnsafeRowConverter(groupKeySchema), + new UnsafeRowConverter(aggBufferSchema), memoryManager, 128, // initial capacity false // disable perf metrics @@ -107,13 +104,43 @@ class UnsafeFixedWidthAggregationMapSuite extends FunSuite with Matchers with Be val rand = new Random(42) val groupKeys: Set[String] = Seq.fill(512)(rand.nextString(1024)).toSet groupKeys.foreach { keyString => - map.getAggregationBuffer(new GenericRow(Array[Any](UTF8String(keyString)))) + map.getAggregationBuffer(InternalRow(UTF8String.fromString(keyString))) } val seenKeys: Set[String] = map.iterator().asScala.map { entry => entry.key.getString(0) }.toSet seenKeys.size should be (groupKeys.size) seenKeys should be (groupKeys) + + map.free() + } + + test("with decimal in the key and values") { + val groupKeySchema = StructType(StructField("price", DecimalType(10, 0)) :: Nil) + val aggBufferSchema = StructType(StructField("amount", DecimalType.Unlimited) :: Nil) + val emptyProjection = GenerateProjection.generate(Seq(Literal(Decimal(0))), + Seq(AttributeReference("price", DecimalType.Unlimited)())) + val map = new UnsafeFixedWidthAggregationMap( + emptyProjection, + new UnsafeRowConverter(groupKeySchema), + new UnsafeRowConverter(aggBufferSchema), + memoryManager, + 1, // initial capacity + false // disable perf metrics + ) + + (0 until 100).foreach { i => + val groupKey = InternalRow(Decimal(i % 10)) + val row = map.getAggregationBuffer(groupKey) + row.update(0, Decimal(i)) + } + val seenKeys: Set[Int] = map.iterator().asScala.map { entry => + entry.key.getAs[Decimal](0).toInt + }.toSet + seenKeys.size should be (10) + seenKeys should be ((0 until 10).toSet) + + map.free() } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/UnsafeRowConverterSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/UnsafeRowConverterSuite.scala index 3a60c7fd3267..94c2f3242b12 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/UnsafeRowConverterSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/UnsafeRowConverterSuite.scala @@ -17,15 +17,19 @@ package org.apache.spark.sql.catalyst.expressions +import java.sql.{Date, Timestamp} import java.util.Arrays -import org.scalatest.{FunSuite, Matchers} +import org.scalatest.Matchers +import org.apache.spark.SparkFunSuite +import org.apache.spark.sql.catalyst.util.{ObjectPool, DateTimeUtils} import org.apache.spark.sql.types._ import org.apache.spark.unsafe.PlatformDependent import org.apache.spark.unsafe.array.ByteArrayMethods +import org.apache.spark.unsafe.types.UTF8String -class UnsafeRowConverterSuite extends FunSuite with Matchers { +class UnsafeRowConverterSuite extends SparkFunSuite with Matchers { test("basic conversion with only primitive types") { val fieldTypes: Array[DataType] = Array(LongType, LongType, IntegerType) @@ -37,40 +41,124 @@ class UnsafeRowConverterSuite extends FunSuite with Matchers { row.setInt(2, 2) val sizeRequired: Int = converter.getSizeRequirement(row) - sizeRequired should be (8 + (3 * 8)) + assert(sizeRequired === 8 + (3 * 8)) val buffer: Array[Long] = new Array[Long](sizeRequired / 8) - val numBytesWritten = converter.writeRow(row, buffer, PlatformDependent.LONG_ARRAY_OFFSET) - numBytesWritten should be (sizeRequired) + val numBytesWritten = converter.writeRow(row, buffer, PlatformDependent.LONG_ARRAY_OFFSET, null) + assert(numBytesWritten === sizeRequired) val unsafeRow = new UnsafeRow() unsafeRow.pointTo(buffer, PlatformDependent.LONG_ARRAY_OFFSET, fieldTypes.length, null) - unsafeRow.getLong(0) should be (0) - unsafeRow.getLong(1) should be (1) - unsafeRow.getInt(2) should be (2) + assert(unsafeRow.getLong(0) === 0) + assert(unsafeRow.getLong(1) === 1) + assert(unsafeRow.getInt(2) === 2) + + unsafeRow.setLong(1, 3) + assert(unsafeRow.getLong(1) === 3) + unsafeRow.setInt(2, 4) + assert(unsafeRow.getInt(2) === 4) + } + + test("basic conversion with primitive, string and binary types") { + val fieldTypes: Array[DataType] = Array(LongType, StringType, BinaryType) + val converter = new UnsafeRowConverter(fieldTypes) + + val row = new SpecificMutableRow(fieldTypes) + row.setLong(0, 0) + row.update(1, UTF8String.fromString("Hello")) + row.update(2, "World".getBytes) + + val sizeRequired: Int = converter.getSizeRequirement(row) + assert(sizeRequired === 8 + (8 * 3) + + ByteArrayMethods.roundNumberOfBytesToNearestWord("Hello".getBytes.length) + + ByteArrayMethods.roundNumberOfBytesToNearestWord("World".getBytes.length)) + val buffer: Array[Long] = new Array[Long](sizeRequired / 8) + val numBytesWritten = converter.writeRow(row, buffer, PlatformDependent.LONG_ARRAY_OFFSET, null) + assert(numBytesWritten === sizeRequired) + + val unsafeRow = new UnsafeRow() + val pool = new ObjectPool(10) + unsafeRow.pointTo(buffer, PlatformDependent.LONG_ARRAY_OFFSET, fieldTypes.length, pool) + assert(unsafeRow.getLong(0) === 0) + assert(unsafeRow.getString(1) === "Hello") + assert(unsafeRow.get(2) === "World".getBytes) + + unsafeRow.update(1, UTF8String.fromString("World")) + assert(unsafeRow.getString(1) === "World") + assert(pool.size === 0) + unsafeRow.update(1, UTF8String.fromString("Hello World")) + assert(unsafeRow.getString(1) === "Hello World") + assert(pool.size === 1) + + unsafeRow.update(2, "World".getBytes) + assert(unsafeRow.get(2) === "World".getBytes) + assert(pool.size === 1) + unsafeRow.update(2, "Hello World".getBytes) + assert(unsafeRow.get(2) === "Hello World".getBytes) + assert(pool.size === 2) + } + + test("basic conversion with primitive, decimal and array") { + val fieldTypes: Array[DataType] = Array(LongType, DecimalType(10, 0), ArrayType(StringType)) + val converter = new UnsafeRowConverter(fieldTypes) + + val row = new SpecificMutableRow(fieldTypes) + row.setLong(0, 0) + row.update(1, Decimal(1)) + row.update(2, Array(2)) + + val pool = new ObjectPool(10) + val sizeRequired: Int = converter.getSizeRequirement(row) + assert(sizeRequired === 8 + (8 * 3)) + val buffer: Array[Long] = new Array[Long](sizeRequired / 8) + val numBytesWritten = converter.writeRow(row, buffer, PlatformDependent.LONG_ARRAY_OFFSET, pool) + assert(numBytesWritten === sizeRequired) + assert(pool.size === 2) + + val unsafeRow = new UnsafeRow() + unsafeRow.pointTo(buffer, PlatformDependent.LONG_ARRAY_OFFSET, fieldTypes.length, pool) + assert(unsafeRow.getLong(0) === 0) + assert(unsafeRow.get(1) === Decimal(1)) + assert(unsafeRow.get(2) === Array(2)) + + unsafeRow.update(1, Decimal(2)) + assert(unsafeRow.get(1) === Decimal(2)) + unsafeRow.update(2, Array(3, 4)) + assert(unsafeRow.get(2) === Array(3, 4)) + assert(pool.size === 2) } - test("basic conversion with primitive and string types") { - val fieldTypes: Array[DataType] = Array(LongType, StringType, StringType) + test("basic conversion with primitive, string, date and timestamp types") { + val fieldTypes: Array[DataType] = Array(LongType, StringType, DateType, TimestampType) val converter = new UnsafeRowConverter(fieldTypes) val row = new SpecificMutableRow(fieldTypes) row.setLong(0, 0) row.setString(1, "Hello") - row.setString(2, "World") + row.update(2, DateTimeUtils.fromJavaDate(Date.valueOf("1970-01-01"))) + row.update(3, DateTimeUtils.fromJavaTimestamp(Timestamp.valueOf("2015-05-08 08:10:25"))) val sizeRequired: Int = converter.getSizeRequirement(row) - sizeRequired should be (8 + (8 * 3) + - ByteArrayMethods.roundNumberOfBytesToNearestWord("Hello".getBytes.length + 8) + - ByteArrayMethods.roundNumberOfBytesToNearestWord("World".getBytes.length + 8)) + assert(sizeRequired === 8 + (8 * 4) + + ByteArrayMethods.roundNumberOfBytesToNearestWord("Hello".getBytes.length)) val buffer: Array[Long] = new Array[Long](sizeRequired / 8) - val numBytesWritten = converter.writeRow(row, buffer, PlatformDependent.LONG_ARRAY_OFFSET) - numBytesWritten should be (sizeRequired) + val numBytesWritten = converter.writeRow(row, buffer, PlatformDependent.LONG_ARRAY_OFFSET, null) + assert(numBytesWritten === sizeRequired) val unsafeRow = new UnsafeRow() unsafeRow.pointTo(buffer, PlatformDependent.LONG_ARRAY_OFFSET, fieldTypes.length, null) - unsafeRow.getLong(0) should be (0) - unsafeRow.getString(1) should be ("Hello") - unsafeRow.getString(2) should be ("World") + assert(unsafeRow.getLong(0) === 0) + assert(unsafeRow.getString(1) === "Hello") + // Date is represented as Int in unsafeRow + assert(DateTimeUtils.toJavaDate(unsafeRow.getInt(2)) === Date.valueOf("1970-01-01")) + // Timestamp is represented as Long in unsafeRow + DateTimeUtils.toJavaTimestamp(unsafeRow.getLong(3)) should be + (Timestamp.valueOf("2015-05-08 08:10:25")) + + unsafeRow.setInt(2, DateTimeUtils.fromJavaDate(Date.valueOf("2015-06-22"))) + assert(DateTimeUtils.toJavaDate(unsafeRow.getInt(2)) === Date.valueOf("2015-06-22")) + unsafeRow.setLong(3, DateTimeUtils.fromJavaTimestamp(Timestamp.valueOf("2015-06-22 08:10:25"))) + DateTimeUtils.toJavaTimestamp(unsafeRow.getLong(3)) should be + (Timestamp.valueOf("2015-06-22 08:10:25")) } test("null handling") { @@ -82,10 +170,15 @@ class UnsafeRowConverterSuite extends FunSuite with Matchers { IntegerType, LongType, FloatType, - DoubleType) + DoubleType, + StringType, + BinaryType, + DecimalType.Unlimited, + ArrayType(IntegerType) + ) val converter = new UnsafeRowConverter(fieldTypes) - val rowWithAllNullColumns: Row = { + val rowWithAllNullColumns: InternalRow = { val r = new SpecificMutableRow(fieldTypes) for (i <- 0 to fieldTypes.length - 1) { r.setNullAt(i) @@ -96,8 +189,8 @@ class UnsafeRowConverterSuite extends FunSuite with Matchers { val sizeRequired: Int = converter.getSizeRequirement(rowWithAllNullColumns) val createdFromNullBuffer: Array[Long] = new Array[Long](sizeRequired / 8) val numBytesWritten = converter.writeRow( - rowWithAllNullColumns, createdFromNullBuffer, PlatformDependent.LONG_ARRAY_OFFSET) - numBytesWritten should be (sizeRequired) + rowWithAllNullColumns, createdFromNullBuffer, PlatformDependent.LONG_ARRAY_OFFSET, null) + assert(numBytesWritten === sizeRequired) val createdFromNull = new UnsafeRow() createdFromNull.pointTo( @@ -105,18 +198,22 @@ class UnsafeRowConverterSuite extends FunSuite with Matchers { for (i <- 0 to fieldTypes.length - 1) { assert(createdFromNull.isNullAt(i)) } - createdFromNull.getBoolean(1) should be (false) - createdFromNull.getByte(2) should be (0) - createdFromNull.getShort(3) should be (0) - createdFromNull.getInt(4) should be (0) - createdFromNull.getLong(5) should be (0) + assert(createdFromNull.getBoolean(1) === false) + assert(createdFromNull.getByte(2) === 0) + assert(createdFromNull.getShort(3) === 0) + assert(createdFromNull.getInt(4) === 0) + assert(createdFromNull.getLong(5) === 0) assert(java.lang.Float.isNaN(createdFromNull.getFloat(6))) - assert(java.lang.Double.isNaN(createdFromNull.getFloat(7))) + assert(java.lang.Double.isNaN(createdFromNull.getDouble(7))) + assert(createdFromNull.getString(8) === null) + assert(createdFromNull.get(9) === null) + assert(createdFromNull.get(10) === null) + assert(createdFromNull.get(11) === null) // If we have an UnsafeRow with columns that are initially non-null and we null out those // columns, then the serialized row representation should be identical to what we would get by // creating an entirely null row via the converter - val rowWithNoNullColumns: Row = { + val rowWithNoNullColumns: InternalRow = { val r = new SpecificMutableRow(fieldTypes) r.setNullAt(0) r.setBoolean(1, false) @@ -126,28 +223,68 @@ class UnsafeRowConverterSuite extends FunSuite with Matchers { r.setLong(5, 500) r.setFloat(6, 600) r.setDouble(7, 700) + r.update(8, UTF8String.fromString("hello")) + r.update(9, "world".getBytes) + r.update(10, Decimal(10)) + r.update(11, Array(11)) r } - val setToNullAfterCreationBuffer: Array[Long] = new Array[Long](sizeRequired / 8) + val pool = new ObjectPool(1) + val setToNullAfterCreationBuffer: Array[Long] = new Array[Long](sizeRequired / 8 + 2) converter.writeRow( - rowWithNoNullColumns, setToNullAfterCreationBuffer, PlatformDependent.LONG_ARRAY_OFFSET) + rowWithNoNullColumns, setToNullAfterCreationBuffer, PlatformDependent.LONG_ARRAY_OFFSET, pool) val setToNullAfterCreation = new UnsafeRow() setToNullAfterCreation.pointTo( - setToNullAfterCreationBuffer, PlatformDependent.LONG_ARRAY_OFFSET, fieldTypes.length, null) + setToNullAfterCreationBuffer, PlatformDependent.LONG_ARRAY_OFFSET, fieldTypes.length, pool) - setToNullAfterCreation.isNullAt(0) should be (rowWithNoNullColumns.isNullAt(0)) - setToNullAfterCreation.getBoolean(1) should be (rowWithNoNullColumns.getBoolean(1)) - setToNullAfterCreation.getByte(2) should be (rowWithNoNullColumns.getByte(2)) - setToNullAfterCreation.getShort(3) should be (rowWithNoNullColumns.getShort(3)) - setToNullAfterCreation.getInt(4) should be (rowWithNoNullColumns.getInt(4)) - setToNullAfterCreation.getLong(5) should be (rowWithNoNullColumns.getLong(5)) - setToNullAfterCreation.getFloat(6) should be (rowWithNoNullColumns.getFloat(6)) - setToNullAfterCreation.getDouble(7) should be (rowWithNoNullColumns.getDouble(7)) + assert(setToNullAfterCreation.isNullAt(0) === rowWithNoNullColumns.isNullAt(0)) + assert(setToNullAfterCreation.getBoolean(1) === rowWithNoNullColumns.getBoolean(1)) + assert(setToNullAfterCreation.getByte(2) === rowWithNoNullColumns.getByte(2)) + assert(setToNullAfterCreation.getShort(3) === rowWithNoNullColumns.getShort(3)) + assert(setToNullAfterCreation.getInt(4) === rowWithNoNullColumns.getInt(4)) + assert(setToNullAfterCreation.getLong(5) === rowWithNoNullColumns.getLong(5)) + assert(setToNullAfterCreation.getFloat(6) === rowWithNoNullColumns.getFloat(6)) + assert(setToNullAfterCreation.getDouble(7) === rowWithNoNullColumns.getDouble(7)) + assert(setToNullAfterCreation.getString(8) === rowWithNoNullColumns.getString(8)) + assert(setToNullAfterCreation.get(9) === rowWithNoNullColumns.get(9)) + assert(setToNullAfterCreation.get(10) === rowWithNoNullColumns.get(10)) + assert(setToNullAfterCreation.get(11) === rowWithNoNullColumns.get(11)) for (i <- 0 to fieldTypes.length - 1) { + if (i >= 8) { + setToNullAfterCreation.update(i, null) + } setToNullAfterCreation.setNullAt(i) } - assert(Arrays.equals(createdFromNullBuffer, setToNullAfterCreationBuffer)) + // There are some garbage left in the var-length area + assert(Arrays.equals(createdFromNullBuffer, + java.util.Arrays.copyOf(setToNullAfterCreationBuffer, sizeRequired / 8))) + + setToNullAfterCreation.setNullAt(0) + setToNullAfterCreation.setBoolean(1, false) + setToNullAfterCreation.setByte(2, 20) + setToNullAfterCreation.setShort(3, 30) + setToNullAfterCreation.setInt(4, 400) + setToNullAfterCreation.setLong(5, 500) + setToNullAfterCreation.setFloat(6, 600) + setToNullAfterCreation.setDouble(7, 700) + setToNullAfterCreation.update(8, UTF8String.fromString("hello")) + setToNullAfterCreation.update(9, "world".getBytes) + setToNullAfterCreation.update(10, Decimal(10)) + setToNullAfterCreation.update(11, Array(11)) + + assert(setToNullAfterCreation.isNullAt(0) === rowWithNoNullColumns.isNullAt(0)) + assert(setToNullAfterCreation.getBoolean(1) === rowWithNoNullColumns.getBoolean(1)) + assert(setToNullAfterCreation.getByte(2) === rowWithNoNullColumns.getByte(2)) + assert(setToNullAfterCreation.getShort(3) === rowWithNoNullColumns.getShort(3)) + assert(setToNullAfterCreation.getInt(4) === rowWithNoNullColumns.getInt(4)) + assert(setToNullAfterCreation.getLong(5) === rowWithNoNullColumns.getLong(5)) + assert(setToNullAfterCreation.getFloat(6) === rowWithNoNullColumns.getFloat(6)) + assert(setToNullAfterCreation.getDouble(7) === rowWithNoNullColumns.getDouble(7)) + assert(setToNullAfterCreation.getString(8) === rowWithNoNullColumns.getString(8)) + assert(setToNullAfterCreation.get(9) === rowWithNoNullColumns.get(9)) + assert(setToNullAfterCreation.get(10) === rowWithNoNullColumns.get(10)) + assert(setToNullAfterCreation.get(11) === rowWithNoNullColumns.get(11)) } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/BooleanSimplificationSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/BooleanSimplificationSuite.scala index 6255578d7fa5..465a5e691420 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/BooleanSimplificationSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/BooleanSimplificationSuite.scala @@ -78,9 +78,9 @@ class BooleanSimplificationSuite extends PlanTest with PredicateHelper { test("(a && b && c && ...) || (a && b && d && ...) || (a && b && e && ...) ...") { checkCondition('b > 3 || 'c > 5, 'b > 3 || 'c > 5) - checkCondition(('a < 2 && 'a > 3 && 'b > 5) || 'a < 2, 'a < 2) + checkCondition(('a < 2 && 'a > 3 && 'b > 5) || 'a < 2, 'a < 2) - checkCondition('a < 2 || ('a < 2 && 'a > 3 && 'b > 5), 'a < 2) + checkCondition('a < 2 || ('a < 2 && 'a > 3 && 'b > 5), 'a < 2) val input = ('a === 'b && 'b > 3 && 'c > 2) || ('a === 'b && 'c < 1 && 'a === 5) || diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/CombiningLimitsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/CombiningLimitsSuite.scala index a30052b38fc1..06c592f4905a 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/CombiningLimitsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/CombiningLimitsSuite.scala @@ -71,7 +71,7 @@ class CombiningLimitsSuite extends PlanTest { comparePlans(optimized, correctAnswer) } - + test("limits: combines two limits after ColumnPruning") { val originalQuery = testRelation @@ -79,7 +79,7 @@ class CombiningLimitsSuite extends PlanTest { .limit(2) .select('a) .limit(5) - + val optimized = Optimize.execute(originalQuery.analyze) val correctAnswer = testRelation diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ConstantFoldingSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ConstantFoldingSuite.scala index 5697c2272b8e..ec3b2f1edfa0 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ConstantFoldingSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ConstantFoldingSuite.scala @@ -248,7 +248,7 @@ class ConstantFoldingSuite extends PlanTest { comparePlans(optimized, correctAnswer) } - + test("Constant folding test: Fold In(v, list) into true or false") { var originalQuery = testRelation diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ConvertToLocalRelationSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ConvertToLocalRelationSuite.scala index 6841bd9890c9..54e8c6462e96 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ConvertToLocalRelationSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ConvertToLocalRelationSuite.scala @@ -17,10 +17,10 @@ package org.apache.spark.sql.catalyst.optimizer -import org.apache.spark.sql.Row import org.apache.spark.sql.catalyst.analysis.UnresolvedAttribute -import org.apache.spark.sql.catalyst.dsl.plans._ import org.apache.spark.sql.catalyst.dsl.expressions._ +import org.apache.spark.sql.catalyst.dsl.plans._ +import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.plans.PlanTest import org.apache.spark.sql.catalyst.plans.logical.{LocalRelation, LogicalPlan} import org.apache.spark.sql.catalyst.rules.RuleExecutor @@ -37,13 +37,11 @@ class ConvertToLocalRelationSuite extends PlanTest { test("Project on LocalRelation should be turned into a single LocalRelation") { val testRelation = LocalRelation( LocalRelation('a.int, 'b.int).output, - Row(1, 2) :: - Row(4, 5) :: Nil) + InternalRow(1, 2) :: InternalRow(4, 5) :: Nil) val correctAnswer = LocalRelation( LocalRelation('a1.int, 'b1.int).output, - Row(1, 3) :: - Row(4, 6) :: Nil) + InternalRow(1, 3) :: InternalRow(4, 6) :: Nil) val projectOnLocal = testRelation.select( UnresolvedAttribute("a").as("a1"), diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/FilterPushdownSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/FilterPushdownSuite.scala index be33cb9bb8ea..ffdc673cdc45 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/FilterPushdownSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/FilterPushdownSuite.scala @@ -36,6 +36,7 @@ class FilterPushdownSuite extends PlanTest { Batch("Filter Pushdown", Once, CombineFilters, PushPredicateThroughProject, + BooleanSimplification, PushPredicateThroughJoin, PushPredicateThroughGenerate, ColumnPruning, @@ -93,11 +94,11 @@ class FilterPushdownSuite extends PlanTest { comparePlans(optimized, correctAnswer) } - + test("column pruning for Project(ne, Limit)") { val originalQuery = testRelation - .select('a,'b) + .select('a, 'b) .limit(2) .select('a) @@ -109,7 +110,7 @@ class FilterPushdownSuite extends PlanTest { comparePlans(optimized, correctAnswer) } - + // After this line is unimplemented. test("simple push down") { val originalQuery = @@ -156,11 +157,9 @@ class FilterPushdownSuite extends PlanTest { .where('a === 1 && 'a === 2) .select('a).analyze - comparePlans(optimized, correctAnswer) } - test("joins: push to either side") { val x = testRelation.subquery('x) val y = testRelation.subquery('y) @@ -198,6 +197,25 @@ class FilterPushdownSuite extends PlanTest { comparePlans(optimized, correctAnswer) } + test("joins: push to one side after transformCondition") { + val x = testRelation.subquery('x) + val y = testRelation1.subquery('y) + + val originalQuery = { + x.join(y) + .where(("x.a".attr === 1 && "y.d".attr === "x.b".attr) || + ("x.a".attr === 1 && "y.d".attr === "x.c".attr)) + } + + val optimized = Optimize.execute(originalQuery.analyze) + val left = testRelation.where('a === 1) + val right = testRelation1 + val correctAnswer = + left.join(right, condition = Some("d".attr === "b".attr || "d".attr === "c".attr)).analyze + + comparePlans(optimized, correctAnswer) + } + test("joins: rewrite filter to push to either side") { val x = testRelation.subquery('x) val y = testRelation.subquery('y) @@ -563,17 +581,16 @@ class FilterPushdownSuite extends PlanTest { // push down invalid val originalQuery1 = { x.select('a, 'b) - .sortBy(SortOrder('a, Ascending)) - .select('b) + .sortBy(SortOrder('a, Ascending)) + .select('b) } val optimized1 = Optimize.execute(originalQuery1.analyze) val correctAnswer1 = x.select('a, 'b) - .sortBy(SortOrder('a, Ascending)) - .select('b).analyze + .sortBy(SortOrder('a, Ascending)) + .select('b).analyze comparePlans(optimized1, analysis.EliminateSubQueries(correctAnswer1)) - } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/OptimizeInSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/OptimizeInSuite.scala index 3eb399e68e70..1d433275fed2 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/OptimizeInSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/OptimizeInSuite.scala @@ -46,7 +46,7 @@ class OptimizeInSuite extends PlanTest { test("OptimizedIn test: In clause optimized to InSet") { val originalQuery = testRelation - .where(In(UnresolvedAttribute("a"), Seq(Literal(1),Literal(2)))) + .where(In(UnresolvedAttribute("a"), Seq(Literal(1), Literal(2)))) .analyze val optimized = Optimize.execute(originalQuery.analyze) @@ -57,17 +57,17 @@ class OptimizeInSuite extends PlanTest { comparePlans(optimized, correctAnswer) } - + test("OptimizedIn test: In clause not optimized in case filter has attributes") { val originalQuery = testRelation - .where(In(UnresolvedAttribute("a"), Seq(Literal(1),Literal(2), UnresolvedAttribute("b")))) + .where(In(UnresolvedAttribute("a"), Seq(Literal(1), Literal(2), UnresolvedAttribute("b")))) .analyze val optimized = Optimize.execute(originalQuery.analyze) val correctAnswer = testRelation - .where(In(UnresolvedAttribute("a"), Seq(Literal(1),Literal(2), UnresolvedAttribute("b")))) + .where(In(UnresolvedAttribute("a"), Seq(Literal(1), Literal(2), UnresolvedAttribute("b")))) .analyze comparePlans(optimized, correctAnswer) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ProjectCollapsingSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ProjectCollapsingSuite.scala new file mode 100644 index 000000000000..151654bffbd6 --- /dev/null +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ProjectCollapsingSuite.scala @@ -0,0 +1,73 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.optimizer + +import org.apache.spark.sql.catalyst.analysis.EliminateSubQueries +import org.apache.spark.sql.catalyst.dsl.plans._ +import org.apache.spark.sql.catalyst.dsl.expressions._ +import org.apache.spark.sql.catalyst.expressions.Rand +import org.apache.spark.sql.catalyst.plans.PlanTest +import org.apache.spark.sql.catalyst.plans.logical.{LocalRelation, LogicalPlan} +import org.apache.spark.sql.catalyst.rules.RuleExecutor + + +class ProjectCollapsingSuite extends PlanTest { + object Optimize extends RuleExecutor[LogicalPlan] { + val batches = + Batch("Subqueries", FixedPoint(10), EliminateSubQueries) :: + Batch("ProjectCollapsing", Once, ProjectCollapsing) :: Nil + } + + val testRelation = LocalRelation('a.int, 'b.int) + + test("collapse two deterministic, independent projects into one") { + val query = testRelation + .select(('a + 1).as('a_plus_1), 'b) + .select('a_plus_1, ('b + 1).as('b_plus_1)) + + val optimized = Optimize.execute(query.analyze) + val correctAnswer = testRelation.select(('a + 1).as('a_plus_1), ('b + 1).as('b_plus_1)).analyze + + comparePlans(optimized, correctAnswer) + } + + test("collapse two deterministic, dependent projects into one") { + val query = testRelation + .select(('a + 1).as('a_plus_1), 'b) + .select(('a_plus_1 + 1).as('a_plus_2), 'b) + + val optimized = Optimize.execute(query.analyze) + + val correctAnswer = testRelation.select( + (('a + 1).as('a_plus_1) + 1).as('a_plus_2), + 'b).analyze + + comparePlans(optimized, correctAnswer) + } + + test("do not collapse nondeterministic projects") { + val query = testRelation + .select(Rand(10).as('rand)) + .select(('rand + 1).as('rand1), ('rand + 2).as('rand2)) + + val optimized = Optimize.execute(query.analyze) + val correctAnswer = query.analyze + + comparePlans(optimized, correctAnswer) + } +} diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ReplaceDistinctWithAggregateSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ReplaceDistinctWithAggregateSuite.scala new file mode 100644 index 000000000000..df29a62ff0e1 --- /dev/null +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ReplaceDistinctWithAggregateSuite.scala @@ -0,0 +1,42 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.optimizer + +import org.apache.spark.sql.catalyst.dsl.plans._ +import org.apache.spark.sql.catalyst.dsl.expressions._ +import org.apache.spark.sql.catalyst.plans.PlanTest +import org.apache.spark.sql.catalyst.plans.logical.{Aggregate, Distinct, LocalRelation, LogicalPlan} +import org.apache.spark.sql.catalyst.rules.RuleExecutor + +class ReplaceDistinctWithAggregateSuite extends PlanTest { + + object Optimize extends RuleExecutor[LogicalPlan] { + val batches = Batch("ProjectCollapsing", Once, ReplaceDistinctWithAggregate) :: Nil + } + + test("replace distinct with aggregate") { + val input = LocalRelation('a.int, 'b.int) + + val query = Distinct(input) + val optimized = Optimize.execute(query.analyze) + + val correctAnswer = Aggregate(input.output, input.output, input) + + comparePlans(optimized, correctAnswer) + } +} diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/UnionPushdownSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/UnionPushdownSuite.scala index a3ad200800b0..ec379489a6d1 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/UnionPushdownSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/UnionPushdownSuite.scala @@ -24,17 +24,17 @@ import org.apache.spark.sql.catalyst.rules._ import org.apache.spark.sql.catalyst.dsl.plans._ import org.apache.spark.sql.catalyst.dsl.expressions._ -class UnionPushdownSuite extends PlanTest { +class UnionPushDownSuite extends PlanTest { object Optimize extends RuleExecutor[LogicalPlan] { val batches = Batch("Subqueries", Once, EliminateSubQueries) :: Batch("Union Pushdown", Once, - UnionPushdown) :: Nil + UnionPushDown) :: Nil } - val testRelation = LocalRelation('a.int, 'b.int, 'c.int) - val testRelation2 = LocalRelation('d.int, 'e.int, 'f.int) + val testRelation = LocalRelation('a.int, 'b.int, 'c.int) + val testRelation2 = LocalRelation('d.int, 'e.int, 'f.int) val testUnion = Union(testRelation, testRelation2) test("union: filter to each side") { diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/PlanTest.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/PlanTest.scala index e7cafcc96de8..765c1e2dda99 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/PlanTest.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/PlanTest.scala @@ -17,8 +17,7 @@ package org.apache.spark.sql.catalyst.plans -import org.scalatest.FunSuite - +import org.apache.spark.SparkFunSuite import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.plans.logical.{OneRowRelation, Filter, LogicalPlan} import org.apache.spark.sql.catalyst.util._ @@ -26,7 +25,7 @@ import org.apache.spark.sql.catalyst.util._ /** * Provides helper methods for comparing plans. */ -class PlanTest extends FunSuite { +class PlanTest extends SparkFunSuite { /** * Since attribute references are given globally unique ids during analysis, diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/SameResultSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/SameResultSuite.scala index 1273921f6394..62d5f6ac7488 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/SameResultSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/SameResultSuite.scala @@ -17,8 +17,7 @@ package org.apache.spark.sql.catalyst.plans -import org.scalatest.FunSuite - +import org.apache.spark.SparkFunSuite import org.apache.spark.sql.catalyst.dsl.plans._ import org.apache.spark.sql.catalyst.dsl.expressions._ import org.apache.spark.sql.catalyst.expressions.{ExprId, AttributeReference} @@ -28,7 +27,7 @@ import org.apache.spark.sql.catalyst.util._ /** * Tests for the sameResult function of [[LogicalPlan]]. */ -class SameResultSuite extends FunSuite { +class SameResultSuite extends SparkFunSuite { val testRelation = LocalRelation('a.int, 'b.int, 'c.int) val testRelation2 = LocalRelation('a.int, 'b.int, 'c.int) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/trees/RuleExecutorSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/trees/RuleExecutorSuite.scala index 2a641c63f87b..a7de7b052bdc 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/trees/RuleExecutorSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/trees/RuleExecutorSuite.scala @@ -17,12 +17,11 @@ package org.apache.spark.sql.catalyst.trees -import org.scalatest.FunSuite - +import org.apache.spark.SparkFunSuite import org.apache.spark.sql.catalyst.expressions.{Expression, IntegerLiteral, Literal} import org.apache.spark.sql.catalyst.rules.{Rule, RuleExecutor} -class RuleExecutorSuite extends FunSuite { +class RuleExecutorSuite extends SparkFunSuite { object DecrementLiterals extends Rule[Expression] { def apply(e: Expression): Expression = e transform { case IntegerLiteral(i) if i > 0 => Literal(i - 1) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/trees/TreeNodeSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/trees/TreeNodeSuite.scala index 3d10dab5ba34..86792f021757 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/trees/TreeNodeSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/trees/TreeNodeSuite.scala @@ -19,21 +19,24 @@ package org.apache.spark.sql.catalyst.trees import scala.collection.mutable.ArrayBuffer -import org.scalatest.FunSuite - +import org.apache.spark.SparkFunSuite import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.types.{IntegerType, StringType, NullType} case class Dummy(optKey: Option[Expression]) extends Expression { - def children: Seq[Expression] = optKey.toSeq - def nullable: Boolean = true - def dataType: NullType = NullType + override def children: Seq[Expression] = optKey.toSeq + override def nullable: Boolean = true + override def dataType: NullType = NullType override lazy val resolved = true - type EvaluatedType = Any - def eval(input: Row): Any = null.asInstanceOf[Any] + override def eval(input: InternalRow): Any = null.asInstanceOf[Any] +} + +case class ComplexPlan(exprs: Seq[Seq[Expression]]) + extends org.apache.spark.sql.catalyst.plans.logical.LeafNode { + override def output: Seq[Attribute] = Nil } -class TreeNodeSuite extends FunSuite { +class TreeNodeSuite extends SparkFunSuite { test("top node changed") { val after = Literal(1) transform { case Literal(1, _) => Literal(2) } assert(after === Literal(2)) @@ -70,7 +73,7 @@ class TreeNodeSuite extends FunSuite { val expected = Seq("+", "1", "*", "2", "-", "3", "4") val expression = Add(Literal(1), Multiply(Literal(2), Subtract(Literal(3), Literal(4)))) expression transformDown { - case b: BinaryExpression => actual.append(b.symbol); b + case b: BinaryOperator => actual.append(b.symbol); b case l: Literal => actual.append(l.toString); l } @@ -82,7 +85,7 @@ class TreeNodeSuite extends FunSuite { val expected = Seq("1", "2", "3", "4", "-", "*", "+") val expression = Add(Literal(1), Multiply(Literal(2), Subtract(Literal(3), Literal(4)))) expression transformUp { - case b: BinaryExpression => actual.append(b.symbol); b + case b: BinaryOperator => actual.append(b.symbol); b case l: Literal => actual.append(l.toString); l } @@ -92,7 +95,7 @@ class TreeNodeSuite extends FunSuite { test("transform works on nodes with Option children") { val dummy1 = Dummy(Some(Literal.create("1", StringType))) val dummy2 = Dummy(None) - val toZero: PartialFunction[Expression, Expression] = { case Literal(_, _) => Literal(0) } + val toZero: PartialFunction[Expression, Expression] = { case Literal(_, _) => Literal(0) } var actual = dummy1 transformDown toZero assert(actual === Dummy(Some(Literal(0)))) @@ -105,7 +108,7 @@ class TreeNodeSuite extends FunSuite { } test("preserves origin") { - CurrentOrigin.setPosition(1,1) + CurrentOrigin.setPosition(1, 1) val add = Add(Literal(1), Literal(1)) CurrentOrigin.reset() @@ -122,7 +125,7 @@ class TreeNodeSuite extends FunSuite { val expected = Seq("1", "2", "3", "4", "-", "*", "+") val expression = Add(Literal(1), Multiply(Literal(2), Subtract(Literal(3), Literal(4)))) expression foreachUp { - case b: BinaryExpression => actual.append(b.symbol); + case b: BinaryOperator => actual.append(b.symbol); case l: Literal => actual.append(l.toString); } @@ -222,4 +225,13 @@ class TreeNodeSuite extends FunSuite { assert(expected === actual) } } + + test("transformExpressions on nested expression sequence") { + val plan = ComplexPlan(Seq(Seq(Literal(1)), Seq(Literal(2)))) + val actual = plan.transformExpressions { + case Literal(value, _) => Literal(value.toString) + } + val expected = ComplexPlan(Seq(Seq(Literal("1")), Seq(Literal("2")))) + assert(expected === actual) + } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/util/DateTimeUtilsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/util/DateTimeUtilsSuite.scala new file mode 100644 index 000000000000..1d4a60c81efc --- /dev/null +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/util/DateTimeUtilsSuite.scala @@ -0,0 +1,89 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.util + +import java.sql.{Date, Timestamp} +import java.text.SimpleDateFormat + +import org.apache.spark.SparkFunSuite + +class DateTimeUtilsSuite extends SparkFunSuite { + + test("timestamp and 100ns") { + val now = new Timestamp(System.currentTimeMillis()) + now.setNanos(100) + val ns = DateTimeUtils.fromJavaTimestamp(now) + assert(ns % 10000000L === 1) + assert(DateTimeUtils.toJavaTimestamp(ns) === now) + + List(-111111111111L, -1L, 0, 1L, 111111111111L).foreach { t => + val ts = DateTimeUtils.toJavaTimestamp(t) + assert(DateTimeUtils.fromJavaTimestamp(ts) === t) + assert(DateTimeUtils.toJavaTimestamp(DateTimeUtils.fromJavaTimestamp(ts)) === ts) + } + } + + test("100ns and julian day") { + val (d, ns) = DateTimeUtils.toJulianDay(0) + assert(d === DateTimeUtils.JULIAN_DAY_OF_EPOCH) + assert(ns === DateTimeUtils.SECONDS_PER_DAY / 2 * DateTimeUtils.NANOS_PER_SECOND) + assert(DateTimeUtils.fromJulianDay(d, ns) == 0L) + + val t = new Timestamp(61394778610000L) // (2015, 6, 11, 10, 10, 10, 100) + val (d1, ns1) = DateTimeUtils.toJulianDay(DateTimeUtils.fromJavaTimestamp(t)) + val t2 = DateTimeUtils.toJavaTimestamp(DateTimeUtils.fromJulianDay(d1, ns1)) + assert(t.equals(t2)) + } + + test("SPARK-6785: java date conversion before and after epoch") { + def checkFromToJavaDate(d1: Date): Unit = { + val d2 = DateTimeUtils.toJavaDate(DateTimeUtils.fromJavaDate(d1)) + assert(d2.toString === d1.toString) + } + + val df1 = new SimpleDateFormat("yyyy-MM-dd HH:mm:ss") + val df2 = new SimpleDateFormat("yyyy-MM-dd HH:mm:ss z") + + checkFromToJavaDate(new Date(100)) + + checkFromToJavaDate(Date.valueOf("1970-01-01")) + + checkFromToJavaDate(new Date(df1.parse("1970-01-01 00:00:00").getTime)) + checkFromToJavaDate(new Date(df2.parse("1970-01-01 00:00:00 UTC").getTime)) + + checkFromToJavaDate(new Date(df1.parse("1970-01-01 00:00:01").getTime)) + checkFromToJavaDate(new Date(df2.parse("1970-01-01 00:00:01 UTC").getTime)) + + checkFromToJavaDate(new Date(df1.parse("1969-12-31 23:59:59").getTime)) + checkFromToJavaDate(new Date(df2.parse("1969-12-31 23:59:59 UTC").getTime)) + + checkFromToJavaDate(Date.valueOf("1969-01-01")) + + checkFromToJavaDate(new Date(df1.parse("1969-01-01 00:00:00").getTime)) + checkFromToJavaDate(new Date(df2.parse("1969-01-01 00:00:00 UTC").getTime)) + + checkFromToJavaDate(new Date(df1.parse("1969-01-01 00:00:01").getTime)) + checkFromToJavaDate(new Date(df2.parse("1969-01-01 00:00:01 UTC").getTime)) + + checkFromToJavaDate(new Date(df1.parse("1989-11-09 11:59:59").getTime)) + checkFromToJavaDate(new Date(df2.parse("1989-11-09 19:59:59 UTC").getTime)) + + checkFromToJavaDate(new Date(df1.parse("1776-07-04 10:30:00").getTime)) + checkFromToJavaDate(new Date(df2.parse("1776-07-04 18:30:00 UTC").getTime)) + } +} diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/util/MetadataSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/util/MetadataSuite.scala index d7d60efee50f..4030a1b1df35 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/util/MetadataSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/util/MetadataSuite.scala @@ -18,11 +18,11 @@ package org.apache.spark.sql.catalyst.util import org.json4s.jackson.JsonMethods.parse -import org.scalatest.FunSuite +import org.apache.spark.SparkFunSuite import org.apache.spark.sql.types.{MetadataBuilder, Metadata} -class MetadataSuite extends FunSuite { +class MetadataSuite extends SparkFunSuite { val baseMetadata = new MetadataBuilder() .putString("purpose", "ml") diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/util/ObjectPoolSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/util/ObjectPoolSuite.scala new file mode 100644 index 000000000000..94764df4b9cd --- /dev/null +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/util/ObjectPoolSuite.scala @@ -0,0 +1,57 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.util + +import org.scalatest.Matchers + +import org.apache.spark.SparkFunSuite + +class ObjectPoolSuite extends SparkFunSuite with Matchers { + + test("pool") { + val pool = new ObjectPool(1) + assert(pool.put(1) === 0) + assert(pool.put("hello") === 1) + assert(pool.put(false) === 2) + + assert(pool.get(0) === 1) + assert(pool.get(1) === "hello") + assert(pool.get(2) === false) + assert(pool.size() === 3) + + pool.replace(1, "world") + assert(pool.get(1) === "world") + assert(pool.size() === 3) + } + + test("unique pool") { + val pool = new UniqueObjectPool(1) + assert(pool.put(1) === 0) + assert(pool.put("hello") === 1) + assert(pool.put(1) === 0) + assert(pool.put("hello") === 1) + + assert(pool.get(0) === 1) + assert(pool.get(1) === "hello") + assert(pool.size() === 2) + + intercept[UnsupportedOperationException] { + pool.replace(1, "world") + } + } +} diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/types/DataTypeParserSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/types/DataTypeParserSuite.scala index 3e7cf7cbb5e6..c6171b7b6916 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/types/DataTypeParserSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/types/DataTypeParserSuite.scala @@ -17,9 +17,9 @@ package org.apache.spark.sql.types -import org.scalatest.FunSuite +import org.apache.spark.SparkFunSuite -class DataTypeParserSuite extends FunSuite { +class DataTypeParserSuite extends SparkFunSuite { def checkDataType(dataTypeString: String, expectedDataType: DataType): Unit = { test(s"parse ${dataTypeString.replace("\n", "")}") { diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/types/DataTypeSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/types/DataTypeSuite.scala index d797510f3668..14e7b4a9561b 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/types/DataTypeSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/types/DataTypeSuite.scala @@ -17,9 +17,9 @@ package org.apache.spark.sql.types -import org.scalatest.FunSuite +import org.apache.spark.{SparkException, SparkFunSuite} -class DataTypeSuite extends FunSuite { +class DataTypeSuite extends SparkFunSuite { test("construct an ArrayType") { val array = ArrayType(StringType) @@ -33,6 +33,37 @@ class DataTypeSuite extends FunSuite { assert(MapType(StringType, IntegerType, true) === map) } + test("construct with add") { + val struct = (new StructType) + .add("a", IntegerType, true) + .add("b", LongType, false) + .add("c", StringType, true) + + assert(StructField("b", LongType, false) === struct("b")) + } + + test("construct with add from StructField") { + // Test creation from StructField type + val struct = (new StructType) + .add(StructField("a", IntegerType, true)) + .add(StructField("b", LongType, false)) + .add(StructField("c", StringType, true)) + + assert(StructField("b", LongType, false) === struct("b")) + } + + test("construct with String DataType") { + // Test creation with DataType as String + val struct = (new StructType) + .add("a", "int", true) + .add("b", "long", false) + .add("c", "string", true) + + assert(StructField("a", IntegerType, true) === struct("a")) + assert(StructField("b", LongType, false) === struct("b")) + assert(StructField("c", StringType, true) === struct("c")) + } + test("extract fields from a StructType") { val struct = StructType( StructField("a", IntegerType, true) :: @@ -69,6 +100,76 @@ class DataTypeSuite extends FunSuite { } } + test("fieldsMap returns map of name to StructField") { + val struct = StructType( + StructField("a", LongType) :: + StructField("b", FloatType) :: Nil) + + val mapped = StructType.fieldsMap(struct.fields) + + val expected = Map( + "a" -> StructField("a", LongType), + "b" -> StructField("b", FloatType)) + + assert(mapped === expected) + } + + test("merge where right is empty") { + val left = StructType( + StructField("a", LongType) :: + StructField("b", FloatType) :: Nil) + + val right = StructType(List()) + val merged = left.merge(right) + + assert(merged === left) + } + + test("merge where left is empty") { + + val left = StructType(List()) + + val right = StructType( + StructField("a", LongType) :: + StructField("b", FloatType) :: Nil) + + val merged = left.merge(right) + + assert(right === merged) + + } + + test("merge where both are non-empty") { + val left = StructType( + StructField("a", LongType) :: + StructField("b", FloatType) :: Nil) + + val right = StructType( + StructField("c", LongType) :: Nil) + + val expected = StructType( + StructField("a", LongType) :: + StructField("b", FloatType) :: + StructField("c", LongType) :: Nil) + + val merged = left.merge(right) + + assert(merged === expected) + } + + test("merge where right contains type conflict") { + val left = StructType( + StructField("a", LongType) :: + StructField("b", FloatType) :: Nil) + + val right = StructType( + StructField("b", LongType) :: Nil) + + intercept[SparkException] { + left.merge(right) + } + } + def checkDataTypeJsonRepr(dataType: DataType): Unit = { test(s"JSON - $dataType") { assert(DataType.fromJson(dataType.json) === dataType) @@ -120,7 +221,7 @@ class DataTypeSuite extends FunSuite { checkDefaultSize(DecimalType(10, 5), 4096) checkDefaultSize(DecimalType.Unlimited, 4096) checkDefaultSize(DateType, 4) - checkDefaultSize(TimestampType,12) + checkDefaultSize(TimestampType, 8) checkDefaultSize(StringType, 4096) checkDefaultSize(BinaryType, 4096) checkDefaultSize(ArrayType(DoubleType, true), 800) @@ -179,11 +280,11 @@ class DataTypeSuite extends FunSuite { expected = false) checkEqualsIgnoreCompatibleNullability( from = MapType(StringType, ArrayType(IntegerType, true), valueContainsNull = true), - to = MapType(StringType, ArrayType(IntegerType, false), valueContainsNull = true), + to = MapType(StringType, ArrayType(IntegerType, false), valueContainsNull = true), expected = false) checkEqualsIgnoreCompatibleNullability( from = MapType(StringType, ArrayType(IntegerType, false), valueContainsNull = true), - to = MapType(StringType, ArrayType(IntegerType, true), valueContainsNull = true), + to = MapType(StringType, ArrayType(IntegerType, true), valueContainsNull = true), expected = true) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/types/DataTypeTestUtils.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/types/DataTypeTestUtils.scala new file mode 100644 index 000000000000..32632b5d6e34 --- /dev/null +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/types/DataTypeTestUtils.scala @@ -0,0 +1,63 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.types + +/** + * Utility functions for working with DataTypes in tests. + */ +object DataTypeTestUtils { + + /** + * Instances of all [[IntegralType]]s. + */ + val integralType: Set[IntegralType] = Set( + ByteType, ShortType, IntegerType, LongType + ) + + /** + * Instances of all [[FractionalType]]s, including both fixed- and unlimited-precision + * decimal types. + */ + val fractionalTypes: Set[FractionalType] = Set( + DecimalType(precisionInfo = None), + DecimalType(2, 1), + DoubleType, + FloatType + ) + + /** + * Instances of all [[NumericType]]s. + */ + val numericTypes: Set[NumericType] = integralType ++ fractionalTypes + + /** + * Instances of all [[AtomicType]]s. + */ + val atomicTypes: Set[DataType] = numericTypes ++ Set( + BinaryType, + BooleanType, + DateType, + StringType, + TimestampType + ) + + /** + * Instances of [[ArrayType]] for all [[AtomicType]]s. Arrays of these types may contain null. + */ + val atomicArrayTypes: Set[ArrayType] = atomicTypes.map(ArrayType(_, containsNull = true)) +} diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/types/UTF8StringSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/types/UTF8StringSuite.scala deleted file mode 100644 index a22aa6f244c4..000000000000 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/types/UTF8StringSuite.scala +++ /dev/null @@ -1,70 +0,0 @@ -/* -* Licensed to the Apache Software Foundation (ASF) under one or more -* contributor license agreements. See the NOTICE file distributed with -* this work for additional information regarding copyright ownership. -* The ASF licenses this file to You under the Apache License, Version 2.0 -* (the "License"); you may not use this file except in compliance with -* the License. You may obtain a copy of the License at -* -* http://www.apache.org/licenses/LICENSE-2.0 -* -* Unless required by applicable law or agreed to in writing, software -* distributed under the License is distributed on an "AS IS" BASIS, -* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -* See the License for the specific language governing permissions and -* limitations under the License. -*/ - -package org.apache.spark.sql.types - -import org.scalatest.FunSuite - -// scalastyle:off -class UTF8StringSuite extends FunSuite { - test("basic") { - def check(str: String, len: Int) { - - assert(UTF8String(str).length == len) - assert(UTF8String(str.getBytes("utf8")).length() == len) - - assert(UTF8String(str) == str) - assert(UTF8String(str.getBytes("utf8")) == str) - assert(UTF8String(str).toString == str) - assert(UTF8String(str.getBytes("utf8")).toString == str) - assert(UTF8String(str.getBytes("utf8")) == UTF8String(str)) - - assert(UTF8String(str).hashCode() == UTF8String(str.getBytes("utf8")).hashCode()) - } - - check("hello", 5) - check("世 界", 3) - } - - test("contains") { - assert(UTF8String("hello").contains(UTF8String("ello"))) - assert(!UTF8String("hello").contains(UTF8String("vello"))) - assert(UTF8String("大千世界").contains(UTF8String("千世"))) - assert(!UTF8String("大千世界").contains(UTF8String("世千"))) - } - - test("prefix") { - assert(UTF8String("hello").startsWith(UTF8String("hell"))) - assert(!UTF8String("hello").startsWith(UTF8String("ell"))) - assert(UTF8String("大千世界").startsWith(UTF8String("大千"))) - assert(!UTF8String("大千世界").startsWith(UTF8String("千"))) - } - - test("suffix") { - assert(UTF8String("hello").endsWith(UTF8String("ello"))) - assert(!UTF8String("hello").endsWith(UTF8String("ellov"))) - assert(UTF8String("大千世界").endsWith(UTF8String("世界"))) - assert(!UTF8String("大千世界").endsWith(UTF8String("世"))) - } - - test("slice") { - assert(UTF8String("hello").slice(1, 3) == UTF8String("el")) - assert(UTF8String("大千世界").slice(0, 1) == UTF8String("大")) - assert(UTF8String("大千世界").slice(1, 3) == UTF8String("千世")) - assert(UTF8String("大千世界").slice(3, 5) == UTF8String("界")) - } -} diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/types/decimal/DecimalSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/types/decimal/DecimalSuite.scala index de6a2cd448c4..5f312964e5bf 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/types/decimal/DecimalSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/types/decimal/DecimalSuite.scala @@ -17,12 +17,13 @@ package org.apache.spark.sql.types.decimal +import org.apache.spark.SparkFunSuite import org.apache.spark.sql.types.Decimal -import org.scalatest.{PrivateMethodTester, FunSuite} +import org.scalatest.PrivateMethodTester import scala.language.postfixOps -class DecimalSuite extends FunSuite with PrivateMethodTester { +class DecimalSuite extends SparkFunSuite with PrivateMethodTester { test("creating decimals") { /** Check that a Decimal has the given string representation, precision and scale */ def checkDecimal(d: Decimal, string: String, precision: Int, scale: Int): Unit = { @@ -155,4 +156,20 @@ class DecimalSuite extends FunSuite with PrivateMethodTester { assert(Decimal(-100) % Decimal(3) === Decimal(-1)) assert(Decimal(100) % Decimal(0) === null) } + + test("set/setOrNull") { + assert(new Decimal().set(10L, 10, 0).toUnscaledLong === 10L) + assert(new Decimal().set(100L, 10, 0).toUnscaledLong === 100L) + assert(Decimal(Long.MaxValue, 100, 0).toUnscaledLong === Long.MaxValue) + } + + test("accurate precision after multiplication") { + val decimal = (Decimal(Long.MaxValue, 38, 0) * Decimal(Long.MaxValue, 38, 0)).toJavaBigDecimal + assert(decimal.unscaledValue.toString === "85070591730234615847396907784232501249") + } + + test("fix non-terminating decimal expansion problem") { + val decimal = Decimal(1.0, 10, 3) / Decimal(3.0, 10, 3) + assert(decimal.toString === "0.333") + } } diff --git a/sql/core/pom.xml b/sql/core/pom.xml index ffe95bb49188..8fc16928adbd 100644 --- a/sql/core/pom.xml +++ b/sql/core/pom.xml @@ -22,7 +22,7 @@ org.apache.spark spark-parent_2.10 - 1.4.0-SNAPSHOT + 1.5.0-SNAPSHOT ../../pom.xml @@ -41,6 +41,13 @@ spark-core_${scala.binary.version} ${project.version} + + org.apache.spark + spark-core_${scala.binary.version} + ${project.version} + test-jar + test + org.apache.spark spark-catalyst_${scala.binary.version} @@ -54,11 +61,11 @@ test - com.twitter + org.apache.parquet parquet-column - com.twitter + org.apache.parquet parquet-hadoop @@ -66,11 +73,6 @@ jackson-databind ${fasterxml.jackson.version} - - org.jodd - jodd-core - ${jodd.version} - junit junit diff --git a/sql/core/src/main/scala/org/apache/spark/sql/Column.scala b/sql/core/src/main/scala/org/apache/spark/sql/Column.scala index dc0aeea7c4ae..f201c8ea8a11 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/Column.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/Column.scala @@ -18,13 +18,12 @@ package org.apache.spark.sql import scala.language.implicitConversions -import scala.collection.JavaConversions._ import org.apache.spark.annotation.Experimental import org.apache.spark.Logging import org.apache.spark.sql.functions.lit import org.apache.spark.sql.catalyst.expressions._ -import org.apache.spark.sql.catalyst.analysis.{MultiAlias, UnresolvedAttribute, UnresolvedStar, UnresolvedExtractValue} +import org.apache.spark.sql.catalyst.analysis._ import org.apache.spark.sql.types._ @@ -349,7 +348,7 @@ class Column(protected[sql] val expr: Expression) extends Logging { * @group expr_ops * @since 1.4.0 */ - def when(condition: Column, value: Any):Column = this.expr match { + def when(condition: Column, value: Any): Column = this.expr match { case CaseWhen(branches: Seq[Expression]) => CaseWhen(branches ++ Seq(lit(condition).expr, lit(value).expr)) case _ => @@ -378,7 +377,7 @@ class Column(protected[sql] val expr: Expression) extends Logging { * @group expr_ops * @since 1.4.0 */ - def otherwise(value: Any):Column = this.expr match { + def otherwise(value: Any): Column = this.expr match { case CaseWhen(branches: Seq[Expression]) => if (branches.size % 2 == 0) { CaseWhen(branches :+ lit(value).expr) @@ -621,7 +620,7 @@ class Column(protected[sql] val expr: Expression) extends Logging { * @since 1.3.0 */ @scala.annotation.varargs - def in(list: Column*): Column = In(expr, list.map(_.expr)) + def in(list: Any*): Column = In(expr, list.map(lit(_).expr)) /** * SQL like expression. @@ -716,6 +715,18 @@ class Column(protected[sql] val expr: Expression) extends Logging { */ def endsWith(literal: String): Column = this.endsWith(lit(literal)) + /** + * Gives the column an alias. Same as `as`. + * {{{ + * // Renames colA to colB in select output. + * df.select($"colA".alias("colB")) + * }}} + * + * @group expr_ops + * @since 1.4.0 + */ + def alias(alias: String): Column = as(alias) + /** * Gives the column an alias. * {{{ @@ -889,6 +900,22 @@ class Column(protected[sql] val expr: Expression) extends Logging { */ def bitwiseXOR(other: Any): Column = BitwiseXor(expr, lit(other).expr) + /** + * Define a windowing column. + * + * {{{ + * val w = Window.partitionBy("name").orderBy("id") + * df.select( + * sum("price").over(w.rangeBetween(Long.MinValue, 2)), + * avg("price").over(w.rowsBetween(0, 4)) + * ) + * }}} + * + * @group expr_ops + * @since 1.4.0 + */ + def over(window: expressions.WindowSpec): Column = window.withAggregate(this) + } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala index 2e20c3d3f4ed..caad2da80b1e 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala @@ -18,10 +18,8 @@ package org.apache.spark.sql import java.io.CharArrayWriter -import java.sql.DriverManager import java.util.Properties -import scala.collection.JavaConversions._ import scala.language.implicitConversions import scala.reflect.ClassTag import scala.reflect.runtime.universe.TypeTag @@ -34,15 +32,14 @@ import org.apache.spark.annotation.{DeveloperApi, Experimental} import org.apache.spark.api.java.JavaRDD import org.apache.spark.api.python.SerDeUtil import org.apache.spark.rdd.RDD -import org.apache.spark.sql.catalyst.analysis.{MultiAlias, ResolvedStar, UnresolvedAttribute, UnresolvedRelation} +import org.apache.spark.sql.catalyst.analysis._ import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.plans.logical.{Filter, _} import org.apache.spark.sql.catalyst.plans.{Inner, JoinType} -import org.apache.spark.sql.catalyst.{expressions, CatalystTypeConverters, ScalaReflection, SqlParser} +import org.apache.spark.sql.catalyst.{CatalystTypeConverters, ScalaReflection, SqlParser} import org.apache.spark.sql.execution.{EvaluatePython, ExplainCommand, LogicalRDD} -import org.apache.spark.sql.jdbc.JDBCWriteDetails import org.apache.spark.sql.json.JacksonGenerator -import org.apache.spark.sql.sources.{CreateTableUsingAsSelect, ResolvedDataSource} +import org.apache.spark.sql.sources.CreateTableUsingAsSelect import org.apache.spark.sql.types._ import org.apache.spark.storage.StorageLevel import org.apache.spark.util.Utils @@ -59,14 +56,11 @@ private[sql] object DataFrame { * :: Experimental :: * A distributed collection of data organized into named columns. * - * A [[DataFrame]] is equivalent to a relational table in Spark SQL. There are multiple ways - * to create a [[DataFrame]]: + * A [[DataFrame]] is equivalent to a relational table in Spark SQL. The following example creates + * a [[DataFrame]] by pointing Spark SQL to a Parquet data set. * {{{ - * // Create a DataFrame from Parquet files - * val people = sqlContext.parquetFile("...") - * - * // Create a DataFrame from data sources - * val df = sqlContext.load("...", "json") + * val people = sqlContext.read.parquet("...") // in Scala + * DataFrame people = sqlContext.read().parquet("...") // in Java * }}} * * Once created, it can be manipulated using the various domain-specific-language (DSL) functions @@ -88,8 +82,8 @@ private[sql] object DataFrame { * A more concrete example in Scala: * {{{ * // To create DataFrame using SQLContext - * val people = sqlContext.parquetFile("...") - * val department = sqlContext.parquetFile("...") + * val people = sqlContext.read.parquet("...") + * val department = sqlContext.read.parquet("...") * * people.filter("age > 30") * .join(department, people("deptId") === department("id")) @@ -100,8 +94,8 @@ private[sql] object DataFrame { * and in Java: * {{{ * // To create DataFrame using SQLContext - * DataFrame people = sqlContext.parquetFile("..."); - * DataFrame department = sqlContext.parquetFile("..."); + * DataFrame people = sqlContext.read().parquet("..."); + * DataFrame department = sqlContext.read().parquet("..."); * * people.filter("age".gt(30)) * .join(department, people.col("deptId").equalTo(department("id"))) @@ -174,23 +168,35 @@ class DataFrame private[sql]( /** * Internal API for Python - * @param numRows Number of rows to show + * @param _numRows Number of rows to show + * @param truncate Whether truncate long strings and align cells right */ - private[sql] def showString(numRows: Int): String = { + private[sql] def showString(_numRows: Int, truncate: Boolean = true): String = { + val numRows = _numRows.max(0) val sb = new StringBuilder - val data = take(numRows) + val takeResult = take(numRows + 1) + val hasMoreData = takeResult.length > numRows + val data = takeResult.take(numRows) val numCols = schema.fieldNames.length + // For array values, replace Seq and Array with square brackets // For cells that are beyond 20 characters, replace it with the first 17 and "..." val rows: Seq[Seq[String]] = schema.fieldNames.toSeq +: data.map { row => row.toSeq.map { cell => - val str = if (cell == null) "null" else cell.toString - if (str.length > 20) str.substring(0, 17) + "..." else str + val str = cell match { + case null => "null" + case array: Array[_] => array.mkString("[", ", ", "]") + case seq: Seq[_] => seq.mkString("[", ", ", "]") + case _ => cell.toString + } + if (truncate && str.length > 20) str.substring(0, 17) + "..." else str }: Seq[String] } + // Initialise the width of each column to a minimum value of '3' + val colWidths = Array.fill(numCols)(3) + // Compute the width of each column - val colWidths = Array.fill(numCols)(0) for (row <- rows) { for ((cell, i) <- row.zipWithIndex) { colWidths(i) = math.max(colWidths(i), cell.length) @@ -202,7 +208,11 @@ class DataFrame private[sql]( // column names rows.head.zipWithIndex.map { case (cell, i) => - StringUtils.leftPad(cell.toString, colWidths(i)) + if (truncate) { + StringUtils.leftPad(cell, colWidths(i)) + } else { + StringUtils.rightPad(cell, colWidths(i)) + } }.addString(sb, "|", "|", "|\n") sb.append(sep) @@ -210,11 +220,22 @@ class DataFrame private[sql]( // data rows.tail.map { _.zipWithIndex.map { case (cell, i) => - StringUtils.leftPad(cell.toString, colWidths(i)) + if (truncate) { + StringUtils.leftPad(cell.toString, colWidths(i)) + } else { + StringUtils.rightPad(cell.toString, colWidths(i)) + } }.addString(sb, "|", "|", "|\n") } sb.append(sep) + + // For Data that has more than "numRows" records + if (hasMoreData) { + val rowsString = if (numRows == 1) "row" else "rows" + sb.append(s"only showing top $numRows ${rowsString}\n") + } + sb.toString() } @@ -227,10 +248,6 @@ class DataFrame private[sql]( } } - /** Left here for backward compatibility. */ - @deprecated("1.3.0", "use toDF") - def toSchemaRDD: DataFrame = this - /** * Returns the object itself. * @group basic @@ -261,7 +278,7 @@ class DataFrame private[sql]( val newCols = logicalPlan.output.zip(colNames).map { case (oldAttribute, newName) => Column(oldAttribute).as(newName) } - select(newCols :_*) + select(newCols : _*) } /** @@ -323,7 +340,8 @@ class DataFrame private[sql]( def isLocal: Boolean = logicalPlan.isInstanceOf[LocalRelation] /** - * Displays the [[DataFrame]] in a tabular form. For example: + * Displays the [[DataFrame]] in a tabular form. Strings more than 20 characters will be + * truncated, and all cells will be aligned right. For example: * {{{ * year month AVG('Adj Close) MAX('Adj Close) * 1980 12 0.503218 0.595103 @@ -337,15 +355,46 @@ class DataFrame private[sql]( * @group action * @since 1.3.0 */ - def show(numRows: Int): Unit = println(showString(numRows)) + def show(numRows: Int): Unit = show(numRows, true) /** - * Displays the top 20 rows of [[DataFrame]] in a tabular form. + * Displays the top 20 rows of [[DataFrame]] in a tabular form. Strings more than 20 characters + * will be truncated, and all cells will be aligned right. * @group action * @since 1.3.0 */ def show(): Unit = show(20) + /** + * Displays the top 20 rows of [[DataFrame]] in a tabular form. + * + * @param truncate Whether truncate long strings. If true, strings more than 20 characters will + * be truncated and all cells will be aligned right + * + * @group action + * @since 1.5.0 + */ + def show(truncate: Boolean): Unit = show(20, truncate) + + /** + * Displays the [[DataFrame]] in a tabular form. For example: + * {{{ + * year month AVG('Adj Close) MAX('Adj Close) + * 1980 12 0.503218 0.595103 + * 1981 01 0.523289 0.570307 + * 1982 02 0.436504 0.475256 + * 1983 03 0.410516 0.442194 + * 1984 04 0.450090 0.483521 + * }}} + * @param numRows Number of rows to show + * @param truncate Whether truncate long strings. If true, strings more than 20 characters will + * be truncated and all cells will be aligned right + * + * @group action + * @since 1.5.0 + */ + def show(numRows: Int, truncate: Boolean): Unit = println(showString(numRows, truncate)) + /** * Returns a [[DataFrameNaFunctions]] for working with missing data. * {{{ @@ -404,22 +453,50 @@ class DataFrame private[sql]( * @since 1.4.0 */ def join(right: DataFrame, usingColumn: String): DataFrame = { + join(right, Seq(usingColumn)) + } + + /** + * Inner equi-join with another [[DataFrame]] using the given columns. + * + * Different from other join functions, the join columns will only appear once in the output, + * i.e. similar to SQL's `JOIN USING` syntax. + * + * {{{ + * // Joining df1 and df2 using the columns "user_id" and "user_name" + * df1.join(df2, Seq("user_id", "user_name")) + * }}} + * + * Note that if you perform a self-join using this function without aliasing the input + * [[DataFrame]]s, you will NOT be able to reference any columns after the join, since + * there is no way to disambiguate which side of the join you would like to reference. + * + * @param right Right side of the join operation. + * @param usingColumns Names of the columns to join on. This columns must exist on both sides. + * @group dfops + * @since 1.4.0 + */ + def join(right: DataFrame, usingColumns: Seq[String]): DataFrame = { // Analyze the self join. The assumption is that the analyzer will disambiguate left vs right // by creating a new instance for one of the branch. val joined = sqlContext.executePlan( Join(logicalPlan, right.logicalPlan, joinType = Inner, None)).analyzed.asInstanceOf[Join] - // Project only one of the join column. - val joinedCol = joined.right.resolve(usingColumn) + // Project only one of the join columns. + val joinedCols = usingColumns.map(col => joined.right.resolve(col)) + val condition = usingColumns.map { col => + catalyst.expressions.EqualTo(joined.left.resolve(col), joined.right.resolve(col)) + }.reduceLeftOption[catalyst.expressions.BinaryExpression] { (cond, eqTo) => + catalyst.expressions.And(cond, eqTo) + } + Project( - joined.output.filterNot(_ == joinedCol), + joined.output.filterNot(joinedCols.contains(_)), Join( joined.left, joined.right, joinType = Inner, - Some(expressions.EqualTo( - joined.left.resolve(usingColumn), - joined.right.resolve(usingColumn)))) + condition) ) } @@ -486,8 +563,9 @@ class DataFrame private[sql]( // By the time we get here, since we have already run analysis, all attributes should've been // resolved and become AttributeReference. val cond = plan.condition.map { _.transform { - case expressions.EqualTo(a: AttributeReference, b: AttributeReference) if a.sameRef(b) => - expressions.EqualTo(plan.left.resolve(a.name), plan.right.resolve(b.name)) + case catalyst.expressions.EqualTo(a: AttributeReference, b: AttributeReference) + if a.sameRef(b) => + catalyst.expressions.EqualTo(plan.left.resolve(a.name), plan.right.resolve(b.name)) }} plan.copy(condition = cond) } @@ -505,7 +583,7 @@ class DataFrame private[sql]( */ @scala.annotation.varargs def sort(sortCol: String, sortCols: String*): DataFrame = { - sort((sortCol +: sortCols).map(apply) :_*) + sort((sortCol +: sortCols).map(apply) : _*) } /** @@ -536,7 +614,7 @@ class DataFrame private[sql]( * @since 1.3.0 */ @scala.annotation.varargs - def orderBy(sortCol: String, sortCols: String*): DataFrame = sort(sortCol, sortCols :_*) + def orderBy(sortCol: String, sortCols: String*): DataFrame = sort(sortCol, sortCols : _*) /** * Returns a new [[DataFrame]] sorted by the given expressions. @@ -545,7 +623,7 @@ class DataFrame private[sql]( * @since 1.3.0 */ @scala.annotation.varargs - def orderBy(sortExprs: Column*): DataFrame = sort(sortExprs :_*) + def orderBy(sortExprs: Column*): DataFrame = sort(sortExprs : _*) /** * Selects column based on the column name and return it as a [[Column]]. @@ -582,7 +660,7 @@ class DataFrame private[sql]( def as(alias: Symbol): DataFrame = as(alias.name) /** - * Selects a set of expressions. + * Selects a set of column based expressions. * {{{ * df.select($"colA", $"colB" + 1) * }}} @@ -592,6 +670,10 @@ class DataFrame private[sql]( @scala.annotation.varargs def select(cols: Column*): DataFrame = { val namedExpressions = cols.map { + // Wrap UnresolvedAttribute with UnresolvedAlias, as when we resolve UnresolvedAttribute, we + // will remove intermediate Alias for ExtractValue chain, and we need to alias it again to + // make it a NamedExpression. + case Column(u: UnresolvedAttribute) => UnresolvedAlias(u) case Column(expr: NamedExpression) => expr // Leave an unaliased explode with an empty list of names since the analzyer will generate the // correct defaults after the nested expression's type has been resolved. @@ -616,7 +698,7 @@ class DataFrame private[sql]( * @since 1.3.0 */ @scala.annotation.varargs - def select(col: String, cols: String*): DataFrame = select((col +: cols).map(Column(_)) :_*) + def select(col: String, cols: String*): DataFrame = select((col +: cols).map(Column(_)) : _*) /** * Selects a set of SQL expressions. This is a variant of `select` that accepts @@ -641,7 +723,6 @@ class DataFrame private[sql]( * // The following are equivalent: * peopleDf.filter($"age" > 15) * peopleDf.where($"age" > 15) - * peopleDf($"age" > 15) * }}} * @group dfops * @since 1.3.0 @@ -666,13 +747,24 @@ class DataFrame private[sql]( * // The following are equivalent: * peopleDf.filter($"age" > 15) * peopleDf.where($"age" > 15) - * peopleDf($"age" > 15) * }}} * @group dfops * @since 1.3.0 */ def where(condition: Column): DataFrame = filter(condition) + /** + * Filters rows using the given SQL expression. + * {{{ + * peopleDf.where("age > 15") + * }}} + * @group dfops + * @since 1.5.0 + */ + def where(conditionExpr: String): DataFrame = { + filter(Column(new SqlParser().parseExpression(conditionExpr))) + } + /** * Groups the [[DataFrame]] using the specified columns, so we can run aggregation on them. * See [[GroupedData]] for all the available aggregate functions. @@ -691,7 +783,53 @@ class DataFrame private[sql]( * @since 1.3.0 */ @scala.annotation.varargs - def groupBy(cols: Column*): GroupedData = new GroupedData(this, cols.map(_.expr)) + def groupBy(cols: Column*): GroupedData = { + GroupedData(this, cols.map(_.expr), GroupedData.GroupByType) + } + + /** + * Create a multi-dimensional rollup for the current [[DataFrame]] using the specified columns, + * so we can run aggregation on them. + * See [[GroupedData]] for all the available aggregate functions. + * + * {{{ + * // Compute the average for all numeric columns rolluped by department and group. + * df.rollup($"department", $"group").avg() + * + * // Compute the max age and average salary, rolluped by department and gender. + * df.rollup($"department", $"gender").agg(Map( + * "salary" -> "avg", + * "age" -> "max" + * )) + * }}} + * @group dfops + * @since 1.4.0 + */ + @scala.annotation.varargs + def rollup(cols: Column*): GroupedData = { + GroupedData(this, cols.map(_.expr), GroupedData.RollupType) + } + + /** + * Create a multi-dimensional cube for the current [[DataFrame]] using the specified columns, + * so we can run aggregation on them. + * See [[GroupedData]] for all the available aggregate functions. + * + * {{{ + * // Compute the average for all numeric columns cubed by department and group. + * df.cube($"department", $"group").avg() + * + * // Compute the max age and average salary, cubed by department and gender. + * df.cube($"department", $"gender").agg(Map( + * "salary" -> "avg", + * "age" -> "max" + * )) + * }}} + * @group dfops + * @since 1.4.0 + */ + @scala.annotation.varargs + def cube(cols: Column*): GroupedData = GroupedData(this, cols.map(_.expr), GroupedData.CubeType) /** * Groups the [[DataFrame]] using the specified columns, so we can run aggregation on them. @@ -716,7 +854,61 @@ class DataFrame private[sql]( @scala.annotation.varargs def groupBy(col1: String, cols: String*): GroupedData = { val colNames: Seq[String] = col1 +: cols - new GroupedData(this, colNames.map(colName => resolve(colName))) + GroupedData(this, colNames.map(colName => resolve(colName)), GroupedData.GroupByType) + } + + /** + * Create a multi-dimensional rollup for the current [[DataFrame]] using the specified columns, + * so we can run aggregation on them. + * See [[GroupedData]] for all the available aggregate functions. + * + * This is a variant of rollup that can only group by existing columns using column names + * (i.e. cannot construct expressions). + * + * {{{ + * // Compute the average for all numeric columns rolluped by department and group. + * df.rollup("department", "group").avg() + * + * // Compute the max age and average salary, rolluped by department and gender. + * df.rollup($"department", $"gender").agg(Map( + * "salary" -> "avg", + * "age" -> "max" + * )) + * }}} + * @group dfops + * @since 1.4.0 + */ + @scala.annotation.varargs + def rollup(col1: String, cols: String*): GroupedData = { + val colNames: Seq[String] = col1 +: cols + GroupedData(this, colNames.map(colName => resolve(colName)), GroupedData.RollupType) + } + + /** + * Create a multi-dimensional cube for the current [[DataFrame]] using the specified columns, + * so we can run aggregation on them. + * See [[GroupedData]] for all the available aggregate functions. + * + * This is a variant of cube that can only group by existing columns using column names + * (i.e. cannot construct expressions). + * + * {{{ + * // Compute the average for all numeric columns cubed by department and group. + * df.cube("department", "group").avg() + * + * // Compute the max age and average salary, cubed by department and gender. + * df.cube($"department", $"gender").agg(Map( + * "salary" -> "avg", + * "age" -> "max" + * )) + * }}} + * @group dfops + * @since 1.4.0 + */ + @scala.annotation.varargs + def cube(col1: String, cols: String*): GroupedData = { + val colNames: Seq[String] = col1 +: cols + GroupedData(this, colNames.map(colName => resolve(colName)), GroupedData.CubeType) } /** @@ -730,7 +922,7 @@ class DataFrame private[sql]( * @since 1.3.0 */ def agg(aggExpr: (String, String), aggExprs: (String, String)*): DataFrame = { - groupBy().agg(aggExpr, aggExprs :_*) + groupBy().agg(aggExpr, aggExprs : _*) } /** @@ -768,7 +960,7 @@ class DataFrame private[sql]( * @since 1.3.0 */ @scala.annotation.varargs - def agg(expr: Column, exprs: Column*): DataFrame = groupBy().agg(expr, exprs :_*) + def agg(expr: Column, exprs: Column*): DataFrame = groupBy().agg(expr, exprs : _*) /** * Returns a new [[DataFrame]] by taking the first `n` rows. The difference between this function @@ -892,9 +1084,10 @@ class DataFrame private[sql]( val elementTypes = schema.toAttributes.map { attr => (attr.dataType, attr.nullable) } val names = schema.toAttributes.map(_.name) + val convert = CatalystTypeConverters.createToCatalystConverter(schema) val rowFunction = - f.andThen(_.map(CatalystTypeConverters.convertToCatalyst(_, schema).asInstanceOf[Row])) + f.andThen(_.map(convert(_).asInstanceOf[InternalRow])) val generator = UserDefinedGenerator(elementTypes, rowFunction, input.map(_.expr)) Generate(generator, join = true, outer = false, @@ -907,7 +1100,7 @@ class DataFrame private[sql]( * columns of the input row are implicitly joined with each value that is output by the function. * * {{{ - * df.explode("words", "word")(words: String => words.split(" ")) + * df.explode("words", "word"){words: String => words.split(" ")} * }}} * @group dfops * @since 1.3.0 @@ -920,8 +1113,9 @@ class DataFrame private[sql]( val elementTypes = attributes.map { attr => (attr.dataType, attr.nullable) } val names = attributes.map(_.name) - def rowFunction(row: Row): TraversableOnce[Row] = { - f(row(0).asInstanceOf[A]).map(o => Row(CatalystTypeConverters.convertToCatalyst(o, dataType))) + def rowFunction(row: Row): TraversableOnce[InternalRow] = { + val convert = CatalystTypeConverters.createToCatalystConverter(dataType) + f(row(0).asInstanceOf[A]).map(o => InternalRow(convert(o))) } val generator = UserDefinedGenerator(elementTypes, rowFunction, apply(inputColumn).expr :: Nil) @@ -944,7 +1138,7 @@ class DataFrame private[sql]( val name = field.name if (resolver(name, colName)) col.as(colName) else Column(name) } - select(colNames :_*) + select(colNames : _*) } else { select(Column("*"), col.as(colName)) } @@ -990,6 +1184,22 @@ class DataFrame private[sql]( } } + /** + * Returns a new [[DataFrame]] with a column dropped. + * This version of drop accepts a Column rather than a name. + * This is a no-op if the DataFrame doesn't have a column + * with an equivalent expression. + * @group dfops + * @since 1.4.1 + */ + def drop(col: Column): DataFrame = { + val attrs = this.logicalPlan.output + val colsAfterDrop = attrs.filter { attr => + attr != col.expr + }.map(attr => Column(attr)) + select(colsAfterDrop : _*) + } + /** * Returns a new [[DataFrame]] that contains only the unique rows from this [[DataFrame]]. * This is an alias for `distinct`. @@ -1067,25 +1277,26 @@ class DataFrame private[sql]( val outputCols = (if (cols.isEmpty) numericColumns.map(_.prettyString) else cols).toList - val ret: Seq[Row] = if (outputCols.nonEmpty) { + val ret: Seq[InternalRow] = if (outputCols.nonEmpty) { val aggExprs = statistics.flatMap { case (_, colToAgg) => - outputCols.map(c => Column(colToAgg(Column(c).expr)).as(c)) + outputCols.map(c => Column(Cast(colToAgg(Column(c).expr), StringType)).as(c)) } val row = agg(aggExprs.head, aggExprs.tail: _*).head().toSeq // Pivot the data so each summary is one row row.grouped(outputCols.size).toSeq.zip(statistics).map { - case (aggregation, (statistic, _)) => Row(statistic :: aggregation.toList: _*) + case (aggregation, (statistic, _)) => + InternalRow(statistic :: aggregation.toList: _*) } } else { // If there are no output columns, just output a single column that contains the stats. - statistics.map { case (name, _) => Row(name) } + statistics.map { case (name, _) => InternalRow(name) } } - // The first column is string type, and the rest are double type. + // All columns are string type val schema = StructType( - StructField("summary", StringType) :: outputCols.map(StructField(_, DoubleType))).toAttributes + StructField("summary", StringType) :: outputCols.map(StructField(_, StringType))).toAttributes LocalRelation(schema, ret) } @@ -1167,7 +1378,7 @@ class DataFrame private[sql]( * @group action * @since 1.3.0 */ - override def collectAsList(): java.util.List[Row] = java.util.Arrays.asList(rdd.collect() :_*) + override def collectAsList(): java.util.List[Row] = java.util.Arrays.asList(rdd.collect() : _*) /** * Returns the number of rows in the [[DataFrame]]. @@ -1203,7 +1414,7 @@ class DataFrame private[sql]( * @group dfops * @since 1.3.0 */ - override def distinct: DataFrame = Distinct(logicalPlan) + override def distinct: DataFrame = dropDuplicates() /** * @group basic @@ -1258,7 +1469,7 @@ class DataFrame private[sql]( lazy val rdd: RDD[Row] = { // use a local variable to make sure the map closure doesn't capture the whole DataFrame val schema = this.schema - queryExecution.executedPlan.execute().mapPartitions { rows => + queryExecution.toRdd.mapPartitions { rows => val converter = CatalystTypeConverters.createToScalaConverter(schema) rows.map(converter(_).asInstanceOf[Row]) } @@ -1289,23 +1500,123 @@ class DataFrame private[sql]( sqlContext.registerDataFrameAsTable(this, tableName) } + /** + * :: Experimental :: + * Interface for saving the content of the [[DataFrame]] out into external storage. + * + * @group output + * @since 1.4.0 + */ + @Experimental + def write: DataFrameWriter = new DataFrameWriter(this) + + /** + * Returns the content of the [[DataFrame]] as a RDD of JSON strings. + * @group rdd + * @since 1.3.0 + */ + def toJSON: RDD[String] = { + val rowSchema = this.schema + this.mapPartitions { iter => + val writer = new CharArrayWriter() + // create the Generator without separator inserted between 2 records + val gen = new JsonFactory().createGenerator(writer).setRootValueSeparator(null) + + new Iterator[String] { + override def hasNext: Boolean = iter.hasNext + override def next(): String = { + JacksonGenerator(rowSchema, gen)(iter.next()) + gen.flush() + + val json = writer.toString + if (hasNext) { + writer.reset() + } else { + gen.close() + } + + json + } + } + } + } + + //////////////////////////////////////////////////////////////////////////// + // for Python API + //////////////////////////////////////////////////////////////////////////// + + /** + * Converts a JavaRDD to a PythonRDD. + */ + protected[sql] def javaToPython: JavaRDD[Array[Byte]] = { + val fieldTypes = schema.fields.map(_.dataType) + val jrdd = rdd.map(EvaluatePython.rowToArray(_, fieldTypes)).toJavaRDD() + SerDeUtil.javaToPython(jrdd) + } + + //////////////////////////////////////////////////////////////////////////// + //////////////////////////////////////////////////////////////////////////// + // Deprecated methods + //////////////////////////////////////////////////////////////////////////// + //////////////////////////////////////////////////////////////////////////// + + /** + * @deprecated As of 1.3.0, replaced by `toDF()`. + */ + @deprecated("use toDF", "1.3.0") + def toSchemaRDD: DataFrame = this + + /** + * Save this [[DataFrame]] to a JDBC database at `url` under the table name `table`. + * This will run a `CREATE TABLE` and a bunch of `INSERT INTO` statements. + * If you pass `true` for `allowExisting`, it will drop any table with the + * given name; if you pass `false`, it will throw if the table already + * exists. + * @group output + * @deprecated As of 1.340, replaced by `write().jdbc()`. + */ + @deprecated("Use write.jdbc()", "1.4.0") + def createJDBCTable(url: String, table: String, allowExisting: Boolean): Unit = { + val w = if (allowExisting) write.mode(SaveMode.Overwrite) else write + w.jdbc(url, table, new Properties) + } + + /** + * Save this [[DataFrame]] to a JDBC database at `url` under the table name `table`. + * Assumes the table already exists and has a compatible schema. If you + * pass `true` for `overwrite`, it will `TRUNCATE` the table before + * performing the `INSERT`s. + * + * The table must already exist on the database. It must have a schema + * that is compatible with the schema of this RDD; inserting the rows of + * the RDD in order via the simple statement + * `INSERT INTO table VALUES (?, ?, ..., ?)` should not fail. + * @group output + * @deprecated As of 1.4.0, replaced by `write().jdbc()`. + */ + @deprecated("Use write.jdbc()", "1.4.0") + def insertIntoJDBC(url: String, table: String, overwrite: Boolean): Unit = { + val w = if (overwrite) write.mode(SaveMode.Overwrite) else write + w.jdbc(url, table, new Properties) + } + /** * Saves the contents of this [[DataFrame]] as a parquet file, preserving the schema. * Files that are written out using this method can be read back in as a [[DataFrame]] * using the `parquetFile` function in [[SQLContext]]. * @group output - * @since 1.3.0 + * @deprecated As of 1.4.0, replaced by `write().parquet()`. */ + @deprecated("Use write.parquet(path)", "1.4.0") def saveAsParquetFile(path: String): Unit = { if (sqlContext.conf.parquetUseDataSourceApi) { - save("org.apache.spark.sql.parquet", SaveMode.ErrorIfExists, Map("path" -> path)) + write.format("parquet").mode(SaveMode.ErrorIfExists).save(path) } else { sqlContext.executePlan(WriteToFile(path, logicalPlan)).toRdd } } /** - * :: Experimental :: * Creates a table from the the contents of this DataFrame. * It will use the default data source configured by spark.sql.sources.default. * This will fail if the table already exists. @@ -1318,15 +1629,14 @@ class DataFrame private[sql]( * Also note that while this function can persist the table metadata into Hive's metastore, * the table will NOT be accessible from Hive, until SPARK-7550 is resolved. * @group output - * @since 1.3.0 + * @deprecated As of 1.4.0, replaced by `write().saveAsTable(tableName)`. */ - @Experimental + @deprecated("Use write.saveAsTable(tableName)", "1.4.0") def saveAsTable(tableName: String): Unit = { - saveAsTable(tableName, SaveMode.ErrorIfExists) + write.mode(SaveMode.ErrorIfExists).saveAsTable(tableName) } /** - * :: Experimental :: * Creates a table from the the contents of this DataFrame, using the default data source * configured by spark.sql.sources.default and [[SaveMode.ErrorIfExists]] as the save mode. * @@ -1338,22 +1648,14 @@ class DataFrame private[sql]( * Also note that while this function can persist the table metadata into Hive's metastore, * the table will NOT be accessible from Hive, until SPARK-7550 is resolved. * @group output - * @since 1.3.0 + * @deprecated As of 1.4.0, replaced by `write().mode(mode).saveAsTable(tableName)`. */ - @Experimental + @deprecated("Use write.mode(mode).saveAsTable(tableName)", "1.4.0") def saveAsTable(tableName: String, mode: SaveMode): Unit = { - if (sqlContext.catalog.tableExists(Seq(tableName)) && mode == SaveMode.Append) { - // If table already exists and the save mode is Append, - // we will just call insertInto to append the contents of this DataFrame. - insertInto(tableName, overwrite = false) - } else { - val dataSourceName = sqlContext.conf.defaultDataSourceName - saveAsTable(tableName, dataSourceName, mode) - } + write.mode(mode).saveAsTable(tableName) } /** - * :: Experimental :: * Creates a table at the given path from the the contents of this DataFrame * based on a given data source and a set of options, * using [[SaveMode.ErrorIfExists]] as the save mode. @@ -1366,11 +1668,11 @@ class DataFrame private[sql]( * Also note that while this function can persist the table metadata into Hive's metastore, * the table will NOT be accessible from Hive, until SPARK-7550 is resolved. * @group output - * @since 1.3.0 + * @deprecated As of 1.4.0, replaced by `write().format(source).saveAsTable(tableName)`. */ - @Experimental + @deprecated("Use write.format(source).saveAsTable(tableName)", "1.4.0") def saveAsTable(tableName: String, source: String): Unit = { - saveAsTable(tableName, source, SaveMode.ErrorIfExists) + write.format(source).saveAsTable(tableName) } /** @@ -1386,15 +1688,14 @@ class DataFrame private[sql]( * Also note that while this function can persist the table metadata into Hive's metastore, * the table will NOT be accessible from Hive, until SPARK-7550 is resolved. * @group output - * @since 1.3.0 + * @deprecated As of 1.4.0, replaced by `write().mode(mode).saveAsTable(tableName)`. */ - @Experimental + @deprecated("Use write.format(source).mode(mode).saveAsTable(tableName)", "1.4.0") def saveAsTable(tableName: String, source: String, mode: SaveMode): Unit = { - saveAsTable(tableName, source, mode, Map.empty[String, String]) + write.format(source).mode(mode).saveAsTable(tableName) } /** - * :: Experimental :: * Creates a table at the given path from the the contents of this DataFrame * based on a given data source, [[SaveMode]] specified by mode, and a set of options. * @@ -1406,42 +1707,20 @@ class DataFrame private[sql]( * Also note that while this function can persist the table metadata into Hive's metastore, * the table will NOT be accessible from Hive, until SPARK-7550 is resolved. * @group output - * @since 1.3.0 + * @deprecated As of 1.4.0, replaced by + * `write().format(source).mode(mode).options(options).saveAsTable(tableName)`. */ - @Experimental + @deprecated("Use write.format(source).mode(mode).options(options).saveAsTable(tableName)", + "1.4.0") def saveAsTable( tableName: String, source: String, mode: SaveMode, options: java.util.Map[String, String]): Unit = { - saveAsTable(tableName, source, mode, options.toMap) - } - - /** - * :: Experimental :: - * Creates a table at the given path from the the contents of this DataFrame - * based on a given data source, [[SaveMode]] specified by mode, a set of options, and a list of - * partition columns. - * - * Note that this currently only works with DataFrames that are created from a HiveContext as - * there is no notion of a persisted catalog in a standard SQL context. Instead you can write - * an RDD out to a parquet file, and then register that file as a table. This "table" can then - * be the target of an `insertInto`. - * @group output - * @since 1.4.0 - */ - @Experimental - def saveAsTable( - tableName: String, - source: String, - mode: SaveMode, - options: java.util.Map[String, String], - partitionColumns: java.util.List[String]): Unit = { - saveAsTable(tableName, source, mode, options.toMap, partitionColumns) + write.format(source).mode(mode).options(options).saveAsTable(tableName) } /** - * :: Experimental :: * (Scala-specific) * Creates a table from the the contents of this DataFrame based on a given data source, * [[SaveMode]] specified by mode, and a set of options. @@ -1454,328 +1733,123 @@ class DataFrame private[sql]( * Also note that while this function can persist the table metadata into Hive's metastore, * the table will NOT be accessible from Hive, until SPARK-7550 is resolved. * @group output - * @since 1.3.0 + * @deprecated As of 1.4.0, replaced by + * `write().format(source).mode(mode).options(options).saveAsTable(tableName)`. */ - @Experimental + @deprecated("Use write.format(source).mode(mode).options(options).saveAsTable(tableName)", + "1.4.0") def saveAsTable( tableName: String, source: String, mode: SaveMode, options: Map[String, String]): Unit = { - val cmd = - CreateTableUsingAsSelect( - tableName, - source, - temporary = false, - Array.empty[String], - mode, - options, - logicalPlan) - - sqlContext.executePlan(cmd).toRdd - } - - /** - * :: Experimental :: - * Creates a table at the given path from the the contents of this DataFrame - * based on a given data source, [[SaveMode]] specified by mode, a set of options, and a list of - * partition columns. - * - * Note that this currently only works with DataFrames that are created from a HiveContext as - * there is no notion of a persisted catalog in a standard SQL context. Instead you can write - * an RDD out to a parquet file, and then register that file as a table. This "table" can then - * be the target of an `insertInto`. - * @group output - * @since 1.4.0 - */ - @Experimental - def saveAsTable( - tableName: String, - source: String, - mode: SaveMode, - options: Map[String, String], - partitionColumns: Seq[String]): Unit = { - sqlContext.executePlan( - CreateTableUsingAsSelect( - tableName, - source, - temporary = false, - partitionColumns.toArray, - mode, - options, - logicalPlan)).toRdd + write.format(source).mode(mode).options(options).saveAsTable(tableName) } /** - * :: Experimental :: * Saves the contents of this DataFrame to the given path, * using the default data source configured by spark.sql.sources.default and * [[SaveMode.ErrorIfExists]] as the save mode. * @group output - * @since 1.3.0 + * @deprecated As of 1.4.0, replaced by `write().save(path)`. */ - @Experimental + @deprecated("Use write.save(path)", "1.4.0") def save(path: String): Unit = { - save(path, SaveMode.ErrorIfExists) + write.save(path) } /** - * :: Experimental :: * Saves the contents of this DataFrame to the given path and [[SaveMode]] specified by mode, * using the default data source configured by spark.sql.sources.default. * @group output - * @since 1.3.0 + * @deprecated As of 1.4.0, replaced by `write().mode(mode).save(path)`. */ - @Experimental + @deprecated("Use write.mode(mode).save(path)", "1.4.0") def save(path: String, mode: SaveMode): Unit = { - val dataSourceName = sqlContext.conf.defaultDataSourceName - save(path, dataSourceName, mode) + write.mode(mode).save(path) } /** - * :: Experimental :: * Saves the contents of this DataFrame to the given path based on the given data source, * using [[SaveMode.ErrorIfExists]] as the save mode. * @group output - * @since 1.3.0 + * @deprecated As of 1.4.0, replaced by `write().format(source).save(path)`. */ - @Experimental + @deprecated("Use write.format(source).save(path)", "1.4.0") def save(path: String, source: String): Unit = { - save(source, SaveMode.ErrorIfExists, Map("path" -> path)) + write.format(source).save(path) } /** - * :: Experimental :: * Saves the contents of this DataFrame to the given path based on the given data source and * [[SaveMode]] specified by mode. * @group output - * @since 1.3.0 + * @deprecated As of 1.4.0, replaced by `write().format(source).mode(mode).save(path)`. */ - @Experimental + @deprecated("Use write.format(source).mode(mode).save(path)", "1.4.0") def save(path: String, source: String, mode: SaveMode): Unit = { - save(source, mode, Map("path" -> path)) + write.format(source).mode(mode).save(path) } /** - * :: Experimental :: * Saves the contents of this DataFrame based on the given data source, * [[SaveMode]] specified by mode, and a set of options. * @group output - * @since 1.3.0 + * @deprecated As of 1.4.0, replaced by + * `write().format(source).mode(mode).options(options).save(path)`. */ - @Experimental + @deprecated("Use write.format(source).mode(mode).options(options).save()", "1.4.0") def save( source: String, mode: SaveMode, options: java.util.Map[String, String]): Unit = { - save(source, mode, options.toMap) - } - - /** - * :: Experimental :: - * Saves the contents of this DataFrame to the given path based on the given data source, - * [[SaveMode]] specified by mode, and partition columns specified by `partitionColumns`. - * @group output - * @since 1.4.0 - */ - @Experimental - def save( - source: String, - mode: SaveMode, - options: java.util.Map[String, String], - partitionColumns: java.util.List[String]): Unit = { - save(source, mode, options.toMap, partitionColumns) + write.format(source).mode(mode).options(options).save() } /** - * :: Experimental :: * (Scala-specific) * Saves the contents of this DataFrame based on the given data source, * [[SaveMode]] specified by mode, and a set of options * @group output - * @since 1.3.0 + * @deprecated As of 1.4.0, replaced by + * `write().format(source).mode(mode).options(options).save(path)`. */ - @Experimental + @deprecated("Use write.format(source).mode(mode).options(options).save()", "1.4.0") def save( source: String, mode: SaveMode, options: Map[String, String]): Unit = { - ResolvedDataSource(sqlContext, source, Array.empty[String], mode, options, this) + write.format(source).mode(mode).options(options).save() } - /** - * :: Experimental :: - * Saves the contents of this DataFrame to the given path based on the given data source, - * [[SaveMode]] specified by mode, and partition columns specified by `partitionColumns`. - * @group output - * @since 1.4.0 - */ - @Experimental - def save( - source: String, - mode: SaveMode, - options: Map[String, String], - partitionColumns: Seq[String]): Unit = { - ResolvedDataSource(sqlContext, source, partitionColumns.toArray, mode, options, this) - } /** - * :: Experimental :: * Adds the rows from this RDD to the specified table, optionally overwriting the existing data. * @group output - * @since 1.3.0 + * @deprecated As of 1.4.0, replaced by + * `write().mode(SaveMode.Append|SaveMode.Overwrite).saveAsTable(tableName)`. */ - @Experimental + @deprecated("Use write.mode(SaveMode.Append|SaveMode.Overwrite).saveAsTable(tableName)", "1.4.0") def insertInto(tableName: String, overwrite: Boolean): Unit = { - sqlContext.executePlan(InsertIntoTable(UnresolvedRelation(Seq(tableName)), - Map.empty, logicalPlan, overwrite, ifNotExists = false)).toRdd + write.mode(if (overwrite) SaveMode.Overwrite else SaveMode.Append).insertInto(tableName) } /** - * :: Experimental :: * Adds the rows from this RDD to the specified table. * Throws an exception if the table already exists. * @group output - * @since 1.3.0 + * @deprecated As of 1.4.0, replaced by + * `write().mode(SaveMode.Append).saveAsTable(tableName)`. */ - @Experimental - def insertInto(tableName: String): Unit = insertInto(tableName, overwrite = false) - - /** - * Returns the content of the [[DataFrame]] as a RDD of JSON strings. - * @group rdd - * @since 1.3.0 - */ - def toJSON: RDD[String] = { - val rowSchema = this.schema - this.mapPartitions { iter => - val writer = new CharArrayWriter() - // create the Generator without separator inserted between 2 records - val gen = new JsonFactory().createGenerator(writer).setRootValueSeparator(null) - - new Iterator[String] { - override def hasNext: Boolean = iter.hasNext - override def next(): String = { - JacksonGenerator(rowSchema, gen)(iter.next()) - gen.flush() - - val json = writer.toString - if (hasNext) { - writer.reset() - } else { - gen.close() - } - - json - } - } - } + @deprecated("Use write.mode(SaveMode.Append).saveAsTable(tableName)", "1.4.0") + def insertInto(tableName: String): Unit = { + write.mode(SaveMode.Append).insertInto(tableName) } //////////////////////////////////////////////////////////////////////////// - // JDBC Write Support //////////////////////////////////////////////////////////////////////////// - - /** - * Save this [[DataFrame]] to a JDBC database at `url` under the table name `table`. - * This will run a `CREATE TABLE` and a bunch of `INSERT INTO` statements. - * If you pass `true` for `allowExisting`, it will drop any table with the - * given name; if you pass `false`, it will throw if the table already - * exists. - * @group output - * @since 1.3.0 - */ - def createJDBCTable(url: String, table: String, allowExisting: Boolean): Unit = { - createJDBCTable(url, table, allowExisting, new Properties()) - } - - /** - * Save this [[DataFrame]] to a JDBC database at `url` under the table name `table` - * using connection properties defined in `properties`. - * This will run a `CREATE TABLE` and a bunch of `INSERT INTO` statements. - * If you pass `true` for `allowExisting`, it will drop any table with the - * given name; if you pass `false`, it will throw if the table already - * exists. - * @group output - * @since 1.4.0 - */ - def createJDBCTable( - url: String, - table: String, - allowExisting: Boolean, - properties: Properties): Unit = { - val conn = DriverManager.getConnection(url, properties) - try { - if (allowExisting) { - val sql = s"DROP TABLE IF EXISTS $table" - conn.prepareStatement(sql).executeUpdate() - } - val schema = JDBCWriteDetails.schemaString(this, url) - val sql = s"CREATE TABLE $table ($schema)" - conn.prepareStatement(sql).executeUpdate() - } finally { - conn.close() - } - JDBCWriteDetails.saveTable(this, url, table, properties) - } - - /** - * Save this [[DataFrame]] to a JDBC database at `url` under the table name `table`. - * Assumes the table already exists and has a compatible schema. If you - * pass `true` for `overwrite`, it will `TRUNCATE` the table before - * performing the `INSERT`s. - * - * The table must already exist on the database. It must have a schema - * that is compatible with the schema of this RDD; inserting the rows of - * the RDD in order via the simple statement - * `INSERT INTO table VALUES (?, ?, ..., ?)` should not fail. - * @group output - * @since 1.3.0 - */ - def insertIntoJDBC(url: String, table: String, overwrite: Boolean): Unit = { - insertIntoJDBC(url, table, overwrite, new Properties()) - } - - /** - * Save this [[DataFrame]] to a JDBC database at `url` under the table name `table` - * using connection properties defined in `properties`. - * Assumes the table already exists and has a compatible schema. If you - * pass `true` for `overwrite`, it will `TRUNCATE` the table before - * performing the `INSERT`s. - * - * The table must already exist on the database. It must have a schema - * that is compatible with the schema of this RDD; inserting the rows of - * the RDD in order via the simple statement - * `INSERT INTO table VALUES (?, ?, ..., ?)` should not fail. - * @group output - * @since 1.4.0 - */ - def insertIntoJDBC( - url: String, - table: String, - overwrite: Boolean, - properties: Properties): Unit = { - if (overwrite) { - val conn = DriverManager.getConnection(url, properties) - try { - val sql = s"TRUNCATE TABLE $table" - conn.prepareStatement(sql).executeUpdate() - } finally { - conn.close() - } - } - JDBCWriteDetails.saveTable(this, url, table, properties) - } + // End of deprecated methods //////////////////////////////////////////////////////////////////////////// - // for Python API //////////////////////////////////////////////////////////////////////////// - /** - * Converts a JavaRDD to a PythonRDD. - */ - protected[sql] def javaToPython: JavaRDD[Array[Byte]] = { - val fieldTypes = schema.fields.map(_.dataType) - val jrdd = rdd.map(EvaluatePython.rowToArray(_, fieldTypes)).toJavaRDD() - SerDeUtil.javaToPython(jrdd) - } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameHolder.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameHolder.scala index b87efb58d51e..2f19ec040301 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameHolder.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameHolder.scala @@ -28,5 +28,5 @@ private[sql] case class DataFrameHolder(df: DataFrame) { // `rdd.toDF("1")` as invoking this toDF and then apply on the returned DataFrame. def toDF(): DataFrame = df - def toDF(colNames: String*): DataFrame = df.toDF(colNames :_*) + def toDF(colNames: String*): DataFrame = df.toDF(colNames : _*) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameNaFunctions.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameNaFunctions.scala index b4c2daa05586..8681a56c82f1 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameNaFunctions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameNaFunctions.scala @@ -391,7 +391,7 @@ final class DataFrameNaFunctions private[sql](df: DataFrame) { * Returns a [[Column]] expression that replaces null value in `col` with `replacement`. */ private def fillCol[T](col: StructField, replacement: T): Column = { - coalesce(df.col(col.name), lit(replacement).cast(col.dataType)).as(col.name) + coalesce(df.col("`" + col.name + "`"), lit(replacement).cast(col.dataType)).as(col.name) } /** diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala new file mode 100644 index 000000000000..1828ed1aab50 --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala @@ -0,0 +1,289 @@ +/* +* Licensed to the Apache Software Foundation (ASF) under one or more +* contributor license agreements. See the NOTICE file distributed with +* this work for additional information regarding copyright ownership. +* The ASF licenses this file to You under the Apache License, Version 2.0 +* (the "License"); you may not use this file except in compliance with +* the License. You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*/ + +package org.apache.spark.sql + +import java.util.Properties + +import org.apache.hadoop.fs.Path +import org.apache.spark.Partition + +import org.apache.spark.annotation.Experimental +import org.apache.spark.api.java.JavaRDD +import org.apache.spark.deploy.SparkHadoopUtil +import org.apache.spark.rdd.RDD +import org.apache.spark.sql.jdbc.{JDBCPartition, JDBCPartitioningInfo, JDBCRelation} +import org.apache.spark.sql.json.{JsonRDD, JSONRelation} +import org.apache.spark.sql.parquet.ParquetRelation2 +import org.apache.spark.sql.sources.{LogicalRelation, ResolvedDataSource} +import org.apache.spark.sql.types.StructType + +/** + * :: Experimental :: + * Interface used to load a [[DataFrame]] from external storage systems (e.g. file systems, + * key-value stores, etc). Use [[SQLContext.read]] to access this. + * + * @since 1.4.0 + */ +@Experimental +class DataFrameReader private[sql](sqlContext: SQLContext) { + + /** + * Specifies the input data source format. + * + * @since 1.4.0 + */ + def format(source: String): DataFrameReader = { + this.source = source + this + } + + /** + * Specifies the input schema. Some data sources (e.g. JSON) can infer the input schema + * automatically from data. By specifying the schema here, the underlying data source can + * skip the schema inference step, and thus speed up data loading. + * + * @since 1.4.0 + */ + def schema(schema: StructType): DataFrameReader = { + this.userSpecifiedSchema = Option(schema) + this + } + + /** + * Adds an input option for the underlying data source. + * + * @since 1.4.0 + */ + def option(key: String, value: String): DataFrameReader = { + this.extraOptions += (key -> value) + this + } + + /** + * (Scala-specific) Adds input options for the underlying data source. + * + * @since 1.4.0 + */ + def options(options: scala.collection.Map[String, String]): DataFrameReader = { + this.extraOptions ++= options + this + } + + /** + * Adds input options for the underlying data source. + * + * @since 1.4.0 + */ + def options(options: java.util.Map[String, String]): DataFrameReader = { + this.options(scala.collection.JavaConversions.mapAsScalaMap(options)) + this + } + + /** + * Loads input in as a [[DataFrame]], for data sources that require a path (e.g. data backed by + * a local or distributed file system). + * + * @since 1.4.0 + */ + def load(path: String): DataFrame = { + option("path", path).load() + } + + /** + * Loads input in as a [[DataFrame]], for data sources that don't require a path (e.g. external + * key-value stores). + * + * @since 1.4.0 + */ + def load(): DataFrame = { + val resolved = ResolvedDataSource( + sqlContext, + userSpecifiedSchema = userSpecifiedSchema, + partitionColumns = Array.empty[String], + provider = source, + options = extraOptions.toMap) + DataFrame(sqlContext, LogicalRelation(resolved.relation)) + } + + /** + * Construct a [[DataFrame]] representing the database table accessible via JDBC URL + * url named table and connection properties. + * + * @since 1.4.0 + */ + def jdbc(url: String, table: String, properties: Properties): DataFrame = { + jdbc(url, table, JDBCRelation.columnPartition(null), properties) + } + + /** + * Construct a [[DataFrame]] representing the database table accessible via JDBC URL + * url named table. Partitions of the table will be retrieved in parallel based on the parameters + * passed to this function. + * + * Don't create too many partitions in parallel on a large cluster; otherwise Spark might crash + * your external database systems. + * + * @param url JDBC database url of the form `jdbc:subprotocol:subname` + * @param table Name of the table in the external database. + * @param columnName the name of a column of integral type that will be used for partitioning. + * @param lowerBound the minimum value of `columnName` used to decide partition stride + * @param upperBound the maximum value of `columnName` used to decide partition stride + * @param numPartitions the number of partitions. the range `minValue`-`maxValue` will be split + * evenly into this many partitions + * @param connectionProperties JDBC database connection arguments, a list of arbitrary string + * tag/value. Normally at least a "user" and "password" property + * should be included. + * + * @since 1.4.0 + */ + def jdbc( + url: String, + table: String, + columnName: String, + lowerBound: Long, + upperBound: Long, + numPartitions: Int, + connectionProperties: Properties): DataFrame = { + val partitioning = JDBCPartitioningInfo(columnName, lowerBound, upperBound, numPartitions) + val parts = JDBCRelation.columnPartition(partitioning) + jdbc(url, table, parts, connectionProperties) + } + + /** + * Construct a [[DataFrame]] representing the database table accessible via JDBC URL + * url named table using connection properties. The `predicates` parameter gives a list + * expressions suitable for inclusion in WHERE clauses; each one defines one partition + * of the [[DataFrame]]. + * + * Don't create too many partitions in parallel on a large cluster; otherwise Spark might crash + * your external database systems. + * + * @param url JDBC database url of the form `jdbc:subprotocol:subname` + * @param table Name of the table in the external database. + * @param predicates Condition in the where clause for each partition. + * @param connectionProperties JDBC database connection arguments, a list of arbitrary string + * tag/value. Normally at least a "user" and "password" property + * should be included. + * @since 1.4.0 + */ + def jdbc( + url: String, + table: String, + predicates: Array[String], + connectionProperties: Properties): DataFrame = { + val parts: Array[Partition] = predicates.zipWithIndex.map { case (part, i) => + JDBCPartition(part, i) : Partition + } + jdbc(url, table, parts, connectionProperties) + } + + private def jdbc( + url: String, + table: String, + parts: Array[Partition], + connectionProperties: Properties): DataFrame = { + val relation = JDBCRelation(url, table, parts, connectionProperties)(sqlContext) + sqlContext.baseRelationToDataFrame(relation) + } + + /** + * Loads a JSON file (one object per line) and returns the result as a [[DataFrame]]. + * + * This function goes through the input once to determine the input schema. If you know the + * schema in advance, use the version that specifies the schema to avoid the extra scan. + * + * @param path input path + * @since 1.4.0 + */ + def json(path: String): DataFrame = format("json").load(path) + + /** + * Loads an `JavaRDD[String]` storing JSON objects (one object per record) and + * returns the result as a [[DataFrame]]. + * + * Unless the schema is specified using [[schema]] function, this function goes through the + * input once to determine the input schema. + * + * @param jsonRDD input RDD with one JSON object per record + * @since 1.4.0 + */ + def json(jsonRDD: JavaRDD[String]): DataFrame = json(jsonRDD.rdd) + + /** + * Loads an `RDD[String]` storing JSON objects (one object per record) and + * returns the result as a [[DataFrame]]. + * + * Unless the schema is specified using [[schema]] function, this function goes through the + * input once to determine the input schema. + * + * @param jsonRDD input RDD with one JSON object per record + * @since 1.4.0 + */ + def json(jsonRDD: RDD[String]): DataFrame = { + val samplingRatio = extraOptions.getOrElse("samplingRatio", "1.0").toDouble + if (sqlContext.conf.useJacksonStreamingAPI) { + sqlContext.baseRelationToDataFrame( + new JSONRelation(() => jsonRDD, None, samplingRatio, userSpecifiedSchema)(sqlContext)) + } else { + val columnNameOfCorruptJsonRecord = sqlContext.conf.columnNameOfCorruptRecord + val appliedSchema = userSpecifiedSchema.getOrElse( + JsonRDD.nullTypeToStringType( + JsonRDD.inferSchema(jsonRDD, 1.0, columnNameOfCorruptJsonRecord))) + val rowRDD = JsonRDD.jsonStringToRow(jsonRDD, appliedSchema, columnNameOfCorruptJsonRecord) + sqlContext.internalCreateDataFrame(rowRDD, appliedSchema) + } + } + + /** + * Loads a Parquet file, returning the result as a [[DataFrame]]. This function returns an empty + * [[DataFrame]] if no paths are passed in. + * + * @since 1.4.0 + */ + @scala.annotation.varargs + def parquet(paths: String*): DataFrame = { + if (paths.isEmpty) { + sqlContext.emptyDataFrame + } else { + val globbedPaths = paths.map(new Path(_)).flatMap(SparkHadoopUtil.get.globPath).toArray + sqlContext.baseRelationToDataFrame( + new ParquetRelation2( + globbedPaths.map(_.toString), None, None, Map.empty[String, String])(sqlContext)) + } + } + + /** + * Returns the specified table as a [[DataFrame]]. + * + * @since 1.4.0 + */ + def table(tableName: String): DataFrame = { + DataFrame(sqlContext, sqlContext.catalog.lookupRelation(Seq(tableName))) + } + + /////////////////////////////////////////////////////////////////////////////////////// + // Builder pattern config options + /////////////////////////////////////////////////////////////////////////////////////// + + private var source: String = sqlContext.conf.defaultDataSourceName + + private var userSpecifiedSchema: Option[StructType] = None + + private var extraOptions = new scala.collection.mutable.HashMap[String, String] + +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameStatFunctions.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameStatFunctions.scala index 5d106c1ac267..587869e57f96 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameStatFunctions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameStatFunctions.scala @@ -43,7 +43,7 @@ final class DataFrameStatFunctions private[sql](df: DataFrame) { /** * Calculates the correlation of two columns of a DataFrame. Currently only supports the Pearson - * Correlation Coefficient. For Spearman Correlation, consider using RDD methods found in + * Correlation Coefficient. For Spearman Correlation, consider using RDD methods found in * MLlib's Statistics. * * @param col1 the name of the column @@ -78,6 +78,9 @@ final class DataFrameStatFunctions private[sql](df: DataFrame) { * The first column of each row will be the distinct values of `col1` and the column names will * be the distinct values of `col2`. The name of the first column will be `$col1_$col2`. Counts * will be returned as `Long`s. Pairs that have no occurrences will have `null` as their counts. + * Null elements will be replaced by "null", and back ticks will be dropped from elements if they + * exist. + * * * @param col1 The name of the first column. Distinct items will make the first item of * each row. @@ -97,6 +100,9 @@ final class DataFrameStatFunctions private[sql](df: DataFrame) { * [[http://dx.doi.org/10.1145/762471.762473, proposed by Karp, Schenker, and Papadimitriou]]. * The `support` should be greater than 1e-4. * + * This function is meant for exploratory data analysis, as we make no guarantee about the + * backward compatibility of the schema of the resulting [[DataFrame]]. + * * @param cols the names of the columns to search frequent items in. * @param support The minimum frequency for an item to be considered `frequent`. Should be greater * than 1e-4. @@ -114,6 +120,9 @@ final class DataFrameStatFunctions private[sql](df: DataFrame) { * [[http://dx.doi.org/10.1145/762471.762473, proposed by Karp, Schenker, and Papadimitriou]]. * Uses a `default` support of 1%. * + * This function is meant for exploratory data analysis, as we make no guarantee about the + * backward compatibility of the schema of the resulting [[DataFrame]]. + * * @param cols the names of the columns to search frequent items in. * @return A Local DataFrame with the Array of frequent items for each column. * @@ -128,6 +137,9 @@ final class DataFrameStatFunctions private[sql](df: DataFrame) { * frequent element count algorithm described in * [[http://dx.doi.org/10.1145/762471.762473, proposed by Karp, Schenker, and Papadimitriou]]. * + * This function is meant for exploratory data analysis, as we make no guarantee about the + * backward compatibility of the schema of the resulting [[DataFrame]]. + * * @param cols the names of the columns to search frequent items in. * @return A Local DataFrame with the Array of frequent items for each column. * @@ -143,6 +155,9 @@ final class DataFrameStatFunctions private[sql](df: DataFrame) { * [[http://dx.doi.org/10.1145/762471.762473, proposed by Karp, Schenker, and Papadimitriou]]. * Uses a `default` support of 1%. * + * This function is meant for exploratory data analysis, as we make no guarantee about the + * backward compatibility of the schema of the resulting [[DataFrame]]. + * * @param cols the names of the columns to search frequent items in. * @return A Local DataFrame with the Array of frequent items for each column. * diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala new file mode 100644 index 000000000000..5548b26cb8f8 --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala @@ -0,0 +1,295 @@ +/* +* Licensed to the Apache Software Foundation (ASF) under one or more +* contributor license agreements. See the NOTICE file distributed with +* this work for additional information regarding copyright ownership. +* The ASF licenses this file to You under the Apache License, Version 2.0 +* (the "License"); you may not use this file except in compliance with +* the License. You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*/ + +package org.apache.spark.sql + +import java.util.Properties + +import org.apache.spark.annotation.Experimental +import org.apache.spark.sql.catalyst.analysis.UnresolvedRelation +import org.apache.spark.sql.catalyst.plans.logical.InsertIntoTable +import org.apache.spark.sql.jdbc.{JDBCWriteDetails, JdbcUtils} +import org.apache.spark.sql.sources.{ResolvedDataSource, CreateTableUsingAsSelect} + + +/** + * :: Experimental :: + * Interface used to write a [[DataFrame]] to external storage systems (e.g. file systems, + * key-value stores, etc). Use [[DataFrame.write]] to access this. + * + * @since 1.4.0 + */ +@Experimental +final class DataFrameWriter private[sql](df: DataFrame) { + + /** + * Specifies the behavior when data or table already exists. Options include: + * - `SaveMode.Overwrite`: overwrite the existing data. + * - `SaveMode.Append`: append the data. + * - `SaveMode.Ignore`: ignore the operation (i.e. no-op). + * - `SaveMode.ErrorIfExists`: default option, throw an exception at runtime. + * + * @since 1.4.0 + */ + def mode(saveMode: SaveMode): DataFrameWriter = { + this.mode = saveMode + this + } + + /** + * Specifies the behavior when data or table already exists. Options include: + * - `overwrite`: overwrite the existing data. + * - `append`: append the data. + * - `ignore`: ignore the operation (i.e. no-op). + * - `error`: default option, throw an exception at runtime. + * + * @since 1.4.0 + */ + def mode(saveMode: String): DataFrameWriter = { + this.mode = saveMode.toLowerCase match { + case "overwrite" => SaveMode.Overwrite + case "append" => SaveMode.Append + case "ignore" => SaveMode.Ignore + case "error" | "default" => SaveMode.ErrorIfExists + case _ => throw new IllegalArgumentException(s"Unknown save mode: $saveMode. " + + "Accepted modes are 'overwrite', 'append', 'ignore', 'error'.") + } + this + } + + /** + * Specifies the underlying output data source. Built-in options include "parquet", "json", etc. + * + * @since 1.4.0 + */ + def format(source: String): DataFrameWriter = { + this.source = source + this + } + + /** + * Adds an output option for the underlying data source. + * + * @since 1.4.0 + */ + def option(key: String, value: String): DataFrameWriter = { + this.extraOptions += (key -> value) + this + } + + /** + * (Scala-specific) Adds output options for the underlying data source. + * + * @since 1.4.0 + */ + def options(options: scala.collection.Map[String, String]): DataFrameWriter = { + this.extraOptions ++= options + this + } + + /** + * Adds output options for the underlying data source. + * + * @since 1.4.0 + */ + def options(options: java.util.Map[String, String]): DataFrameWriter = { + this.options(scala.collection.JavaConversions.mapAsScalaMap(options)) + this + } + + /** + * Partitions the output by the given columns on the file system. If specified, the output is + * laid out on the file system similar to Hive's partitioning scheme. + * + * This is only applicable for Parquet at the moment. + * + * @since 1.4.0 + */ + @scala.annotation.varargs + def partitionBy(colNames: String*): DataFrameWriter = { + this.partitioningColumns = Option(colNames) + this + } + + /** + * Saves the content of the [[DataFrame]] at the specified path. + * + * @since 1.4.0 + */ + def save(path: String): Unit = { + this.extraOptions += ("path" -> path) + save() + } + + /** + * Saves the content of the [[DataFrame]] as the specified table. + * + * @since 1.4.0 + */ + def save(): Unit = { + ResolvedDataSource( + df.sqlContext, + source, + partitioningColumns.map(_.toArray).getOrElse(Array.empty[String]), + mode, + extraOptions.toMap, + df) + } + + /** + * Inserts the content of the [[DataFrame]] to the specified table. It requires that + * the schema of the [[DataFrame]] is the same as the schema of the table. + * + * Because it inserts data to an existing table, format or options will be ignored. + * + * @since 1.4.0 + */ + def insertInto(tableName: String): Unit = { + val partitions = + partitioningColumns.map(_.map(col => col -> (None: Option[String])).toMap) + val overwrite = (mode == SaveMode.Overwrite) + df.sqlContext.executePlan(InsertIntoTable( + UnresolvedRelation(Seq(tableName)), + partitions.getOrElse(Map.empty[String, Option[String]]), + df.logicalPlan, + overwrite, + ifNotExists = false)).toRdd + } + + /** + * Saves the content of the [[DataFrame]] as the specified table. + * + * In the case the table already exists, behavior of this function depends on the + * save mode, specified by the `mode` function (default to throwing an exception). + * When `mode` is `Overwrite`, the schema of the [[DataFrame]] does not need to be + * the same as that of the existing table. + * When `mode` is `Append`, the schema of the [[DataFrame]] need to be + * the same as that of the existing table, and format or options will be ignored. + * + * @since 1.4.0 + */ + def saveAsTable(tableName: String): Unit = { + if (df.sqlContext.catalog.tableExists(tableName :: Nil) && mode != SaveMode.Overwrite) { + mode match { + case SaveMode.Ignore => + // Do nothing + + case SaveMode.ErrorIfExists => + throw new AnalysisException(s"Table $tableName already exists.") + + case SaveMode.Append => + // If it is Append, we just ask insertInto to handle it. We will not use insertInto + // to handle saveAsTable with Overwrite because saveAsTable can change the schema of + // the table. But, insertInto with Overwrite requires the schema of data be the same + // the schema of the table. + insertInto(tableName) + } + } else { + val cmd = + CreateTableUsingAsSelect( + tableName, + source, + temporary = false, + partitioningColumns.map(_.toArray).getOrElse(Array.empty[String]), + mode, + extraOptions.toMap, + df.logicalPlan) + df.sqlContext.executePlan(cmd).toRdd + } + } + + /** + * Saves the content of the [[DataFrame]] to a external database table via JDBC. In the case the + * table already exists in the external database, behavior of this function depends on the + * save mode, specified by the `mode` function (default to throwing an exception). + * + * Don't create too many partitions in parallel on a large cluster; otherwise Spark might crash + * your external database systems. + * + * @param url JDBC database url of the form `jdbc:subprotocol:subname` + * @param table Name of the table in the external database. + * @param connectionProperties JDBC database connection arguments, a list of arbitrary string + * tag/value. Normally at least a "user" and "password" property + * should be included. + */ + def jdbc(url: String, table: String, connectionProperties: Properties): Unit = { + val conn = JdbcUtils.createConnection(url, connectionProperties) + + try { + var tableExists = JdbcUtils.tableExists(conn, table) + + if (mode == SaveMode.Ignore && tableExists) { + return + } + + if (mode == SaveMode.ErrorIfExists && tableExists) { + sys.error(s"Table $table already exists.") + } + + if (mode == SaveMode.Overwrite && tableExists) { + JdbcUtils.dropTable(conn, table) + tableExists = false + } + + // Create the table if the table didn't exist. + if (!tableExists) { + val schema = JDBCWriteDetails.schemaString(df, url) + val sql = s"CREATE TABLE $table ($schema)" + conn.prepareStatement(sql).executeUpdate() + } + } finally { + conn.close() + } + + JDBCWriteDetails.saveTable(df, url, table, connectionProperties) + } + + /** + * Saves the content of the [[DataFrame]] in JSON format at the specified path. + * This is equivalent to: + * {{{ + * format("json").save(path) + * }}} + * + * @since 1.4.0 + */ + def json(path: String): Unit = format("json").save(path) + + /** + * Saves the content of the [[DataFrame]] in Parquet format at the specified path. + * This is equivalent to: + * {{{ + * format("parquet").save(path) + * }}} + * + * @since 1.4.0 + */ + def parquet(path: String): Unit = format("parquet").save(path) + + /////////////////////////////////////////////////////////////////////////////////////// + // Builder pattern config options + /////////////////////////////////////////////////////////////////////////////////////// + + private var source: String = df.sqlContext.conf.defaultDataSourceName + + private var mode: SaveMode = SaveMode.ErrorIfExists + + private var extraOptions = new scala.collection.mutable.HashMap[String, String] + + private var partitioningColumns: Option[Seq[String]] = None + +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/GroupedData.scala b/sql/core/src/main/scala/org/apache/spark/sql/GroupedData.scala index 1381b9f1a608..99d557b03a03 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/GroupedData.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/GroupedData.scala @@ -21,11 +21,42 @@ import scala.collection.JavaConversions._ import scala.language.implicitConversions import org.apache.spark.annotation.Experimental -import org.apache.spark.sql.catalyst.analysis.Star +import org.apache.spark.sql.catalyst.analysis.{UnresolvedAlias, UnresolvedAttribute, Star} import org.apache.spark.sql.catalyst.expressions._ -import org.apache.spark.sql.catalyst.plans.logical.Aggregate +import org.apache.spark.sql.catalyst.plans.logical.{Rollup, Cube, Aggregate} import org.apache.spark.sql.types.NumericType +/** + * Companion object for GroupedData + */ +private[sql] object GroupedData { + def apply( + df: DataFrame, + groupingExprs: Seq[Expression], + groupType: GroupType): GroupedData = { + new GroupedData(df, groupingExprs, groupType: GroupType) + } + + /** + * The Grouping Type + */ + private[sql] trait GroupType + + /** + * To indicate it's the GroupBy + */ + private[sql] object GroupByType extends GroupType + + /** + * To indicate it's the CUBE + */ + private[sql] object CubeType extends GroupType + + /** + * To indicate it's the ROLLUP + */ + private[sql] object RollupType extends GroupType +} /** * :: Experimental :: @@ -34,19 +65,41 @@ import org.apache.spark.sql.types.NumericType * @since 1.3.0 */ @Experimental -class GroupedData protected[sql](df: DataFrame, groupingExprs: Seq[Expression]) { +class GroupedData protected[sql]( + df: DataFrame, + groupingExprs: Seq[Expression], + private val groupType: GroupedData.GroupType) { + + private[this] def toDF(aggExprs: Seq[Expression]): DataFrame = { + val aggregates = if (df.sqlContext.conf.dataFrameRetainGroupColumns) { + groupingExprs ++ aggExprs + } else { + aggExprs + } - private[sql] implicit def toDF(aggExprs: Seq[NamedExpression]): DataFrame = { - val namedGroupingExprs = groupingExprs.map { + val aliasedAgg = aggregates.map { + // Wrap UnresolvedAttribute with UnresolvedAlias, as when we resolve UnresolvedAttribute, we + // will remove intermediate Alias for ExtractValue chain, and we need to alias it again to + // make it a NamedExpression. + case u: UnresolvedAttribute => UnresolvedAlias(u) case expr: NamedExpression => expr case expr: Expression => Alias(expr, expr.prettyString)() } - DataFrame( - df.sqlContext, Aggregate(groupingExprs, namedGroupingExprs ++ aggExprs, df.logicalPlan)) + groupType match { + case GroupedData.GroupByType => + DataFrame( + df.sqlContext, Aggregate(groupingExprs, aliasedAgg, df.logicalPlan)) + case GroupedData.RollupType => + DataFrame( + df.sqlContext, Rollup(groupingExprs, df.logicalPlan, aliasedAgg)) + case GroupedData.CubeType => + DataFrame( + df.sqlContext, Cube(groupingExprs, df.logicalPlan, aliasedAgg)) + } } private[this] def aggregateNumericColumns(colNames: String*)(f: Expression => Expression) - : Seq[NamedExpression] = { + : DataFrame = { val columnExprs = if (colNames.isEmpty) { // No columns specified. Use all numeric columns. @@ -63,10 +116,7 @@ class GroupedData protected[sql](df: DataFrame, groupingExprs: Seq[Expression]) namedExpr } } - columnExprs.map { c => - val a = f(c) - Alias(a, a.prettyString)() - } + toDF(columnExprs.map(f)) } private[this] def strToExpr(expr: String): (Expression => Expression) = { @@ -119,10 +169,9 @@ class GroupedData protected[sql](df: DataFrame, groupingExprs: Seq[Expression]) * @since 1.3.0 */ def agg(exprs: Map[String, String]): DataFrame = { - exprs.map { case (colName, expr) => - val a = strToExpr(expr)(df(colName).expr) - Alias(a, a.prettyString)() - }.toSeq + toDF(exprs.map { case (colName, expr) => + strToExpr(expr)(df(colName).expr) + }.toSeq) } /** @@ -175,19 +224,7 @@ class GroupedData protected[sql](df: DataFrame, groupingExprs: Seq[Expression]) */ @scala.annotation.varargs def agg(expr: Column, exprs: Column*): DataFrame = { - val aggExprs = (expr +: exprs).map(_.expr).map { - case expr: NamedExpression => expr - case expr: Expression => Alias(expr, expr.prettyString)() - } - if (df.sqlContext.conf.dataFrameRetainGroupColumns) { - val retainedExprs = groupingExprs.map { - case expr: NamedExpression => expr - case expr: Expression => Alias(expr, expr.prettyString)() - } - DataFrame(df.sqlContext, Aggregate(groupingExprs, retainedExprs ++ aggExprs, df.logicalPlan)) - } else { - DataFrame(df.sqlContext, Aggregate(groupingExprs, aggExprs, df.logicalPlan)) - } + toDF((expr +: exprs).map(_.expr)) } /** @@ -196,7 +233,7 @@ class GroupedData protected[sql](df: DataFrame, groupingExprs: Seq[Expression]) * * @since 1.3.0 */ - def count(): DataFrame = Seq(Alias(Count(Literal(1)), "count")()) + def count(): DataFrame = toDF(Seq(Alias(Count(Literal(1)), "count")())) /** * Compute the average value for each numeric columns for each group. This is an alias for `avg`. @@ -207,9 +244,9 @@ class GroupedData protected[sql](df: DataFrame, groupingExprs: Seq[Expression]) */ @scala.annotation.varargs def mean(colNames: String*): DataFrame = { - aggregateNumericColumns(colNames:_*)(Average) + aggregateNumericColumns(colNames : _*)(Average) } - + /** * Compute the max value for each numeric columns for each group. * The resulting [[DataFrame]] will also contain the grouping columns. @@ -219,7 +256,7 @@ class GroupedData protected[sql](df: DataFrame, groupingExprs: Seq[Expression]) */ @scala.annotation.varargs def max(colNames: String*): DataFrame = { - aggregateNumericColumns(colNames:_*)(Max) + aggregateNumericColumns(colNames : _*)(Max) } /** @@ -231,7 +268,7 @@ class GroupedData protected[sql](df: DataFrame, groupingExprs: Seq[Expression]) */ @scala.annotation.varargs def avg(colNames: String*): DataFrame = { - aggregateNumericColumns(colNames:_*)(Average) + aggregateNumericColumns(colNames : _*)(Average) } /** @@ -243,7 +280,7 @@ class GroupedData protected[sql](df: DataFrame, groupingExprs: Seq[Expression]) */ @scala.annotation.varargs def min(colNames: String*): DataFrame = { - aggregateNumericColumns(colNames:_*)(Min) + aggregateNumericColumns(colNames : _*)(Min) } /** @@ -255,6 +292,6 @@ class GroupedData protected[sql](df: DataFrame, groupingExprs: Seq[Expression]) */ @scala.annotation.varargs def sum(colNames: String*): DataFrame = { - aggregateNumericColumns(colNames:_*)(Sum) - } + aggregateNumericColumns(colNames : _*)(Sum) + } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SQLConf.scala b/sql/core/src/main/scala/org/apache/spark/sql/SQLConf.scala index f07bb196c11e..6005d35f015a 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/SQLConf.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/SQLConf.scala @@ -22,67 +22,366 @@ import java.util.Properties import scala.collection.immutable import scala.collection.JavaConversions._ +import org.apache.parquet.hadoop.ParquetOutputCommitter + import org.apache.spark.sql.catalyst.CatalystConf private[spark] object SQLConf { - val COMPRESS_CACHED = "spark.sql.inMemoryColumnarStorage.compressed" - val COLUMN_BATCH_SIZE = "spark.sql.inMemoryColumnarStorage.batchSize" - val IN_MEMORY_PARTITION_PRUNING = "spark.sql.inMemoryColumnarStorage.partitionPruning" - val AUTO_BROADCASTJOIN_THRESHOLD = "spark.sql.autoBroadcastJoinThreshold" - val DEFAULT_SIZE_IN_BYTES = "spark.sql.defaultSizeInBytes" - val SHUFFLE_PARTITIONS = "spark.sql.shuffle.partitions" - val CODEGEN_ENABLED = "spark.sql.codegen" - val UNSAFE_ENABLED = "spark.sql.unsafe.enabled" - val DIALECT = "spark.sql.dialect" - val CASE_SENSITIVE = "spark.sql.caseSensitive" - - val PARQUET_BINARY_AS_STRING = "spark.sql.parquet.binaryAsString" - val PARQUET_INT96_AS_TIMESTAMP = "spark.sql.parquet.int96AsTimestamp" - val PARQUET_CACHE_METADATA = "spark.sql.parquet.cacheMetadata" - val PARQUET_COMPRESSION = "spark.sql.parquet.compression.codec" - val PARQUET_FILTER_PUSHDOWN_ENABLED = "spark.sql.parquet.filterPushdown" - val PARQUET_USE_DATA_SOURCE_API = "spark.sql.parquet.useDataSourceApi" - - val HIVE_VERIFY_PARTITIONPATH = "spark.sql.hive.verifyPartitionPath" - - val COLUMN_NAME_OF_CORRUPT_RECORD = "spark.sql.columnNameOfCorruptRecord" - val BROADCAST_TIMEOUT = "spark.sql.broadcastTimeout" + + private val sqlConfEntries = java.util.Collections.synchronizedMap( + new java.util.HashMap[String, SQLConfEntry[_]]()) + + /** + * An entry contains all meta information for a configuration. + * + * @param key the key for the configuration + * @param defaultValue the default value for the configuration + * @param valueConverter how to convert a string to the value. It should throw an exception if the + * string does not have the required format. + * @param stringConverter how to convert a value to a string that the user can use it as a valid + * string value. It's usually `toString`. But sometimes, a custom converter + * is necessary. E.g., if T is List[String], `a, b, c` is better than + * `List(a, b, c)`. + * @param doc the document for the configuration + * @param isPublic if this configuration is public to the user. If it's `false`, this + * configuration is only used internally and we should not expose it to the user. + * @tparam T the value type + */ + private[sql] class SQLConfEntry[T] private( + val key: String, + val defaultValue: Option[T], + val valueConverter: String => T, + val stringConverter: T => String, + val doc: String, + val isPublic: Boolean) { + + def defaultValueString: String = defaultValue.map(stringConverter).getOrElse("") + + override def toString: String = { + s"SQLConfEntry(key = $key, defaultValue=$defaultValueString, doc=$doc, isPublic = $isPublic)" + } + } + + private[sql] object SQLConfEntry { + + private def apply[T]( + key: String, + defaultValue: Option[T], + valueConverter: String => T, + stringConverter: T => String, + doc: String, + isPublic: Boolean): SQLConfEntry[T] = + sqlConfEntries.synchronized { + if (sqlConfEntries.containsKey(key)) { + throw new IllegalArgumentException(s"Duplicate SQLConfEntry. $key has been registered") + } + val entry = + new SQLConfEntry[T](key, defaultValue, valueConverter, stringConverter, doc, isPublic) + sqlConfEntries.put(key, entry) + entry + } + + def intConf( + key: String, + defaultValue: Option[Int] = None, + doc: String = "", + isPublic: Boolean = true): SQLConfEntry[Int] = + SQLConfEntry(key, defaultValue, { v => + try { + v.toInt + } catch { + case _: NumberFormatException => + throw new IllegalArgumentException(s"$key should be int, but was $v") + } + }, _.toString, doc, isPublic) + + def longConf( + key: String, + defaultValue: Option[Long] = None, + doc: String = "", + isPublic: Boolean = true): SQLConfEntry[Long] = + SQLConfEntry(key, defaultValue, { v => + try { + v.toLong + } catch { + case _: NumberFormatException => + throw new IllegalArgumentException(s"$key should be long, but was $v") + } + }, _.toString, doc, isPublic) + + def doubleConf( + key: String, + defaultValue: Option[Double] = None, + doc: String = "", + isPublic: Boolean = true): SQLConfEntry[Double] = + SQLConfEntry(key, defaultValue, { v => + try { + v.toDouble + } catch { + case _: NumberFormatException => + throw new IllegalArgumentException(s"$key should be double, but was $v") + } + }, _.toString, doc, isPublic) + + def booleanConf( + key: String, + defaultValue: Option[Boolean] = None, + doc: String = "", + isPublic: Boolean = true): SQLConfEntry[Boolean] = + SQLConfEntry(key, defaultValue, { v => + try { + v.toBoolean + } catch { + case _: IllegalArgumentException => + throw new IllegalArgumentException(s"$key should be boolean, but was $v") + } + }, _.toString, doc, isPublic) + + def stringConf( + key: String, + defaultValue: Option[String] = None, + doc: String = "", + isPublic: Boolean = true): SQLConfEntry[String] = + SQLConfEntry(key, defaultValue, v => v, v => v, doc, isPublic) + + def enumConf[T]( + key: String, + valueConverter: String => T, + validValues: Set[T], + defaultValue: Option[T] = None, + doc: String = "", + isPublic: Boolean = true): SQLConfEntry[T] = + SQLConfEntry(key, defaultValue, v => { + val _v = valueConverter(v) + if (!validValues.contains(_v)) { + throw new IllegalArgumentException( + s"The value of $key should be one of ${validValues.mkString(", ")}, but was $v") + } + _v + }, _.toString, doc, isPublic) + + def seqConf[T]( + key: String, + valueConverter: String => T, + defaultValue: Option[Seq[T]] = None, + doc: String = "", + isPublic: Boolean = true): SQLConfEntry[Seq[T]] = { + SQLConfEntry( + key, defaultValue, _.split(",").map(valueConverter), _.mkString(","), doc, isPublic) + } + + def stringSeqConf( + key: String, + defaultValue: Option[Seq[String]] = None, + doc: String = "", + isPublic: Boolean = true): SQLConfEntry[Seq[String]] = { + seqConf(key, s => s, defaultValue, doc, isPublic) + } + } + + import SQLConfEntry._ + + val COMPRESS_CACHED = booleanConf("spark.sql.inMemoryColumnarStorage.compressed", + defaultValue = Some(true), + doc = "When set to true Spark SQL will automatically select a compression codec for each " + + "column based on statistics of the data.") + + val COLUMN_BATCH_SIZE = intConf("spark.sql.inMemoryColumnarStorage.batchSize", + defaultValue = Some(10000), + doc = "Controls the size of batches for columnar caching. Larger batch sizes can improve " + + "memory utilization and compression, but risk OOMs when caching data.") + + val IN_MEMORY_PARTITION_PRUNING = + booleanConf("spark.sql.inMemoryColumnarStorage.partitionPruning", + defaultValue = Some(false), + doc = "") + + val AUTO_BROADCASTJOIN_THRESHOLD = intConf("spark.sql.autoBroadcastJoinThreshold", + defaultValue = Some(10 * 1024 * 1024), + doc = "Configures the maximum size in bytes for a table that will be broadcast to all worker " + + "nodes when performing a join. By setting this value to -1 broadcasting can be disabled. " + + "Note that currently statistics are only supported for Hive Metastore tables where the " + + "commandANALYZE TABLE <tableName> COMPUTE STATISTICS noscan has been run.") + + val DEFAULT_SIZE_IN_BYTES = longConf("spark.sql.defaultSizeInBytes", isPublic = false) + + val SHUFFLE_PARTITIONS = intConf("spark.sql.shuffle.partitions", + defaultValue = Some(200), + doc = "Configures the number of partitions to use when shuffling data for joins or " + + "aggregations.") + + val CODEGEN_ENABLED = booleanConf("spark.sql.codegen", + defaultValue = Some(true), + doc = "When true, code will be dynamically generated at runtime for expression evaluation in" + + " a specific query. For some queries with complicated expression this option can lead to " + + "significant speed-ups. However, for simple queries this can actually slow down query " + + "execution.") + + val UNSAFE_ENABLED = booleanConf("spark.sql.unsafe.enabled", + defaultValue = Some(false), + doc = "") + + val DIALECT = stringConf("spark.sql.dialect", defaultValue = Some("sql"), doc = "") + + val CASE_SENSITIVE = booleanConf("spark.sql.caseSensitive", + defaultValue = Some(true), + doc = "") + + val PARQUET_SCHEMA_MERGING_ENABLED = booleanConf("spark.sql.parquet.mergeSchema", + defaultValue = Some(true), + doc = "When true, the Parquet data source merges schemas collected from all data files, " + + "otherwise the schema is picked from the summary file or a random data file " + + "if no summary file is available.") + + val PARQUET_BINARY_AS_STRING = booleanConf("spark.sql.parquet.binaryAsString", + defaultValue = Some(false), + doc = "Some other Parquet-producing systems, in particular Impala and older versions of " + + "Spark SQL, do not differentiate between binary data and strings when writing out the " + + "Parquet schema. This flag tells Spark SQL to interpret binary data as a string to provide " + + "compatibility with these systems.") + + val PARQUET_INT96_AS_TIMESTAMP = booleanConf("spark.sql.parquet.int96AsTimestamp", + defaultValue = Some(true), + doc = "Some Parquet-producing systems, in particular Impala, store Timestamp into INT96. " + + "Spark would also store Timestamp as INT96 because we need to avoid precision lost of the " + + "nanoseconds field. This flag tells Spark SQL to interpret INT96 data as a timestamp to " + + "provide compatibility with these systems.") + + val PARQUET_CACHE_METADATA = booleanConf("spark.sql.parquet.cacheMetadata", + defaultValue = Some(true), + doc = "Turns on caching of Parquet schema metadata. Can speed up querying of static data.") + + val PARQUET_COMPRESSION = enumConf("spark.sql.parquet.compression.codec", + valueConverter = v => v.toLowerCase, + validValues = Set("uncompressed", "snappy", "gzip", "lzo"), + defaultValue = Some("gzip"), + doc = "Sets the compression codec use when writing Parquet files. Acceptable values include: " + + "uncompressed, snappy, gzip, lzo.") + + val PARQUET_FILTER_PUSHDOWN_ENABLED = booleanConf("spark.sql.parquet.filterPushdown", + defaultValue = Some(false), + doc = "Turn on Parquet filter pushdown optimization. This feature is turned off by default " + + "because of a known bug in Parquet 1.6.0rc3 " + + "(PARQUET-136, https://issues.apache.org/jira/browse/PARQUET-136). However, " + + "if your table doesn't contain any nullable string or binary columns, it's still safe to " + + "turn this feature on.") + + val PARQUET_USE_DATA_SOURCE_API = booleanConf("spark.sql.parquet.useDataSourceApi", + defaultValue = Some(true), + doc = "") + + val PARQUET_FOLLOW_PARQUET_FORMAT_SPEC = booleanConf( + key = "spark.sql.parquet.followParquetFormatSpec", + defaultValue = Some(false), + doc = "Whether to stick to Parquet format specification when converting Parquet schema to " + + "Spark SQL schema and vice versa. Sticks to the specification if set to true; falls back " + + "to compatible mode if set to false.", + isPublic = false) + + val PARQUET_OUTPUT_COMMITTER_CLASS = stringConf( + key = "spark.sql.parquet.output.committer.class", + defaultValue = Some(classOf[ParquetOutputCommitter].getName), + doc = "The output committer class used by Parquet. The specified class needs to be a " + + "subclass of org.apache.hadoop.mapreduce.OutputCommitter. Typically, it's also a subclass " + + "of org.apache.parquet.hadoop.ParquetOutputCommitter. NOTE: 1. Instead of SQLConf, this " + + "option must be set in Hadoop Configuration. 2. This option overrides " + + "\"spark.sql.sources.outputCommitterClass\"." + ) + + val ORC_FILTER_PUSHDOWN_ENABLED = booleanConf("spark.sql.orc.filterPushdown", + defaultValue = Some(false), + doc = "") + + val HIVE_VERIFY_PARTITION_PATH = booleanConf("spark.sql.hive.verifyPartitionPath", + defaultValue = Some(true), + doc = "") + + val COLUMN_NAME_OF_CORRUPT_RECORD = stringConf("spark.sql.columnNameOfCorruptRecord", + defaultValue = Some("_corrupt_record"), + doc = "") + + val BROADCAST_TIMEOUT = intConf("spark.sql.broadcastTimeout", + defaultValue = Some(5 * 60), + doc = "") // Options that control which operators can be chosen by the query planner. These should be // considered hints and may be ignored by future versions of Spark SQL. - val EXTERNAL_SORT = "spark.sql.planner.externalSort" - val SORTMERGE_JOIN = "spark.sql.planner.sortMergeJoin" + val EXTERNAL_SORT = booleanConf("spark.sql.planner.externalSort", + defaultValue = Some(true), + doc = "When true, performs sorts spilling to disk as needed otherwise sort each partition in" + + " memory.") + + val SORTMERGE_JOIN = booleanConf("spark.sql.planner.sortMergeJoin", + defaultValue = Some(false), + doc = "") // This is only used for the thriftserver - val THRIFTSERVER_POOL = "spark.sql.thriftserver.scheduler.pool" - val THRIFTSERVER_UI_STATEMENT_LIMIT = "spark.sql.thriftserver.ui.retainedStatements" - val THRIFTSERVER_UI_SESSION_LIMIT = "spark.sql.thriftserver.ui.retainedSessions" + val THRIFTSERVER_POOL = stringConf("spark.sql.thriftserver.scheduler.pool", + doc = "Set a Fair Scheduler pool for a JDBC client session") + + val THRIFTSERVER_UI_STATEMENT_LIMIT = intConf("spark.sql.thriftserver.ui.retainedStatements", + defaultValue = Some(200), + doc = "") + + val THRIFTSERVER_UI_SESSION_LIMIT = intConf("spark.sql.thriftserver.ui.retainedSessions", + defaultValue = Some(200), + doc = "") // This is used to set the default data source - val DEFAULT_DATA_SOURCE_NAME = "spark.sql.sources.default" + val DEFAULT_DATA_SOURCE_NAME = stringConf("spark.sql.sources.default", + defaultValue = Some("org.apache.spark.sql.parquet"), + doc = "") + // This is used to control the when we will split a schema's JSON string to multiple pieces // in order to fit the JSON string in metastore's table property (by default, the value has // a length restriction of 4000 characters). We will split the JSON string of a schema // to its length exceeds the threshold. - val SCHEMA_STRING_LENGTH_THRESHOLD = "spark.sql.sources.schemaStringLengthThreshold" + val SCHEMA_STRING_LENGTH_THRESHOLD = intConf("spark.sql.sources.schemaStringLengthThreshold", + defaultValue = Some(4000), + doc = "") // Whether to perform partition discovery when loading external data sources. Default to true. - val PARTITION_DISCOVERY_ENABLED = "spark.sql.sources.partitionDiscovery.enabled" + val PARTITION_DISCOVERY_ENABLED = booleanConf("spark.sql.sources.partitionDiscovery.enabled", + defaultValue = Some(true), + doc = "") + + // Whether to perform partition column type inference. Default to true. + val PARTITION_COLUMN_TYPE_INFERENCE = + booleanConf("spark.sql.sources.partitionColumnTypeInference.enabled", + defaultValue = Some(true), + doc = "") + + // The output committer class used by HadoopFsRelation. The specified class needs to be a + // subclass of org.apache.hadoop.mapreduce.OutputCommitter. + // + // NOTE: + // + // 1. Instead of SQLConf, this option *must be set in Hadoop Configuration*. + // 2. This option can be overriden by "spark.sql.parquet.output.committer.class". + val OUTPUT_COMMITTER_CLASS = + stringConf("spark.sql.sources.outputCommitterClass", isPublic = false) // Whether to perform eager analysis when constructing a dataframe. // Set to false when debugging requires the ability to look at invalid query plans. - val DATAFRAME_EAGER_ANALYSIS = "spark.sql.eagerAnalysis" + val DATAFRAME_EAGER_ANALYSIS = booleanConf("spark.sql.eagerAnalysis", + defaultValue = Some(true), + doc = "") // Whether to automatically resolve ambiguity in join conditions for self-joins. // See SPARK-6231. - val DATAFRAME_SELF_JOIN_AUTO_RESOLVE_AMBIGUITY = "spark.sql.selfJoinAutoResolveAmbiguity" + val DATAFRAME_SELF_JOIN_AUTO_RESOLVE_AMBIGUITY = + booleanConf("spark.sql.selfJoinAutoResolveAmbiguity", defaultValue = Some(true), doc = "") // Whether to retain group by columns or not in GroupedData.agg. - val DATAFRAME_RETAIN_GROUP_COLUMNS = "spark.sql.retainGroupColumns" + val DATAFRAME_RETAIN_GROUP_COLUMNS = booleanConf("spark.sql.retainGroupColumns", + defaultValue = Some(true), + doc = "") - val USE_SQL_SERIALIZER2 = "spark.sql.useSerializer2" + val USE_SQL_SERIALIZER2 = booleanConf("spark.sql.useSerializer2", + defaultValue = Some(true), doc = "") - val USE_JACKSON_STREAMING_API = "spark.sql.json.useJacksonStreamingAPI" + val USE_JACKSON_STREAMING_API = booleanConf("spark.sql.json.useJacksonStreamingAPI", + defaultValue = Some(true), doc = "") object Deprecated { val MAPRED_REDUCE_TASKS = "mapred.reduce.tasks" @@ -121,57 +420,54 @@ private[sql] class SQLConf extends Serializable with CatalystConf { * Note that the choice of dialect does not affect things like what tables are available or * how query execution is performed. */ - private[spark] def dialect: String = getConf(DIALECT, "sql") + private[spark] def dialect: String = getConf(DIALECT) /** When true tables cached using the in-memory columnar caching will be compressed. */ - private[spark] def useCompression: Boolean = getConf(COMPRESS_CACHED, "true").toBoolean + private[spark] def useCompression: Boolean = getConf(COMPRESS_CACHED) /** The compression codec for writing to a Parquetfile */ - private[spark] def parquetCompressionCodec: String = getConf(PARQUET_COMPRESSION, "gzip") + private[spark] def parquetCompressionCodec: String = getConf(PARQUET_COMPRESSION) + + private[spark] def parquetCacheMetadata: Boolean = getConf(PARQUET_CACHE_METADATA) /** The number of rows that will be */ - private[spark] def columnBatchSize: Int = getConf(COLUMN_BATCH_SIZE, "10000").toInt + private[spark] def columnBatchSize: Int = getConf(COLUMN_BATCH_SIZE) /** Number of partitions to use for shuffle operators. */ - private[spark] def numShufflePartitions: Int = getConf(SHUFFLE_PARTITIONS, "200").toInt + private[spark] def numShufflePartitions: Int = getConf(SHUFFLE_PARTITIONS) /** When true predicates will be passed to the parquet record reader when possible. */ - private[spark] def parquetFilterPushDown = - getConf(PARQUET_FILTER_PUSHDOWN_ENABLED, "false").toBoolean + private[spark] def parquetFilterPushDown: Boolean = getConf(PARQUET_FILTER_PUSHDOWN_ENABLED) /** When true uses Parquet implementation based on data source API */ - private[spark] def parquetUseDataSourceApi = - getConf(PARQUET_USE_DATA_SOURCE_API, "true").toBoolean + private[spark] def parquetUseDataSourceApi: Boolean = getConf(PARQUET_USE_DATA_SOURCE_API) + + private[spark] def orcFilterPushDown: Boolean = getConf(ORC_FILTER_PUSHDOWN_ENABLED) /** When true uses verifyPartitionPath to prune the path which is not exists. */ - private[spark] def verifyPartitionPath = - getConf(HIVE_VERIFY_PARTITIONPATH, "true").toBoolean + private[spark] def verifyPartitionPath: Boolean = getConf(HIVE_VERIFY_PARTITION_PATH) /** When true the planner will use the external sort, which may spill to disk. */ - private[spark] def externalSortEnabled: Boolean = getConf(EXTERNAL_SORT, "false").toBoolean + private[spark] def externalSortEnabled: Boolean = getConf(EXTERNAL_SORT) /** * Sort merge join would sort the two side of join first, and then iterate both sides together * only once to get all matches. Using sort merge join can save a lot of memory usage compared * to HashJoin. */ - private[spark] def sortMergeJoinEnabled: Boolean = getConf(SORTMERGE_JOIN, "false").toBoolean + private[spark] def sortMergeJoinEnabled: Boolean = getConf(SORTMERGE_JOIN) /** - * When set to true, Spark SQL will use the Scala compiler at runtime to generate custom bytecode + * When set to true, Spark SQL will use the Janino at runtime to generate custom bytecode * that evaluates expressions found in queries. In general this custom code runs much faster - * than interpreted evaluation, but there are significant start-up costs due to compilation. - * As a result codegen is only beneficial when queries run for a long time, or when the same - * expressions are used multiple times. - * - * Defaults to false as this feature is currently experimental. + * than interpreted evaluation, but there are some start-up costs (5-10ms) due to compilation. */ - private[spark] def codegenEnabled: Boolean = getConf(CODEGEN_ENABLED, "false").toBoolean + private[spark] def codegenEnabled: Boolean = getConf(CODEGEN_ENABLED) /** * caseSensitive analysis true by default */ - def caseSensitiveAnalysis: Boolean = getConf(SQLConf.CASE_SENSITIVE, "true").toBoolean + def caseSensitiveAnalysis: Boolean = getConf(SQLConf.CASE_SENSITIVE) /** * When set to true, Spark SQL will use managed memory for certain operations. This option only @@ -179,15 +475,14 @@ private[sql] class SQLConf extends Serializable with CatalystConf { * * Defaults to false as this feature is currently experimental. */ - private[spark] def unsafeEnabled: Boolean = getConf(UNSAFE_ENABLED, "false").toBoolean + private[spark] def unsafeEnabled: Boolean = getConf(UNSAFE_ENABLED) - private[spark] def useSqlSerializer2: Boolean = getConf(USE_SQL_SERIALIZER2, "true").toBoolean + private[spark] def useSqlSerializer2: Boolean = getConf(USE_SQL_SERIALIZER2) /** * Selects between the new (true) and old (false) JSON handlers, to be removed in Spark 1.5.0 */ - private[spark] def useJacksonStreamingAPI: Boolean = - getConf(USE_JACKSON_STREAMING_API, "true").toBoolean + private[spark] def useJacksonStreamingAPI: Boolean = getConf(USE_JACKSON_STREAMING_API) /** * Upper bound on the sizes (in bytes) of the tables qualified for the auto conversion to @@ -196,8 +491,7 @@ private[sql] class SQLConf extends Serializable with CatalystConf { * * Hive setting: hive.auto.convert.join.noconditionaltask.size, whose default value is 10000. */ - private[spark] def autoBroadcastJoinThreshold: Int = - getConf(AUTO_BROADCASTJOIN_THRESHOLD, (10 * 1024 * 1024).toString).toInt + private[spark] def autoBroadcastJoinThreshold: Int = getConf(AUTO_BROADCASTJOIN_THRESHOLD) /** * The default size in bytes to assign to a logical operator's estimation statistics. By default, @@ -206,79 +500,122 @@ private[sql] class SQLConf extends Serializable with CatalystConf { * in joins. */ private[spark] def defaultSizeInBytes: Long = - getConf(DEFAULT_SIZE_IN_BYTES, (autoBroadcastJoinThreshold + 1).toString).toLong + getConf(DEFAULT_SIZE_IN_BYTES, autoBroadcastJoinThreshold + 1L) /** * When set to true, we always treat byte arrays in Parquet files as strings. */ - private[spark] def isParquetBinaryAsString: Boolean = - getConf(PARQUET_BINARY_AS_STRING, "false").toBoolean + private[spark] def isParquetBinaryAsString: Boolean = getConf(PARQUET_BINARY_AS_STRING) /** * When set to true, we always treat INT96Values in Parquet files as timestamp. */ - private[spark] def isParquetINT96AsTimestamp: Boolean = - getConf(PARQUET_INT96_AS_TIMESTAMP, "true").toBoolean + private[spark] def isParquetINT96AsTimestamp: Boolean = getConf(PARQUET_INT96_AS_TIMESTAMP) + + /** + * When set to true, sticks to Parquet format spec when converting Parquet schema to Spark SQL + * schema and vice versa. Otherwise, falls back to compatible mode. + */ + private[spark] def followParquetFormatSpec: Boolean = getConf(PARQUET_FOLLOW_PARQUET_FORMAT_SPEC) /** * When set to true, partition pruning for in-memory columnar tables is enabled. */ - private[spark] def inMemoryPartitionPruning: Boolean = - getConf(IN_MEMORY_PARTITION_PRUNING, "false").toBoolean + private[spark] def inMemoryPartitionPruning: Boolean = getConf(IN_MEMORY_PARTITION_PRUNING) - private[spark] def columnNameOfCorruptRecord: String = - getConf(COLUMN_NAME_OF_CORRUPT_RECORD, "_corrupt_record") + private[spark] def columnNameOfCorruptRecord: String = getConf(COLUMN_NAME_OF_CORRUPT_RECORD) /** * Timeout in seconds for the broadcast wait time in hash join */ - private[spark] def broadcastTimeout: Int = - getConf(BROADCAST_TIMEOUT, (5 * 60).toString).toInt + private[spark] def broadcastTimeout: Int = getConf(BROADCAST_TIMEOUT) - private[spark] def defaultDataSourceName: String = - getConf(DEFAULT_DATA_SOURCE_NAME, "org.apache.spark.sql.parquet") + private[spark] def defaultDataSourceName: String = getConf(DEFAULT_DATA_SOURCE_NAME) - private[spark] def partitionDiscoveryEnabled() = - getConf(SQLConf.PARTITION_DISCOVERY_ENABLED, "true").toBoolean + private[spark] def partitionDiscoveryEnabled(): Boolean = + getConf(SQLConf.PARTITION_DISCOVERY_ENABLED) + + private[spark] def partitionColumnTypeInferenceEnabled(): Boolean = + getConf(SQLConf.PARTITION_COLUMN_TYPE_INFERENCE) // Do not use a value larger than 4000 as the default value of this property. // See the comments of SCHEMA_STRING_LENGTH_THRESHOLD above for more information. - private[spark] def schemaStringLengthThreshold: Int = - getConf(SCHEMA_STRING_LENGTH_THRESHOLD, "4000").toInt + private[spark] def schemaStringLengthThreshold: Int = getConf(SCHEMA_STRING_LENGTH_THRESHOLD) - private[spark] def dataFrameEagerAnalysis: Boolean = - getConf(DATAFRAME_EAGER_ANALYSIS, "true").toBoolean + private[spark] def dataFrameEagerAnalysis: Boolean = getConf(DATAFRAME_EAGER_ANALYSIS) private[spark] def dataFrameSelfJoinAutoResolveAmbiguity: Boolean = - getConf(DATAFRAME_SELF_JOIN_AUTO_RESOLVE_AMBIGUITY, "true").toBoolean + getConf(DATAFRAME_SELF_JOIN_AUTO_RESOLVE_AMBIGUITY) + + private[spark] def dataFrameRetainGroupColumns: Boolean = getConf(DATAFRAME_RETAIN_GROUP_COLUMNS) - private[spark] def dataFrameRetainGroupColumns: Boolean = - getConf(DATAFRAME_RETAIN_GROUP_COLUMNS, "true").toBoolean - /** ********************** SQLConf functionality methods ************ */ /** Set Spark SQL configuration properties. */ def setConf(props: Properties): Unit = settings.synchronized { - props.foreach { case (k, v) => settings.put(k, v) } + props.foreach { case (k, v) => setConfString(k, v) } } - /** Set the given Spark SQL configuration property. */ - def setConf(key: String, value: String): Unit = { + /** Set the given Spark SQL configuration property using a `string` value. */ + def setConfString(key: String, value: String): Unit = { require(key != null, "key cannot be null") require(value != null, s"value cannot be null for key: $key") + val entry = sqlConfEntries.get(key) + if (entry != null) { + // Only verify configs in the SQLConf object + entry.valueConverter(value) + } settings.put(key, value) } + /** Set the given Spark SQL configuration property. */ + def setConf[T](entry: SQLConfEntry[T], value: T): Unit = { + require(entry != null, "entry cannot be null") + require(value != null, s"value cannot be null for key: ${entry.key}") + require(sqlConfEntries.get(entry.key) == entry, s"$entry is not registered") + settings.put(entry.key, entry.stringConverter(value)) + } + /** Return the value of Spark SQL configuration property for the given key. */ - def getConf(key: String): String = { - Option(settings.get(key)).getOrElse(throw new NoSuchElementException(key)) + def getConfString(key: String): String = { + Option(settings.get(key)). + orElse { + // Try to use the default value + Option(sqlConfEntries.get(key)).map(_.defaultValueString) + }. + getOrElse(throw new NoSuchElementException(key)) + } + + /** + * Return the value of Spark SQL configuration property for the given key. If the key is not set + * yet, return `defaultValue`. This is useful when `defaultValue` in SQLConfEntry is not the + * desired one. + */ + def getConf[T](entry: SQLConfEntry[T], defaultValue: T): T = { + require(sqlConfEntries.get(entry.key) == entry, s"$entry is not registered") + Option(settings.get(entry.key)).map(entry.valueConverter).getOrElse(defaultValue) } /** * Return the value of Spark SQL configuration property for the given key. If the key is not set - * yet, return `defaultValue`. + * yet, return `defaultValue` in [[SQLConfEntry]]. */ - def getConf(key: String, defaultValue: String): String = { + def getConf[T](entry: SQLConfEntry[T]): T = { + require(sqlConfEntries.get(entry.key) == entry, s"$entry is not registered") + Option(settings.get(entry.key)).map(entry.valueConverter).orElse(entry.defaultValue). + getOrElse(throw new NoSuchElementException(entry.key)) + } + + /** + * Return the `string` value of Spark SQL configuration property for the given key. If the key is + * not set yet, return `defaultValue`. + */ + def getConfString(key: String, defaultValue: String): String = { + val entry = sqlConfEntries.get(key) + if (entry != null && defaultValue != "") { + // Only verify configs in the SQLConf object + entry.valueConverter(defaultValue) + } Option(settings.get(key)).getOrElse(defaultValue) } @@ -288,11 +625,25 @@ private[sql] class SQLConf extends Serializable with CatalystConf { */ def getAllConfs: immutable.Map[String, String] = settings.synchronized { settings.toMap } - private[spark] def unsetConf(key: String) { + /** + * Return all the configuration definitions that have been defined in [[SQLConf]]. Each + * definition contains key, defaultValue and doc. + */ + def getAllDefinedConfs: Seq[(String, String, String)] = sqlConfEntries.synchronized { + sqlConfEntries.values.filter(_.isPublic).map { entry => + (entry.key, entry.defaultValueString, entry.doc) + }.toSeq + } + + private[spark] def unsetConf(key: String): Unit = { settings -= key } - private[spark] def clear() { + private[spark] def unsetConf(entry: SQLConfEntry[_]): Unit = { + settings -= entry.key + } + + private[spark] def clear(): Unit = { settings.clear() } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala b/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala index 9fb355eb8193..e81371e7b0e8 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala @@ -19,6 +19,7 @@ package org.apache.spark.sql import java.beans.Introspector import java.util.Properties +import java.util.concurrent.atomic.AtomicReference import scala.collection.JavaConversions._ import scala.collection.immutable @@ -26,29 +27,23 @@ import scala.language.implicitConversions import scala.reflect.runtime.universe.TypeTag import scala.util.control.NonFatal -import com.google.common.reflect.TypeToken -import org.apache.hadoop.fs.Path - +import org.apache.spark.SparkContext import org.apache.spark.annotation.{DeveloperApi, Experimental} import org.apache.spark.api.java.{JavaRDD, JavaSparkContext} -import org.apache.spark.deploy.SparkHadoopUtil import org.apache.spark.rdd.RDD -import org.apache.spark.sql.catalyst._ +import org.apache.spark.sql.SQLConf.SQLConfEntry import org.apache.spark.sql.catalyst.analysis._ -import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.errors.DialectException +import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.optimizer.{DefaultOptimizer, Optimizer} import org.apache.spark.sql.catalyst.plans.logical.{LocalRelation, LogicalPlan} import org.apache.spark.sql.catalyst.rules.RuleExecutor -import org.apache.spark.sql.catalyst.ParserDialect +import org.apache.spark.sql.catalyst.{InternalRow, ParserDialect, _} import org.apache.spark.sql.execution.{Filter, _} -import org.apache.spark.sql.jdbc.{JDBCPartition, JDBCPartitioningInfo, JDBCRelation} -import org.apache.spark.sql.json._ -import org.apache.spark.sql.parquet.ParquetRelation2 import org.apache.spark.sql.sources._ import org.apache.spark.sql.types._ +import org.apache.spark.unsafe.types.UTF8String import org.apache.spark.util.Utils -import org.apache.spark.{Partition, SparkContext} /** * The entry point for working with structured data (rows and columns) in Spark. Allows the @@ -86,13 +81,16 @@ class SQLContext(@transient val sparkContext: SparkContext) */ def setConf(props: Properties): Unit = conf.setConf(props) + /** Set the given Spark SQL configuration property. */ + private[sql] def setConf[T](entry: SQLConfEntry[T], value: T): Unit = conf.setConf(entry, value) + /** * Set the given Spark SQL configuration property. * * @group config * @since 1.0.0 */ - def setConf(key: String, value: String): Unit = conf.setConf(key, value) + def setConf(key: String, value: String): Unit = conf.setConfString(key, value) /** * Return the value of Spark SQL configuration property for the given key. @@ -100,7 +98,22 @@ class SQLContext(@transient val sparkContext: SparkContext) * @group config * @since 1.0.0 */ - def getConf(key: String): String = conf.getConf(key) + def getConf(key: String): String = conf.getConfString(key) + + /** + * Return the value of Spark SQL configuration property for the given key. If the key is not set + * yet, return `defaultValue` in [[SQLConfEntry]]. + */ + private[sql] def getConf[T](entry: SQLConfEntry[T]): T = conf.getConf(entry) + + /** + * Return the value of Spark SQL configuration property for the given key. If the key is not set + * yet, return `defaultValue`. This is useful when `defaultValue` in SQLConfEntry is not the + * desired one. + */ + private[sql] def getConf[T](entry: SQLConfEntry[T], defaultValue: T): T = { + conf.getConf(entry, defaultValue) + } /** * Return the value of Spark SQL configuration property for the given key. If the key is not set @@ -109,7 +122,7 @@ class SQLContext(@transient val sparkContext: SparkContext) * @group config * @since 1.0.0 */ - def getConf(key: String, defaultValue: String): String = conf.getConf(key, defaultValue) + def getConf(key: String, defaultValue: String): String = conf.getConfString(key, defaultValue) /** * Return all the configuration properties that have been set (i.e. not the default). @@ -126,13 +139,14 @@ class SQLContext(@transient val sparkContext: SparkContext) // TODO how to handle the temp function per user session? @transient - protected[sql] lazy val functionRegistry: FunctionRegistry = new SimpleFunctionRegistry(true) + protected[sql] lazy val functionRegistry: FunctionRegistry = + new OverrideFunctionRegistry(FunctionRegistry.builtin) @transient protected[sql] lazy val analyzer: Analyzer = new Analyzer(catalog, functionRegistry, conf) { override val extendedResolutionRules = - ExtractPythonUdfs :: + ExtractPythonUDFs :: sources.PreInsertCastAndRename :: Nil @@ -188,9 +202,28 @@ class SQLContext(@transient val sparkContext: SparkContext) conf.dialect } - sparkContext.getConf.getAll.foreach { - case (key, value) if key.startsWith("spark.sql") => setConf(key, value) - case _ => + { + // We extract spark sql settings from SparkContext's conf and put them to + // Spark SQL's conf. + // First, we populate the SQLConf (conf). So, we can make sure that other values using + // those settings in their construction can get the correct settings. + // For example, metadataHive in HiveContext may need both spark.sql.hive.metastore.version + // and spark.sql.hive.metastore.jars to get correctly constructed. + val properties = new Properties + sparkContext.getConf.getAll.foreach { + case (key, value) if key.startsWith("spark.sql") => properties.setProperty(key, value) + case _ => + } + // We directly put those settings to conf to avoid of calling setConf, which may have + // side-effects. For example, in HiveContext, setConf may cause executionHive and metadataHive + // get constructed. If we call setConf directly, the constructed metadataHive may have + // wrong settings, or the construction may fail. + conf.setConf(properties) + // After we have populated SQLConf, we call setConf to populate other confs in the subclass + // (e.g. hiveconf in HiveContext). + properties.foreach { + case (key, value) => setConf(key, value) + } } @transient @@ -224,7 +257,7 @@ class SQLContext(@transient val sparkContext: SparkContext) * * The following example registers a Scala closure as UDF: * {{{ - * sqlContext.udf.register("myUdf", (arg1: Int, arg2: String) => arg2 + arg1) + * sqlContext.udf.register("myUDF", (arg1: Int, arg2: String) => arg2 + arg1) * }}} * * The following example registers a UDF in Java: @@ -241,7 +274,7 @@ class SQLContext(@transient val sparkContext: SparkContext) * Or, to use Java 8 lambda syntax: * {{{ * sqlContext.udf().register("myUDF", - * (Integer arg1, String arg2) -> arg2 + arg1), + * (Integer arg1, String arg2) -> arg2 + arg1, * DataTypes.StringType); * }}} * @@ -304,7 +337,7 @@ class SQLContext(@transient val sparkContext: SparkContext) */ implicit class StringToColumn(val sc: StringContext) { def $(args: Any*): ColumnName = { - new ColumnName(sc.s(args :_*)) + new ColumnName(sc.s(args : _*)) } } @@ -345,10 +378,11 @@ class SQLContext(@transient val sparkContext: SparkContext) val row = new SpecificMutableRow(dataType :: Nil) iter.map { v => row.setInt(0, v) - row: Row + row: InternalRow } } - DataFrameHolder(self.createDataFrame(rows, StructType(StructField("_1", dataType) :: Nil))) + DataFrameHolder( + self.internalCreateDataFrame(rows, StructType(StructField("_1", dataType) :: Nil))) } /** @@ -361,10 +395,11 @@ class SQLContext(@transient val sparkContext: SparkContext) val row = new SpecificMutableRow(dataType :: Nil) iter.map { v => row.setLong(0, v) - row: Row + row: InternalRow } } - DataFrameHolder(self.createDataFrame(rows, StructType(StructField("_1", dataType) :: Nil))) + DataFrameHolder( + self.internalCreateDataFrame(rows, StructType(StructField("_1", dataType) :: Nil))) } /** @@ -376,11 +411,12 @@ class SQLContext(@transient val sparkContext: SparkContext) val rows = data.mapPartitions { iter => val row = new SpecificMutableRow(dataType :: Nil) iter.map { v => - row.setString(0, v) - row: Row + row.update(0, UTF8String.fromString(v)) + row: InternalRow } } - DataFrameHolder(self.createDataFrame(rows, StructType(StructField("_1", dataType) :: Nil))) + DataFrameHolder( + self.internalCreateDataFrame(rows, StructType(StructField("_1", dataType) :: Nil))) } } @@ -396,7 +432,7 @@ class SQLContext(@transient val sparkContext: SparkContext) SparkPlan.currentContext.set(self) val schema = ScalaReflection.schemaFor[A].dataType.asInstanceOf[StructType] val attributeSeq = schema.toAttributes - val rowRDD = RDDConversions.productToRowRdd(rdd, schema) + val rowRDD = RDDConversions.productToRowRdd(rdd, schema.map(_.dataType)) DataFrame(self, LogicalRDD(attributeSeq, rowRDD)(self)) } @@ -472,14 +508,26 @@ class SQLContext(@transient val sparkContext: SparkContext) // schema differs from the existing schema on any field data type. val catalystRows = if (needsConversion) { val converter = CatalystTypeConverters.createToCatalystConverter(schema) - rowRDD.map(converter(_).asInstanceOf[Row]) + rowRDD.map(converter(_).asInstanceOf[InternalRow]) } else { - rowRDD + rowRDD.map{r: Row => InternalRow.fromSeq(r.toSeq)} } val logicalPlan = LogicalRDD(schema.toAttributes, catalystRows)(self) DataFrame(this, logicalPlan) } + /** + * Creates a DataFrame from an RDD[Row]. User can specify whether the input rows should be + * converted to Catalyst rows. + */ + private[sql] + def internalCreateDataFrame(catalystRows: RDD[InternalRow], schema: StructType) = { + // TODO: use MutableProjection when rowRDD is another DataFrame and the applied + // schema differs from the existing schema on any field data type. + val logicalPlan = LogicalRDD(schema.toAttributes, catalystRows)(self) + DataFrame(this, logicalPlan) + } + /** * :: DeveloperApi :: * Creates a [[DataFrame]] from an [[JavaRDD]] containing [[Row]]s using the given schema. @@ -511,13 +559,13 @@ class SQLContext(@transient val sparkContext: SparkContext) Class.forName(className, true, Utils.getContextOrSparkClassLoader)) val extractors = localBeanInfo.getPropertyDescriptors.filterNot(_.getName == "class").map(_.getReadMethod) - + val methodsToConverts = extractors.zip(attributeSeq).map { case (e, attr) => + (e, CatalystTypeConverters.createToCatalystConverter(attr.dataType)) + } iter.map { row => - new GenericRow( - extractors.zip(attributeSeq).map { case (e, attr) => - CatalystTypeConverters.convertToCatalyst(e.invoke(row), attr.dataType) - }.toArray[Any] - ) : Row + new GenericInternalRow( + methodsToConverts.map { case (e, convert) => convert(e.invoke(row)) }.toArray[Any] + ): InternalRow } } DataFrame(this, LogicalRDD(attributeSeq, rowRdd)(this)) @@ -536,619 +584,225 @@ class SQLContext(@transient val sparkContext: SparkContext) } /** - * :: DeveloperApi :: - * Creates a [[DataFrame]] from an [[RDD]] containing [[Row]]s by applying a schema to this RDD. - * It is important to make sure that the structure of every [[Row]] of the provided RDD matches - * the provided schema. Otherwise, there will be runtime exception. - * Example: + * :: Experimental :: + * Returns a [[DataFrameReader]] that can be used to read data in as a [[DataFrame]]. * {{{ - * import org.apache.spark.sql._ - * import org.apache.spark.sql.types._ - * val sqlContext = new org.apache.spark.sql.SQLContext(sc) - * - * val schema = - * StructType( - * StructField("name", StringType, false) :: - * StructField("age", IntegerType, true) :: Nil) - * - * val people = - * sc.textFile("examples/src/main/resources/people.txt").map( - * _.split(",")).map(p => Row(p(0), p(1).trim.toInt)) - * val dataFrame = sqlContext. applySchema(people, schema) - * dataFrame.printSchema - * // root - * // |-- name: string (nullable = false) - * // |-- age: integer (nullable = true) - * - * dataFrame.registerTempTable("people") - * sqlContext.sql("select name from people").collect.foreach(println) + * sqlContext.read.parquet("/path/to/file.parquet") + * sqlContext.read.schema(schema).json("/path/to/file.json") * }}} + * + * @group genericdata + * @since 1.4.0 */ - @deprecated("use createDataFrame", "1.3.0") - def applySchema(rowRDD: RDD[Row], schema: StructType): DataFrame = { - createDataFrame(rowRDD, schema) - } - - @deprecated("use createDataFrame", "1.3.0") - def applySchema(rowRDD: JavaRDD[Row], schema: StructType): DataFrame = { - createDataFrame(rowRDD, schema) - } + @Experimental + def read: DataFrameReader = new DataFrameReader(this) /** - * Applies a schema to an RDD of Java Beans. + * :: Experimental :: + * Creates an external table from the given path and returns the corresponding DataFrame. + * It will use the default data source configured by spark.sql.sources.default. * - * WARNING: Since there is no guaranteed ordering for fields in a Java Bean, - * SELECT * queries will return the columns in an undefined order. + * @group ddl_ops + * @since 1.3.0 */ - @deprecated("use createDataFrame", "1.3.0") - def applySchema(rdd: RDD[_], beanClass: Class[_]): DataFrame = { - createDataFrame(rdd, beanClass) + @Experimental + def createExternalTable(tableName: String, path: String): DataFrame = { + val dataSourceName = conf.defaultDataSourceName + createExternalTable(tableName, path, dataSourceName) } /** - * Applies a schema to an RDD of Java Beans. + * :: Experimental :: + * Creates an external table from the given path based on a data source + * and returns the corresponding DataFrame. * - * WARNING: Since there is no guaranteed ordering for fields in a Java Bean, - * SELECT * queries will return the columns in an undefined order. + * @group ddl_ops + * @since 1.3.0 */ - @deprecated("use createDataFrame", "1.3.0") - def applySchema(rdd: JavaRDD[_], beanClass: Class[_]): DataFrame = { - createDataFrame(rdd, beanClass) + @Experimental + def createExternalTable( + tableName: String, + path: String, + source: String): DataFrame = { + createExternalTable(tableName, source, Map("path" -> path)) } /** - * Loads a Parquet file, returning the result as a [[DataFrame]]. This function returns an empty - * [[DataFrame]] if no paths are passed in. + * :: Experimental :: + * Creates an external table from the given path based on a data source and a set of options. + * Then, returns the corresponding DataFrame. * - * @group specificdata + * @group ddl_ops * @since 1.3.0 */ - @scala.annotation.varargs - def parquetFile(paths: String*): DataFrame = { - if (paths.isEmpty) { - emptyDataFrame - } else if (conf.parquetUseDataSourceApi) { - val globbedPaths = paths.map(new Path(_)).flatMap(SparkHadoopUtil.get.globPath).toArray - baseRelationToDataFrame( - new ParquetRelation2( - globbedPaths.map(_.toString), None, None, Map.empty[String, String])(this)) - } else { - DataFrame(this, parquet.ParquetRelation( - paths.mkString(","), Some(sparkContext.hadoopConfiguration), this)) - } + @Experimental + def createExternalTable( + tableName: String, + source: String, + options: java.util.Map[String, String]): DataFrame = { + createExternalTable(tableName, source, options.toMap) } /** - * Loads a JSON file (one object per line), returning the result as a [[DataFrame]]. - * It goes through the entire dataset once to determine the schema. + * :: Experimental :: + * (Scala-specific) + * Creates an external table from the given path based on a data source and a set of options. + * Then, returns the corresponding DataFrame. * - * @group specificdata + * @group ddl_ops * @since 1.3.0 */ - def jsonFile(path: String): DataFrame = jsonFile(path, 1.0) + @Experimental + def createExternalTable( + tableName: String, + source: String, + options: Map[String, String]): DataFrame = { + val cmd = + CreateTableUsing( + tableName, + userSpecifiedSchema = None, + source, + temporary = false, + options, + allowExisting = false, + managedIfNoPath = false) + executePlan(cmd).toRdd + table(tableName) + } /** * :: Experimental :: - * Loads a JSON file (one object per line) and applies the given schema, - * returning the result as a [[DataFrame]]. + * Create an external table from the given path based on a data source, a schema and + * a set of options. Then, returns the corresponding DataFrame. * - * @group specificdata + * @group ddl_ops * @since 1.3.0 */ @Experimental - def jsonFile(path: String, schema: StructType): DataFrame = - load("json", schema, Map("path" -> path)) + def createExternalTable( + tableName: String, + source: String, + schema: StructType, + options: java.util.Map[String, String]): DataFrame = { + createExternalTable(tableName, source, schema, options.toMap) + } /** * :: Experimental :: - * @group specificdata + * (Scala-specific) + * Create an external table from the given path based on a data source, a schema and + * a set of options. Then, returns the corresponding DataFrame. + * + * @group ddl_ops * @since 1.3.0 */ @Experimental - def jsonFile(path: String, samplingRatio: Double): DataFrame = - load("json", Map("path" -> path, "samplingRatio" -> samplingRatio.toString)) + def createExternalTable( + tableName: String, + source: String, + schema: StructType, + options: Map[String, String]): DataFrame = { + val cmd = + CreateTableUsing( + tableName, + userSpecifiedSchema = Some(schema), + source, + temporary = false, + options, + allowExisting = false, + managedIfNoPath = false) + executePlan(cmd).toRdd + table(tableName) + } /** - * Loads an RDD[String] storing JSON objects (one object per record), returning the result as a - * [[DataFrame]]. - * It goes through the entire dataset once to determine the schema. - * - * @group specificdata - * @since 1.3.0 + * Registers the given [[DataFrame]] as a temporary table in the catalog. Temporary tables exist + * only during the lifetime of this instance of SQLContext. */ - def jsonRDD(json: RDD[String]): DataFrame = jsonRDD(json, 1.0) - + private[sql] def registerDataFrameAsTable(df: DataFrame, tableName: String): Unit = { + catalog.registerTable(Seq(tableName), df.logicalPlan) + } /** - * Loads an RDD[String] storing JSON objects (one object per record), returning the result as a - * [[DataFrame]]. - * It goes through the entire dataset once to determine the schema. + * Drops the temporary table with the given table name in the catalog. If the table has been + * cached/persisted before, it's also unpersisted. * - * @group specificdata + * @param tableName the name of the table to be unregistered. + * + * @group basic * @since 1.3.0 */ - def jsonRDD(json: JavaRDD[String]): DataFrame = jsonRDD(json.rdd, 1.0) + def dropTempTable(tableName: String): Unit = { + cacheManager.tryUncacheQuery(table(tableName)) + catalog.unregisterTable(Seq(tableName)) + } /** * :: Experimental :: - * Loads an RDD[String] storing JSON objects (one object per record) and applies the given schema, - * returning the result as a [[DataFrame]]. + * Creates a [[DataFrame]] with a single [[LongType]] column named `id`, containing elements + * in an range from 0 to `end` (exclusive) with step value 1. * - * @group specificdata - * @since 1.3.0 + * @since 1.4.1 + * @group dataframe */ @Experimental - def jsonRDD(json: RDD[String], schema: StructType): DataFrame = { - if (conf.useJacksonStreamingAPI) { - baseRelationToDataFrame(new JSONRelation(() => json, None, 1.0, Some(schema))(this)) - } else { - val columnNameOfCorruptJsonRecord = conf.columnNameOfCorruptRecord - val appliedSchema = - Option(schema).getOrElse( - JsonRDD.nullTypeToStringType( - JsonRDD.inferSchema(json, 1.0, columnNameOfCorruptJsonRecord))) - val rowRDD = JsonRDD.jsonStringToRow(json, appliedSchema, columnNameOfCorruptJsonRecord) - createDataFrame(rowRDD, appliedSchema, needsConversion = false) - } - } + def range(end: Long): DataFrame = range(0, end) /** * :: Experimental :: - * Loads an JavaRDD storing JSON objects (one object per record) and applies the given - * schema, returning the result as a [[DataFrame]]. + * Creates a [[DataFrame]] with a single [[LongType]] column named `id`, containing elements + * in an range from `start` to `end` (exclusive) with step value 1. * - * @group specificdata - * @since 1.3.0 + * @since 1.4.0 + * @group dataframe */ @Experimental - def jsonRDD(json: JavaRDD[String], schema: StructType): DataFrame = { - jsonRDD(json.rdd, schema) + def range(start: Long, end: Long): DataFrame = { + createDataFrame( + sparkContext.range(start, end).map(Row(_)), + StructType(StructField("id", LongType, nullable = false) :: Nil)) } /** * :: Experimental :: - * Loads an RDD[String] storing JSON objects (one object per record) inferring the - * schema, returning the result as a [[DataFrame]]. + * Creates a [[DataFrame]] with a single [[LongType]] column named `id`, containing elements + * in an range from `start` to `end` (exclusive) with an step value, with partition number + * specified. * - * @group specificdata - * @since 1.3.0 + * @since 1.4.0 + * @group dataframe */ @Experimental - def jsonRDD(json: RDD[String], samplingRatio: Double): DataFrame = { - if (conf.useJacksonStreamingAPI) { - baseRelationToDataFrame(new JSONRelation(() => json, None, samplingRatio, None)(this)) - } else { - val columnNameOfCorruptJsonRecord = conf.columnNameOfCorruptRecord - val appliedSchema = - JsonRDD.nullTypeToStringType( - JsonRDD.inferSchema(json, samplingRatio, columnNameOfCorruptJsonRecord)) - val rowRDD = JsonRDD.jsonStringToRow(json, appliedSchema, columnNameOfCorruptJsonRecord) - createDataFrame(rowRDD, appliedSchema, needsConversion = false) - } + def range(start: Long, end: Long, step: Long, numPartitions: Int): DataFrame = { + createDataFrame( + sparkContext.range(start, end, step, numPartitions).map(Row(_)), + StructType(StructField("id", LongType, nullable = false) :: Nil)) } /** - * :: Experimental :: - * Loads a JavaRDD[String] storing JSON objects (one object per record) inferring the - * schema, returning the result as a [[DataFrame]]. + * Executes a SQL query using Spark, returning the result as a [[DataFrame]]. The dialect that is + * used for SQL parsing can be configured with 'spark.sql.dialect'. * - * @group specificdata + * @group basic * @since 1.3.0 */ - @Experimental - def jsonRDD(json: JavaRDD[String], samplingRatio: Double): DataFrame = { - jsonRDD(json.rdd, samplingRatio); + def sql(sqlText: String): DataFrame = { + DataFrame(this, parseSql(sqlText)) } /** - * :: Experimental :: - * Returns the dataset stored at path as a DataFrame, - * using the default data source configured by spark.sql.sources.default. + * Returns the specified table as a [[DataFrame]]. * - * @group genericdata + * @group ddl_ops * @since 1.3.0 */ - @Experimental - def load(path: String): DataFrame = { - val dataSourceName = conf.defaultDataSourceName - load(path, dataSourceName) - } + def table(tableName: String): DataFrame = + DataFrame(this, catalog.lookupRelation(Seq(tableName))) /** - * :: Experimental :: - * Returns the dataset stored at path as a DataFrame, using the given data source. - * - * @group genericdata - * @since 1.3.0 - */ - @Experimental - def load(path: String, source: String): DataFrame = { - load(source, Map("path" -> path)) - } - - /** - * :: Experimental :: - * (Java-specific) Returns the dataset specified by the given data source and - * a set of options as a DataFrame. - * - * @group genericdata - * @since 1.3.0 - */ - @Experimental - def load(source: String, options: java.util.Map[String, String]): DataFrame = { - load(source, options.toMap) - } - - /** - * :: Experimental :: - * (Scala-specific) Returns the dataset specified by the given data source and - * a set of options as a DataFrame. - * - * @group genericdata - * @since 1.3.0 - */ - @Experimental - def load(source: String, options: Map[String, String]): DataFrame = { - val resolved = ResolvedDataSource(this, None, Array.empty[String], source, options) - DataFrame(this, LogicalRelation(resolved.relation)) - } - - /** - * :: Experimental :: - * (Java-specific) Returns the dataset specified by the given data source and - * a set of options as a DataFrame, using the given schema as the schema of the DataFrame. - * - * @group genericdata - * @since 1.3.0 - */ - @Experimental - def load( - source: String, - schema: StructType, - options: java.util.Map[String, String]): DataFrame = { - load(source, schema, options.toMap) - } - - /** - * :: Experimental :: - * (Java-specific) Returns the dataset specified by the given data source and - * a set of options as a DataFrame, using the given schema as the schema of the DataFrame. - * - * @group genericdata - * @since 1.3.0 - */ - @Experimental - def load( - source: String, - schema: StructType, - partitionColumns: Array[String], - options: java.util.Map[String, String]): DataFrame = { - load(source, schema, partitionColumns, options.toMap) - } - - /** - * :: Experimental :: - * (Scala-specific) Returns the dataset specified by the given data source and - * a set of options as a DataFrame, using the given schema as the schema of the DataFrame. - * @group genericdata - * @since 1.3.0 - */ - @Experimental - def load( - source: String, - schema: StructType, - options: Map[String, String]): DataFrame = { - val resolved = ResolvedDataSource(this, Some(schema), Array.empty[String], source, options) - DataFrame(this, LogicalRelation(resolved.relation)) - } - - /** - * :: Experimental :: - * (Scala-specific) Returns the dataset specified by the given data source and - * a set of options as a DataFrame, using the given schema as the schema of the DataFrame. - * @group genericdata - * @since 1.3.0 - */ - @Experimental - def load( - source: String, - schema: StructType, - partitionColumns: Array[String], - options: Map[String, String]): DataFrame = { - val resolved = ResolvedDataSource(this, Some(schema), partitionColumns, source, options) - DataFrame(this, LogicalRelation(resolved.relation)) - } - - /** - * :: Experimental :: - * Creates an external table from the given path and returns the corresponding DataFrame. - * It will use the default data source configured by spark.sql.sources.default. - * - * @group ddl_ops - * @since 1.3.0 - */ - @Experimental - def createExternalTable(tableName: String, path: String): DataFrame = { - val dataSourceName = conf.defaultDataSourceName - createExternalTable(tableName, path, dataSourceName) - } - - /** - * :: Experimental :: - * Creates an external table from the given path based on a data source - * and returns the corresponding DataFrame. - * - * @group ddl_ops - * @since 1.3.0 - */ - @Experimental - def createExternalTable( - tableName: String, - path: String, - source: String): DataFrame = { - createExternalTable(tableName, source, Map("path" -> path)) - } - - /** - * :: Experimental :: - * Creates an external table from the given path based on a data source and a set of options. - * Then, returns the corresponding DataFrame. - * - * @group ddl_ops - * @since 1.3.0 - */ - @Experimental - def createExternalTable( - tableName: String, - source: String, - options: java.util.Map[String, String]): DataFrame = { - createExternalTable(tableName, source, options.toMap) - } - - /** - * :: Experimental :: - * (Scala-specific) - * Creates an external table from the given path based on a data source and a set of options. - * Then, returns the corresponding DataFrame. - * - * @group ddl_ops - * @since 1.3.0 - */ - @Experimental - def createExternalTable( - tableName: String, - source: String, - options: Map[String, String]): DataFrame = { - val cmd = - CreateTableUsing( - tableName, - userSpecifiedSchema = None, - source, - temporary = false, - options, - allowExisting = false, - managedIfNoPath = false) - executePlan(cmd).toRdd - table(tableName) - } - - /** - * :: Experimental :: - * Create an external table from the given path based on a data source, a schema and - * a set of options. Then, returns the corresponding DataFrame. - * - * @group ddl_ops - * @since 1.3.0 - */ - @Experimental - def createExternalTable( - tableName: String, - source: String, - schema: StructType, - options: java.util.Map[String, String]): DataFrame = { - createExternalTable(tableName, source, schema, options.toMap) - } - - /** - * :: Experimental :: - * (Scala-specific) - * Create an external table from the given path based on a data source, a schema and - * a set of options. Then, returns the corresponding DataFrame. - * - * @group ddl_ops - * @since 1.3.0 - */ - @Experimental - def createExternalTable( - tableName: String, - source: String, - schema: StructType, - options: Map[String, String]): DataFrame = { - val cmd = - CreateTableUsing( - tableName, - userSpecifiedSchema = Some(schema), - source, - temporary = false, - options, - allowExisting = false, - managedIfNoPath = false) - executePlan(cmd).toRdd - table(tableName) - } - - /** - * :: Experimental :: - * Construct a [[DataFrame]] representing the database table accessible via JDBC URL - * url named table. - * - * @group specificdata - * @since 1.3.0 - */ - @Experimental - def jdbc(url: String, table: String): DataFrame = { - jdbc(url, table, JDBCRelation.columnPartition(null), new Properties()) - } - - /** - * :: Experimental :: - * Construct a [[DataFrame]] representing the database table accessible via JDBC URL - * url named table and connection properties. - * - * @group specificdata - * @since 1.4.0 - */ - @Experimental - def jdbc(url: String, table: String, properties: Properties): DataFrame = { - jdbc(url, table, JDBCRelation.columnPartition(null), properties) - } - - /** - * :: Experimental :: - * Construct a [[DataFrame]] representing the database table accessible via JDBC URL - * url named table. Partitions of the table will be retrieved in parallel based on the parameters - * passed to this function. - * - * @param columnName the name of a column of integral type that will be used for partitioning. - * @param lowerBound the minimum value of `columnName` used to decide partition stride - * @param upperBound the maximum value of `columnName` used to decide partition stride - * @param numPartitions the number of partitions. the range `minValue`-`maxValue` will be split - * evenly into this many partitions - * @group specificdata - * @since 1.3.0 - */ - @Experimental - def jdbc( - url: String, - table: String, - columnName: String, - lowerBound: Long, - upperBound: Long, - numPartitions: Int): DataFrame = { - jdbc(url, table, columnName, lowerBound, upperBound, numPartitions, new Properties()) - } - - /** - * :: Experimental :: - * Construct a [[DataFrame]] representing the database table accessible via JDBC URL - * url named table. Partitions of the table will be retrieved in parallel based on the parameters - * passed to this function. - * - * @param columnName the name of a column of integral type that will be used for partitioning. - * @param lowerBound the minimum value of `columnName` used to decide partition stride - * @param upperBound the maximum value of `columnName` used to decide partition stride - * @param numPartitions the number of partitions. the range `minValue`-`maxValue` will be split - * evenly into this many partitions - * @param properties connection properties - * @group specificdata - * @since 1.4.0 - */ - @Experimental - def jdbc( - url: String, - table: String, - columnName: String, - lowerBound: Long, - upperBound: Long, - numPartitions: Int, - properties: Properties): DataFrame = { - val partitioning = JDBCPartitioningInfo(columnName, lowerBound, upperBound, numPartitions) - val parts = JDBCRelation.columnPartition(partitioning) - jdbc(url, table, parts, properties) - } - - /** - * :: Experimental :: - * Construct a [[DataFrame]] representing the database table accessible via JDBC URL - * url named table. The theParts parameter gives a list expressions - * suitable for inclusion in WHERE clauses; each one defines one partition - * of the [[DataFrame]]. - * - * @group specificdata - * @since 1.3.0 - */ - @Experimental - def jdbc(url: String, table: String, theParts: Array[String]): DataFrame = { - jdbc(url, table, theParts, new Properties()) - } - - /** - * :: Experimental :: - * Construct a [[DataFrame]] representing the database table accessible via JDBC URL - * url named table using connection properties. The theParts parameter gives a list expressions - * suitable for inclusion in WHERE clauses; each one defines one partition - * of the [[DataFrame]]. - * - * @group specificdata - * @since 1.4.0 - */ - @Experimental - def jdbc( - url: String, - table: String, - theParts: Array[String], - properties: Properties): DataFrame = { - val parts: Array[Partition] = theParts.zipWithIndex.map { case (part, i) => - JDBCPartition(part, i) : Partition - } - jdbc(url, table, parts, properties) - } - - private def jdbc( - url: String, - table: String, - parts: Array[Partition], - properties: Properties): DataFrame = { - val relation = JDBCRelation(url, table, parts, properties)(this) - baseRelationToDataFrame(relation) - } - - /** - * Registers the given [[DataFrame]] as a temporary table in the catalog. Temporary tables exist - * only during the lifetime of this instance of SQLContext. - */ - private[sql] def registerDataFrameAsTable(df: DataFrame, tableName: String): Unit = { - catalog.registerTable(Seq(tableName), df.logicalPlan) - } - - /** - * Drops the temporary table with the given table name in the catalog. If the table has been - * cached/persisted before, it's also unpersisted. - * - * @param tableName the name of the table to be unregistered. - * - * @group basic - * @since 1.3.0 - */ - def dropTempTable(tableName: String): Unit = { - cacheManager.tryUncacheQuery(table(tableName)) - catalog.unregisterTable(Seq(tableName)) - } - - /** - * Executes a SQL query using Spark, returning the result as a [[DataFrame]]. The dialect that is - * used for SQL parsing can be configured with 'spark.sql.dialect'. - * - * @group basic - * @since 1.3.0 - */ - def sql(sqlText: String): DataFrame = { - DataFrame(this, parseSql(sqlText)) - } - - /** - * Returns the specified table as a [[DataFrame]]. - * - * @group ddl_ops - * @since 1.3.0 - */ - def table(tableName: String): DataFrame = - DataFrame(this, catalog.lookupRelation(Seq(tableName))) - - /** - * Returns a [[DataFrame]] containing names of existing tables in the current database. - * The returned DataFrame has two columns, tableName and isTemporary (a Boolean - * indicating if a table is a temporary one or not). + * Returns a [[DataFrame]] containing names of existing tables in the current database. + * The returned DataFrame has two columns, tableName and isTemporary (a Boolean + * indicating if a table is a temporary one or not). * * @group ddl_ops * @since 1.3.0 @@ -1208,7 +862,7 @@ class SQLContext(@transient val sparkContext: SparkContext) experimental.extraStrategies ++ ( DataSourceStrategy :: DDLStrategy :: - TakeOrdered :: + TakeOrderedAndProject :: HashAggregation :: LeftSemiJoin :: HashJoin :: @@ -1266,7 +920,7 @@ class SQLContext(@transient val sparkContext: SparkContext) protected[sql] val planner = new SparkPlanner @transient - protected[sql] lazy val emptyResult = sparkContext.parallelize(Seq.empty[Row], 1) + protected[sql] lazy val emptyResult = sparkContext.parallelize(Seq.empty[InternalRow], 1) /** * Prepares a planned SparkPlan for execution by inserting shuffle operations as needed. @@ -1297,6 +951,11 @@ class SQLContext(@transient val sparkContext: SparkContext) tlSession.remove() } + protected[sql] def setSession(session: SQLSession): Unit = { + detachSession() + tlSession.set(session) + } + protected[sql] class SQLSession { // Note that this is a lazy val so we can override the default value in subclasses. protected[sql] lazy val conf: SQLConf = new SQLConf @@ -1328,7 +987,7 @@ class SQLContext(@transient val sparkContext: SparkContext) lazy val executedPlan: SparkPlan = prepareForExecution.execute(sparkPlan) /** Internal version of the RDD. Avoids copies and has no schema */ - lazy val toRdd: RDD[Row] = executedPlan.execute() + lazy val toRdd: RDD[InternalRow] = executedPlan.execute() protected def stringOrError[A](f: => A): String = try f.toString catch { case e: Throwable => e.toString } @@ -1410,7 +1069,7 @@ class SQLContext(@transient val sparkContext: SparkContext) } val rowRdd = convertedRdd.mapPartitions { iter => - iter.map { m => new GenericRow(m): Row} + iter.map { m => new GenericInternalRow(m): InternalRow} } DataFrame(this, LogicalRDD(schema.toAttributes, rowRdd)(self)) @@ -1420,12 +1079,346 @@ class SQLContext(@transient val sparkContext: SparkContext) * Returns a Catalyst Schema for the given java bean class. */ protected def getSchema(beanClass: Class[_]): Seq[AttributeReference] = { - val (dataType, _) = JavaTypeInference.inferDataType(TypeToken.of(beanClass)) + val (dataType, _) = JavaTypeInference.inferDataType(beanClass) dataType.asInstanceOf[StructType].fields.map { f => AttributeReference(f.name, f.dataType, f.nullable)() } } + //////////////////////////////////////////////////////////////////////////// + //////////////////////////////////////////////////////////////////////////// + // Deprecated methods + //////////////////////////////////////////////////////////////////////////// + //////////////////////////////////////////////////////////////////////////// + + /** + * @deprecated As of 1.3.0, replaced by `createDataFrame()`. + */ + @deprecated("use createDataFrame", "1.3.0") + def applySchema(rowRDD: RDD[Row], schema: StructType): DataFrame = { + createDataFrame(rowRDD, schema) + } + + /** + * @deprecated As of 1.3.0, replaced by `createDataFrame()`. + */ + @deprecated("use createDataFrame", "1.3.0") + def applySchema(rowRDD: JavaRDD[Row], schema: StructType): DataFrame = { + createDataFrame(rowRDD, schema) + } + + /** + * @deprecated As of 1.3.0, replaced by `createDataFrame()`. + */ + @deprecated("use createDataFrame", "1.3.0") + def applySchema(rdd: RDD[_], beanClass: Class[_]): DataFrame = { + createDataFrame(rdd, beanClass) + } + + /** + * @deprecated As of 1.3.0, replaced by `createDataFrame()`. + */ + @deprecated("use createDataFrame", "1.3.0") + def applySchema(rdd: JavaRDD[_], beanClass: Class[_]): DataFrame = { + createDataFrame(rdd, beanClass) + } + + /** + * Loads a Parquet file, returning the result as a [[DataFrame]]. This function returns an empty + * [[DataFrame]] if no paths are passed in. + * + * @group specificdata + * @deprecated As of 1.4.0, replaced by `read().parquet()`. + */ + @deprecated("Use read.parquet()", "1.4.0") + @scala.annotation.varargs + def parquetFile(paths: String*): DataFrame = { + if (paths.isEmpty) { + emptyDataFrame + } else if (conf.parquetUseDataSourceApi) { + read.parquet(paths : _*) + } else { + DataFrame(this, parquet.ParquetRelation( + paths.mkString(","), Some(sparkContext.hadoopConfiguration), this)) + } + } + + /** + * Loads a JSON file (one object per line), returning the result as a [[DataFrame]]. + * It goes through the entire dataset once to determine the schema. + * + * @group specificdata + * @deprecated As of 1.4.0, replaced by `read().json()`. + */ + @deprecated("Use read.json()", "1.4.0") + def jsonFile(path: String): DataFrame = { + read.json(path) + } + + /** + * Loads a JSON file (one object per line) and applies the given schema, + * returning the result as a [[DataFrame]]. + * + * @group specificdata + * @deprecated As of 1.4.0, replaced by `read().json()`. + */ + @deprecated("Use read.json()", "1.4.0") + def jsonFile(path: String, schema: StructType): DataFrame = { + read.schema(schema).json(path) + } + + /** + * @group specificdata + * @deprecated As of 1.4.0, replaced by `read().json()`. + */ + @deprecated("Use read.json()", "1.4.0") + def jsonFile(path: String, samplingRatio: Double): DataFrame = { + read.option("samplingRatio", samplingRatio.toString).json(path) + } + + /** + * Loads an RDD[String] storing JSON objects (one object per record), returning the result as a + * [[DataFrame]]. + * It goes through the entire dataset once to determine the schema. + * + * @group specificdata + * @deprecated As of 1.4.0, replaced by `read().json()`. + */ + @deprecated("Use read.json()", "1.4.0") + def jsonRDD(json: RDD[String]): DataFrame = read.json(json) + + /** + * Loads an RDD[String] storing JSON objects (one object per record), returning the result as a + * [[DataFrame]]. + * It goes through the entire dataset once to determine the schema. + * + * @group specificdata + * @deprecated As of 1.4.0, replaced by `read().json()`. + */ + @deprecated("Use read.json()", "1.4.0") + def jsonRDD(json: JavaRDD[String]): DataFrame = read.json(json) + + /** + * Loads an RDD[String] storing JSON objects (one object per record) and applies the given schema, + * returning the result as a [[DataFrame]]. + * + * @group specificdata + * @deprecated As of 1.4.0, replaced by `read().json()`. + */ + @deprecated("Use read.json()", "1.4.0") + def jsonRDD(json: RDD[String], schema: StructType): DataFrame = { + read.schema(schema).json(json) + } + + /** + * Loads an JavaRDD storing JSON objects (one object per record) and applies the given + * schema, returning the result as a [[DataFrame]]. + * + * @group specificdata + * @deprecated As of 1.4.0, replaced by `read().json()`. + */ + @deprecated("Use read.json()", "1.4.0") + def jsonRDD(json: JavaRDD[String], schema: StructType): DataFrame = { + read.schema(schema).json(json) + } + + /** + * Loads an RDD[String] storing JSON objects (one object per record) inferring the + * schema, returning the result as a [[DataFrame]]. + * + * @group specificdata + * @deprecated As of 1.4.0, replaced by `read().json()`. + */ + @deprecated("Use read.json()", "1.4.0") + def jsonRDD(json: RDD[String], samplingRatio: Double): DataFrame = { + read.option("samplingRatio", samplingRatio.toString).json(json) + } + + /** + * Loads a JavaRDD[String] storing JSON objects (one object per record) inferring the + * schema, returning the result as a [[DataFrame]]. + * + * @group specificdata + * @deprecated As of 1.4.0, replaced by `read().json()`. + */ + @deprecated("Use read.json()", "1.4.0") + def jsonRDD(json: JavaRDD[String], samplingRatio: Double): DataFrame = { + read.option("samplingRatio", samplingRatio.toString).json(json) + } + + /** + * Returns the dataset stored at path as a DataFrame, + * using the default data source configured by spark.sql.sources.default. + * + * @group genericdata + * @deprecated As of 1.4.0, replaced by `read().load(path)`. + */ + @deprecated("Use read.load(path)", "1.4.0") + def load(path: String): DataFrame = { + read.load(path) + } + + /** + * Returns the dataset stored at path as a DataFrame, using the given data source. + * + * @group genericdata + * @deprecated As of 1.4.0, replaced by `read().format(source).load(path)`. + */ + @deprecated("Use read.format(source).load(path)", "1.4.0") + def load(path: String, source: String): DataFrame = { + read.format(source).load(path) + } + + /** + * (Java-specific) Returns the dataset specified by the given data source and + * a set of options as a DataFrame. + * + * @group genericdata + * @deprecated As of 1.4.0, replaced by `read().format(source).options(options).load()`. + */ + @deprecated("Use read.format(source).options(options).load()", "1.4.0") + def load(source: String, options: java.util.Map[String, String]): DataFrame = { + read.options(options).format(source).load() + } + + /** + * (Scala-specific) Returns the dataset specified by the given data source and + * a set of options as a DataFrame. + * + * @group genericdata + * @deprecated As of 1.4.0, replaced by `read().format(source).options(options).load()`. + */ + @deprecated("Use read.format(source).options(options).load()", "1.4.0") + def load(source: String, options: Map[String, String]): DataFrame = { + read.options(options).format(source).load() + } + + /** + * (Java-specific) Returns the dataset specified by the given data source and + * a set of options as a DataFrame, using the given schema as the schema of the DataFrame. + * + * @group genericdata + * @deprecated As of 1.4.0, replaced by + * `read().format(source).schema(schema).options(options).load()`. + */ + @deprecated("Use read.format(source).schema(schema).options(options).load()", "1.4.0") + def load(source: String, schema: StructType, options: java.util.Map[String, String]): DataFrame = + { + read.format(source).schema(schema).options(options).load() + } + + /** + * (Scala-specific) Returns the dataset specified by the given data source and + * a set of options as a DataFrame, using the given schema as the schema of the DataFrame. + * + * @group genericdata + * @deprecated As of 1.4.0, replaced by + * `read().format(source).schema(schema).options(options).load()`. + */ + @deprecated("Use read.format(source).schema(schema).options(options).load()", "1.4.0") + def load(source: String, schema: StructType, options: Map[String, String]): DataFrame = { + read.format(source).schema(schema).options(options).load() + } + + /** + * Construct a [[DataFrame]] representing the database table accessible via JDBC URL + * url named table. + * + * @group specificdata + * @deprecated As of 1.4.0, replaced by `read().jdbc()`. + */ + @deprecated("use read.jdbc()", "1.4.0") + def jdbc(url: String, table: String): DataFrame = { + read.jdbc(url, table, new Properties) + } + + /** + * Construct a [[DataFrame]] representing the database table accessible via JDBC URL + * url named table. Partitions of the table will be retrieved in parallel based on the parameters + * passed to this function. + * + * @param columnName the name of a column of integral type that will be used for partitioning. + * @param lowerBound the minimum value of `columnName` used to decide partition stride + * @param upperBound the maximum value of `columnName` used to decide partition stride + * @param numPartitions the number of partitions. the range `minValue`-`maxValue` will be split + * evenly into this many partitions + * @group specificdata + * @deprecated As of 1.4.0, replaced by `read().jdbc()`. + */ + @deprecated("use read.jdbc()", "1.4.0") + def jdbc( + url: String, + table: String, + columnName: String, + lowerBound: Long, + upperBound: Long, + numPartitions: Int): DataFrame = { + read.jdbc(url, table, columnName, lowerBound, upperBound, numPartitions, new Properties) + } + + /** + * Construct a [[DataFrame]] representing the database table accessible via JDBC URL + * url named table. The theParts parameter gives a list expressions + * suitable for inclusion in WHERE clauses; each one defines one partition + * of the [[DataFrame]]. + * + * @group specificdata + * @deprecated As of 1.4.0, replaced by `read().jdbc()`. + */ + @deprecated("use read.jdbc()", "1.4.0") + def jdbc(url: String, table: String, theParts: Array[String]): DataFrame = { + read.jdbc(url, table, theParts, new Properties) + } + + //////////////////////////////////////////////////////////////////////////// + //////////////////////////////////////////////////////////////////////////// + // End of deprecated methods + //////////////////////////////////////////////////////////////////////////// + //////////////////////////////////////////////////////////////////////////// + + + // Register a succesfully instantiatd context to the singleton. This should be at the end of + // the class definition so that the singleton is updated only if there is no exception in the + // construction of the instance. + SQLContext.setLastInstantiatedContext(self) } +/** + * This SQLContext object contains utility functions to create a singleton SQLContext instance, + * or to get the last created SQLContext instance. + */ +object SQLContext { + + private val INSTANTIATION_LOCK = new Object() + + /** + * Reference to the last created SQLContext. + */ + @transient private val lastInstantiatedContext = new AtomicReference[SQLContext]() + + /** + * Get the singleton SQLContext if it exists or create a new one using the given SparkContext. + * This function can be used to create a singleton SQLContext object that can be shared across + * the JVM. + */ + def getOrCreate(sparkContext: SparkContext): SQLContext = { + INSTANTIATION_LOCK.synchronized { + if (lastInstantiatedContext.get() == null) { + new SQLContext(sparkContext) + } + } + lastInstantiatedContext.get() + } + + private[sql] def clearLastInstantiatedContext(): Unit = { + INSTANTIATION_LOCK.synchronized { + lastInstantiatedContext.set(null) + } + } + private[sql] def setLastInstantiatedContext(sqlContext: SQLContext): Unit = { + INSTANTIATION_LOCK.synchronized { + lastInstantiatedContext.set(sqlContext) + } + } +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SparkSQLParser.scala b/sql/core/src/main/scala/org/apache/spark/sql/SparkSQLParser.scala index 6b1ae81972e4..e59fa6e16290 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/SparkSQLParser.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/SparkSQLParser.scala @@ -44,8 +44,8 @@ private[sql] class SparkSQLParser(fallback: String => LogicalPlan) extends Abstr private val pair: Parser[LogicalPlan] = (key ~ ("=".r ~> value).?).? ^^ { - case None => SetCommand(None, output) - case Some(k ~ v) => SetCommand(Some(k.trim -> v.map(_.trim)), output) + case None => SetCommand(None) + case Some(k ~ v) => SetCommand(Some(k.trim -> v.map(_.trim))) } def apply(input: String): LogicalPlan = parseAll(pair, input) match { @@ -54,15 +54,15 @@ private[sql] class SparkSQLParser(fallback: String => LogicalPlan) extends Abstr } } - protected val AS = Keyword("AS") - protected val CACHE = Keyword("CACHE") - protected val CLEAR = Keyword("CLEAR") - protected val IN = Keyword("IN") - protected val LAZY = Keyword("LAZY") - protected val SET = Keyword("SET") - protected val SHOW = Keyword("SHOW") - protected val TABLE = Keyword("TABLE") - protected val TABLES = Keyword("TABLES") + protected val AS = Keyword("AS") + protected val CACHE = Keyword("CACHE") + protected val CLEAR = Keyword("CLEAR") + protected val IN = Keyword("IN") + protected val LAZY = Keyword("LAZY") + protected val SET = Keyword("SET") + protected val SHOW = Keyword("SHOW") + protected val TABLE = Keyword("TABLE") + protected val TABLES = Keyword("TABLES") protected val UNCACHE = Keyword("UNCACHE") override protected lazy val start: Parser[LogicalPlan] = cache | uncache | set | show | others diff --git a/sql/core/src/main/scala/org/apache/spark/sql/UDFRegistration.scala b/sql/core/src/main/scala/org/apache/spark/sql/UDFRegistration.scala index dc3389c41bbf..d35d37d01719 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/UDFRegistration.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/UDFRegistration.scala @@ -20,17 +20,17 @@ package org.apache.spark.sql import java.util.{List => JList, Map => JMap} import scala.reflect.runtime.universe.TypeTag +import scala.util.Try import org.apache.spark.{Accumulator, Logging} import org.apache.spark.api.python.PythonBroadcast import org.apache.spark.broadcast.Broadcast import org.apache.spark.sql.api.java._ import org.apache.spark.sql.catalyst.ScalaReflection -import org.apache.spark.sql.catalyst.expressions.{Expression, ScalaUdf} +import org.apache.spark.sql.catalyst.expressions.{Expression, ScalaUDF} import org.apache.spark.sql.execution.PythonUDF import org.apache.spark.sql.types.DataType - /** * Functions for registering user-defined functions. Use [[SQLContext.udf]] to access this. * @@ -46,6 +46,7 @@ class UDFRegistration private[sql] (sqlContext: SQLContext) extends Logging { envVars: JMap[String, String], pythonIncludes: JList[String], pythonExec: String, + pythonVer: String, broadcastVars: JList[Broadcast[PythonBroadcast]], accumulator: Accumulator[JList[Array[Byte]]], stringDataType: String): Unit = { @@ -70,6 +71,7 @@ class UDFRegistration private[sql] (sqlContext: SQLContext) extends Logging { envVars, pythonIncludes, pythonExec, + pythonVer, broadcastVars, accumulator, dataType, @@ -85,6 +87,7 @@ class UDFRegistration private[sql] (sqlContext: SQLContext) extends Logging { (0 to 22).map { x => val types = (1 to x).foldRight("RT")((i, s) => {s"A$i, $s"}) val typeTags = (1 to x).map(i => s"A${i}: TypeTag").foldLeft("RT: TypeTag")(_ + ", " + _) + val inputTypes = (1 to x).foldRight("Nil")((i, s) => {s"ScalaReflection.schemaFor[A$i].dataType :: $s"}) println(s""" /** * Register a Scala closure of ${x} arguments as user-defined function (UDF). @@ -93,7 +96,8 @@ class UDFRegistration private[sql] (sqlContext: SQLContext) extends Logging { */ def register[$typeTags](name: String, func: Function$x[$types]): UserDefinedFunction = { val dataType = ScalaReflection.schemaFor[RT].dataType - def builder(e: Seq[Expression]) = ScalaUdf(func, dataType, e) + val inputTypes = Try($inputTypes).getOrElse(Nil) + def builder(e: Seq[Expression]) = ScalaUDF(func, dataType, e, inputTypes) functionRegistry.registerFunction(name, builder) UserDefinedFunction(func, dataType) }""") @@ -112,7 +116,7 @@ class UDFRegistration private[sql] (sqlContext: SQLContext) extends Logging { |def register(name: String, f: UDF$i[$extTypeArgs, _], returnType: DataType) = { | functionRegistry.registerFunction( | name, - | (e: Seq[Expression]) => ScalaUdf(f$anyCast.call($anyParams), returnType, e)) + | (e: Seq[Expression]) => ScalaUDF(f$anyCast.call($anyParams), returnType, e)) |}""".stripMargin) } */ @@ -124,7 +128,8 @@ class UDFRegistration private[sql] (sqlContext: SQLContext) extends Logging { */ def register[RT: TypeTag](name: String, func: Function0[RT]): UserDefinedFunction = { val dataType = ScalaReflection.schemaFor[RT].dataType - def builder(e: Seq[Expression]) = ScalaUdf(func, dataType, e) + val inputTypes = Try(Nil).getOrElse(Nil) + def builder(e: Seq[Expression]) = ScalaUDF(func, dataType, e, inputTypes) functionRegistry.registerFunction(name, builder) UserDefinedFunction(func, dataType) } @@ -136,7 +141,8 @@ class UDFRegistration private[sql] (sqlContext: SQLContext) extends Logging { */ def register[RT: TypeTag, A1: TypeTag](name: String, func: Function1[A1, RT]): UserDefinedFunction = { val dataType = ScalaReflection.schemaFor[RT].dataType - def builder(e: Seq[Expression]) = ScalaUdf(func, dataType, e) + val inputTypes = Try(ScalaReflection.schemaFor[A1].dataType :: Nil).getOrElse(Nil) + def builder(e: Seq[Expression]) = ScalaUDF(func, dataType, e, inputTypes) functionRegistry.registerFunction(name, builder) UserDefinedFunction(func, dataType) } @@ -148,7 +154,8 @@ class UDFRegistration private[sql] (sqlContext: SQLContext) extends Logging { */ def register[RT: TypeTag, A1: TypeTag, A2: TypeTag](name: String, func: Function2[A1, A2, RT]): UserDefinedFunction = { val dataType = ScalaReflection.schemaFor[RT].dataType - def builder(e: Seq[Expression]) = ScalaUdf(func, dataType, e) + val inputTypes = Try(ScalaReflection.schemaFor[A1].dataType :: ScalaReflection.schemaFor[A2].dataType :: Nil).getOrElse(Nil) + def builder(e: Seq[Expression]) = ScalaUDF(func, dataType, e, inputTypes) functionRegistry.registerFunction(name, builder) UserDefinedFunction(func, dataType) } @@ -160,7 +167,8 @@ class UDFRegistration private[sql] (sqlContext: SQLContext) extends Logging { */ def register[RT: TypeTag, A1: TypeTag, A2: TypeTag, A3: TypeTag](name: String, func: Function3[A1, A2, A3, RT]): UserDefinedFunction = { val dataType = ScalaReflection.schemaFor[RT].dataType - def builder(e: Seq[Expression]) = ScalaUdf(func, dataType, e) + val inputTypes = Try(ScalaReflection.schemaFor[A1].dataType :: ScalaReflection.schemaFor[A2].dataType :: ScalaReflection.schemaFor[A3].dataType :: Nil).getOrElse(Nil) + def builder(e: Seq[Expression]) = ScalaUDF(func, dataType, e, inputTypes) functionRegistry.registerFunction(name, builder) UserDefinedFunction(func, dataType) } @@ -172,7 +180,8 @@ class UDFRegistration private[sql] (sqlContext: SQLContext) extends Logging { */ def register[RT: TypeTag, A1: TypeTag, A2: TypeTag, A3: TypeTag, A4: TypeTag](name: String, func: Function4[A1, A2, A3, A4, RT]): UserDefinedFunction = { val dataType = ScalaReflection.schemaFor[RT].dataType - def builder(e: Seq[Expression]) = ScalaUdf(func, dataType, e) + val inputTypes = Try(ScalaReflection.schemaFor[A1].dataType :: ScalaReflection.schemaFor[A2].dataType :: ScalaReflection.schemaFor[A3].dataType :: ScalaReflection.schemaFor[A4].dataType :: Nil).getOrElse(Nil) + def builder(e: Seq[Expression]) = ScalaUDF(func, dataType, e, inputTypes) functionRegistry.registerFunction(name, builder) UserDefinedFunction(func, dataType) } @@ -184,7 +193,8 @@ class UDFRegistration private[sql] (sqlContext: SQLContext) extends Logging { */ def register[RT: TypeTag, A1: TypeTag, A2: TypeTag, A3: TypeTag, A4: TypeTag, A5: TypeTag](name: String, func: Function5[A1, A2, A3, A4, A5, RT]): UserDefinedFunction = { val dataType = ScalaReflection.schemaFor[RT].dataType - def builder(e: Seq[Expression]) = ScalaUdf(func, dataType, e) + val inputTypes = Try(ScalaReflection.schemaFor[A1].dataType :: ScalaReflection.schemaFor[A2].dataType :: ScalaReflection.schemaFor[A3].dataType :: ScalaReflection.schemaFor[A4].dataType :: ScalaReflection.schemaFor[A5].dataType :: Nil).getOrElse(Nil) + def builder(e: Seq[Expression]) = ScalaUDF(func, dataType, e, inputTypes) functionRegistry.registerFunction(name, builder) UserDefinedFunction(func, dataType) } @@ -196,7 +206,8 @@ class UDFRegistration private[sql] (sqlContext: SQLContext) extends Logging { */ def register[RT: TypeTag, A1: TypeTag, A2: TypeTag, A3: TypeTag, A4: TypeTag, A5: TypeTag, A6: TypeTag](name: String, func: Function6[A1, A2, A3, A4, A5, A6, RT]): UserDefinedFunction = { val dataType = ScalaReflection.schemaFor[RT].dataType - def builder(e: Seq[Expression]) = ScalaUdf(func, dataType, e) + val inputTypes = Try(ScalaReflection.schemaFor[A1].dataType :: ScalaReflection.schemaFor[A2].dataType :: ScalaReflection.schemaFor[A3].dataType :: ScalaReflection.schemaFor[A4].dataType :: ScalaReflection.schemaFor[A5].dataType :: ScalaReflection.schemaFor[A6].dataType :: Nil).getOrElse(Nil) + def builder(e: Seq[Expression]) = ScalaUDF(func, dataType, e, inputTypes) functionRegistry.registerFunction(name, builder) UserDefinedFunction(func, dataType) } @@ -208,7 +219,8 @@ class UDFRegistration private[sql] (sqlContext: SQLContext) extends Logging { */ def register[RT: TypeTag, A1: TypeTag, A2: TypeTag, A3: TypeTag, A4: TypeTag, A5: TypeTag, A6: TypeTag, A7: TypeTag](name: String, func: Function7[A1, A2, A3, A4, A5, A6, A7, RT]): UserDefinedFunction = { val dataType = ScalaReflection.schemaFor[RT].dataType - def builder(e: Seq[Expression]) = ScalaUdf(func, dataType, e) + val inputTypes = Try(ScalaReflection.schemaFor[A1].dataType :: ScalaReflection.schemaFor[A2].dataType :: ScalaReflection.schemaFor[A3].dataType :: ScalaReflection.schemaFor[A4].dataType :: ScalaReflection.schemaFor[A5].dataType :: ScalaReflection.schemaFor[A6].dataType :: ScalaReflection.schemaFor[A7].dataType :: Nil).getOrElse(Nil) + def builder(e: Seq[Expression]) = ScalaUDF(func, dataType, e, inputTypes) functionRegistry.registerFunction(name, builder) UserDefinedFunction(func, dataType) } @@ -220,7 +232,8 @@ class UDFRegistration private[sql] (sqlContext: SQLContext) extends Logging { */ def register[RT: TypeTag, A1: TypeTag, A2: TypeTag, A3: TypeTag, A4: TypeTag, A5: TypeTag, A6: TypeTag, A7: TypeTag, A8: TypeTag](name: String, func: Function8[A1, A2, A3, A4, A5, A6, A7, A8, RT]): UserDefinedFunction = { val dataType = ScalaReflection.schemaFor[RT].dataType - def builder(e: Seq[Expression]) = ScalaUdf(func, dataType, e) + val inputTypes = Try(ScalaReflection.schemaFor[A1].dataType :: ScalaReflection.schemaFor[A2].dataType :: ScalaReflection.schemaFor[A3].dataType :: ScalaReflection.schemaFor[A4].dataType :: ScalaReflection.schemaFor[A5].dataType :: ScalaReflection.schemaFor[A6].dataType :: ScalaReflection.schemaFor[A7].dataType :: ScalaReflection.schemaFor[A8].dataType :: Nil).getOrElse(Nil) + def builder(e: Seq[Expression]) = ScalaUDF(func, dataType, e, inputTypes) functionRegistry.registerFunction(name, builder) UserDefinedFunction(func, dataType) } @@ -232,7 +245,8 @@ class UDFRegistration private[sql] (sqlContext: SQLContext) extends Logging { */ def register[RT: TypeTag, A1: TypeTag, A2: TypeTag, A3: TypeTag, A4: TypeTag, A5: TypeTag, A6: TypeTag, A7: TypeTag, A8: TypeTag, A9: TypeTag](name: String, func: Function9[A1, A2, A3, A4, A5, A6, A7, A8, A9, RT]): UserDefinedFunction = { val dataType = ScalaReflection.schemaFor[RT].dataType - def builder(e: Seq[Expression]) = ScalaUdf(func, dataType, e) + val inputTypes = Try(ScalaReflection.schemaFor[A1].dataType :: ScalaReflection.schemaFor[A2].dataType :: ScalaReflection.schemaFor[A3].dataType :: ScalaReflection.schemaFor[A4].dataType :: ScalaReflection.schemaFor[A5].dataType :: ScalaReflection.schemaFor[A6].dataType :: ScalaReflection.schemaFor[A7].dataType :: ScalaReflection.schemaFor[A8].dataType :: ScalaReflection.schemaFor[A9].dataType :: Nil).getOrElse(Nil) + def builder(e: Seq[Expression]) = ScalaUDF(func, dataType, e, inputTypes) functionRegistry.registerFunction(name, builder) UserDefinedFunction(func, dataType) } @@ -244,7 +258,8 @@ class UDFRegistration private[sql] (sqlContext: SQLContext) extends Logging { */ def register[RT: TypeTag, A1: TypeTag, A2: TypeTag, A3: TypeTag, A4: TypeTag, A5: TypeTag, A6: TypeTag, A7: TypeTag, A8: TypeTag, A9: TypeTag, A10: TypeTag](name: String, func: Function10[A1, A2, A3, A4, A5, A6, A7, A8, A9, A10, RT]): UserDefinedFunction = { val dataType = ScalaReflection.schemaFor[RT].dataType - def builder(e: Seq[Expression]) = ScalaUdf(func, dataType, e) + val inputTypes = Try(ScalaReflection.schemaFor[A1].dataType :: ScalaReflection.schemaFor[A2].dataType :: ScalaReflection.schemaFor[A3].dataType :: ScalaReflection.schemaFor[A4].dataType :: ScalaReflection.schemaFor[A5].dataType :: ScalaReflection.schemaFor[A6].dataType :: ScalaReflection.schemaFor[A7].dataType :: ScalaReflection.schemaFor[A8].dataType :: ScalaReflection.schemaFor[A9].dataType :: ScalaReflection.schemaFor[A10].dataType :: Nil).getOrElse(Nil) + def builder(e: Seq[Expression]) = ScalaUDF(func, dataType, e, inputTypes) functionRegistry.registerFunction(name, builder) UserDefinedFunction(func, dataType) } @@ -256,7 +271,8 @@ class UDFRegistration private[sql] (sqlContext: SQLContext) extends Logging { */ def register[RT: TypeTag, A1: TypeTag, A2: TypeTag, A3: TypeTag, A4: TypeTag, A5: TypeTag, A6: TypeTag, A7: TypeTag, A8: TypeTag, A9: TypeTag, A10: TypeTag, A11: TypeTag](name: String, func: Function11[A1, A2, A3, A4, A5, A6, A7, A8, A9, A10, A11, RT]): UserDefinedFunction = { val dataType = ScalaReflection.schemaFor[RT].dataType - def builder(e: Seq[Expression]) = ScalaUdf(func, dataType, e) + val inputTypes = Try(ScalaReflection.schemaFor[A1].dataType :: ScalaReflection.schemaFor[A2].dataType :: ScalaReflection.schemaFor[A3].dataType :: ScalaReflection.schemaFor[A4].dataType :: ScalaReflection.schemaFor[A5].dataType :: ScalaReflection.schemaFor[A6].dataType :: ScalaReflection.schemaFor[A7].dataType :: ScalaReflection.schemaFor[A8].dataType :: ScalaReflection.schemaFor[A9].dataType :: ScalaReflection.schemaFor[A10].dataType :: ScalaReflection.schemaFor[A11].dataType :: Nil).getOrElse(Nil) + def builder(e: Seq[Expression]) = ScalaUDF(func, dataType, e, inputTypes) functionRegistry.registerFunction(name, builder) UserDefinedFunction(func, dataType) } @@ -268,7 +284,8 @@ class UDFRegistration private[sql] (sqlContext: SQLContext) extends Logging { */ def register[RT: TypeTag, A1: TypeTag, A2: TypeTag, A3: TypeTag, A4: TypeTag, A5: TypeTag, A6: TypeTag, A7: TypeTag, A8: TypeTag, A9: TypeTag, A10: TypeTag, A11: TypeTag, A12: TypeTag](name: String, func: Function12[A1, A2, A3, A4, A5, A6, A7, A8, A9, A10, A11, A12, RT]): UserDefinedFunction = { val dataType = ScalaReflection.schemaFor[RT].dataType - def builder(e: Seq[Expression]) = ScalaUdf(func, dataType, e) + val inputTypes = Try(ScalaReflection.schemaFor[A1].dataType :: ScalaReflection.schemaFor[A2].dataType :: ScalaReflection.schemaFor[A3].dataType :: ScalaReflection.schemaFor[A4].dataType :: ScalaReflection.schemaFor[A5].dataType :: ScalaReflection.schemaFor[A6].dataType :: ScalaReflection.schemaFor[A7].dataType :: ScalaReflection.schemaFor[A8].dataType :: ScalaReflection.schemaFor[A9].dataType :: ScalaReflection.schemaFor[A10].dataType :: ScalaReflection.schemaFor[A11].dataType :: ScalaReflection.schemaFor[A12].dataType :: Nil).getOrElse(Nil) + def builder(e: Seq[Expression]) = ScalaUDF(func, dataType, e, inputTypes) functionRegistry.registerFunction(name, builder) UserDefinedFunction(func, dataType) } @@ -280,7 +297,8 @@ class UDFRegistration private[sql] (sqlContext: SQLContext) extends Logging { */ def register[RT: TypeTag, A1: TypeTag, A2: TypeTag, A3: TypeTag, A4: TypeTag, A5: TypeTag, A6: TypeTag, A7: TypeTag, A8: TypeTag, A9: TypeTag, A10: TypeTag, A11: TypeTag, A12: TypeTag, A13: TypeTag](name: String, func: Function13[A1, A2, A3, A4, A5, A6, A7, A8, A9, A10, A11, A12, A13, RT]): UserDefinedFunction = { val dataType = ScalaReflection.schemaFor[RT].dataType - def builder(e: Seq[Expression]) = ScalaUdf(func, dataType, e) + val inputTypes = Try(ScalaReflection.schemaFor[A1].dataType :: ScalaReflection.schemaFor[A2].dataType :: ScalaReflection.schemaFor[A3].dataType :: ScalaReflection.schemaFor[A4].dataType :: ScalaReflection.schemaFor[A5].dataType :: ScalaReflection.schemaFor[A6].dataType :: ScalaReflection.schemaFor[A7].dataType :: ScalaReflection.schemaFor[A8].dataType :: ScalaReflection.schemaFor[A9].dataType :: ScalaReflection.schemaFor[A10].dataType :: ScalaReflection.schemaFor[A11].dataType :: ScalaReflection.schemaFor[A12].dataType :: ScalaReflection.schemaFor[A13].dataType :: Nil).getOrElse(Nil) + def builder(e: Seq[Expression]) = ScalaUDF(func, dataType, e, inputTypes) functionRegistry.registerFunction(name, builder) UserDefinedFunction(func, dataType) } @@ -292,7 +310,8 @@ class UDFRegistration private[sql] (sqlContext: SQLContext) extends Logging { */ def register[RT: TypeTag, A1: TypeTag, A2: TypeTag, A3: TypeTag, A4: TypeTag, A5: TypeTag, A6: TypeTag, A7: TypeTag, A8: TypeTag, A9: TypeTag, A10: TypeTag, A11: TypeTag, A12: TypeTag, A13: TypeTag, A14: TypeTag](name: String, func: Function14[A1, A2, A3, A4, A5, A6, A7, A8, A9, A10, A11, A12, A13, A14, RT]): UserDefinedFunction = { val dataType = ScalaReflection.schemaFor[RT].dataType - def builder(e: Seq[Expression]) = ScalaUdf(func, dataType, e) + val inputTypes = Try(ScalaReflection.schemaFor[A1].dataType :: ScalaReflection.schemaFor[A2].dataType :: ScalaReflection.schemaFor[A3].dataType :: ScalaReflection.schemaFor[A4].dataType :: ScalaReflection.schemaFor[A5].dataType :: ScalaReflection.schemaFor[A6].dataType :: ScalaReflection.schemaFor[A7].dataType :: ScalaReflection.schemaFor[A8].dataType :: ScalaReflection.schemaFor[A9].dataType :: ScalaReflection.schemaFor[A10].dataType :: ScalaReflection.schemaFor[A11].dataType :: ScalaReflection.schemaFor[A12].dataType :: ScalaReflection.schemaFor[A13].dataType :: ScalaReflection.schemaFor[A14].dataType :: Nil).getOrElse(Nil) + def builder(e: Seq[Expression]) = ScalaUDF(func, dataType, e, inputTypes) functionRegistry.registerFunction(name, builder) UserDefinedFunction(func, dataType) } @@ -304,7 +323,8 @@ class UDFRegistration private[sql] (sqlContext: SQLContext) extends Logging { */ def register[RT: TypeTag, A1: TypeTag, A2: TypeTag, A3: TypeTag, A4: TypeTag, A5: TypeTag, A6: TypeTag, A7: TypeTag, A8: TypeTag, A9: TypeTag, A10: TypeTag, A11: TypeTag, A12: TypeTag, A13: TypeTag, A14: TypeTag, A15: TypeTag](name: String, func: Function15[A1, A2, A3, A4, A5, A6, A7, A8, A9, A10, A11, A12, A13, A14, A15, RT]): UserDefinedFunction = { val dataType = ScalaReflection.schemaFor[RT].dataType - def builder(e: Seq[Expression]) = ScalaUdf(func, dataType, e) + val inputTypes = Try(ScalaReflection.schemaFor[A1].dataType :: ScalaReflection.schemaFor[A2].dataType :: ScalaReflection.schemaFor[A3].dataType :: ScalaReflection.schemaFor[A4].dataType :: ScalaReflection.schemaFor[A5].dataType :: ScalaReflection.schemaFor[A6].dataType :: ScalaReflection.schemaFor[A7].dataType :: ScalaReflection.schemaFor[A8].dataType :: ScalaReflection.schemaFor[A9].dataType :: ScalaReflection.schemaFor[A10].dataType :: ScalaReflection.schemaFor[A11].dataType :: ScalaReflection.schemaFor[A12].dataType :: ScalaReflection.schemaFor[A13].dataType :: ScalaReflection.schemaFor[A14].dataType :: ScalaReflection.schemaFor[A15].dataType :: Nil).getOrElse(Nil) + def builder(e: Seq[Expression]) = ScalaUDF(func, dataType, e, inputTypes) functionRegistry.registerFunction(name, builder) UserDefinedFunction(func, dataType) } @@ -316,7 +336,8 @@ class UDFRegistration private[sql] (sqlContext: SQLContext) extends Logging { */ def register[RT: TypeTag, A1: TypeTag, A2: TypeTag, A3: TypeTag, A4: TypeTag, A5: TypeTag, A6: TypeTag, A7: TypeTag, A8: TypeTag, A9: TypeTag, A10: TypeTag, A11: TypeTag, A12: TypeTag, A13: TypeTag, A14: TypeTag, A15: TypeTag, A16: TypeTag](name: String, func: Function16[A1, A2, A3, A4, A5, A6, A7, A8, A9, A10, A11, A12, A13, A14, A15, A16, RT]): UserDefinedFunction = { val dataType = ScalaReflection.schemaFor[RT].dataType - def builder(e: Seq[Expression]) = ScalaUdf(func, dataType, e) + val inputTypes = Try(ScalaReflection.schemaFor[A1].dataType :: ScalaReflection.schemaFor[A2].dataType :: ScalaReflection.schemaFor[A3].dataType :: ScalaReflection.schemaFor[A4].dataType :: ScalaReflection.schemaFor[A5].dataType :: ScalaReflection.schemaFor[A6].dataType :: ScalaReflection.schemaFor[A7].dataType :: ScalaReflection.schemaFor[A8].dataType :: ScalaReflection.schemaFor[A9].dataType :: ScalaReflection.schemaFor[A10].dataType :: ScalaReflection.schemaFor[A11].dataType :: ScalaReflection.schemaFor[A12].dataType :: ScalaReflection.schemaFor[A13].dataType :: ScalaReflection.schemaFor[A14].dataType :: ScalaReflection.schemaFor[A15].dataType :: ScalaReflection.schemaFor[A16].dataType :: Nil).getOrElse(Nil) + def builder(e: Seq[Expression]) = ScalaUDF(func, dataType, e, inputTypes) functionRegistry.registerFunction(name, builder) UserDefinedFunction(func, dataType) } @@ -328,7 +349,8 @@ class UDFRegistration private[sql] (sqlContext: SQLContext) extends Logging { */ def register[RT: TypeTag, A1: TypeTag, A2: TypeTag, A3: TypeTag, A4: TypeTag, A5: TypeTag, A6: TypeTag, A7: TypeTag, A8: TypeTag, A9: TypeTag, A10: TypeTag, A11: TypeTag, A12: TypeTag, A13: TypeTag, A14: TypeTag, A15: TypeTag, A16: TypeTag, A17: TypeTag](name: String, func: Function17[A1, A2, A3, A4, A5, A6, A7, A8, A9, A10, A11, A12, A13, A14, A15, A16, A17, RT]): UserDefinedFunction = { val dataType = ScalaReflection.schemaFor[RT].dataType - def builder(e: Seq[Expression]) = ScalaUdf(func, dataType, e) + val inputTypes = Try(ScalaReflection.schemaFor[A1].dataType :: ScalaReflection.schemaFor[A2].dataType :: ScalaReflection.schemaFor[A3].dataType :: ScalaReflection.schemaFor[A4].dataType :: ScalaReflection.schemaFor[A5].dataType :: ScalaReflection.schemaFor[A6].dataType :: ScalaReflection.schemaFor[A7].dataType :: ScalaReflection.schemaFor[A8].dataType :: ScalaReflection.schemaFor[A9].dataType :: ScalaReflection.schemaFor[A10].dataType :: ScalaReflection.schemaFor[A11].dataType :: ScalaReflection.schemaFor[A12].dataType :: ScalaReflection.schemaFor[A13].dataType :: ScalaReflection.schemaFor[A14].dataType :: ScalaReflection.schemaFor[A15].dataType :: ScalaReflection.schemaFor[A16].dataType :: ScalaReflection.schemaFor[A17].dataType :: Nil).getOrElse(Nil) + def builder(e: Seq[Expression]) = ScalaUDF(func, dataType, e, inputTypes) functionRegistry.registerFunction(name, builder) UserDefinedFunction(func, dataType) } @@ -340,7 +362,8 @@ class UDFRegistration private[sql] (sqlContext: SQLContext) extends Logging { */ def register[RT: TypeTag, A1: TypeTag, A2: TypeTag, A3: TypeTag, A4: TypeTag, A5: TypeTag, A6: TypeTag, A7: TypeTag, A8: TypeTag, A9: TypeTag, A10: TypeTag, A11: TypeTag, A12: TypeTag, A13: TypeTag, A14: TypeTag, A15: TypeTag, A16: TypeTag, A17: TypeTag, A18: TypeTag](name: String, func: Function18[A1, A2, A3, A4, A5, A6, A7, A8, A9, A10, A11, A12, A13, A14, A15, A16, A17, A18, RT]): UserDefinedFunction = { val dataType = ScalaReflection.schemaFor[RT].dataType - def builder(e: Seq[Expression]) = ScalaUdf(func, dataType, e) + val inputTypes = Try(ScalaReflection.schemaFor[A1].dataType :: ScalaReflection.schemaFor[A2].dataType :: ScalaReflection.schemaFor[A3].dataType :: ScalaReflection.schemaFor[A4].dataType :: ScalaReflection.schemaFor[A5].dataType :: ScalaReflection.schemaFor[A6].dataType :: ScalaReflection.schemaFor[A7].dataType :: ScalaReflection.schemaFor[A8].dataType :: ScalaReflection.schemaFor[A9].dataType :: ScalaReflection.schemaFor[A10].dataType :: ScalaReflection.schemaFor[A11].dataType :: ScalaReflection.schemaFor[A12].dataType :: ScalaReflection.schemaFor[A13].dataType :: ScalaReflection.schemaFor[A14].dataType :: ScalaReflection.schemaFor[A15].dataType :: ScalaReflection.schemaFor[A16].dataType :: ScalaReflection.schemaFor[A17].dataType :: ScalaReflection.schemaFor[A18].dataType :: Nil).getOrElse(Nil) + def builder(e: Seq[Expression]) = ScalaUDF(func, dataType, e, inputTypes) functionRegistry.registerFunction(name, builder) UserDefinedFunction(func, dataType) } @@ -352,7 +375,8 @@ class UDFRegistration private[sql] (sqlContext: SQLContext) extends Logging { */ def register[RT: TypeTag, A1: TypeTag, A2: TypeTag, A3: TypeTag, A4: TypeTag, A5: TypeTag, A6: TypeTag, A7: TypeTag, A8: TypeTag, A9: TypeTag, A10: TypeTag, A11: TypeTag, A12: TypeTag, A13: TypeTag, A14: TypeTag, A15: TypeTag, A16: TypeTag, A17: TypeTag, A18: TypeTag, A19: TypeTag](name: String, func: Function19[A1, A2, A3, A4, A5, A6, A7, A8, A9, A10, A11, A12, A13, A14, A15, A16, A17, A18, A19, RT]): UserDefinedFunction = { val dataType = ScalaReflection.schemaFor[RT].dataType - def builder(e: Seq[Expression]) = ScalaUdf(func, dataType, e) + val inputTypes = Try(ScalaReflection.schemaFor[A1].dataType :: ScalaReflection.schemaFor[A2].dataType :: ScalaReflection.schemaFor[A3].dataType :: ScalaReflection.schemaFor[A4].dataType :: ScalaReflection.schemaFor[A5].dataType :: ScalaReflection.schemaFor[A6].dataType :: ScalaReflection.schemaFor[A7].dataType :: ScalaReflection.schemaFor[A8].dataType :: ScalaReflection.schemaFor[A9].dataType :: ScalaReflection.schemaFor[A10].dataType :: ScalaReflection.schemaFor[A11].dataType :: ScalaReflection.schemaFor[A12].dataType :: ScalaReflection.schemaFor[A13].dataType :: ScalaReflection.schemaFor[A14].dataType :: ScalaReflection.schemaFor[A15].dataType :: ScalaReflection.schemaFor[A16].dataType :: ScalaReflection.schemaFor[A17].dataType :: ScalaReflection.schemaFor[A18].dataType :: ScalaReflection.schemaFor[A19].dataType :: Nil).getOrElse(Nil) + def builder(e: Seq[Expression]) = ScalaUDF(func, dataType, e, inputTypes) functionRegistry.registerFunction(name, builder) UserDefinedFunction(func, dataType) } @@ -364,7 +388,8 @@ class UDFRegistration private[sql] (sqlContext: SQLContext) extends Logging { */ def register[RT: TypeTag, A1: TypeTag, A2: TypeTag, A3: TypeTag, A4: TypeTag, A5: TypeTag, A6: TypeTag, A7: TypeTag, A8: TypeTag, A9: TypeTag, A10: TypeTag, A11: TypeTag, A12: TypeTag, A13: TypeTag, A14: TypeTag, A15: TypeTag, A16: TypeTag, A17: TypeTag, A18: TypeTag, A19: TypeTag, A20: TypeTag](name: String, func: Function20[A1, A2, A3, A4, A5, A6, A7, A8, A9, A10, A11, A12, A13, A14, A15, A16, A17, A18, A19, A20, RT]): UserDefinedFunction = { val dataType = ScalaReflection.schemaFor[RT].dataType - def builder(e: Seq[Expression]) = ScalaUdf(func, dataType, e) + val inputTypes = Try(ScalaReflection.schemaFor[A1].dataType :: ScalaReflection.schemaFor[A2].dataType :: ScalaReflection.schemaFor[A3].dataType :: ScalaReflection.schemaFor[A4].dataType :: ScalaReflection.schemaFor[A5].dataType :: ScalaReflection.schemaFor[A6].dataType :: ScalaReflection.schemaFor[A7].dataType :: ScalaReflection.schemaFor[A8].dataType :: ScalaReflection.schemaFor[A9].dataType :: ScalaReflection.schemaFor[A10].dataType :: ScalaReflection.schemaFor[A11].dataType :: ScalaReflection.schemaFor[A12].dataType :: ScalaReflection.schemaFor[A13].dataType :: ScalaReflection.schemaFor[A14].dataType :: ScalaReflection.schemaFor[A15].dataType :: ScalaReflection.schemaFor[A16].dataType :: ScalaReflection.schemaFor[A17].dataType :: ScalaReflection.schemaFor[A18].dataType :: ScalaReflection.schemaFor[A19].dataType :: ScalaReflection.schemaFor[A20].dataType :: Nil).getOrElse(Nil) + def builder(e: Seq[Expression]) = ScalaUDF(func, dataType, e, inputTypes) functionRegistry.registerFunction(name, builder) UserDefinedFunction(func, dataType) } @@ -376,7 +401,8 @@ class UDFRegistration private[sql] (sqlContext: SQLContext) extends Logging { */ def register[RT: TypeTag, A1: TypeTag, A2: TypeTag, A3: TypeTag, A4: TypeTag, A5: TypeTag, A6: TypeTag, A7: TypeTag, A8: TypeTag, A9: TypeTag, A10: TypeTag, A11: TypeTag, A12: TypeTag, A13: TypeTag, A14: TypeTag, A15: TypeTag, A16: TypeTag, A17: TypeTag, A18: TypeTag, A19: TypeTag, A20: TypeTag, A21: TypeTag](name: String, func: Function21[A1, A2, A3, A4, A5, A6, A7, A8, A9, A10, A11, A12, A13, A14, A15, A16, A17, A18, A19, A20, A21, RT]): UserDefinedFunction = { val dataType = ScalaReflection.schemaFor[RT].dataType - def builder(e: Seq[Expression]) = ScalaUdf(func, dataType, e) + val inputTypes = Try(ScalaReflection.schemaFor[A1].dataType :: ScalaReflection.schemaFor[A2].dataType :: ScalaReflection.schemaFor[A3].dataType :: ScalaReflection.schemaFor[A4].dataType :: ScalaReflection.schemaFor[A5].dataType :: ScalaReflection.schemaFor[A6].dataType :: ScalaReflection.schemaFor[A7].dataType :: ScalaReflection.schemaFor[A8].dataType :: ScalaReflection.schemaFor[A9].dataType :: ScalaReflection.schemaFor[A10].dataType :: ScalaReflection.schemaFor[A11].dataType :: ScalaReflection.schemaFor[A12].dataType :: ScalaReflection.schemaFor[A13].dataType :: ScalaReflection.schemaFor[A14].dataType :: ScalaReflection.schemaFor[A15].dataType :: ScalaReflection.schemaFor[A16].dataType :: ScalaReflection.schemaFor[A17].dataType :: ScalaReflection.schemaFor[A18].dataType :: ScalaReflection.schemaFor[A19].dataType :: ScalaReflection.schemaFor[A20].dataType :: ScalaReflection.schemaFor[A21].dataType :: Nil).getOrElse(Nil) + def builder(e: Seq[Expression]) = ScalaUDF(func, dataType, e, inputTypes) functionRegistry.registerFunction(name, builder) UserDefinedFunction(func, dataType) } @@ -388,7 +414,8 @@ class UDFRegistration private[sql] (sqlContext: SQLContext) extends Logging { */ def register[RT: TypeTag, A1: TypeTag, A2: TypeTag, A3: TypeTag, A4: TypeTag, A5: TypeTag, A6: TypeTag, A7: TypeTag, A8: TypeTag, A9: TypeTag, A10: TypeTag, A11: TypeTag, A12: TypeTag, A13: TypeTag, A14: TypeTag, A15: TypeTag, A16: TypeTag, A17: TypeTag, A18: TypeTag, A19: TypeTag, A20: TypeTag, A21: TypeTag, A22: TypeTag](name: String, func: Function22[A1, A2, A3, A4, A5, A6, A7, A8, A9, A10, A11, A12, A13, A14, A15, A16, A17, A18, A19, A20, A21, A22, RT]): UserDefinedFunction = { val dataType = ScalaReflection.schemaFor[RT].dataType - def builder(e: Seq[Expression]) = ScalaUdf(func, dataType, e) + val inputTypes = Try(ScalaReflection.schemaFor[A1].dataType :: ScalaReflection.schemaFor[A2].dataType :: ScalaReflection.schemaFor[A3].dataType :: ScalaReflection.schemaFor[A4].dataType :: ScalaReflection.schemaFor[A5].dataType :: ScalaReflection.schemaFor[A6].dataType :: ScalaReflection.schemaFor[A7].dataType :: ScalaReflection.schemaFor[A8].dataType :: ScalaReflection.schemaFor[A9].dataType :: ScalaReflection.schemaFor[A10].dataType :: ScalaReflection.schemaFor[A11].dataType :: ScalaReflection.schemaFor[A12].dataType :: ScalaReflection.schemaFor[A13].dataType :: ScalaReflection.schemaFor[A14].dataType :: ScalaReflection.schemaFor[A15].dataType :: ScalaReflection.schemaFor[A16].dataType :: ScalaReflection.schemaFor[A17].dataType :: ScalaReflection.schemaFor[A18].dataType :: ScalaReflection.schemaFor[A19].dataType :: ScalaReflection.schemaFor[A20].dataType :: ScalaReflection.schemaFor[A21].dataType :: ScalaReflection.schemaFor[A22].dataType :: Nil).getOrElse(Nil) + def builder(e: Seq[Expression]) = ScalaUDF(func, dataType, e, inputTypes) functionRegistry.registerFunction(name, builder) UserDefinedFunction(func, dataType) } @@ -403,7 +430,7 @@ class UDFRegistration private[sql] (sqlContext: SQLContext) extends Logging { def register(name: String, f: UDF1[_, _], returnType: DataType) = { functionRegistry.registerFunction( name, - (e: Seq[Expression]) => ScalaUdf(f.asInstanceOf[UDF1[Any, Any]].call(_: Any), returnType, e)) + (e: Seq[Expression]) => ScalaUDF(f.asInstanceOf[UDF1[Any, Any]].call(_: Any), returnType, e)) } /** @@ -413,7 +440,7 @@ class UDFRegistration private[sql] (sqlContext: SQLContext) extends Logging { def register(name: String, f: UDF2[_, _, _], returnType: DataType) = { functionRegistry.registerFunction( name, - (e: Seq[Expression]) => ScalaUdf(f.asInstanceOf[UDF2[Any, Any, Any]].call(_: Any, _: Any), returnType, e)) + (e: Seq[Expression]) => ScalaUDF(f.asInstanceOf[UDF2[Any, Any, Any]].call(_: Any, _: Any), returnType, e)) } /** @@ -423,7 +450,7 @@ class UDFRegistration private[sql] (sqlContext: SQLContext) extends Logging { def register(name: String, f: UDF3[_, _, _, _], returnType: DataType) = { functionRegistry.registerFunction( name, - (e: Seq[Expression]) => ScalaUdf(f.asInstanceOf[UDF3[Any, Any, Any, Any]].call(_: Any, _: Any, _: Any), returnType, e)) + (e: Seq[Expression]) => ScalaUDF(f.asInstanceOf[UDF3[Any, Any, Any, Any]].call(_: Any, _: Any, _: Any), returnType, e)) } /** @@ -433,7 +460,7 @@ class UDFRegistration private[sql] (sqlContext: SQLContext) extends Logging { def register(name: String, f: UDF4[_, _, _, _, _], returnType: DataType) = { functionRegistry.registerFunction( name, - (e: Seq[Expression]) => ScalaUdf(f.asInstanceOf[UDF4[Any, Any, Any, Any, Any]].call(_: Any, _: Any, _: Any, _: Any), returnType, e)) + (e: Seq[Expression]) => ScalaUDF(f.asInstanceOf[UDF4[Any, Any, Any, Any, Any]].call(_: Any, _: Any, _: Any, _: Any), returnType, e)) } /** @@ -443,7 +470,7 @@ class UDFRegistration private[sql] (sqlContext: SQLContext) extends Logging { def register(name: String, f: UDF5[_, _, _, _, _, _], returnType: DataType) = { functionRegistry.registerFunction( name, - (e: Seq[Expression]) => ScalaUdf(f.asInstanceOf[UDF5[Any, Any, Any, Any, Any, Any]].call(_: Any, _: Any, _: Any, _: Any, _: Any), returnType, e)) + (e: Seq[Expression]) => ScalaUDF(f.asInstanceOf[UDF5[Any, Any, Any, Any, Any, Any]].call(_: Any, _: Any, _: Any, _: Any, _: Any), returnType, e)) } /** @@ -453,7 +480,7 @@ class UDFRegistration private[sql] (sqlContext: SQLContext) extends Logging { def register(name: String, f: UDF6[_, _, _, _, _, _, _], returnType: DataType) = { functionRegistry.registerFunction( name, - (e: Seq[Expression]) => ScalaUdf(f.asInstanceOf[UDF6[Any, Any, Any, Any, Any, Any, Any]].call(_: Any, _: Any, _: Any, _: Any, _: Any, _: Any), returnType, e)) + (e: Seq[Expression]) => ScalaUDF(f.asInstanceOf[UDF6[Any, Any, Any, Any, Any, Any, Any]].call(_: Any, _: Any, _: Any, _: Any, _: Any, _: Any), returnType, e)) } /** @@ -463,7 +490,7 @@ class UDFRegistration private[sql] (sqlContext: SQLContext) extends Logging { def register(name: String, f: UDF7[_, _, _, _, _, _, _, _], returnType: DataType) = { functionRegistry.registerFunction( name, - (e: Seq[Expression]) => ScalaUdf(f.asInstanceOf[UDF7[Any, Any, Any, Any, Any, Any, Any, Any]].call(_: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any), returnType, e)) + (e: Seq[Expression]) => ScalaUDF(f.asInstanceOf[UDF7[Any, Any, Any, Any, Any, Any, Any, Any]].call(_: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any), returnType, e)) } /** @@ -473,7 +500,7 @@ class UDFRegistration private[sql] (sqlContext: SQLContext) extends Logging { def register(name: String, f: UDF8[_, _, _, _, _, _, _, _, _], returnType: DataType) = { functionRegistry.registerFunction( name, - (e: Seq[Expression]) => ScalaUdf(f.asInstanceOf[UDF8[Any, Any, Any, Any, Any, Any, Any, Any, Any]].call(_: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any), returnType, e)) + (e: Seq[Expression]) => ScalaUDF(f.asInstanceOf[UDF8[Any, Any, Any, Any, Any, Any, Any, Any, Any]].call(_: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any), returnType, e)) } /** @@ -483,7 +510,7 @@ class UDFRegistration private[sql] (sqlContext: SQLContext) extends Logging { def register(name: String, f: UDF9[_, _, _, _, _, _, _, _, _, _], returnType: DataType) = { functionRegistry.registerFunction( name, - (e: Seq[Expression]) => ScalaUdf(f.asInstanceOf[UDF9[Any, Any, Any, Any, Any, Any, Any, Any, Any, Any]].call(_: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any), returnType, e)) + (e: Seq[Expression]) => ScalaUDF(f.asInstanceOf[UDF9[Any, Any, Any, Any, Any, Any, Any, Any, Any, Any]].call(_: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any), returnType, e)) } /** @@ -493,7 +520,7 @@ class UDFRegistration private[sql] (sqlContext: SQLContext) extends Logging { def register(name: String, f: UDF10[_, _, _, _, _, _, _, _, _, _, _], returnType: DataType) = { functionRegistry.registerFunction( name, - (e: Seq[Expression]) => ScalaUdf(f.asInstanceOf[UDF10[Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any]].call(_: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any), returnType, e)) + (e: Seq[Expression]) => ScalaUDF(f.asInstanceOf[UDF10[Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any]].call(_: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any), returnType, e)) } /** @@ -503,7 +530,7 @@ class UDFRegistration private[sql] (sqlContext: SQLContext) extends Logging { def register(name: String, f: UDF11[_, _, _, _, _, _, _, _, _, _, _, _], returnType: DataType) = { functionRegistry.registerFunction( name, - (e: Seq[Expression]) => ScalaUdf(f.asInstanceOf[UDF11[Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any]].call(_: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any), returnType, e)) + (e: Seq[Expression]) => ScalaUDF(f.asInstanceOf[UDF11[Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any]].call(_: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any), returnType, e)) } /** @@ -513,7 +540,7 @@ class UDFRegistration private[sql] (sqlContext: SQLContext) extends Logging { def register(name: String, f: UDF12[_, _, _, _, _, _, _, _, _, _, _, _, _], returnType: DataType) = { functionRegistry.registerFunction( name, - (e: Seq[Expression]) => ScalaUdf(f.asInstanceOf[UDF12[Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any]].call(_: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any), returnType, e)) + (e: Seq[Expression]) => ScalaUDF(f.asInstanceOf[UDF12[Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any]].call(_: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any), returnType, e)) } /** @@ -523,7 +550,7 @@ class UDFRegistration private[sql] (sqlContext: SQLContext) extends Logging { def register(name: String, f: UDF13[_, _, _, _, _, _, _, _, _, _, _, _, _, _], returnType: DataType) = { functionRegistry.registerFunction( name, - (e: Seq[Expression]) => ScalaUdf(f.asInstanceOf[UDF13[Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any]].call(_: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any), returnType, e)) + (e: Seq[Expression]) => ScalaUDF(f.asInstanceOf[UDF13[Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any]].call(_: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any), returnType, e)) } /** @@ -533,7 +560,7 @@ class UDFRegistration private[sql] (sqlContext: SQLContext) extends Logging { def register(name: String, f: UDF14[_, _, _, _, _, _, _, _, _, _, _, _, _, _, _], returnType: DataType) = { functionRegistry.registerFunction( name, - (e: Seq[Expression]) => ScalaUdf(f.asInstanceOf[UDF14[Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any]].call(_: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any), returnType, e)) + (e: Seq[Expression]) => ScalaUDF(f.asInstanceOf[UDF14[Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any]].call(_: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any), returnType, e)) } /** @@ -543,7 +570,7 @@ class UDFRegistration private[sql] (sqlContext: SQLContext) extends Logging { def register(name: String, f: UDF15[_, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _], returnType: DataType) = { functionRegistry.registerFunction( name, - (e: Seq[Expression]) => ScalaUdf(f.asInstanceOf[UDF15[Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any]].call(_: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any), returnType, e)) + (e: Seq[Expression]) => ScalaUDF(f.asInstanceOf[UDF15[Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any]].call(_: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any), returnType, e)) } /** @@ -553,7 +580,7 @@ class UDFRegistration private[sql] (sqlContext: SQLContext) extends Logging { def register(name: String, f: UDF16[_, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _], returnType: DataType) = { functionRegistry.registerFunction( name, - (e: Seq[Expression]) => ScalaUdf(f.asInstanceOf[UDF16[Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any]].call(_: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any), returnType, e)) + (e: Seq[Expression]) => ScalaUDF(f.asInstanceOf[UDF16[Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any]].call(_: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any), returnType, e)) } /** @@ -563,7 +590,7 @@ class UDFRegistration private[sql] (sqlContext: SQLContext) extends Logging { def register(name: String, f: UDF17[_, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _], returnType: DataType) = { functionRegistry.registerFunction( name, - (e: Seq[Expression]) => ScalaUdf(f.asInstanceOf[UDF17[Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any]].call(_: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any), returnType, e)) + (e: Seq[Expression]) => ScalaUDF(f.asInstanceOf[UDF17[Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any]].call(_: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any), returnType, e)) } /** @@ -573,7 +600,7 @@ class UDFRegistration private[sql] (sqlContext: SQLContext) extends Logging { def register(name: String, f: UDF18[_, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _], returnType: DataType) = { functionRegistry.registerFunction( name, - (e: Seq[Expression]) => ScalaUdf(f.asInstanceOf[UDF18[Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any]].call(_: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any), returnType, e)) + (e: Seq[Expression]) => ScalaUDF(f.asInstanceOf[UDF18[Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any]].call(_: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any), returnType, e)) } /** @@ -583,7 +610,7 @@ class UDFRegistration private[sql] (sqlContext: SQLContext) extends Logging { def register(name: String, f: UDF19[_, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _], returnType: DataType) = { functionRegistry.registerFunction( name, - (e: Seq[Expression]) => ScalaUdf(f.asInstanceOf[UDF19[Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any]].call(_: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any), returnType, e)) + (e: Seq[Expression]) => ScalaUDF(f.asInstanceOf[UDF19[Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any]].call(_: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any), returnType, e)) } /** @@ -593,7 +620,7 @@ class UDFRegistration private[sql] (sqlContext: SQLContext) extends Logging { def register(name: String, f: UDF20[_, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _], returnType: DataType) = { functionRegistry.registerFunction( name, - (e: Seq[Expression]) => ScalaUdf(f.asInstanceOf[UDF20[Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any]].call(_: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any), returnType, e)) + (e: Seq[Expression]) => ScalaUDF(f.asInstanceOf[UDF20[Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any]].call(_: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any), returnType, e)) } /** @@ -603,7 +630,7 @@ class UDFRegistration private[sql] (sqlContext: SQLContext) extends Logging { def register(name: String, f: UDF21[_, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _], returnType: DataType) = { functionRegistry.registerFunction( name, - (e: Seq[Expression]) => ScalaUdf(f.asInstanceOf[UDF21[Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any]].call(_: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any), returnType, e)) + (e: Seq[Expression]) => ScalaUDF(f.asInstanceOf[UDF21[Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any]].call(_: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any), returnType, e)) } /** @@ -613,7 +640,7 @@ class UDFRegistration private[sql] (sqlContext: SQLContext) extends Logging { def register(name: String, f: UDF22[_, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _], returnType: DataType) = { functionRegistry.registerFunction( name, - (e: Seq[Expression]) => ScalaUdf(f.asInstanceOf[UDF22[Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any]].call(_: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any), returnType, e)) + (e: Seq[Expression]) => ScalaUDF(f.asInstanceOf[UDF22[Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any]].call(_: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any), returnType, e)) } // scalastyle:on diff --git a/sql/core/src/main/scala/org/apache/spark/sql/UserDefinedFunction.scala b/sql/core/src/main/scala/org/apache/spark/sql/UserDefinedFunction.scala index 505ab1301ec9..b14e00ab9b16 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/UserDefinedFunction.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/UserDefinedFunction.scala @@ -23,7 +23,7 @@ import org.apache.spark.Accumulator import org.apache.spark.annotation.Experimental import org.apache.spark.api.python.PythonBroadcast import org.apache.spark.broadcast.Broadcast -import org.apache.spark.sql.catalyst.expressions.ScalaUdf +import org.apache.spark.sql.catalyst.expressions.ScalaUDF import org.apache.spark.sql.execution.PythonUDF import org.apache.spark.sql.types.DataType @@ -41,10 +41,13 @@ import org.apache.spark.sql.types.DataType * @since 1.3.0 */ @Experimental -case class UserDefinedFunction protected[sql] (f: AnyRef, dataType: DataType) { +case class UserDefinedFunction protected[sql] ( + f: AnyRef, + dataType: DataType, + inputTypes: Seq[DataType] = Nil) { def apply(exprs: Column*): Column = { - Column(ScalaUdf(f, dataType, exprs.map(_.expr))) + Column(ScalaUDF(f, dataType, exprs.map(_.expr), inputTypes)) } } @@ -58,14 +61,15 @@ private[sql] case class UserDefinedPythonFunction( envVars: JMap[String, String], pythonIncludes: JList[String], pythonExec: String, + pythonVer: String, broadcastVars: JList[Broadcast[PythonBroadcast]], accumulator: Accumulator[JList[Array[Byte]]], dataType: DataType) { /** Returns a [[Column]] that will evaluate to calling this UDF with the given input. */ def apply(exprs: Column*): Column = { - val udf = PythonUDF(name, command, envVars, pythonIncludes, pythonExec, broadcastVars, - accumulator, dataType, exprs.map(_.expr)) + val udf = PythonUDF(name, command, envVars, pythonIncludes, pythonExec, pythonVer, + broadcastVars, accumulator, dataType, exprs.map(_.expr)) Column(udf) } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/api/r/SQLUtils.scala b/sql/core/src/main/scala/org/apache/spark/sql/api/r/SQLUtils.scala index 423ecdff5804..43b62f0e822f 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/api/r/SQLUtils.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/api/r/SQLUtils.scala @@ -106,7 +106,7 @@ private[r] object SQLUtils { dfCols.map { col => colToRBytes(col) - } + } } def convertRowsToColumns(localDF: Array[Row], numCols: Int): Array[Array[Any]] = { @@ -121,7 +121,7 @@ private[r] object SQLUtils { val numRows = col.length val bos = new ByteArrayOutputStream() val dos = new DataOutputStream(bos) - + SerDe.writeInt(dos, numRows) col.map { item => @@ -139,4 +139,19 @@ private[r] object SQLUtils { case "ignore" => SaveMode.Ignore } } + + def loadDF( + sqlContext: SQLContext, + source: String, + options: java.util.Map[String, String]): DataFrame = { + sqlContext.read.format(source).options(options).load() + } + + def loadDF( + sqlContext: SQLContext, + source: String, + schema: StructType, + options: java.util.Map[String, String]): DataFrame = { + sqlContext.read.format(source).schema(schema).options(options).load() + } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/columnar/ColumnAccessor.scala b/sql/core/src/main/scala/org/apache/spark/sql/columnar/ColumnAccessor.scala index 64449b2659b4..931469bed634 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/columnar/ColumnAccessor.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/columnar/ColumnAccessor.scala @@ -71,44 +71,44 @@ private[sql] abstract class NativeColumnAccessor[T <: AtomicType]( private[sql] class BooleanColumnAccessor(buffer: ByteBuffer) extends NativeColumnAccessor(buffer, BOOLEAN) -private[sql] class IntColumnAccessor(buffer: ByteBuffer) - extends NativeColumnAccessor(buffer, INT) +private[sql] class ByteColumnAccessor(buffer: ByteBuffer) + extends NativeColumnAccessor(buffer, BYTE) private[sql] class ShortColumnAccessor(buffer: ByteBuffer) extends NativeColumnAccessor(buffer, SHORT) +private[sql] class IntColumnAccessor(buffer: ByteBuffer) + extends NativeColumnAccessor(buffer, INT) + private[sql] class LongColumnAccessor(buffer: ByteBuffer) extends NativeColumnAccessor(buffer, LONG) -private[sql] class ByteColumnAccessor(buffer: ByteBuffer) - extends NativeColumnAccessor(buffer, BYTE) - -private[sql] class DoubleColumnAccessor(buffer: ByteBuffer) - extends NativeColumnAccessor(buffer, DOUBLE) - private[sql] class FloatColumnAccessor(buffer: ByteBuffer) extends NativeColumnAccessor(buffer, FLOAT) -private[sql] class FixedDecimalColumnAccessor(buffer: ByteBuffer, precision: Int, scale: Int) - extends NativeColumnAccessor(buffer, FIXED_DECIMAL(precision, scale)) +private[sql] class DoubleColumnAccessor(buffer: ByteBuffer) + extends NativeColumnAccessor(buffer, DOUBLE) private[sql] class StringColumnAccessor(buffer: ByteBuffer) extends NativeColumnAccessor(buffer, STRING) -private[sql] class DateColumnAccessor(buffer: ByteBuffer) - extends NativeColumnAccessor(buffer, DATE) - -private[sql] class TimestampColumnAccessor(buffer: ByteBuffer) - extends NativeColumnAccessor(buffer, TIMESTAMP) - private[sql] class BinaryColumnAccessor(buffer: ByteBuffer) extends BasicColumnAccessor[BinaryType.type, Array[Byte]](buffer, BINARY) with NullableColumnAccessor +private[sql] class FixedDecimalColumnAccessor(buffer: ByteBuffer, precision: Int, scale: Int) + extends NativeColumnAccessor(buffer, FIXED_DECIMAL(precision, scale)) + private[sql] class GenericColumnAccessor(buffer: ByteBuffer) extends BasicColumnAccessor[DataType, Array[Byte]](buffer, GENERIC) with NullableColumnAccessor +private[sql] class DateColumnAccessor(buffer: ByteBuffer) + extends NativeColumnAccessor(buffer, DATE) + +private[sql] class TimestampColumnAccessor(buffer: ByteBuffer) + extends NativeColumnAccessor(buffer, TIMESTAMP) + private[sql] object ColumnAccessor { def apply(dataType: DataType, buffer: ByteBuffer): ColumnAccessor = { val dup = buffer.duplicate().order(ByteOrder.nativeOrder) @@ -118,17 +118,17 @@ private[sql] object ColumnAccessor { dup.getInt() dataType match { + case BooleanType => new BooleanColumnAccessor(dup) + case ByteType => new ByteColumnAccessor(dup) + case ShortType => new ShortColumnAccessor(dup) case IntegerType => new IntColumnAccessor(dup) + case DateType => new DateColumnAccessor(dup) case LongType => new LongColumnAccessor(dup) + case TimestampType => new TimestampColumnAccessor(dup) case FloatType => new FloatColumnAccessor(dup) case DoubleType => new DoubleColumnAccessor(dup) - case BooleanType => new BooleanColumnAccessor(dup) - case ByteType => new ByteColumnAccessor(dup) - case ShortType => new ShortColumnAccessor(dup) case StringType => new StringColumnAccessor(dup) case BinaryType => new BinaryColumnAccessor(dup) - case DateType => new DateColumnAccessor(dup) - case TimestampType => new TimestampColumnAccessor(dup) case DecimalType.Fixed(precision, scale) if precision < 19 => new FixedDecimalColumnAccessor(dup, precision, scale) case _ => new GenericColumnAccessor(dup) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/columnar/ColumnBuilder.scala b/sql/core/src/main/scala/org/apache/spark/sql/columnar/ColumnBuilder.scala index aa10af400c81..087c52239713 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/columnar/ColumnBuilder.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/columnar/ColumnBuilder.scala @@ -19,7 +19,7 @@ package org.apache.spark.sql.columnar import java.nio.{ByteBuffer, ByteOrder} -import org.apache.spark.sql.Row +import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.columnar.ColumnBuilder._ import org.apache.spark.sql.columnar.compression.{AllCompressionSchemes, CompressibleColumnBuilder} import org.apache.spark.sql.types._ @@ -33,7 +33,7 @@ private[sql] trait ColumnBuilder { /** * Appends `row(ordinal)` to the column builder. */ - def appendFrom(row: Row, ordinal: Int) + def appendFrom(row: InternalRow, ordinal: Int) /** * Column statistics information @@ -68,7 +68,7 @@ private[sql] class BasicColumnBuilder[T <: DataType, JvmType]( buffer.order(ByteOrder.nativeOrder()).putInt(columnType.typeId) } - override def appendFrom(row: Row, ordinal: Int): Unit = { + override def appendFrom(row: InternalRow, ordinal: Int): Unit = { buffer = ensureFreeSpace(buffer, columnType.actualSize(row, ordinal)) columnType.append(row, ordinal, buffer) } @@ -94,17 +94,21 @@ private[sql] abstract class NativeColumnBuilder[T <: AtomicType]( private[sql] class BooleanColumnBuilder extends NativeColumnBuilder(new BooleanColumnStats, BOOLEAN) -private[sql] class IntColumnBuilder extends NativeColumnBuilder(new IntColumnStats, INT) +private[sql] class ByteColumnBuilder extends NativeColumnBuilder(new ByteColumnStats, BYTE) private[sql] class ShortColumnBuilder extends NativeColumnBuilder(new ShortColumnStats, SHORT) +private[sql] class IntColumnBuilder extends NativeColumnBuilder(new IntColumnStats, INT) + private[sql] class LongColumnBuilder extends NativeColumnBuilder(new LongColumnStats, LONG) -private[sql] class ByteColumnBuilder extends NativeColumnBuilder(new ByteColumnStats, BYTE) +private[sql] class FloatColumnBuilder extends NativeColumnBuilder(new FloatColumnStats, FLOAT) private[sql] class DoubleColumnBuilder extends NativeColumnBuilder(new DoubleColumnStats, DOUBLE) -private[sql] class FloatColumnBuilder extends NativeColumnBuilder(new FloatColumnStats, FLOAT) +private[sql] class StringColumnBuilder extends NativeColumnBuilder(new StringColumnStats, STRING) + +private[sql] class BinaryColumnBuilder extends ComplexColumnBuilder(new BinaryColumnStats, BINARY) private[sql] class FixedDecimalColumnBuilder( precision: Int, @@ -113,19 +117,15 @@ private[sql] class FixedDecimalColumnBuilder( new FixedDecimalColumnStats, FIXED_DECIMAL(precision, scale)) -private[sql] class StringColumnBuilder extends NativeColumnBuilder(new StringColumnStats, STRING) +// TODO (lian) Add support for array, struct and map +private[sql] class GenericColumnBuilder + extends ComplexColumnBuilder(new GenericColumnStats, GENERIC) private[sql] class DateColumnBuilder extends NativeColumnBuilder(new DateColumnStats, DATE) private[sql] class TimestampColumnBuilder extends NativeColumnBuilder(new TimestampColumnStats, TIMESTAMP) -private[sql] class BinaryColumnBuilder extends ComplexColumnBuilder(new BinaryColumnStats, BINARY) - -// TODO (lian) Add support for array, struct and map -private[sql] class GenericColumnBuilder - extends ComplexColumnBuilder(new GenericColumnStats, GENERIC) - private[sql] object ColumnBuilder { val DEFAULT_INITIAL_BUFFER_SIZE = 1024 * 1024 @@ -151,17 +151,17 @@ private[sql] object ColumnBuilder { columnName: String = "", useCompression: Boolean = false): ColumnBuilder = { val builder: ColumnBuilder = dataType match { + case BooleanType => new BooleanColumnBuilder + case ByteType => new ByteColumnBuilder + case ShortType => new ShortColumnBuilder case IntegerType => new IntColumnBuilder + case DateType => new DateColumnBuilder case LongType => new LongColumnBuilder + case TimestampType => new TimestampColumnBuilder case FloatType => new FloatColumnBuilder case DoubleType => new DoubleColumnBuilder - case BooleanType => new BooleanColumnBuilder - case ByteType => new ByteColumnBuilder - case ShortType => new ShortColumnBuilder case StringType => new StringColumnBuilder case BinaryType => new BinaryColumnBuilder - case DateType => new DateColumnBuilder - case TimestampType => new TimestampColumnBuilder case DecimalType.Fixed(precision, scale) if precision < 19 => new FixedDecimalColumnBuilder(precision, scale) case _ => new GenericColumnBuilder diff --git a/sql/core/src/main/scala/org/apache/spark/sql/columnar/ColumnStats.scala b/sql/core/src/main/scala/org/apache/spark/sql/columnar/ColumnStats.scala index b0f983c18067..00374d1fa3ef 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/columnar/ColumnStats.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/columnar/ColumnStats.scala @@ -17,11 +17,10 @@ package org.apache.spark.sql.columnar -import java.sql.Timestamp - -import org.apache.spark.sql.Row -import org.apache.spark.sql.catalyst.expressions.{AttributeMap, Attribute, AttributeReference} +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeMap, AttributeReference} import org.apache.spark.sql.types._ +import org.apache.spark.unsafe.types.UTF8String private[sql] class ColumnStatisticsSchema(a: Attribute) extends Serializable { val upperBound = AttributeReference(a.name + ".upperBound", a.dataType, nullable = true)() @@ -54,7 +53,7 @@ private[sql] sealed trait ColumnStats extends Serializable { /** * Gathers statistics information from `row(ordinal)`. */ - def gatherStats(row: Row, ordinal: Int): Unit = { + def gatherStats(row: InternalRow, ordinal: Int): Unit = { if (row.isNullAt(ordinal)) { nullCount += 1 // 4 bytes for null position @@ -67,23 +66,23 @@ private[sql] sealed trait ColumnStats extends Serializable { * Column statistics represented as a single row, currently including closed lower bound, closed * upper bound and null count. */ - def collectedStatistics: Row + def collectedStatistics: InternalRow } /** * A no-op ColumnStats only used for testing purposes. */ private[sql] class NoopColumnStats extends ColumnStats { - override def gatherStats(row: Row, ordinal: Int): Unit = super.gatherStats(row, ordinal) + override def gatherStats(row: InternalRow, ordinal: Int): Unit = super.gatherStats(row, ordinal) - override def collectedStatistics: Row = Row(null, null, nullCount, count, 0L) + override def collectedStatistics: InternalRow = InternalRow(null, null, nullCount, count, 0L) } private[sql] class BooleanColumnStats extends ColumnStats { protected var upper = false protected var lower = true - override def gatherStats(row: Row, ordinal: Int): Unit = { + override def gatherStats(row: InternalRow, ordinal: Int): Unit = { super.gatherStats(row, ordinal) if (!row.isNullAt(ordinal)) { val value = row.getBoolean(ordinal) @@ -93,14 +92,15 @@ private[sql] class BooleanColumnStats extends ColumnStats { } } - override def collectedStatistics: Row = Row(lower, upper, nullCount, count, sizeInBytes) + override def collectedStatistics: InternalRow = + InternalRow(lower, upper, nullCount, count, sizeInBytes) } private[sql] class ByteColumnStats extends ColumnStats { protected var upper = Byte.MinValue protected var lower = Byte.MaxValue - override def gatherStats(row: Row, ordinal: Int): Unit = { + override def gatherStats(row: InternalRow, ordinal: Int): Unit = { super.gatherStats(row, ordinal) if (!row.isNullAt(ordinal)) { val value = row.getByte(ordinal) @@ -110,14 +110,15 @@ private[sql] class ByteColumnStats extends ColumnStats { } } - override def collectedStatistics: Row = Row(lower, upper, nullCount, count, sizeInBytes) + override def collectedStatistics: InternalRow = + InternalRow(lower, upper, nullCount, count, sizeInBytes) } private[sql] class ShortColumnStats extends ColumnStats { protected var upper = Short.MinValue protected var lower = Short.MaxValue - override def gatherStats(row: Row, ordinal: Int): Unit = { + override def gatherStats(row: InternalRow, ordinal: Int): Unit = { super.gatherStats(row, ordinal) if (!row.isNullAt(ordinal)) { val value = row.getShort(ordinal) @@ -127,48 +128,51 @@ private[sql] class ShortColumnStats extends ColumnStats { } } - override def collectedStatistics: Row = Row(lower, upper, nullCount, count, sizeInBytes) + override def collectedStatistics: InternalRow = + InternalRow(lower, upper, nullCount, count, sizeInBytes) } -private[sql] class LongColumnStats extends ColumnStats { - protected var upper = Long.MinValue - protected var lower = Long.MaxValue +private[sql] class IntColumnStats extends ColumnStats { + protected var upper = Int.MinValue + protected var lower = Int.MaxValue - override def gatherStats(row: Row, ordinal: Int): Unit = { + override def gatherStats(row: InternalRow, ordinal: Int): Unit = { super.gatherStats(row, ordinal) if (!row.isNullAt(ordinal)) { - val value = row.getLong(ordinal) + val value = row.getInt(ordinal) if (value > upper) upper = value if (value < lower) lower = value - sizeInBytes += LONG.defaultSize + sizeInBytes += INT.defaultSize } } - override def collectedStatistics: Row = Row(lower, upper, nullCount, count, sizeInBytes) + override def collectedStatistics: InternalRow = + InternalRow(lower, upper, nullCount, count, sizeInBytes) } -private[sql] class DoubleColumnStats extends ColumnStats { - protected var upper = Double.MinValue - protected var lower = Double.MaxValue +private[sql] class LongColumnStats extends ColumnStats { + protected var upper = Long.MinValue + protected var lower = Long.MaxValue - override def gatherStats(row: Row, ordinal: Int): Unit = { + override def gatherStats(row: InternalRow, ordinal: Int): Unit = { super.gatherStats(row, ordinal) if (!row.isNullAt(ordinal)) { - val value = row.getDouble(ordinal) + val value = row.getLong(ordinal) if (value > upper) upper = value if (value < lower) lower = value - sizeInBytes += DOUBLE.defaultSize + sizeInBytes += LONG.defaultSize } } - override def collectedStatistics: Row = Row(lower, upper, nullCount, count, sizeInBytes) + override def collectedStatistics: InternalRow = + InternalRow(lower, upper, nullCount, count, sizeInBytes) } private[sql] class FloatColumnStats extends ColumnStats { protected var upper = Float.MinValue protected var lower = Float.MaxValue - override def gatherStats(row: Row, ordinal: Int): Unit = { + override def gatherStats(row: InternalRow, ordinal: Int): Unit = { super.gatherStats(row, ordinal) if (!row.isNullAt(ordinal)) { val value = row.getFloat(ordinal) @@ -178,48 +182,33 @@ private[sql] class FloatColumnStats extends ColumnStats { } } - override def collectedStatistics: Row = Row(lower, upper, nullCount, count, sizeInBytes) -} - -private[sql] class FixedDecimalColumnStats extends ColumnStats { - protected var upper: Decimal = null - protected var lower: Decimal = null - - override def gatherStats(row: Row, ordinal: Int): Unit = { - super.gatherStats(row, ordinal) - if (!row.isNullAt(ordinal)) { - val value = row(ordinal).asInstanceOf[Decimal] - if (upper == null || value.compareTo(upper) > 0) upper = value - if (lower == null || value.compareTo(lower) < 0) lower = value - sizeInBytes += FIXED_DECIMAL.defaultSize - } - } - - override def collectedStatistics: Row = Row(lower, upper, nullCount, count, sizeInBytes) + override def collectedStatistics: InternalRow = + InternalRow(lower, upper, nullCount, count, sizeInBytes) } -private[sql] class IntColumnStats extends ColumnStats { - protected var upper = Int.MinValue - protected var lower = Int.MaxValue +private[sql] class DoubleColumnStats extends ColumnStats { + protected var upper = Double.MinValue + protected var lower = Double.MaxValue - override def gatherStats(row: Row, ordinal: Int): Unit = { + override def gatherStats(row: InternalRow, ordinal: Int): Unit = { super.gatherStats(row, ordinal) if (!row.isNullAt(ordinal)) { - val value = row.getInt(ordinal) + val value = row.getDouble(ordinal) if (value > upper) upper = value if (value < lower) lower = value - sizeInBytes += INT.defaultSize + sizeInBytes += DOUBLE.defaultSize } } - override def collectedStatistics: Row = Row(lower, upper, nullCount, count, sizeInBytes) + override def collectedStatistics: InternalRow = + InternalRow(lower, upper, nullCount, count, sizeInBytes) } private[sql] class StringColumnStats extends ColumnStats { protected var upper: UTF8String = null protected var lower: UTF8String = null - override def gatherStats(row: Row, ordinal: Int): Unit = { + override def gatherStats(row: InternalRow, ordinal: Int): Unit = { super.gatherStats(row, ordinal) if (!row.isNullAt(ordinal)) { val value = row(ordinal).asInstanceOf[UTF8String] @@ -229,46 +218,52 @@ private[sql] class StringColumnStats extends ColumnStats { } } - override def collectedStatistics: Row = Row(lower, upper, nullCount, count, sizeInBytes) + override def collectedStatistics: InternalRow = + InternalRow(lower, upper, nullCount, count, sizeInBytes) } -private[sql] class DateColumnStats extends IntColumnStats - -private[sql] class TimestampColumnStats extends ColumnStats { - protected var upper: Timestamp = null - protected var lower: Timestamp = null - - override def gatherStats(row: Row, ordinal: Int): Unit = { +private[sql] class BinaryColumnStats extends ColumnStats { + override def gatherStats(row: InternalRow, ordinal: Int): Unit = { super.gatherStats(row, ordinal) if (!row.isNullAt(ordinal)) { - val value = row(ordinal).asInstanceOf[Timestamp] - if (upper == null || value.compareTo(upper) > 0) upper = value - if (lower == null || value.compareTo(lower) < 0) lower = value - sizeInBytes += TIMESTAMP.defaultSize + sizeInBytes += BINARY.actualSize(row, ordinal) } } - override def collectedStatistics: Row = Row(lower, upper, nullCount, count, sizeInBytes) + override def collectedStatistics: InternalRow = + InternalRow(null, null, nullCount, count, sizeInBytes) } -private[sql] class BinaryColumnStats extends ColumnStats { - override def gatherStats(row: Row, ordinal: Int): Unit = { +private[sql] class FixedDecimalColumnStats extends ColumnStats { + protected var upper: Decimal = null + protected var lower: Decimal = null + + override def gatherStats(row: InternalRow, ordinal: Int): Unit = { super.gatherStats(row, ordinal) if (!row.isNullAt(ordinal)) { - sizeInBytes += BINARY.actualSize(row, ordinal) + val value = row(ordinal).asInstanceOf[Decimal] + if (upper == null || value.compareTo(upper) > 0) upper = value + if (lower == null || value.compareTo(lower) < 0) lower = value + sizeInBytes += FIXED_DECIMAL.defaultSize } } - override def collectedStatistics: Row = Row(null, null, nullCount, count, sizeInBytes) + override def collectedStatistics: InternalRow = + InternalRow(lower, upper, nullCount, count, sizeInBytes) } private[sql] class GenericColumnStats extends ColumnStats { - override def gatherStats(row: Row, ordinal: Int): Unit = { + override def gatherStats(row: InternalRow, ordinal: Int): Unit = { super.gatherStats(row, ordinal) if (!row.isNullAt(ordinal)) { sizeInBytes += GENERIC.actualSize(row, ordinal) } } - override def collectedStatistics: Row = Row(null, null, nullCount, count, sizeInBytes) + override def collectedStatistics: InternalRow = + InternalRow(null, null, nullCount, count, sizeInBytes) } + +private[sql] class DateColumnStats extends IntColumnStats + +private[sql] class TimestampColumnStats extends LongColumnStats diff --git a/sql/core/src/main/scala/org/apache/spark/sql/columnar/ColumnType.scala b/sql/core/src/main/scala/org/apache/spark/sql/columnar/ColumnType.scala index 20be5ca9d004..fc72360c88fe 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/columnar/ColumnType.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/columnar/ColumnType.scala @@ -18,14 +18,14 @@ package org.apache.spark.sql.columnar import java.nio.ByteBuffer -import java.sql.Timestamp import scala.reflect.runtime.universe.TypeTag -import org.apache.spark.sql.Row +import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions.MutableRow import org.apache.spark.sql.execution.SparkSqlSerializer import org.apache.spark.sql.types._ +import org.apache.spark.unsafe.types.UTF8String /** * An abstract class that represents type of a column. Used to append/extract Java objects into/from @@ -63,7 +63,7 @@ private[sql] sealed abstract class ColumnType[T <: DataType, JvmType]( * Appends `row(ordinal)` of type T into the given ByteBuffer. Subclasses should override this * method to avoid boxing/unboxing costs whenever possible. */ - def append(row: Row, ordinal: Int, buffer: ByteBuffer): Unit = { + def append(row: InternalRow, ordinal: Int, buffer: ByteBuffer): Unit = { append(getField(row, ordinal), buffer) } @@ -71,13 +71,13 @@ private[sql] sealed abstract class ColumnType[T <: DataType, JvmType]( * Returns the size of the value `row(ordinal)`. This is used to calculate the size of variable * length types such as byte arrays and strings. */ - def actualSize(row: Row, ordinal: Int): Int = defaultSize + def actualSize(row: InternalRow, ordinal: Int): Int = defaultSize /** * Returns `row(ordinal)`. Subclasses should override this method to avoid boxing/unboxing costs * whenever possible. */ - def getField(row: Row, ordinal: Int): JvmType + def getField(row: InternalRow, ordinal: Int): JvmType /** * Sets `row(ordinal)` to `field`. Subclasses should override this method to avoid boxing/unboxing @@ -89,7 +89,7 @@ private[sql] sealed abstract class ColumnType[T <: DataType, JvmType]( * Copies `from(fromOrdinal)` to `to(toOrdinal)`. Subclasses should override this method to avoid * boxing/unboxing costs whenever possible. */ - def copyField(from: Row, fromOrdinal: Int, to: MutableRow, toOrdinal: Int): Unit = { + def copyField(from: InternalRow, fromOrdinal: Int, to: MutableRow, toOrdinal: Int): Unit = { to(toOrdinal) = from(fromOrdinal) } @@ -118,7 +118,7 @@ private[sql] object INT extends NativeColumnType(IntegerType, 0, 4) { buffer.putInt(v) } - override def append(row: Row, ordinal: Int, buffer: ByteBuffer): Unit = { + override def append(row: InternalRow, ordinal: Int, buffer: ByteBuffer): Unit = { buffer.putInt(row.getInt(ordinal)) } @@ -134,9 +134,9 @@ private[sql] object INT extends NativeColumnType(IntegerType, 0, 4) { row.setInt(ordinal, value) } - override def getField(row: Row, ordinal: Int): Int = row.getInt(ordinal) + override def getField(row: InternalRow, ordinal: Int): Int = row.getInt(ordinal) - override def copyField(from: Row, fromOrdinal: Int, to: MutableRow, toOrdinal: Int): Unit = { + override def copyField(from: InternalRow, fromOrdinal: Int, to: MutableRow, toOrdinal: Int) { to.setInt(toOrdinal, from.getInt(fromOrdinal)) } } @@ -146,7 +146,7 @@ private[sql] object LONG extends NativeColumnType(LongType, 1, 8) { buffer.putLong(v) } - override def append(row: Row, ordinal: Int, buffer: ByteBuffer): Unit = { + override def append(row: InternalRow, ordinal: Int, buffer: ByteBuffer): Unit = { buffer.putLong(row.getLong(ordinal)) } @@ -162,9 +162,9 @@ private[sql] object LONG extends NativeColumnType(LongType, 1, 8) { row.setLong(ordinal, value) } - override def getField(row: Row, ordinal: Int): Long = row.getLong(ordinal) + override def getField(row: InternalRow, ordinal: Int): Long = row.getLong(ordinal) - override def copyField(from: Row, fromOrdinal: Int, to: MutableRow, toOrdinal: Int): Unit = { + override def copyField(from: InternalRow, fromOrdinal: Int, to: MutableRow, toOrdinal: Int) { to.setLong(toOrdinal, from.getLong(fromOrdinal)) } } @@ -174,7 +174,7 @@ private[sql] object FLOAT extends NativeColumnType(FloatType, 2, 4) { buffer.putFloat(v) } - override def append(row: Row, ordinal: Int, buffer: ByteBuffer): Unit = { + override def append(row: InternalRow, ordinal: Int, buffer: ByteBuffer): Unit = { buffer.putFloat(row.getFloat(ordinal)) } @@ -190,9 +190,9 @@ private[sql] object FLOAT extends NativeColumnType(FloatType, 2, 4) { row.setFloat(ordinal, value) } - override def getField(row: Row, ordinal: Int): Float = row.getFloat(ordinal) + override def getField(row: InternalRow, ordinal: Int): Float = row.getFloat(ordinal) - override def copyField(from: Row, fromOrdinal: Int, to: MutableRow, toOrdinal: Int): Unit = { + override def copyField(from: InternalRow, fromOrdinal: Int, to: MutableRow, toOrdinal: Int) { to.setFloat(toOrdinal, from.getFloat(fromOrdinal)) } } @@ -202,7 +202,7 @@ private[sql] object DOUBLE extends NativeColumnType(DoubleType, 3, 8) { buffer.putDouble(v) } - override def append(row: Row, ordinal: Int, buffer: ByteBuffer): Unit = { + override def append(row: InternalRow, ordinal: Int, buffer: ByteBuffer): Unit = { buffer.putDouble(row.getDouble(ordinal)) } @@ -218,9 +218,9 @@ private[sql] object DOUBLE extends NativeColumnType(DoubleType, 3, 8) { row.setDouble(ordinal, value) } - override def getField(row: Row, ordinal: Int): Double = row.getDouble(ordinal) + override def getField(row: InternalRow, ordinal: Int): Double = row.getDouble(ordinal) - override def copyField(from: Row, fromOrdinal: Int, to: MutableRow, toOrdinal: Int): Unit = { + override def copyField(from: InternalRow, fromOrdinal: Int, to: MutableRow, toOrdinal: Int) { to.setDouble(toOrdinal, from.getDouble(fromOrdinal)) } } @@ -230,7 +230,7 @@ private[sql] object BOOLEAN extends NativeColumnType(BooleanType, 4, 1) { buffer.put(if (v) 1: Byte else 0: Byte) } - override def append(row: Row, ordinal: Int, buffer: ByteBuffer): Unit = { + override def append(row: InternalRow, ordinal: Int, buffer: ByteBuffer): Unit = { buffer.put(if (row.getBoolean(ordinal)) 1: Byte else 0: Byte) } @@ -244,9 +244,9 @@ private[sql] object BOOLEAN extends NativeColumnType(BooleanType, 4, 1) { row.setBoolean(ordinal, value) } - override def getField(row: Row, ordinal: Int): Boolean = row.getBoolean(ordinal) + override def getField(row: InternalRow, ordinal: Int): Boolean = row.getBoolean(ordinal) - override def copyField(from: Row, fromOrdinal: Int, to: MutableRow, toOrdinal: Int): Unit = { + override def copyField(from: InternalRow, fromOrdinal: Int, to: MutableRow, toOrdinal: Int) { to.setBoolean(toOrdinal, from.getBoolean(fromOrdinal)) } } @@ -256,7 +256,7 @@ private[sql] object BYTE extends NativeColumnType(ByteType, 5, 1) { buffer.put(v) } - override def append(row: Row, ordinal: Int, buffer: ByteBuffer): Unit = { + override def append(row: InternalRow, ordinal: Int, buffer: ByteBuffer): Unit = { buffer.put(row.getByte(ordinal)) } @@ -272,9 +272,9 @@ private[sql] object BYTE extends NativeColumnType(ByteType, 5, 1) { row.setByte(ordinal, value) } - override def getField(row: Row, ordinal: Int): Byte = row.getByte(ordinal) + override def getField(row: InternalRow, ordinal: Int): Byte = row.getByte(ordinal) - override def copyField(from: Row, fromOrdinal: Int, to: MutableRow, toOrdinal: Int): Unit = { + override def copyField(from: InternalRow, fromOrdinal: Int, to: MutableRow, toOrdinal: Int) { to.setByte(toOrdinal, from.getByte(fromOrdinal)) } } @@ -284,7 +284,7 @@ private[sql] object SHORT extends NativeColumnType(ShortType, 6, 2) { buffer.putShort(v) } - override def append(row: Row, ordinal: Int, buffer: ByteBuffer): Unit = { + override def append(row: InternalRow, ordinal: Int, buffer: ByteBuffer): Unit = { buffer.putShort(row.getShort(ordinal)) } @@ -300,15 +300,15 @@ private[sql] object SHORT extends NativeColumnType(ShortType, 6, 2) { row.setShort(ordinal, value) } - override def getField(row: Row, ordinal: Int): Short = row.getShort(ordinal) + override def getField(row: InternalRow, ordinal: Int): Short = row.getShort(ordinal) - override def copyField(from: Row, fromOrdinal: Int, to: MutableRow, toOrdinal: Int): Unit = { + override def copyField(from: InternalRow, fromOrdinal: Int, to: MutableRow, toOrdinal: Int) { to.setShort(toOrdinal, from.getShort(fromOrdinal)) } } private[sql] object STRING extends NativeColumnType(StringType, 7, 8) { - override def actualSize(row: Row, ordinal: Int): Int = { + override def actualSize(row: InternalRow, ordinal: Int): Int = { row.getString(ordinal).getBytes("utf-8").length + 4 } @@ -321,18 +321,18 @@ private[sql] object STRING extends NativeColumnType(StringType, 7, 8) { val length = buffer.getInt() val stringBytes = new Array[Byte](length) buffer.get(stringBytes, 0, length) - UTF8String(stringBytes) + UTF8String.fromBytes(stringBytes) } override def setField(row: MutableRow, ordinal: Int, value: UTF8String): Unit = { row.update(ordinal, value) } - override def getField(row: Row, ordinal: Int): UTF8String = { + override def getField(row: InternalRow, ordinal: Int): UTF8String = { row(ordinal).asInstanceOf[UTF8String] } - override def copyField(from: Row, fromOrdinal: Int, to: MutableRow, toOrdinal: Int): Unit = { + override def copyField(from: InternalRow, fromOrdinal: Int, to: MutableRow, toOrdinal: Int) { to.update(toOrdinal, from(fromOrdinal)) } } @@ -346,7 +346,7 @@ private[sql] object DATE extends NativeColumnType(DateType, 8, 4) { buffer.putInt(v) } - override def getField(row: Row, ordinal: Int): Int = { + override def getField(row: InternalRow, ordinal: Int): Int = { row(ordinal).asInstanceOf[Int] } @@ -355,22 +355,20 @@ private[sql] object DATE extends NativeColumnType(DateType, 8, 4) { } } -private[sql] object TIMESTAMP extends NativeColumnType(TimestampType, 9, 12) { - override def extract(buffer: ByteBuffer): Timestamp = { - val timestamp = new Timestamp(buffer.getLong()) - timestamp.setNanos(buffer.getInt()) - timestamp +private[sql] object TIMESTAMP extends NativeColumnType(TimestampType, 9, 8) { + override def extract(buffer: ByteBuffer): Long = { + buffer.getLong } - override def append(v: Timestamp, buffer: ByteBuffer): Unit = { - buffer.putLong(v.getTime).putInt(v.getNanos) + override def append(v: Long, buffer: ByteBuffer): Unit = { + buffer.putLong(v) } - override def getField(row: Row, ordinal: Int): Timestamp = { - row(ordinal).asInstanceOf[Timestamp] + override def getField(row: InternalRow, ordinal: Int): Long = { + row(ordinal).asInstanceOf[Long] } - override def setField(row: MutableRow, ordinal: Int, value: Timestamp): Unit = { + override def setField(row: MutableRow, ordinal: Int, value: Long): Unit = { row(ordinal) = value } } @@ -389,7 +387,7 @@ private[sql] case class FIXED_DECIMAL(precision: Int, scale: Int) buffer.putLong(v.toUnscaledLong) } - override def getField(row: Row, ordinal: Int): Decimal = { + override def getField(row: InternalRow, ordinal: Int): Decimal = { row(ordinal).asInstanceOf[Decimal] } @@ -407,7 +405,7 @@ private[sql] sealed abstract class ByteArrayColumnType[T <: DataType]( defaultSize: Int) extends ColumnType[T, Array[Byte]](typeId, defaultSize) { - override def actualSize(row: Row, ordinal: Int): Int = { + override def actualSize(row: InternalRow, ordinal: Int): Int = { getField(row, ordinal).length + 4 } @@ -428,7 +426,7 @@ private[sql] object BINARY extends ByteArrayColumnType[BinaryType.type](11, 16) row(ordinal) = value } - override def getField(row: Row, ordinal: Int): Array[Byte] = { + override def getField(row: InternalRow, ordinal: Int): Array[Byte] = { row(ordinal).asInstanceOf[Array[Byte]] } } @@ -441,7 +439,7 @@ private[sql] object GENERIC extends ByteArrayColumnType[DataType](12, 16) { row(ordinal) = SparkSqlSerializer.deserialize[Any](value) } - override def getField(row: Row, ordinal: Int): Array[Byte] = { + override def getField(row: InternalRow, ordinal: Int): Array[Byte] = { SparkSqlSerializer.serialize(row(ordinal)) } } @@ -449,17 +447,17 @@ private[sql] object GENERIC extends ByteArrayColumnType[DataType](12, 16) { private[sql] object ColumnType { def apply(dataType: DataType): ColumnType[_, _] = { dataType match { + case BooleanType => BOOLEAN + case ByteType => BYTE + case ShortType => SHORT case IntegerType => INT + case DateType => DATE case LongType => LONG + case TimestampType => TIMESTAMP case FloatType => FLOAT case DoubleType => DOUBLE - case BooleanType => BOOLEAN - case ByteType => BYTE - case ShortType => SHORT case StringType => STRING case BinaryType => BINARY - case DateType => DATE - case TimestampType => TIMESTAMP case DecimalType.Fixed(precision, scale) if precision < 19 => FIXED_DECIMAL(precision, scale) case _ => GENERIC diff --git a/sql/core/src/main/scala/org/apache/spark/sql/columnar/InMemoryColumnarTableScan.scala b/sql/core/src/main/scala/org/apache/spark/sql/columnar/InMemoryColumnarTableScan.scala index 0ded1cce6839..cb1fd4947fdb 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/columnar/InMemoryColumnarTableScan.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/columnar/InMemoryColumnarTableScan.scala @@ -19,21 +19,16 @@ package org.apache.spark.sql.columnar import java.nio.ByteBuffer -import org.apache.spark.{Accumulable, Accumulator, Accumulators} -import org.apache.spark.sql.catalyst.expressions - import scala.collection.mutable.ArrayBuffer -import scala.collection.mutable.HashMap import org.apache.spark.rdd.RDD -import org.apache.spark.sql.Row -import org.apache.spark.SparkContext import org.apache.spark.sql.catalyst.analysis.MultiInstanceRelation import org.apache.spark.sql.catalyst.dsl.expressions._ import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, Statistics} import org.apache.spark.sql.execution.{LeafNode, SparkPlan} import org.apache.spark.storage.StorageLevel +import org.apache.spark.{Accumulable, Accumulator, Accumulators} private[sql] object InMemoryRelation { def apply( @@ -45,7 +40,7 @@ private[sql] object InMemoryRelation { new InMemoryRelation(child.output, useCompression, batchSize, storageLevel, child, tableName)() } -private[sql] case class CachedBatch(buffers: Array[Array[Byte]], stats: Row) +private[sql] case class CachedBatch(buffers: Array[Array[Byte]], stats: InternalRow) private[sql] case class InMemoryRelation( output: Seq[Attribute], @@ -56,12 +51,12 @@ private[sql] case class InMemoryRelation( tableName: Option[String])( private var _cachedColumnBuffers: RDD[CachedBatch] = null, private var _statistics: Statistics = null, - private var _batchStats: Accumulable[ArrayBuffer[Row], Row] = null) + private var _batchStats: Accumulable[ArrayBuffer[InternalRow], InternalRow] = null) extends LogicalPlan with MultiInstanceRelation { - private val batchStats: Accumulable[ArrayBuffer[Row], Row] = + private val batchStats: Accumulable[ArrayBuffer[InternalRow], InternalRow] = if (_batchStats == null) { - child.sqlContext.sparkContext.accumulableCollection(ArrayBuffer.empty[Row]) + child.sqlContext.sparkContext.accumulableCollection(ArrayBuffer.empty[InternalRow]) } else { _batchStats } @@ -151,7 +146,8 @@ private[sql] case class InMemoryRelation( rowCount += 1 } - val stats = Row.merge(columnBuilders.map(_.columnStats.collectedStatistics) : _*) + val stats = InternalRow.fromSeq(columnBuilders.map(_.columnStats.collectedStatistics) + .flatMap(_.toSeq)) batchStats += stats CachedBatch(columnBuilders.map(_.build().array()), stats) @@ -236,7 +232,7 @@ private[sql] case class InMemoryColumnarTableScan( case GreaterThanOrEqual(a: AttributeReference, l: Literal) => l <= statsFor(a).upperBound case GreaterThanOrEqual(l: Literal, a: AttributeReference) => statsFor(a).lowerBound <= l - case IsNull(a: Attribute) => statsFor(a).nullCount > 0 + case IsNull(a: Attribute) => statsFor(a).nullCount > 0 case IsNotNull(a: Attribute) => statsFor(a).count - statsFor(a).nullCount > 0 } @@ -267,7 +263,7 @@ private[sql] case class InMemoryColumnarTableScan( private val inMemoryPartitionPruningEnabled = sqlContext.conf.inMemoryPartitionPruning - protected override def doExecute(): RDD[Row] = { + protected override def doExecute(): RDD[InternalRow] = { if (enableAccumulators) { readPartitions.setValue(0) readBatches.setValue(0) @@ -296,7 +292,7 @@ private[sql] case class InMemoryColumnarTableScan( val nextRow = new SpecificMutableRow(requestedColumnDataTypes) - def cachedBatchesToRows(cacheBatches: Iterator[CachedBatch]): Iterator[Row] = { + def cachedBatchesToRows(cacheBatches: Iterator[CachedBatch]): Iterator[InternalRow] = { val rows = cacheBatches.flatMap { cachedBatch => // Build column accessors val columnAccessors = requestedColumnIndices.map { batchColumnIndex => @@ -306,15 +302,15 @@ private[sql] case class InMemoryColumnarTableScan( } // Extract rows via column accessors - new Iterator[Row] { + new Iterator[InternalRow] { private[this] val rowLen = nextRow.length - override def next(): Row = { + override def next(): InternalRow = { var i = 0 while (i < rowLen) { columnAccessors(i).extractTo(nextRow, i) i += 1 } - nextRow + if (attributes.isEmpty) InternalRow.empty else nextRow } override def hasNext: Boolean = columnAccessors(0).hasNext diff --git a/sql/core/src/main/scala/org/apache/spark/sql/columnar/NullableColumnBuilder.scala b/sql/core/src/main/scala/org/apache/spark/sql/columnar/NullableColumnBuilder.scala index f1f494ac26d0..ba47bc783f31 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/columnar/NullableColumnBuilder.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/columnar/NullableColumnBuilder.scala @@ -19,7 +19,7 @@ package org.apache.spark.sql.columnar import java.nio.{ByteBuffer, ByteOrder} -import org.apache.spark.sql.Row +import org.apache.spark.sql.catalyst.InternalRow /** * A stackable trait used for building byte buffer for a column containing null values. Memory @@ -52,7 +52,7 @@ private[sql] trait NullableColumnBuilder extends ColumnBuilder { super.initialize(initialSize, columnName, useCompression) } - abstract override def appendFrom(row: Row, ordinal: Int): Unit = { + abstract override def appendFrom(row: InternalRow, ordinal: Int): Unit = { columnStats.gatherStats(row, ordinal) if (row.isNullAt(ordinal)) { nulls = ColumnBuilder.ensureFreeSpace(nulls, 4) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/columnar/compression/CompressibleColumnBuilder.scala b/sql/core/src/main/scala/org/apache/spark/sql/columnar/compression/CompressibleColumnBuilder.scala index 8e2a1af6dae7..39b21ddb47ba 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/columnar/compression/CompressibleColumnBuilder.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/columnar/compression/CompressibleColumnBuilder.scala @@ -20,7 +20,7 @@ package org.apache.spark.sql.columnar.compression import java.nio.{ByteBuffer, ByteOrder} import org.apache.spark.Logging -import org.apache.spark.sql.Row +import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.columnar.{ColumnBuilder, NativeColumnBuilder} import org.apache.spark.sql.types.AtomicType @@ -66,7 +66,7 @@ private[sql] trait CompressibleColumnBuilder[T <: AtomicType] encoder.compressionRatio < 0.8 } - private def gatherCompressibilityStats(row: Row, ordinal: Int): Unit = { + private def gatherCompressibilityStats(row: InternalRow, ordinal: Int): Unit = { var i = 0 while (i < compressionEncoders.length) { compressionEncoders(i).gatherCompressibilityStats(row, ordinal) @@ -74,7 +74,7 @@ private[sql] trait CompressibleColumnBuilder[T <: AtomicType] } } - abstract override def appendFrom(row: Row, ordinal: Int): Unit = { + abstract override def appendFrom(row: InternalRow, ordinal: Int): Unit = { super.appendFrom(row, ordinal) if (!row.isNullAt(ordinal)) { gatherCompressibilityStats(row, ordinal) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/columnar/compression/CompressionScheme.scala b/sql/core/src/main/scala/org/apache/spark/sql/columnar/compression/CompressionScheme.scala index 17c2d9b11118..4eaec6d853d4 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/columnar/compression/CompressionScheme.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/columnar/compression/CompressionScheme.scala @@ -18,14 +18,13 @@ package org.apache.spark.sql.columnar.compression import java.nio.{ByteBuffer, ByteOrder} - -import org.apache.spark.sql.Row +import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions.MutableRow import org.apache.spark.sql.columnar.{ColumnType, NativeColumnType} import org.apache.spark.sql.types.AtomicType private[sql] trait Encoder[T <: AtomicType] { - def gatherCompressibilityStats(row: Row, ordinal: Int): Unit = {} + def gatherCompressibilityStats(row: InternalRow, ordinal: Int): Unit = {} def compressedSize: Int diff --git a/sql/core/src/main/scala/org/apache/spark/sql/columnar/compression/compressionSchemes.scala b/sql/core/src/main/scala/org/apache/spark/sql/columnar/compression/compressionSchemes.scala index 534ae90ddbc8..5abc1259a19a 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/columnar/compression/compressionSchemes.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/columnar/compression/compressionSchemes.scala @@ -22,8 +22,7 @@ import java.nio.ByteBuffer import scala.collection.mutable import scala.reflect.ClassTag import scala.reflect.runtime.universe.runtimeMirror - -import org.apache.spark.sql.Row +import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions.{MutableRow, SpecificMutableRow} import org.apache.spark.sql.columnar._ import org.apache.spark.sql.types._ @@ -96,7 +95,7 @@ private[sql] case object RunLengthEncoding extends CompressionScheme { override def compressedSize: Int = _compressedSize - override def gatherCompressibilityStats(row: Row, ordinal: Int): Unit = { + override def gatherCompressibilityStats(row: InternalRow, ordinal: Int): Unit = { val value = columnType.getField(row, ordinal) val actualSize = columnType.actualSize(row, ordinal) _uncompressedSize += actualSize @@ -217,7 +216,7 @@ private[sql] case object DictionaryEncoding extends CompressionScheme { // to store dictionary element count. private var dictionarySize = 4 - override def gatherCompressibilityStats(row: Row, ordinal: Int): Unit = { + override def gatherCompressibilityStats(row: InternalRow, ordinal: Int): Unit = { val value = columnType.getField(row, ordinal) if (!overflow) { @@ -310,7 +309,7 @@ private[sql] case object BooleanBitSet extends CompressionScheme { class Encoder extends compression.Encoder[BooleanType.type] { private var _uncompressedSize = 0 - override def gatherCompressibilityStats(row: Row, ordinal: Int): Unit = { + override def gatherCompressibilityStats(row: InternalRow, ordinal: Int): Unit = { _uncompressedSize += BOOLEAN.defaultSize } @@ -404,7 +403,7 @@ private[sql] case object IntDelta extends CompressionScheme { private var prevValue: Int = _ - override def gatherCompressibilityStats(row: Row, ordinal: Int): Unit = { + override def gatherCompressibilityStats(row: InternalRow, ordinal: Int): Unit = { val value = row.getInt(ordinal) val delta = value - prevValue @@ -484,7 +483,7 @@ private[sql] case object LongDelta extends CompressionScheme { private var prevValue: Long = _ - override def gatherCompressibilityStats(row: Row, ordinal: Int): Unit = { + override def gatherCompressibilityStats(row: InternalRow, ordinal: Int): Unit = { val value = row.getLong(ordinal) val delta = value - prevValue diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/Aggregate.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/Aggregate.scala index 8d16749697aa..6e8a5ef18ab6 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/Aggregate.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/Aggregate.scala @@ -20,12 +20,10 @@ package org.apache.spark.sql.execution import java.util.HashMap import org.apache.spark.annotation.DeveloperApi -import org.apache.spark.SparkContext import org.apache.spark.rdd.RDD import org.apache.spark.sql.catalyst.errors._ import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.plans.physical._ -import org.apache.spark.sql.SQLContext /** * :: DeveloperApi :: @@ -121,11 +119,11 @@ case class Aggregate( } } - protected override def doExecute(): RDD[Row] = attachTree(this, "execute") { + protected override def doExecute(): RDD[InternalRow] = attachTree(this, "execute") { if (groupingExpressions.isEmpty) { child.execute().mapPartitions { iter => val buffer = newAggregateBuffer() - var currentRow: Row = null + var currentRow: InternalRow = null while (iter.hasNext) { currentRow = iter.next() var i = 0 @@ -147,10 +145,10 @@ case class Aggregate( } } else { child.execute().mapPartitions { iter => - val hashTable = new HashMap[Row, Array[AggregateFunction]] + val hashTable = new HashMap[InternalRow, Array[AggregateFunction]] val groupingProjection = new InterpretedMutableProjection(groupingExpressions, child.output) - var currentRow: Row = null + var currentRow: InternalRow = null while (iter.hasNext) { currentRow = iter.next() val currentGroup = groupingProjection(currentRow) @@ -167,7 +165,7 @@ case class Aggregate( } } - new Iterator[Row] { + new Iterator[InternalRow] { private[this] val hashTableIter = hashTable.entrySet().iterator() private[this] val aggregateResults = new GenericMutableRow(computedAggregates.length) private[this] val resultProjection = @@ -177,7 +175,7 @@ case class Aggregate( override final def hasNext: Boolean = hashTableIter.hasNext - override final def next(): Row = { + override final def next(): InternalRow = { val currentEntry = hashTableIter.next() val currentGroup = currentEntry.getKey val currentBuffer = currentEntry.getValue diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/CacheManager.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/CacheManager.scala index 5fcc48a67948..a4b38d364d54 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/CacheManager.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/CacheManager.scala @@ -103,7 +103,7 @@ private[sql] class CacheManager(sqlContext: SQLContext) extends Logging { sqlContext.conf.useCompression, sqlContext.conf.columnBatchSize, storageLevel, - query.queryExecution.executedPlan, + sqlContext.executePlan(query.logicalPlan).executedPlan, tableName)) } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/Exchange.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/Exchange.scala index 3e46596ecf6a..e054c1d144e3 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/Exchange.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/Exchange.scala @@ -17,29 +17,20 @@ package org.apache.spark.sql.execution -import org.apache.spark.{HashPartitioner, Partitioner, RangePartitioner, SparkEnv} import org.apache.spark.annotation.DeveloperApi import org.apache.spark.rdd.{RDD, ShuffledRDD} import org.apache.spark.serializer.Serializer +import org.apache.spark.shuffle.hash.HashShuffleManager import org.apache.spark.shuffle.sort.SortShuffleManager import org.apache.spark.shuffle.unsafe.UnsafeShuffleManager +import org.apache.spark.sql.SQLContext import org.apache.spark.sql.catalyst.errors.attachTree import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.plans.physical._ import org.apache.spark.sql.catalyst.rules.Rule import org.apache.spark.sql.types.DataType -import org.apache.spark.sql.{SQLContext, Row} import org.apache.spark.util.MutablePair - -object Exchange { - /** - * Returns true when the ordering expressions are a subset of the key. - * if true, ShuffledRDD can use `setKeyOrdering(orderingKey)` to sort within [[Exchange]]. - */ - def canSortWithShuffle(partitioning: Partitioning, desiredOrdering: Seq[SortOrder]): Boolean = { - desiredOrdering.map(_.child).toSet.subsetOf(partitioning.keyExpressions.toSet) - } -} +import org.apache.spark.{HashPartitioner, Partitioner, RangePartitioner, SparkEnv} /** * :: DeveloperApi :: @@ -91,11 +82,7 @@ case class Exchange( shuffleManager.isInstanceOf[UnsafeShuffleManager] val bypassMergeThreshold = conf.getInt("spark.shuffle.sort.bypassMergeThreshold", 200) val serializeMapOutputs = conf.getBoolean("spark.shuffle.sort.serializeMapOutputs", true) - if (newOrdering.nonEmpty) { - // If a new ordering is required, then records will be sorted with Spark's `ExternalSorter`, - // which requires a defensive copy. - true - } else if (sortBasedShuffleOn) { + if (sortBasedShuffleOn) { val bypassIsSupported = SparkEnv.get.shuffleManager.isInstanceOf[SortShuffleManager] if (bypassIsSupported && partitioner.numPartitions <= bypassMergeThreshold) { // If we're using the original SortShuffleManager and the number of output partitions is @@ -106,8 +93,11 @@ case class Exchange( } else if (serializeMapOutputs && serializer.supportsRelocationOfSerializedObjects) { // SPARK-4550 extended sort-based shuffle to serialize individual records prior to sorting // them. This optimization is guarded by a feature-flag and is only applied in cases where - // shuffle dependency does not specify an ordering and the record serializer has certain - // properties. If this optimization is enabled, we can safely avoid the copy. + // shuffle dependency does not specify an aggregator or ordering and the record serializer + // has certain properties. If this optimization is enabled, we can safely avoid the copy. + // + // Exchange never configures its ShuffledRDDs with aggregators or key orderings, so we only + // need to check whether the optimization is enabled and supported by our serializer. // // This optimization also applies to UnsafeShuffleManager (added in SPARK-7081). false @@ -118,23 +108,12 @@ case class Exchange( // both cases, we must copy. true } - } else { + } else if (shuffleManager.isInstanceOf[HashShuffleManager]) { // We're using hash-based shuffle, so we don't need to copy. false - } - } - - private val keyOrdering = { - if (newOrdering.nonEmpty) { - val key = newPartitioning.keyExpressions - val boundOrdering = newOrdering.map { o => - val ordinal = key.indexOf(o.child) - if (ordinal == -1) sys.error(s"Invalid ordering on $o requested for $newPartitioning") - o.copy(child = BoundReference(ordinal, o.child.dataType, o.child.nullable)) - } - new RowOrdering(boundOrdering) } else { - null // Ordering will not be used + // Catch-all case to safely handle any future ShuffleManager implementations. + true } } @@ -143,7 +122,6 @@ case class Exchange( private def getSerializer( keySchema: Array[DataType], valueSchema: Array[DataType], - hasKeyOrdering: Boolean, numPartitions: Int): Serializer = { // It is true when there is no field that needs to be write out. // For now, we will not use SparkSqlSerializer2 when noField is true. @@ -159,7 +137,7 @@ case class Exchange( val serializer = if (useSqlSerializer2) { logInfo("Using SparkSqlSerializer2.") - new SparkSqlSerializer2(keySchema, valueSchema, hasKeyOrdering) + new SparkSqlSerializer2(keySchema, valueSchema) } else { logInfo("Using SparkSqlSerializer.") new SparkSqlSerializer(sparkConf) @@ -168,12 +146,12 @@ case class Exchange( serializer } - protected override def doExecute(): RDD[Row] = attachTree(this , "execute") { + protected override def doExecute(): RDD[InternalRow] = attachTree(this , "execute") { newPartitioning match { case HashPartitioning(expressions, numPartitions) => val keySchema = expressions.map(_.dataType).toArray val valueSchema = child.output.map(_.dataType).toArray - val serializer = getSerializer(keySchema, valueSchema, newOrdering.nonEmpty, numPartitions) + val serializer = getSerializer(keySchema, valueSchema, numPartitions) val part = new HashPartitioner(numPartitions) val rdd = if (needToCopyObjectsBeforeShuffle(part, serializer)) { @@ -184,27 +162,24 @@ case class Exchange( } else { child.execute().mapPartitions { iter => val hashExpressions = newMutableProjection(expressions, child.output)() - val mutablePair = new MutablePair[Row, Row]() + val mutablePair = new MutablePair[InternalRow, InternalRow]() iter.map(r => mutablePair.update(hashExpressions(r), r)) } } - val shuffled = new ShuffledRDD[Row, Row, Row](rdd, part) - if (newOrdering.nonEmpty) { - shuffled.setKeyOrdering(keyOrdering) - } + val shuffled = new ShuffledRDD[InternalRow, InternalRow, InternalRow](rdd, part) shuffled.setSerializer(serializer) shuffled.map(_._2) case RangePartitioning(sortingExpressions, numPartitions) => val keySchema = child.output.map(_.dataType).toArray - val serializer = getSerializer(keySchema, null, newOrdering.nonEmpty, numPartitions) + val serializer = getSerializer(keySchema, null, numPartitions) val childRdd = child.execute() val part: Partitioner = { // Internally, RangePartitioner runs a job on the RDD that samples keys to compute // partition bounds. To get accurate samples, we need to copy the mutable keys. val rddForSampling = childRdd.mapPartitions { iter => - val mutablePair = new MutablePair[Row, Null]() + val mutablePair = new MutablePair[InternalRow, Null]() iter.map(row => mutablePair.update(row.copy(), null)) } // TODO: RangePartitioner should take an Ordering. @@ -216,32 +191,31 @@ case class Exchange( childRdd.mapPartitions { iter => iter.map(row => (row.copy(), null))} } else { childRdd.mapPartitions { iter => - val mutablePair = new MutablePair[Row, Null]() + val mutablePair = new MutablePair[InternalRow, Null]() iter.map(row => mutablePair.update(row, null)) } } - val shuffled = new ShuffledRDD[Row, Null, Null](rdd, part) - if (newOrdering.nonEmpty) { - shuffled.setKeyOrdering(keyOrdering) - } + val shuffled = new ShuffledRDD[InternalRow, Null, Null](rdd, part) shuffled.setSerializer(serializer) shuffled.map(_._1) case SinglePartition => val valueSchema = child.output.map(_.dataType).toArray - val serializer = getSerializer(null, valueSchema, hasKeyOrdering = false, 1) + val serializer = getSerializer(null, valueSchema, numPartitions = 1) val partitioner = new HashPartitioner(1) val rdd = if (needToCopyObjectsBeforeShuffle(partitioner, serializer)) { - child.execute().mapPartitions { iter => iter.map(r => (null, r.copy())) } + child.execute().mapPartitions { + iter => iter.map(r => (null, r.copy())) + } } else { child.execute().mapPartitions { iter => - val mutablePair = new MutablePair[Null, Row]() + val mutablePair = new MutablePair[Null, InternalRow]() iter.map(r => mutablePair.update(null, r)) } } - val shuffled = new ShuffledRDD[Null, Row, Row](rdd, partitioner) + val shuffled = new ShuffledRDD[Null, InternalRow, InternalRow](rdd, partitioner) shuffled.setSerializer(serializer) shuffled.map(_._2) @@ -296,7 +270,7 @@ private[sql] case class EnsureRequirements(sqlContext: SQLContext) extends Rule[ .sliding(2) .map { case Seq(a) => true - case Seq(a,b) => a compatibleWith b + case Seq(a, b) => a.compatibleWith(b) }.exists(!_) // Adds Exchange or Sort operators as required @@ -306,29 +280,24 @@ private[sql] case class EnsureRequirements(sqlContext: SQLContext) extends Rule[ child: SparkPlan): SparkPlan = { val needSort = rowOrdering.nonEmpty && child.outputOrdering != rowOrdering val needsShuffle = child.outputPartitioning != partitioning - val canSortWithShuffle = Exchange.canSortWithShuffle(partitioning, rowOrdering) - if (needSort && needsShuffle && canSortWithShuffle) { - Exchange(partitioning, rowOrdering, child) + val withShuffle = if (needsShuffle) { + Exchange(partitioning, Nil, child) } else { - val withShuffle = if (needsShuffle) { - Exchange(partitioning, Nil, child) - } else { - child - } + child + } - val withSort = if (needSort) { - if (sqlContext.conf.externalSortEnabled) { - ExternalSort(rowOrdering, global = false, withShuffle) - } else { - Sort(rowOrdering, global = false, withShuffle) - } + val withSort = if (needSort) { + if (sqlContext.conf.externalSortEnabled) { + ExternalSort(rowOrdering, global = false, withShuffle) } else { - withShuffle + Sort(rowOrdering, global = false, withShuffle) } - - withSort + } else { + withShuffle } + + withSort } if (meetsRequirements && compatible && !needsAnySort) { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/ExistingRDD.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/ExistingRDD.scala index a500269f3cdc..da27a753a710 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/ExistingRDD.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/ExistingRDD.scala @@ -19,11 +19,11 @@ package org.apache.spark.sql.execution import org.apache.spark.annotation.DeveloperApi import org.apache.spark.rdd.RDD -import org.apache.spark.sql.catalyst.CatalystTypeConverters +import org.apache.spark.sql.catalyst.{InternalRow, CatalystTypeConverters} import org.apache.spark.sql.catalyst.analysis.MultiInstanceRelation -import org.apache.spark.sql.catalyst.expressions.{Attribute, GenericMutableRow, SpecificMutableRow} +import org.apache.spark.sql.catalyst.expressions.{Attribute, GenericMutableRow} import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, Statistics} -import org.apache.spark.sql.types.StructType +import org.apache.spark.sql.types.DataType import org.apache.spark.sql.{Row, SQLContext} /** @@ -31,26 +31,19 @@ import org.apache.spark.sql.{Row, SQLContext} */ @DeveloperApi object RDDConversions { - def productToRowRdd[A <: Product](data: RDD[A], schema: StructType): RDD[Row] = { + def productToRowRdd[A <: Product](data: RDD[A], outputTypes: Seq[DataType]): RDD[InternalRow] = { data.mapPartitions { iterator => - if (iterator.isEmpty) { - Iterator.empty - } else { - val bufferedIterator = iterator.buffered - val mutableRow = new SpecificMutableRow(schema.fields.map(_.dataType)) - val schemaFields = schema.fields.toArray - val converters = schemaFields.map { - f => CatalystTypeConverters.createToCatalystConverter(f.dataType) - } - bufferedIterator.map { r => - var i = 0 - while (i < mutableRow.length) { - mutableRow(i) = converters(i)(r.productElement(i)) - i += 1 - } - - mutableRow + val numColumns = outputTypes.length + val mutableRow = new GenericMutableRow(numColumns) + val converters = outputTypes.map(CatalystTypeConverters.createToCatalystConverter) + iterator.map { r => + var i = 0 + while (i < numColumns) { + mutableRow(i) = converters(i)(r.productElement(i)) + i += 1 } + + mutableRow } } } @@ -58,33 +51,28 @@ object RDDConversions { /** * Convert the objects inside Row into the types Catalyst expected. */ - def rowToRowRdd(data: RDD[Row], schema: StructType): RDD[Row] = { + def rowToRowRdd(data: RDD[Row], outputTypes: Seq[DataType]): RDD[InternalRow] = { data.mapPartitions { iterator => - if (iterator.isEmpty) { - Iterator.empty - } else { - val bufferedIterator = iterator.buffered - val mutableRow = new GenericMutableRow(bufferedIterator.head.toSeq.toArray) - val schemaFields = schema.fields.toArray - val converters = schemaFields.map { - f => CatalystTypeConverters.createToCatalystConverter(f.dataType) - } - bufferedIterator.map { r => - var i = 0 - while (i < mutableRow.length) { - mutableRow(i) = converters(i)(r(i)) - i += 1 - } - - mutableRow + val numColumns = outputTypes.length + val mutableRow = new GenericMutableRow(numColumns) + val converters = outputTypes.map(CatalystTypeConverters.createToCatalystConverter) + iterator.map { r => + var i = 0 + while (i < numColumns) { + mutableRow(i) = converters(i)(r(i)) + i += 1 } + + mutableRow } } } } /** Logical plan node for scanning data from an RDD. */ -private[sql] case class LogicalRDD(output: Seq[Attribute], rdd: RDD[Row])(sqlContext: SQLContext) +private[sql] case class LogicalRDD( + output: Seq[Attribute], + rdd: RDD[InternalRow])(sqlContext: SQLContext) extends LogicalPlan with MultiInstanceRelation { override def children: Seq[LogicalPlan] = Nil @@ -105,13 +93,15 @@ private[sql] case class LogicalRDD(output: Seq[Attribute], rdd: RDD[Row])(sqlCon } /** Physical plan node for scanning data from an RDD. */ -private[sql] case class PhysicalRDD(output: Seq[Attribute], rdd: RDD[Row]) extends LeafNode { - protected override def doExecute(): RDD[Row] = rdd +private[sql] case class PhysicalRDD( + output: Seq[Attribute], + rdd: RDD[InternalRow]) extends LeafNode { + protected override def doExecute(): RDD[InternalRow] = rdd } /** Logical plan node for scanning data from a local collection. */ private[sql] -case class LogicalLocalTable(output: Seq[Attribute], rows: Seq[Row])(sqlContext: SQLContext) +case class LogicalLocalTable(output: Seq[Attribute], rows: Seq[InternalRow])(sqlContext: SQLContext) extends LogicalPlan with MultiInstanceRelation { override def children: Seq[LogicalPlan] = Nil diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/Expand.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/Expand.scala index f16ca36909fa..42a0c1be4f69 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/Expand.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/Expand.scala @@ -19,10 +19,9 @@ package org.apache.spark.sql.execution import org.apache.spark.annotation.DeveloperApi import org.apache.spark.rdd.RDD -import org.apache.spark.sql.SQLContext import org.apache.spark.sql.catalyst.errors._ import org.apache.spark.sql.catalyst.expressions._ -import org.apache.spark.sql.catalyst.plans.physical.{UnknownPartitioning, Partitioning} +import org.apache.spark.sql.catalyst.plans.physical.{Partitioning, UnknownPartitioning} /** * Apply the all of the GroupExpressions to every input row, hence we will get @@ -34,7 +33,7 @@ import org.apache.spark.sql.catalyst.plans.physical.{UnknownPartitioning, Partit */ @DeveloperApi case class Expand( - projections: Seq[GroupExpression], + projections: Seq[Seq[Expression]], output: Seq[Attribute], child: SparkPlan) extends UnaryNode { @@ -43,22 +42,22 @@ case class Expand( // as UNKNOWN partitioning override def outputPartitioning: Partitioning = UnknownPartitioning(0) - protected override def doExecute(): RDD[Row] = attachTree(this, "execute") { + protected override def doExecute(): RDD[InternalRow] = attachTree(this, "execute") { child.execute().mapPartitions { iter => // TODO Move out projection objects creation and transfer to // workers via closure. However we can't assume the Projection // is serializable because of the code gen, so we have to // create the projections within each of the partition processing. - val groups = projections.map(ee => newProjection(ee.children, child.output)).toArray + val groups = projections.map(ee => newProjection(ee, child.output)).toArray - new Iterator[Row] { - private[this] var result: Row = _ + new Iterator[InternalRow] { + private[this] var result: InternalRow = _ private[this] var idx = -1 // -1 means the initial state - private[this] var input: Row = _ + private[this] var input: InternalRow = _ override final def hasNext: Boolean = (-1 < idx && idx < groups.length) || iter.hasNext - override final def next(): Row = { + override final def next(): InternalRow = { if (idx <= 0) { // in the initial (-1) or beginning(0) of a new input row, fetch the next input tuple input = iter.next() diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/Generate.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/Generate.scala index dd02c1f4573b..c1665f78a960 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/Generate.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/Generate.scala @@ -25,12 +25,12 @@ import org.apache.spark.sql.catalyst.expressions._ * For lazy computing, be sure the generator.terminate() called in the very last * TODO reusing the CompletionIterator? */ -private[execution] sealed case class LazyIterator(func: () => TraversableOnce[Row]) - extends Iterator[Row] { +private[execution] sealed case class LazyIterator(func: () => TraversableOnce[InternalRow]) + extends Iterator[InternalRow] { lazy val results = func().toIterator override def hasNext: Boolean = results.hasNext - override def next(): Row = results.next() + override def next(): InternalRow = results.next() } /** @@ -58,11 +58,11 @@ case class Generate( val boundGenerator = BindReferences.bindReference(generator, child.output) - protected override def doExecute(): RDD[Row] = { + protected override def doExecute(): RDD[InternalRow] = { // boundGenerator.terminate() should be triggered after all of the rows in the partition if (join) { child.execute().mapPartitions { iter => - val generatorNullRow = Row.fromSeq(Seq.fill[Any](generator.elementTypes.size)(null)) + val generatorNullRow = InternalRow.fromSeq(Seq.fill[Any](generator.elementTypes.size)(null)) val joinedRow = new JoinedRow iter.flatMap { row => diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/GeneratedAggregate.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/GeneratedAggregate.scala index 2ec7d4fbc92d..44930f82b53a 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/GeneratedAggregate.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/GeneratedAggregate.scala @@ -66,7 +66,7 @@ case class GeneratedAggregate( override def output: Seq[Attribute] = aggregateExpressions.map(_.toAttribute) - protected override def doExecute(): RDD[Row] = { + protected override def doExecute(): RDD[InternalRow] = { val aggregatesToCompute = aggregateExpressions.flatMap { a => a.collect { case agg: AggregateExpression => agg} } @@ -118,7 +118,7 @@ case class GeneratedAggregate( AggregateEvaluation(currentSum :: Nil, initialValue :: Nil, updateFunction :: Nil, result) case cs @ CombineSum(expr) => - val calcType = expr.dataType + val calcType = expr.dataType match { case DecimalType.Fixed(_, _) => DecimalType.Unlimited @@ -129,7 +129,7 @@ case class GeneratedAggregate( val currentSum = AttributeReference("currentSum", calcType, nullable = true)() val initialValue = Literal.create(null, calcType) - // Coalasce avoids double calculation... + // Coalesce avoids double calculation... // but really, common sub expression elimination would be better.... val zero = Cast(Literal(0), calcType) // If we're evaluating UnscaledValue(x), we can do Count on x directly, since its @@ -138,15 +138,15 @@ case class GeneratedAggregate( case UnscaledValue(e) => e case _ => expr } - // partial sum result can be null only when no input rows present + // partial sum result can be null only when no input rows present val updateFunction = If( IsNotNull(actualExpr), Coalesce( Add( - Coalesce(currentSum :: zero :: Nil), + Coalesce(currentSum :: zero :: Nil), Cast(expr, calcType)) :: currentSum :: zero :: Nil), currentSum) - + val result = expr.dataType match { case DecimalType.Fixed(_, _) => @@ -155,7 +155,7 @@ case class GeneratedAggregate( } AggregateEvaluation(currentSum :: Nil, initialValue :: Nil, updateFunction :: Nil, result) - + case m @ Max(expr) => val currentMax = AttributeReference("currentMax", expr.dataType, nullable = true)() val initialValue = Literal.create(null, expr.dataType) @@ -214,18 +214,18 @@ case class GeneratedAggregate( }.toMap val namedGroups = groupingExpressions.zipWithIndex.map { - case (ne: NamedExpression, _) => (ne, ne) - case (e, i) => (e, Alias(e, s"GroupingExpr$i")()) + case (ne: NamedExpression, _) => (ne, ne.toAttribute) + case (e, i) => (e, Alias(e, s"GroupingExpr$i")().toAttribute) } - val groupMap: Map[Expression, Attribute] = - namedGroups.map { case (k, v) => k -> v.toAttribute}.toMap - // The set of expressions that produce the final output given the aggregation buffer and the // grouping expressions. val resultExpressions = aggregateExpressions.map(_.transform { case e: Expression if resultMap.contains(new TreeNodeRef(e)) => resultMap(new TreeNodeRef(e)) - case e: Expression if groupMap.contains(e) => groupMap(e) + case e: Expression => + namedGroups.collectFirst { + case (expr, attr) if expr semanticEquals e => attr + }.getOrElse(e) }) val aggregationBufferSchema: StructType = StructType.fromAttributes(computationSchema) @@ -238,11 +238,6 @@ case class GeneratedAggregate( StructType(fields) } - val schemaSupportsUnsafe: Boolean = { - UnsafeFixedWidthAggregationMap.supportsAggregationBufferSchema(aggregationBufferSchema) && - UnsafeFixedWidthAggregationMap.supportsGroupKeySchema(groupKeySchema) - } - child.execute().mapPartitions { iter => // Builds a new custom class for holding the results of aggregation for a group. val initialValues = computeFunctions.flatMap(_.initialValues) @@ -265,7 +260,7 @@ case class GeneratedAggregate( val resultProjectionBuilder = newMutableProjection( resultExpressions, - (namedGroups.map(_._2.toAttribute) ++ computationSchema).toSeq) + namedGroups.map(_._2) ++ computationSchema) log.info(s"Result Projection: ${resultExpressions.mkString(",")}") val joinedRow = new JoinedRow3 @@ -273,7 +268,7 @@ case class GeneratedAggregate( if (groupingExpressions.isEmpty) { // TODO: Codegening anything other than the updateProjection is probably over kill. val buffer = newAggregationBuffer(EmptyRow).asInstanceOf[MutableRow] - var currentRow: Row = null + var currentRow: InternalRow = null updateProjection.target(buffer) while (iter.hasNext) { @@ -283,31 +278,31 @@ case class GeneratedAggregate( val resultProjection = resultProjectionBuilder() Iterator(resultProjection(buffer)) - } else if (unsafeEnabled && schemaSupportsUnsafe) { + } else if (unsafeEnabled) { log.info("Using Unsafe-based aggregator") val aggregationMap = new UnsafeFixedWidthAggregationMap( - newAggregationBuffer(EmptyRow), - aggregationBufferSchema, - groupKeySchema, + newAggregationBuffer, + new UnsafeRowConverter(groupKeySchema), + new UnsafeRowConverter(aggregationBufferSchema), TaskContext.get.taskMemoryManager(), 1024 * 16, // initial capacity false // disable tracking of performance metrics ) while (iter.hasNext) { - val currentRow: Row = iter.next() - val groupKey: Row = groupProjection(currentRow) + val currentRow: InternalRow = iter.next() + val groupKey: InternalRow = groupProjection(currentRow) val aggregationBuffer = aggregationMap.getAggregationBuffer(groupKey) updateProjection.target(aggregationBuffer)(joinedRow(aggregationBuffer, currentRow)) } - new Iterator[Row] { + new Iterator[InternalRow] { private[this] val mapIterator = aggregationMap.iterator() private[this] val resultProjection = resultProjectionBuilder() def hasNext: Boolean = mapIterator.hasNext - def next(): Row = { + def next(): InternalRow = { val entry = mapIterator.next() val result = resultProjection(joinedRow(entry.key, entry.value)) if (hasNext) { @@ -323,12 +318,9 @@ case class GeneratedAggregate( } } } else { - if (unsafeEnabled) { - log.info("Not using Unsafe-based aggregator because it is not supported for this schema") - } - val buffers = new java.util.HashMap[Row, MutableRow]() + val buffers = new java.util.HashMap[InternalRow, MutableRow]() - var currentRow: Row = null + var currentRow: InternalRow = null while (iter.hasNext) { currentRow = iter.next() val currentGroup = groupProjection(currentRow) @@ -342,13 +334,13 @@ case class GeneratedAggregate( updateProjection.target(currentBuffer)(joinedRow(currentBuffer, currentRow)) } - new Iterator[Row] { + new Iterator[InternalRow] { private[this] val resultIterator = buffers.entrySet.iterator() private[this] val resultProjection = resultProjectionBuilder() def hasNext: Boolean = resultIterator.hasNext - def next(): Row = { + def next(): InternalRow = { val currentGroup = resultIterator.next() resultProjection(joinedRow(currentGroup.getKey, currentGroup.getValue)) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/LocalTableScan.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/LocalTableScan.scala index 03bee80ad7f3..cd341180b610 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/LocalTableScan.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/LocalTableScan.scala @@ -19,18 +19,20 @@ package org.apache.spark.sql.execution import org.apache.spark.rdd.RDD import org.apache.spark.sql.Row -import org.apache.spark.sql.catalyst.CatalystTypeConverters +import org.apache.spark.sql.catalyst.{InternalRow, CatalystTypeConverters} import org.apache.spark.sql.catalyst.expressions.Attribute /** * Physical plan node for scanning data from a local collection. */ -private[sql] case class LocalTableScan(output: Seq[Attribute], rows: Seq[Row]) extends LeafNode { +private[sql] case class LocalTableScan( + output: Seq[Attribute], + rows: Seq[InternalRow]) extends LeafNode { private lazy val rdd = sqlContext.sparkContext.parallelize(rows) - protected override def doExecute(): RDD[Row] = rdd + protected override def doExecute(): RDD[InternalRow] = rdd override def executeCollect(): Array[Row] = { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala index 435ac011178d..7739a9f949c7 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala @@ -23,6 +23,7 @@ import org.apache.spark.rdd.{RDD, RDDOperationScope} import org.apache.spark.sql.SQLContext import org.apache.spark.sql.catalyst.{CatalystTypeConverters, trees} import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.Row import org.apache.spark.sql.catalyst.expressions.codegen._ import org.apache.spark.sql.catalyst.plans.QueryPlan import org.apache.spark.sql.catalyst.plans.physical._ @@ -79,11 +80,11 @@ abstract class SparkPlan extends QueryPlan[SparkPlan] with Logging with Serializ def requiredChildOrdering: Seq[Seq[SortOrder]] = Seq.fill(children.size)(Nil) /** - * Returns the result of this query as an RDD[Row] by delegating to doExecute + * Returns the result of this query as an RDD[InternalRow] by delegating to doExecute * after adding query plan information to created RDDs for visualization. * Concrete implementations of SparkPlan should override doExecute instead. */ - final def execute(): RDD[Row] = { + final def execute(): RDD[InternalRow] = { RDDOperationScope.withScope(sparkContext, nodeName, false, true) { doExecute() } @@ -91,9 +92,9 @@ abstract class SparkPlan extends QueryPlan[SparkPlan] with Logging with Serializ /** * Overridden by concrete implementations of SparkPlan. - * Produces the result of the query as an RDD[Row] + * Produces the result of the query as an RDD[InternalRow] */ - protected def doExecute(): RDD[Row] + protected def doExecute(): RDD[InternalRow] /** * Runs this query returning the result as an array. @@ -117,7 +118,7 @@ abstract class SparkPlan extends QueryPlan[SparkPlan] with Logging with Serializ val childRDD = execute().map(_.copy()) - val buf = new ArrayBuffer[Row] + val buf = new ArrayBuffer[InternalRow] val totalParts = childRDD.partitions.length var partsScanned = 0 while (buf.size < n && partsScanned < totalParts) { @@ -140,7 +141,8 @@ abstract class SparkPlan extends QueryPlan[SparkPlan] with Logging with Serializ val p = partsScanned until math.min(partsScanned + numPartsToTry, totalParts) val sc = sqlContext.sparkContext val res = - sc.runJob(childRDD, (it: Iterator[Row]) => it.take(left).toArray, p, allowLocal = false) + sc.runJob(childRDD, (it: Iterator[InternalRow]) => it.take(left).toArray, p, + allowLocal = false) res.foreach(buf ++= _.take(n - buf.size)) partsScanned += numPartsToTry @@ -175,7 +177,7 @@ abstract class SparkPlan extends QueryPlan[SparkPlan] with Logging with Serializ protected def newPredicate( - expression: Expression, inputSchema: Seq[Attribute]): (Row) => Boolean = { + expression: Expression, inputSchema: Seq[Attribute]): (InternalRow) => Boolean = { if (codegenEnabled) { GeneratePredicate.generate(expression, inputSchema) } else { @@ -183,7 +185,9 @@ abstract class SparkPlan extends QueryPlan[SparkPlan] with Logging with Serializ } } - protected def newOrdering(order: Seq[SortOrder], inputSchema: Seq[Attribute]): Ordering[Row] = { + protected def newOrdering( + order: Seq[SortOrder], + inputSchema: Seq[Attribute]): Ordering[InternalRow] = { if (codegenEnabled) { GenerateOrdering.generate(order, inputSchema) } else { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkSqlSerializer.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkSqlSerializer.scala index eea15aff5dbc..b19ad4f1c563 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkSqlSerializer.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkSqlSerializer.scala @@ -20,22 +20,20 @@ package org.apache.spark.sql.execution import java.nio.ByteBuffer import java.util.{HashMap => JavaHashMap} -import org.apache.spark.sql.types.Decimal - import scala.reflect.ClassTag import com.clearspring.analytics.stream.cardinality.HyperLogLog import com.esotericsoftware.kryo.io.{Input, Output} -import com.esotericsoftware.kryo.{Serializer, Kryo} +import com.esotericsoftware.kryo.{Kryo, Serializer} import com.twitter.chill.ResourcePool -import org.apache.spark.{SparkEnv, SparkConf} -import org.apache.spark.serializer.{SerializerInstance, KryoSerializer} -import org.apache.spark.sql.catalyst.expressions.GenericRow -import org.apache.spark.util.collection.OpenHashSet -import org.apache.spark.util.MutablePair - +import org.apache.spark.serializer.{KryoSerializer, SerializerInstance} +import org.apache.spark.sql.catalyst.expressions.GenericInternalRow import org.apache.spark.sql.catalyst.expressions.codegen.{IntegerHashSet, LongHashSet} +import org.apache.spark.sql.types.Decimal +import org.apache.spark.util.MutablePair +import org.apache.spark.util.collection.OpenHashSet +import org.apache.spark.{SparkConf, SparkEnv} private[sql] class SparkSqlSerializer(conf: SparkConf) extends KryoSerializer(conf) { override def newKryo(): Kryo = { @@ -43,6 +41,7 @@ private[sql] class SparkSqlSerializer(conf: SparkConf) extends KryoSerializer(co kryo.setRegistrationRequired(false) kryo.register(classOf[MutablePair[_, _]]) kryo.register(classOf[org.apache.spark.sql.catalyst.expressions.GenericRow]) + kryo.register(classOf[org.apache.spark.sql.catalyst.expressions.GenericInternalRow]) kryo.register(classOf[org.apache.spark.sql.catalyst.expressions.GenericMutableRow]) kryo.register(classOf[com.clearspring.analytics.stream.cardinality.HyperLogLog], new HyperLogLogSerializer) @@ -139,7 +138,7 @@ private[sql] class OpenHashSetSerializer extends Serializer[OpenHashSet[_]] { val iterator = hs.iterator while(iterator.hasNext) { val row = iterator.next() - rowSerializer.write(kryo, output, row.asInstanceOf[GenericRow].values) + rowSerializer.write(kryo, output, row.asInstanceOf[GenericInternalRow].values) } } @@ -150,7 +149,7 @@ private[sql] class OpenHashSetSerializer extends Serializer[OpenHashSet[_]] { var i = 0 while (i < numItems) { val row = - new GenericRow(rowSerializer.read( + new GenericInternalRow(rowSerializer.read( kryo, input, classOf[Array[Any]].asInstanceOf[Class[Any]]).asInstanceOf[Array[Any]]) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkSqlSerializer2.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkSqlSerializer2.scala index 256d527d7b63..056d435eecd2 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkSqlSerializer2.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkSqlSerializer2.scala @@ -20,15 +20,16 @@ package org.apache.spark.sql.execution import java.io._ import java.math.{BigDecimal, BigInteger} import java.nio.ByteBuffer -import java.sql.Timestamp import scala.reflect.ClassTag -import org.apache.spark.serializer._ import org.apache.spark.Logging +import org.apache.spark.serializer._ import org.apache.spark.sql.Row -import org.apache.spark.sql.catalyst.expressions.{SpecificMutableRow, MutableRow, GenericMutableRow} +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.expressions.{MutableRow, SpecificMutableRow} import org.apache.spark.sql.types._ +import org.apache.spark.unsafe.types.UTF8String /** * The serialization stream for [[SparkSqlSerializer2]]. It assumes that the object passed in @@ -86,7 +87,6 @@ private[sql] class Serializer2SerializationStream( private[sql] class Serializer2DeserializationStream( keySchema: Array[DataType], valueSchema: Array[DataType], - hasKeyOrdering: Boolean, in: InputStream) extends DeserializationStream with Logging { @@ -96,14 +96,9 @@ private[sql] class Serializer2DeserializationStream( if (schema == null) { () => null } else { - if (hasKeyOrdering) { - // We have key ordering specified in a ShuffledRDD, it is not safe to reuse a mutable row. - () => new GenericMutableRow(schema.length) - } else { - // It is safe to reuse the mutable row. - val mutableRow = new SpecificMutableRow(schema) - () => mutableRow - } + // It is safe to reuse the mutable row. + val mutableRow = new SpecificMutableRow(schema) + () => mutableRow } } @@ -133,8 +128,7 @@ private[sql] class Serializer2DeserializationStream( private[sql] class SparkSqlSerializer2Instance( keySchema: Array[DataType], - valueSchema: Array[DataType], - hasKeyOrdering: Boolean) + valueSchema: Array[DataType]) extends SerializerInstance { def serialize[T: ClassTag](t: T): ByteBuffer = @@ -151,7 +145,7 @@ private[sql] class SparkSqlSerializer2Instance( } def deserializeStream(s: InputStream): DeserializationStream = { - new Serializer2DeserializationStream(keySchema, valueSchema, hasKeyOrdering, s) + new Serializer2DeserializationStream(keySchema, valueSchema, s) } } @@ -164,14 +158,13 @@ private[sql] class SparkSqlSerializer2Instance( */ private[sql] class SparkSqlSerializer2( keySchema: Array[DataType], - valueSchema: Array[DataType], - hasKeyOrdering: Boolean) + valueSchema: Array[DataType]) extends Serializer with Logging with Serializable{ def newInstance(): SerializerInstance = - new SparkSqlSerializer2Instance(keySchema, valueSchema, hasKeyOrdering) + new SparkSqlSerializer2Instance(keySchema, valueSchema) override def supportsRelocationOfSerializedObjects: Boolean = { // SparkSqlSerializer2 is stateless and writes no stream headers @@ -244,7 +237,7 @@ private[sql] object SparkSqlSerializer2 { out.writeShort(row.getShort(i)) } - case IntegerType => + case IntegerType | DateType => if (row.isNullAt(i)) { out.writeByte(NULL) } else { @@ -252,7 +245,7 @@ private[sql] object SparkSqlSerializer2 { out.writeInt(row.getInt(i)) } - case LongType => + case LongType | TimestampType => if (row.isNullAt(i)) { out.writeByte(NULL) } else { @@ -276,59 +269,39 @@ private[sql] object SparkSqlSerializer2 { out.writeDouble(row.getDouble(i)) } - case decimal: DecimalType => + case StringType => if (row.isNullAt(i)) { out.writeByte(NULL) } else { out.writeByte(NOT_NULL) - val value = row.apply(i).asInstanceOf[Decimal] - val javaBigDecimal = value.toJavaBigDecimal - // First, write out the unscaled value. - val bytes: Array[Byte] = javaBigDecimal.unscaledValue().toByteArray + val bytes = row.getAs[UTF8String](i).getBytes out.writeInt(bytes.length) out.write(bytes) - // Then, write out the scale. - out.writeInt(javaBigDecimal.scale()) } - case DateType => - if (row.isNullAt(i)) { - out.writeByte(NULL) - } else { - out.writeByte(NOT_NULL) - out.writeInt(row.getAs[Int](i)) - } - - case TimestampType => - if (row.isNullAt(i)) { - out.writeByte(NULL) - } else { - out.writeByte(NOT_NULL) - val timestamp = row.getAs[java.sql.Timestamp](i) - val time = timestamp.getTime - val nanos = timestamp.getNanos - out.writeLong(time - (nanos / 1000000)) // Write the milliseconds value. - out.writeInt(nanos) // Write the nanoseconds part. - } - - case StringType => + case BinaryType => if (row.isNullAt(i)) { out.writeByte(NULL) } else { out.writeByte(NOT_NULL) - val bytes = row.getAs[UTF8String](i).getBytes + val bytes = row.getAs[Array[Byte]](i) out.writeInt(bytes.length) out.write(bytes) } - case BinaryType => + case decimal: DecimalType => if (row.isNullAt(i)) { out.writeByte(NULL) } else { out.writeByte(NOT_NULL) - val bytes = row.getAs[Array[Byte]](i) + val value = row.apply(i).asInstanceOf[Decimal] + val javaBigDecimal = value.toJavaBigDecimal + // First, write out the unscaled value. + val bytes: Array[Byte] = javaBigDecimal.unscaledValue().toByteArray out.writeInt(bytes.length) out.write(bytes) + // Then, write out the scale. + out.writeInt(javaBigDecimal.scale()) } } i += 1 @@ -341,7 +314,7 @@ private[sql] object SparkSqlSerializer2 { */ def createDeserializationFunction( schema: Array[DataType], - in: DataInputStream): (MutableRow) => Row = { + in: DataInputStream): (MutableRow) => InternalRow = { if (schema == null) { (mutableRow: MutableRow) => null } else { @@ -375,14 +348,14 @@ private[sql] object SparkSqlSerializer2 { mutableRow.setShort(i, in.readShort()) } - case IntegerType => + case IntegerType | DateType => if (in.readByte() == NULL) { mutableRow.setNullAt(i) } else { mutableRow.setInt(i, in.readInt()) } - case LongType => + case LongType | TimestampType => if (in.readByte() == NULL) { mutableRow.setNullAt(i) } else { @@ -403,57 +376,39 @@ private[sql] object SparkSqlSerializer2 { mutableRow.setDouble(i, in.readDouble()) } - case decimal: DecimalType => + case StringType => if (in.readByte() == NULL) { mutableRow.setNullAt(i) } else { - // First, read in the unscaled value. val length = in.readInt() val bytes = new Array[Byte](length) in.readFully(bytes) - val unscaledVal = new BigInteger(bytes) - // Then, read the scale. - val scale = in.readInt() - // Finally, create the Decimal object and set it in the row. - mutableRow.update(i, Decimal(new BigDecimal(unscaledVal, scale))) - } - - case DateType => - if (in.readByte() == NULL) { - mutableRow.setNullAt(i) - } else { - mutableRow.update(i, in.readInt()) + mutableRow.update(i, UTF8String.fromBytes(bytes)) } - case TimestampType => - if (in.readByte() == NULL) { - mutableRow.setNullAt(i) - } else { - val time = in.readLong() // Read the milliseconds value. - val nanos = in.readInt() // Read the nanoseconds part. - val timestamp = new Timestamp(time) - timestamp.setNanos(nanos) - mutableRow.update(i, timestamp) - } - - case StringType => + case BinaryType => if (in.readByte() == NULL) { mutableRow.setNullAt(i) } else { val length = in.readInt() val bytes = new Array[Byte](length) in.readFully(bytes) - mutableRow.update(i, UTF8String(bytes)) + mutableRow.update(i, bytes) } - case BinaryType => + case decimal: DecimalType => if (in.readByte() == NULL) { mutableRow.setNullAt(i) } else { + // First, read in the unscaled value. val length = in.readInt() val bytes = new Array[Byte](length) in.readFully(bytes) - mutableRow.update(i, bytes) + val unscaledVal = new BigInteger(bytes) + // Then, read the scale. + val scale = in.readInt() + // Finally, create the Decimal object and set it in the row. + mutableRow.update(i, Decimal(new BigDecimal(unscaledVal, scale))) } } i += 1 diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala index af0029cb84f9..32044989044a 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala @@ -20,7 +20,7 @@ package org.apache.spark.sql.execution import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.planning._ import org.apache.spark.sql.catalyst.plans._ -import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan +import org.apache.spark.sql.catalyst.plans.logical.{BroadcastHint, LogicalPlan} import org.apache.spark.sql.catalyst.plans.physical._ import org.apache.spark.sql.columnar.{InMemoryColumnarTableScan, InMemoryRelation} import org.apache.spark.sql.execution.{DescribeCommand => RunnableDescribeCommand} @@ -52,6 +52,18 @@ private[sql] abstract class SparkStrategies extends QueryPlanner[SparkPlan] { } } + /** + * Matches a plan whose output should be small enough to be used in broadcast join. + */ + object CanBroadcast { + def unapply(plan: LogicalPlan): Option[LogicalPlan] = plan match { + case BroadcastHint(p) => Some(p) + case p if sqlContext.conf.autoBroadcastJoinThreshold > 0 && + p.statistics.sizeInBytes <= sqlContext.conf.autoBroadcastJoinThreshold => Some(p) + case _ => None + } + } + /** * Uses the ExtractEquiJoinKeys pattern to find joins where at least some of the predicates can be * evaluated by matching hash keys. @@ -80,15 +92,11 @@ private[sql] abstract class SparkStrategies extends QueryPlanner[SparkPlan] { } def apply(plan: LogicalPlan): Seq[SparkPlan] = plan match { - case ExtractEquiJoinKeys(Inner, leftKeys, rightKeys, condition, left, right) - if sqlContext.conf.autoBroadcastJoinThreshold > 0 && - right.statistics.sizeInBytes <= sqlContext.conf.autoBroadcastJoinThreshold => + case ExtractEquiJoinKeys(Inner, leftKeys, rightKeys, condition, left, CanBroadcast(right)) => makeBroadcastHashJoin(leftKeys, rightKeys, left, right, condition, joins.BuildRight) - case ExtractEquiJoinKeys(Inner, leftKeys, rightKeys, condition, left, right) - if sqlContext.conf.autoBroadcastJoinThreshold > 0 && - left.statistics.sizeInBytes <= sqlContext.conf.autoBroadcastJoinThreshold => - makeBroadcastHashJoin(leftKeys, rightKeys, left, right, condition, joins.BuildLeft) + case ExtractEquiJoinKeys(Inner, leftKeys, rightKeys, condition, CanBroadcast(left), right) => + makeBroadcastHashJoin(leftKeys, rightKeys, left, right, condition, joins.BuildLeft) // If the sort merge join option is set, we want to use sort merge join prior to hashjoin // for now let's support inner join first, then add outer join @@ -109,8 +117,18 @@ private[sql] abstract class SparkStrategies extends QueryPlanner[SparkPlan] { leftKeys, rightKeys, buildSide, planLater(left), planLater(right)) condition.map(Filter(_, hashJoin)).getOrElse(hashJoin) :: Nil + case ExtractEquiJoinKeys( + LeftOuter, leftKeys, rightKeys, condition, left, CanBroadcast(right)) => + joins.BroadcastHashOuterJoin( + leftKeys, rightKeys, LeftOuter, condition, planLater(left), planLater(right)) :: Nil + + case ExtractEquiJoinKeys( + RightOuter, leftKeys, rightKeys, condition, CanBroadcast(left), right) => + joins.BroadcastHashOuterJoin( + leftKeys, rightKeys, RightOuter, condition, planLater(left), planLater(right)) :: Nil + case ExtractEquiJoinKeys(joinType, leftKeys, rightKeys, condition, left, right) => - joins.HashOuterJoin( + joins.ShuffledHashOuterJoin( leftKeys, rightKeys, joinType, condition, planLater(left), planLater(right)) :: Nil case _ => Nil @@ -202,13 +220,16 @@ private[sql] abstract class SparkStrategies extends QueryPlanner[SparkPlan] { } } - protected lazy val singleRowRdd = - sparkContext.parallelize(Seq(new GenericRow(Array[Any]()): Row), 1) + protected lazy val singleRowRdd = sparkContext.parallelize(Seq(InternalRow()), 1) - object TakeOrdered extends Strategy { + object TakeOrderedAndProject extends Strategy { def apply(plan: LogicalPlan): Seq[SparkPlan] = plan match { case logical.Limit(IntegerLiteral(limit), logical.Sort(order, true, child)) => - execution.TakeOrdered(limit, order, planLater(child)) :: Nil + execution.TakeOrderedAndProject(limit, order, None, planLater(child)) :: Nil + case logical.Limit( + IntegerLiteral(limit), + logical.Project(projectList, logical.Sort(order, true, child))) => + execution.TakeOrderedAndProject(limit, order, Some(projectList), planLater(child)) :: Nil case _ => Nil } } @@ -243,8 +264,9 @@ private[sql] abstract class SparkStrategies extends QueryPlanner[SparkPlan] { case (predicate, None) => predicate // Filter needs to be applied above when it contains partitioning // columns - case (predicate, _) if(!predicate.references.map(_.name).toSet - .intersect (partitionColNames).isEmpty) => predicate + case (predicate, _) + if !predicate.references.map(_.name).toSet.intersect(partitionColNames).isEmpty => + predicate } } } else { @@ -270,7 +292,7 @@ private[sql] abstract class SparkStrategies extends QueryPlanner[SparkPlan] { projectList, filters, identity[Seq[Expression]], // All filters still need to be evaluated. - InMemoryColumnarTableScan(_, filters, mem)) :: Nil + InMemoryColumnarTableScan(_, filters, mem)) :: Nil case _ => Nil } } @@ -283,8 +305,8 @@ private[sql] abstract class SparkStrategies extends QueryPlanner[SparkPlan] { case r: RunnableCommand => ExecutedCommand(r) :: Nil case logical.Distinct(child) => - execution.Distinct(partial = false, - execution.Distinct(partial = true, planLater(child))) :: Nil + throw new IllegalStateException( + "logical distinct operator should have been replaced by aggregate in the optimizer") case logical.Repartition(numPartitions, shuffle, child) => execution.Repartition(numPartitions, shuffle, planLater(child)) :: Nil case logical.SortPartitions(sortExprs, child) => @@ -299,8 +321,8 @@ private[sql] abstract class SparkStrategies extends QueryPlanner[SparkPlan] { execution.Project(projectList, planLater(child)) :: Nil case logical.Filter(condition, child) => execution.Filter(condition, planLater(child)) :: Nil - case logical.Expand(projections, output, child) => - execution.Expand(projections, output, planLater(child)) :: Nil + case e @ logical.Expand(_, _, _, child) => + execution.Expand(e.projections, e.output, planLater(child)) :: Nil case logical.Aggregate(group, agg, child) => execution.Aggregate(partial = false, group, agg, planLater(child)) :: Nil case logical.Window(projectList, windowExpressions, spec, child) => @@ -328,6 +350,7 @@ private[sql] abstract class SparkStrategies extends QueryPlanner[SparkPlan] { case e @ EvaluatePython(udf, child, _) => BatchPythonEvaluation(udf, e.output, planLater(child)) :: Nil case LogicalRDD(output, rdd) => PhysicalRDD(output, rdd) :: Nil + case BroadcastHint(child) => apply(child) case _ => Nil } } @@ -354,10 +377,10 @@ private[sql] abstract class SparkStrategies extends QueryPlanner[SparkPlan] { case c: CreateTableUsingAsSelect if !c.temporary => sys.error("Tables created with SQLContext must be TEMPORARY. Use a HiveContext instead.") - case LogicalDescribeCommand(table, isExtended) => + case describe @ LogicalDescribeCommand(table, isExtended) => val resultPlan = self.sqlContext.executePlan(table).executedPlan ExecutedCommand( - RunnableDescribeCommand(resultPlan, resultPlan.output, isExtended)) :: Nil + RunnableDescribeCommand(resultPlan, describe.output, isExtended)) :: Nil case _ => Nil } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/Window.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/Window.scala index c4327ce262ac..fd6f1d7ae125 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/Window.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/Window.scala @@ -20,9 +20,8 @@ package org.apache.spark.sql.execution import java.util import org.apache.spark.rdd.RDD -import org.apache.spark.sql.Row import org.apache.spark.sql.catalyst.expressions._ -import org.apache.spark.sql.catalyst.plans.physical.{AllTuples, Distribution, ClusteredDistribution, Partitioning} +import org.apache.spark.sql.catalyst.plans.physical.{AllTuples, ClusteredDistribution, Distribution, Partitioning} import org.apache.spark.util.collection.CompactBuffer /** @@ -112,16 +111,16 @@ case class Window( } } - protected override def doExecute(): RDD[Row] = { + protected override def doExecute(): RDD[InternalRow] = { child.execute().mapPartitions { iter => - new Iterator[Row] { + new Iterator[InternalRow] { // Although input rows are grouped based on windowSpec.partitionSpec, we need to // know when we have a new partition. // This is to manually construct an ordering that can be used to compare rows. // TODO: We may want to have a newOrdering that takes BoundReferences. // So, we can take advantave of code gen. - private val partitionOrdering: Ordering[Row] = + private val partitionOrdering: Ordering[InternalRow] = RowOrdering.forSchema(windowSpec.partitionSpec.map(_.dataType)) // This is used to project expressions for the partition specification. @@ -137,13 +136,13 @@ case class Window( // The number of buffered rows in the inputRowBuffer (the size of the current partition). var partitionSize: Int = 0 // The buffer used to buffer rows in a partition. - var inputRowBuffer: CompactBuffer[Row] = _ + var inputRowBuffer: CompactBuffer[InternalRow] = _ // The partition key of the current partition. - var currentPartitionKey: Row = _ + var currentPartitionKey: InternalRow = _ // The partition key of next partition. - var nextPartitionKey: Row = _ + var nextPartitionKey: InternalRow = _ // The first row of next partition. - var firstRowInNextPartition: Row = _ + var firstRowInNextPartition: InternalRow = _ // Indicates if this partition is the last one in the iter. var lastPartition: Boolean = false @@ -316,7 +315,7 @@ case class Window( !lastPartition || (rowPosition < partitionSize) } - override final def next(): Row = { + override final def next(): InternalRow = { if (hasNext) { if (rowPosition == partitionSize) { // All rows of this buffer have been consumed. @@ -353,7 +352,7 @@ case class Window( // Fetch the next partition. private def fetchNextPartition(): Unit = { // Create a new buffer for input rows. - inputRowBuffer = new CompactBuffer[Row]() + inputRowBuffer = new CompactBuffer[InternalRow]() // We already have the first row for this partition // (recorded in firstRowInNextPartition). Add it back. inputRowBuffer += firstRowInNextPartition diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala index 6cb67b4bbbb6..647c4ab5cb65 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala @@ -17,16 +17,17 @@ package org.apache.spark.sql.execution -import org.apache.spark.{SparkEnv, HashPartitioner, SparkConf} import org.apache.spark.annotation.DeveloperApi import org.apache.spark.rdd.{RDD, ShuffledRDD} import org.apache.spark.shuffle.sort.SortShuffleManager import org.apache.spark.sql.catalyst.CatalystTypeConverters import org.apache.spark.sql.catalyst.errors._ import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.Row import org.apache.spark.sql.catalyst.plans.physical._ -import org.apache.spark.util.{CompletionIterator, MutablePair} import org.apache.spark.util.collection.ExternalSorter +import org.apache.spark.util.{CompletionIterator, MutablePair} +import org.apache.spark.{HashPartitioner, SparkEnv} /** * :: DeveloperApi :: @@ -37,9 +38,9 @@ case class Project(projectList: Seq[NamedExpression], child: SparkPlan) extends @transient lazy val buildProjection = newMutableProjection(projectList, child.output) - protected override def doExecute(): RDD[Row] = child.execute().mapPartitions { iter => - val resuableProjection = buildProjection() - iter.map(resuableProjection) + protected override def doExecute(): RDD[InternalRow] = child.execute().mapPartitions { iter => + val reusableProjection = buildProjection() + iter.map(reusableProjection) } override def outputOrdering: Seq[SortOrder] = child.outputOrdering @@ -52,9 +53,10 @@ case class Project(projectList: Seq[NamedExpression], child: SparkPlan) extends case class Filter(condition: Expression, child: SparkPlan) extends UnaryNode { override def output: Seq[Attribute] = child.output - @transient lazy val conditionEvaluator: (Row) => Boolean = newPredicate(condition, child.output) + @transient lazy val conditionEvaluator: (InternalRow) => Boolean = + newPredicate(condition, child.output) - protected override def doExecute(): RDD[Row] = child.execute().mapPartitions { iter => + protected override def doExecute(): RDD[InternalRow] = child.execute().mapPartitions { iter => iter.filter(conditionEvaluator) } @@ -65,7 +67,7 @@ case class Filter(condition: Expression, child: SparkPlan) extends UnaryNode { * :: DeveloperApi :: * Sample the dataset. * @param lowerBound Lower-bound of the sampling probability (usually 0.0) - * @param upperBound Upper-bound of the sampling probability. The expected fraction sampled + * @param upperBound Upper-bound of the sampling probability. The expected fraction sampled * will be ub - lb. * @param withReplacement Whether to sample with replacement. * @param seed the random seed @@ -83,7 +85,7 @@ case class Sample( override def output: Seq[Attribute] = child.output // TODO: How to pick seed? - protected override def doExecute(): RDD[Row] = { + protected override def doExecute(): RDD[InternalRow] = { if (withReplacement) { child.execute().map(_.copy()).sample(withReplacement, upperBound - lowerBound, seed) } else { @@ -99,7 +101,8 @@ case class Sample( case class Union(children: Seq[SparkPlan]) extends SparkPlan { // TODO: attributes output by union should be distinct for nullability purposes override def output: Seq[Attribute] = children.head.output - protected override def doExecute(): RDD[Row] = sparkContext.union(children.map(_.execute())) + protected override def doExecute(): RDD[InternalRow] = + sparkContext.union(children.map(_.execute())) } /** @@ -124,19 +127,19 @@ case class Limit(limit: Int, child: SparkPlan) override def executeCollect(): Array[Row] = child.executeTake(limit) - protected override def doExecute(): RDD[Row] = { - val rdd: RDD[_ <: Product2[Boolean, Row]] = if (sortBasedShuffleOn) { + protected override def doExecute(): RDD[InternalRow] = { + val rdd: RDD[_ <: Product2[Boolean, InternalRow]] = if (sortBasedShuffleOn) { child.execute().mapPartitions { iter => iter.take(limit).map(row => (false, row.copy())) } } else { child.execute().mapPartitions { iter => - val mutablePair = new MutablePair[Boolean, Row]() + val mutablePair = new MutablePair[Boolean, InternalRow]() iter.take(limit).map(row => mutablePair.update(false, row)) } } val part = new HashPartitioner(1) - val shuffled = new ShuffledRDD[Boolean, Row, Row](rdd, part) + val shuffled = new ShuffledRDD[Boolean, InternalRow, InternalRow](rdd, part) shuffled.setSerializer(new SparkSqlSerializer(child.sqlContext.sparkContext.getConf)) shuffled.mapPartitions(_.take(limit).map(_._2)) } @@ -144,12 +147,18 @@ case class Limit(limit: Int, child: SparkPlan) /** * :: DeveloperApi :: - * Take the first limit elements as defined by the sortOrder. This is logically equivalent to - * having a [[Limit]] operator after a [[Sort]] operator. This could have been named TopK, but - * Spark's top operator does the opposite in ordering so we name it TakeOrdered to avoid confusion. + * Take the first limit elements as defined by the sortOrder, and do projection if needed. + * This is logically equivalent to having a [[Limit]] operator after a [[Sort]] operator, + * or having a [[Project]] operator between them. + * This could have been named TopK, but Spark's top operator does the opposite in ordering + * so we name it TakeOrdered to avoid confusion. */ @DeveloperApi -case class TakeOrdered(limit: Int, sortOrder: Seq[SortOrder], child: SparkPlan) extends UnaryNode { +case class TakeOrderedAndProject( + limit: Int, + sortOrder: Seq[SortOrder], + projectList: Option[Seq[NamedExpression]], + child: SparkPlan) extends UnaryNode { override def output: Seq[Attribute] = child.output @@ -157,7 +166,13 @@ case class TakeOrdered(limit: Int, sortOrder: Seq[SortOrder], child: SparkPlan) private val ord: RowOrdering = new RowOrdering(sortOrder, child.output) - private def collectData(): Array[Row] = child.execute().map(_.copy()).takeOrdered(limit)(ord) + // TODO: remove @transient after figure out how to clean closure at InsertIntoHiveTable. + @transient private val projection = projectList.map(new InterpretedProjection(_, child.output)) + + private def collectData(): Array[InternalRow] = { + val data = child.execute().map(_.copy()).takeOrdered(limit)(ord) + projection.map(data.map(_)).getOrElse(data) + } override def executeCollect(): Array[Row] = { val converter = CatalystTypeConverters.createToScalaConverter(schema) @@ -166,7 +181,7 @@ case class TakeOrdered(limit: Int, sortOrder: Seq[SortOrder], child: SparkPlan) // TODO: Terminal split should be implemented differently from non-terminal split. // TODO: Pick num splits based on |limit|. - protected override def doExecute(): RDD[Row] = sparkContext.makeRDD(collectData(), 1) + protected override def doExecute(): RDD[InternalRow] = sparkContext.makeRDD(collectData(), 1) override def outputOrdering: Seq[SortOrder] = sortOrder } @@ -186,7 +201,7 @@ case class Sort( override def requiredChildDistribution: Seq[Distribution] = if (global) OrderedDistribution(sortOrder) :: Nil else UnspecifiedDistribution :: Nil - protected override def doExecute(): RDD[Row] = attachTree(this, "sort") { + protected override def doExecute(): RDD[InternalRow] = attachTree(this, "sort") { child.execute().mapPartitions( { iterator => val ordering = newOrdering(sortOrder, child.output) iterator.map(_.copy()).toArray.sorted(ordering).iterator @@ -214,14 +229,14 @@ case class ExternalSort( override def requiredChildDistribution: Seq[Distribution] = if (global) OrderedDistribution(sortOrder) :: Nil else UnspecifiedDistribution :: Nil - protected override def doExecute(): RDD[Row] = attachTree(this, "sort") { + protected override def doExecute(): RDD[InternalRow] = attachTree(this, "sort") { child.execute().mapPartitions( { iterator => val ordering = newOrdering(sortOrder, child.output) - val sorter = new ExternalSorter[Row, Null, Row](ordering = Some(ordering)) + val sorter = new ExternalSorter[InternalRow, Null, InternalRow](ordering = Some(ordering)) sorter.insertAll(iterator.map(r => (r.copy, null))) val baseIterator = sorter.iterator.map(_._1) // TODO(marmbrus): The complex type signature below thwarts inference for no reason. - CompletionIterator[Row, Iterator[Row]](baseIterator, sorter.stop()) + CompletionIterator[InternalRow, Iterator[InternalRow]](baseIterator, sorter.stop()) }, preservesPartitioning = true) } @@ -230,37 +245,6 @@ case class ExternalSort( override def outputOrdering: Seq[SortOrder] = sortOrder } -/** - * :: DeveloperApi :: - * Computes the set of distinct input rows using a HashSet. - * @param partial when true the distinct operation is performed partially, per partition, without - * shuffling the data. - * @param child the input query plan. - */ -@DeveloperApi -case class Distinct(partial: Boolean, child: SparkPlan) extends UnaryNode { - override def output: Seq[Attribute] = child.output - - override def requiredChildDistribution: Seq[Distribution] = - if (partial) UnspecifiedDistribution :: Nil else ClusteredDistribution(child.output) :: Nil - - protected override def doExecute(): RDD[Row] = { - child.execute().mapPartitions { iter => - val hashSet = new scala.collection.mutable.HashSet[Row]() - - var currentRow: Row = null - while (iter.hasNext) { - currentRow = iter.next() - if (!hashSet.contains(currentRow)) { - hashSet.add(currentRow.copy()) - } - } - - hashSet.iterator - } - } -} - /** * :: DeveloperApi :: * Return a new RDD that has exactly `numPartitions` partitions. @@ -270,7 +254,7 @@ case class Repartition(numPartitions: Int, shuffle: Boolean, child: SparkPlan) extends UnaryNode { override def output: Seq[Attribute] = child.output - protected override def doExecute(): RDD[Row] = { + protected override def doExecute(): RDD[InternalRow] = { child.execute().map(_.copy()).coalesce(numPartitions, shuffle) } } @@ -285,7 +269,7 @@ case class Repartition(numPartitions: Int, shuffle: Boolean, child: SparkPlan) case class Except(left: SparkPlan, right: SparkPlan) extends BinaryNode { override def output: Seq[Attribute] = left.output - protected override def doExecute(): RDD[Row] = { + protected override def doExecute(): RDD[InternalRow] = { left.execute().map(_.copy()).subtract(right.execute().map(_.copy())) } } @@ -299,7 +283,7 @@ case class Except(left: SparkPlan, right: SparkPlan) extends BinaryNode { case class Intersect(left: SparkPlan, right: SparkPlan) extends BinaryNode { override def output: Seq[Attribute] = children.head.output - protected override def doExecute(): RDD[Row] = { + protected override def doExecute(): RDD[InternalRow] = { left.execute().map(_.copy()).intersection(right.execute().map(_.copy())) } } @@ -314,5 +298,5 @@ case class Intersect(left: SparkPlan, right: SparkPlan) extends BinaryNode { case class OutputFaker(output: Seq[Attribute], child: SparkPlan) extends SparkPlan { def children: Seq[SparkPlan] = child :: Nil - protected override def doExecute(): RDD[Row] = child.execute() + protected override def doExecute(): RDD[InternalRow] = child.execute() } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/commands.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/commands.scala index 49b361e96b2d..5e9951f248ff 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/commands.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/commands.scala @@ -17,16 +17,18 @@ package org.apache.spark.sql.execution +import java.util.NoSuchElementException + import org.apache.spark.Logging import org.apache.spark.annotation.DeveloperApi import org.apache.spark.rdd.RDD -import org.apache.spark.sql.catalyst.CatalystTypeConverters +import org.apache.spark.sql.catalyst.{InternalRow, CatalystTypeConverters} import org.apache.spark.sql.catalyst.errors.TreeNodeException -import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeReference, Row} +import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeReference} import org.apache.spark.sql.catalyst.plans.logical import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan import org.apache.spark.sql.types._ -import org.apache.spark.sql.{DataFrame, SQLConf, SQLContext} +import org.apache.spark.sql.{DataFrame, Row, SQLConf, SQLContext} /** * A logical command that is executed for its side-effects. `RunnableCommand`s are @@ -64,9 +66,9 @@ private[sql] case class ExecutedCommand(cmd: RunnableCommand) extends SparkPlan override def executeTake(limit: Int): Array[Row] = sideEffectResult.take(limit).toArray - protected override def doExecute(): RDD[Row] = { - val converted = sideEffectResult.map(r => - CatalystTypeConverters.convertToCatalyst(r, schema).asInstanceOf[Row]) + protected override def doExecute(): RDD[InternalRow] = { + val convert = CatalystTypeConverters.createToCatalystConverter(schema) + val converted = sideEffectResult.map(convert(_).asInstanceOf[InternalRow]) sqlContext.sparkContext.parallelize(converted, 1) } } @@ -75,48 +77,92 @@ private[sql] case class ExecutedCommand(cmd: RunnableCommand) extends SparkPlan * :: DeveloperApi :: */ @DeveloperApi -case class SetCommand( - kv: Option[(String, Option[String])], - override val output: Seq[Attribute]) - extends RunnableCommand with Logging { +case class SetCommand(kv: Option[(String, Option[String])]) extends RunnableCommand with Logging { + + private def keyValueOutput: Seq[Attribute] = { + val schema = StructType( + StructField("key", StringType, false) :: + StructField("value", StringType, false) :: Nil) + schema.toAttributes + } - override def run(sqlContext: SQLContext): Seq[Row] = kv match { + private val (_output, runFunc): (Seq[Attribute], SQLContext => Seq[Row]) = kv match { // Configures the deprecated "mapred.reduce.tasks" property. case Some((SQLConf.Deprecated.MAPRED_REDUCE_TASKS, Some(value))) => - logWarning( - s"Property ${SQLConf.Deprecated.MAPRED_REDUCE_TASKS} is deprecated, " + - s"automatically converted to ${SQLConf.SHUFFLE_PARTITIONS} instead.") - if (value.toInt < 1) { - val msg = s"Setting negative ${SQLConf.Deprecated.MAPRED_REDUCE_TASKS} for automatically " + - "determining the number of reducers is not supported." - throw new IllegalArgumentException(msg) - } else { - sqlContext.setConf(SQLConf.SHUFFLE_PARTITIONS, value) - Seq(Row(s"${SQLConf.SHUFFLE_PARTITIONS}=$value")) + val runFunc = (sqlContext: SQLContext) => { + logWarning( + s"Property ${SQLConf.Deprecated.MAPRED_REDUCE_TASKS} is deprecated, " + + s"automatically converted to ${SQLConf.SHUFFLE_PARTITIONS.key} instead.") + if (value.toInt < 1) { + val msg = + s"Setting negative ${SQLConf.Deprecated.MAPRED_REDUCE_TASKS} for automatically " + + "determining the number of reducers is not supported." + throw new IllegalArgumentException(msg) + } else { + sqlContext.setConf(SQLConf.SHUFFLE_PARTITIONS.key, value) + Seq(Row(SQLConf.SHUFFLE_PARTITIONS.key, value)) + } } + (keyValueOutput, runFunc) // Configures a single property. case Some((key, Some(value))) => - sqlContext.setConf(key, value) - Seq(Row(s"$key=$value")) + val runFunc = (sqlContext: SQLContext) => { + sqlContext.setConf(key, value) + Seq(Row(key, value)) + } + (keyValueOutput, runFunc) - // Queries all key-value pairs that are set in the SQLConf of the sqlContext. - // Notice that different from Hive, here "SET -v" is an alias of "SET". // (In Hive, "SET" returns all changed properties while "SET -v" returns all properties.) - case Some(("-v", None)) | None => - sqlContext.getAllConfs.map { case (k, v) => Row(s"$k=$v") }.toSeq + // Queries all key-value pairs that are set in the SQLConf of the sqlContext. + case None => + val runFunc = (sqlContext: SQLContext) => { + sqlContext.getAllConfs.map { case (k, v) => Row(k, v) }.toSeq + } + (keyValueOutput, runFunc) + + // Queries all properties along with their default values and docs that are defined in the + // SQLConf of the sqlContext. + case Some(("-v", None)) => + val runFunc = (sqlContext: SQLContext) => { + sqlContext.conf.getAllDefinedConfs.map { case (key, defaultValue, doc) => + Row(key, defaultValue, doc) + } + } + val schema = StructType( + StructField("key", StringType, false) :: + StructField("default", StringType, false) :: + StructField("meaning", StringType, false) :: Nil) + (schema.toAttributes, runFunc) // Queries the deprecated "mapred.reduce.tasks" property. case Some((SQLConf.Deprecated.MAPRED_REDUCE_TASKS, None)) => - logWarning( - s"Property ${SQLConf.Deprecated.MAPRED_REDUCE_TASKS} is deprecated, " + - s"showing ${SQLConf.SHUFFLE_PARTITIONS} instead.") - Seq(Row(s"${SQLConf.SHUFFLE_PARTITIONS}=${sqlContext.conf.numShufflePartitions}")) + val runFunc = (sqlContext: SQLContext) => { + logWarning( + s"Property ${SQLConf.Deprecated.MAPRED_REDUCE_TASKS} is deprecated, " + + s"showing ${SQLConf.SHUFFLE_PARTITIONS.key} instead.") + Seq(Row(SQLConf.SHUFFLE_PARTITIONS.key, sqlContext.conf.numShufflePartitions.toString)) + } + (keyValueOutput, runFunc) // Queries a single property. case Some((key, None)) => - Seq(Row(s"$key=${sqlContext.getConf(key, "")}")) + val runFunc = (sqlContext: SQLContext) => { + val value = + try { + sqlContext.getConf(key) + } catch { + case _: NoSuchElementException => "" + } + Seq(Row(key, value)) + } + (keyValueOutput, runFunc) } + + override val output: Seq[Attribute] = _output + + override def run(sqlContext: SQLContext): Seq[Row] = runFunc(sqlContext) + } /** diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/debug/package.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/debug/package.scala index dffb265601bd..2964edac1aba 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/debug/package.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/debug/package.scala @@ -18,13 +18,15 @@ package org.apache.spark.sql.execution import org.apache.spark.rdd.RDD +import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions.Attribute +import org.apache.spark.unsafe.types.UTF8String import scala.collection.mutable.HashSet import org.apache.spark.{AccumulatorParam, Accumulator} import org.apache.spark.annotation.DeveloperApi -import org.apache.spark.sql.{SQLConf, SQLContext, DataFrame, Row} +import org.apache.spark.sql._ import org.apache.spark.sql.catalyst.trees.TreeNodeRef import org.apache.spark.sql.types._ @@ -46,7 +48,7 @@ package object debug { */ implicit class DebugSQLContext(sqlContext: SQLContext) { def debug(): Unit = { - sqlContext.setConf(SQLConf.DATAFRAME_EAGER_ANALYSIS, "false") + sqlContext.setConf(SQLConf.DATAFRAME_EAGER_ANALYSIS, false) } } @@ -125,11 +127,11 @@ package object debug { } } - protected override def doExecute(): RDD[Row] = { + protected override def doExecute(): RDD[InternalRow] = { child.execute().mapPartitions { iter => - new Iterator[Row] { + new Iterator[InternalRow] { def hasNext: Boolean = iter.hasNext - def next(): Row = { + def next(): InternalRow = { val currentRow = iter.next() tupleCount += 1 var i = 0 @@ -154,7 +156,7 @@ package object debug { def typeCheck(data: Any, schema: DataType): Unit = (data, schema) match { case (null, _) => - case (row: Row, StructType(fields)) => + case (row: InternalRow, StructType(fields)) => row.toSeq.zip(fields.map(_.dataType)).foreach { case(d, t) => typeCheck(d, t) } case (s: Seq[_], ArrayType(elemType, _)) => s.foreach(typeCheck(_, elemType)) @@ -170,6 +172,8 @@ package object debug { case (_: Short, ShortType) => case (_: Boolean, BooleanType) => case (_: Double, DoubleType) => + case (_: Int, DateType) => + case (_: Long, TimestampType) => case (v, udt: UserDefinedType[_]) => typeCheck(v, udt.sqlType) case (d, t) => sys.error(s"Invalid data found: got $d (${d.getClass}) expected $t") @@ -193,7 +197,7 @@ package object debug { def children: List[SparkPlan] = child :: Nil - protected override def doExecute(): RDD[Row] = { + protected override def doExecute(): RDD[InternalRow] = { child.execute().map { row => try typeCheck(row, child.schema) catch { case e: Exception => diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/expressions/MonotonicallyIncreasingID.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/expressions/MonotonicallyIncreasingID.scala index 9ac732b55b18..68914cf85cb5 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/expressions/MonotonicallyIncreasingID.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/expressions/MonotonicallyIncreasingID.scala @@ -18,7 +18,7 @@ package org.apache.spark.sql.execution.expressions import org.apache.spark.TaskContext -import org.apache.spark.sql.catalyst.expressions.{Row, LeafExpression} +import org.apache.spark.sql.catalyst.expressions.{InternalRow, LeafExpression} import org.apache.spark.sql.types.{LongType, DataType} /** @@ -39,13 +39,11 @@ private[sql] case class MonotonicallyIncreasingID() extends LeafExpression { */ @transient private[this] var count: Long = 0L - override type EvaluatedType = Long - override def nullable: Boolean = false override def dataType: DataType = LongType - override def eval(input: Row): Long = { + override def eval(input: InternalRow): Long = { val currentCount = count count += 1 (TaskContext.get().partitionId().toLong << 33) + currentCount diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/expressions/SparkPartitionID.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/expressions/SparkPartitionID.scala index c2c6cbd49159..12c2eed0d6b7 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/expressions/SparkPartitionID.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/expressions/SparkPartitionID.scala @@ -18,7 +18,7 @@ package org.apache.spark.sql.execution.expressions import org.apache.spark.TaskContext -import org.apache.spark.sql.catalyst.expressions.{LeafExpression, Row} +import org.apache.spark.sql.catalyst.expressions.{LeafExpression, InternalRow} import org.apache.spark.sql.types.{IntegerType, DataType} @@ -27,11 +27,9 @@ import org.apache.spark.sql.types.{IntegerType, DataType} */ private[sql] case object SparkPartitionID extends LeafExpression { - override type EvaluatedType = Int - override def nullable: Boolean = false override def dataType: DataType = IntegerType - override def eval(input: Row): Int = TaskContext.get().partitionId() + override def eval(input: InternalRow): Int = TaskContext.get().partitionId() } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashJoin.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashJoin.scala index 05dd5681edfa..2d2e1b92b86b 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashJoin.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashJoin.scala @@ -17,16 +17,15 @@ package org.apache.spark.sql.execution.joins -import org.apache.spark.rdd.RDD - import scala.concurrent._ import scala.concurrent.duration._ -import scala.concurrent.ExecutionContext.Implicits.global import org.apache.spark.annotation.DeveloperApi -import org.apache.spark.sql.catalyst.expressions.{Row, Expression} +import org.apache.spark.rdd.RDD +import org.apache.spark.sql.catalyst.expressions.{Expression, InternalRow} import org.apache.spark.sql.catalyst.plans.physical.{Distribution, Partitioning, UnspecifiedDistribution} import org.apache.spark.sql.execution.{BinaryNode, SparkPlan} +import org.apache.spark.util.ThreadUtils /** * :: DeveloperApi :: @@ -61,12 +60,12 @@ case class BroadcastHashJoin( @transient private val broadcastFuture = future { // Note that we use .execute().collect() because we don't want to convert data to Scala types - val input: Array[Row] = buildPlan.execute().map(_.copy()).collect() + val input: Array[InternalRow] = buildPlan.execute().map(_.copy()).collect() val hashed = HashedRelation(input.iterator, buildSideKeyGenerator, input.length) sparkContext.broadcast(hashed) - } + }(BroadcastHashJoin.broadcastHashJoinExecutionContext) - protected override def doExecute(): RDD[Row] = { + protected override def doExecute(): RDD[InternalRow] = { val broadcastRelation = Await.result(broadcastFuture, timeout) streamedPlan.execute().mapPartitions { streamedIter => @@ -74,3 +73,9 @@ case class BroadcastHashJoin( } } } + +object BroadcastHashJoin { + + private val broadcastHashJoinExecutionContext = ExecutionContext.fromExecutorService( + ThreadUtils.newDaemonCachedThreadPool("broadcast-hash-join", 128)) +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashOuterJoin.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashOuterJoin.scala new file mode 100644 index 000000000000..5da04c78744d --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashOuterJoin.scala @@ -0,0 +1,121 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.joins + +import org.apache.spark.annotation.DeveloperApi +import org.apache.spark.rdd.RDD +import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.catalyst.plans.physical.{Distribution, UnspecifiedDistribution} +import org.apache.spark.sql.catalyst.plans.{JoinType, LeftOuter, RightOuter} +import org.apache.spark.sql.execution.{BinaryNode, SparkPlan} +import org.apache.spark.util.ThreadUtils + +import scala.collection.JavaConversions._ +import scala.concurrent._ +import scala.concurrent.duration._ + +/** + * :: DeveloperApi :: + * Performs a outer hash join for two child relations. When the output RDD of this operator is + * being constructed, a Spark job is asynchronously started to calculate the values for the + * broadcasted relation. This data is then placed in a Spark broadcast variable. The streamed + * relation is not shuffled. + */ +@DeveloperApi +case class BroadcastHashOuterJoin( + leftKeys: Seq[Expression], + rightKeys: Seq[Expression], + joinType: JoinType, + condition: Option[Expression], + left: SparkPlan, + right: SparkPlan) extends BinaryNode with HashOuterJoin { + + val timeout = { + val timeoutValue = sqlContext.conf.broadcastTimeout + if (timeoutValue < 0) { + Duration.Inf + } else { + timeoutValue.seconds + } + } + + override def requiredChildDistribution: Seq[Distribution] = + UnspecifiedDistribution :: UnspecifiedDistribution :: Nil + + private[this] lazy val (buildPlan, streamedPlan) = joinType match { + case RightOuter => (left, right) + case LeftOuter => (right, left) + case x => + throw new IllegalArgumentException( + s"BroadcastHashOuterJoin should not take $x as the JoinType") + } + + private[this] lazy val (buildKeys, streamedKeys) = joinType match { + case RightOuter => (leftKeys, rightKeys) + case LeftOuter => (rightKeys, leftKeys) + case x => + throw new IllegalArgumentException( + s"BroadcastHashOuterJoin should not take $x as the JoinType") + } + + @transient + private val broadcastFuture = future { + // Note that we use .execute().collect() because we don't want to convert data to Scala types + val input: Array[InternalRow] = buildPlan.execute().map(_.copy()).collect() + // buildHashTable uses code-generated rows as keys, which are not serializable + val hashed = + buildHashTable(input.iterator, new InterpretedProjection(buildKeys, buildPlan.output)) + sparkContext.broadcast(hashed) + }(BroadcastHashOuterJoin.broadcastHashOuterJoinExecutionContext) + + override def doExecute(): RDD[InternalRow] = { + val broadcastRelation = Await.result(broadcastFuture, timeout) + + streamedPlan.execute().mapPartitions { streamedIter => + val joinedRow = new JoinedRow() + val hashTable = broadcastRelation.value + val keyGenerator = newProjection(streamedKeys, streamedPlan.output) + + joinType match { + case LeftOuter => + streamedIter.flatMap(currentRow => { + val rowKey = keyGenerator(currentRow) + joinedRow.withLeft(currentRow) + leftOuterIterator(rowKey, joinedRow, hashTable.getOrElse(rowKey, EMPTY_LIST)) + }) + + case RightOuter => + streamedIter.flatMap(currentRow => { + val rowKey = keyGenerator(currentRow) + joinedRow.withRight(currentRow) + rightOuterIterator(rowKey, hashTable.getOrElse(rowKey, EMPTY_LIST), joinedRow) + }) + + case x => + throw new IllegalArgumentException( + s"BroadcastHashOuterJoin should not take $x as the JoinType") + } + } + } +} + +object BroadcastHashOuterJoin { + + private val broadcastHashOuterJoinExecutionContext = ExecutionContext.fromExecutorService( + ThreadUtils.newDaemonCachedThreadPool("broadcast-hash-outer-join", 128)) +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastLeftSemiJoinHash.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastLeftSemiJoinHash.scala index 640fc26ba3ba..412a3d4178e1 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastLeftSemiJoinHash.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastLeftSemiJoinHash.scala @@ -19,7 +19,7 @@ package org.apache.spark.sql.execution.joins import org.apache.spark.annotation.DeveloperApi import org.apache.spark.rdd.RDD -import org.apache.spark.sql.catalyst.expressions.{Attribute, Expression, Row} +import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.execution.{BinaryNode, SparkPlan} /** @@ -38,10 +38,10 @@ case class BroadcastLeftSemiJoinHash( override def output: Seq[Attribute] = left.output - protected override def doExecute(): RDD[Row] = { - val buildIter= buildPlan.execute().map(_.copy()).collect().toIterator - val hashSet = new java.util.HashSet[Row]() - var currentRow: Row = null + protected override def doExecute(): RDD[InternalRow] = { + val buildIter = buildPlan.execute().map(_.copy()).collect().toIterator + val hashSet = new java.util.HashSet[InternalRow]() + var currentRow: InternalRow = null // Create a Hash set of buildKeys while (buildIter.hasNext) { @@ -50,7 +50,8 @@ case class BroadcastLeftSemiJoinHash( if (!rowKey.anyNull) { val keyExists = hashSet.contains(rowKey) if (!keyExists) { - hashSet.add(rowKey) + // rowKey may be not serializable (from codegen) + hashSet.add(rowKey.copy()) } } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastNestedLoopJoin.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastNestedLoopJoin.scala index caad3dfbe1c5..0b2cf8e12a6c 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastNestedLoopJoin.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastNestedLoopJoin.scala @@ -61,13 +61,14 @@ case class BroadcastNestedLoopJoin( @transient private lazy val boundCondition = newPredicate(condition.getOrElse(Literal(true)), left.output ++ right.output) - protected override def doExecute(): RDD[Row] = { + protected override def doExecute(): RDD[InternalRow] = { val broadcastedRelation = - sparkContext.broadcast(broadcast.execute().map(_.copy()).collect().toIndexedSeq) + sparkContext.broadcast(broadcast.execute().map(_.copy()) + .collect().toIndexedSeq) /** All rows that either match both-way, or rows from streamed joined with nulls. */ val matchesOrStreamedRowsWithNulls = streamed.execute().mapPartitions { streamedIter => - val matchedRows = new CompactBuffer[Row] + val matchedRows = new CompactBuffer[InternalRow] // TODO: Use Spark's BitSet. val includedBroadcastTuples = new scala.collection.mutable.BitSet(broadcastedRelation.value.size) @@ -118,8 +119,8 @@ case class BroadcastNestedLoopJoin( val leftNulls = new GenericMutableRow(left.output.size) val rightNulls = new GenericMutableRow(right.output.size) /** Rows from broadcasted joined with nulls. */ - val broadcastRowsWithNulls: Seq[Row] = { - val buf: CompactBuffer[Row] = new CompactBuffer() + val broadcastRowsWithNulls: Seq[InternalRow] = { + val buf: CompactBuffer[InternalRow] = new CompactBuffer() var i = 0 val rel = broadcastedRelation.value while (i < rel.length) { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/CartesianProduct.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/CartesianProduct.scala index 191c00cb55da..261b4724159f 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/CartesianProduct.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/CartesianProduct.scala @@ -19,7 +19,7 @@ package org.apache.spark.sql.execution.joins import org.apache.spark.annotation.DeveloperApi import org.apache.spark.rdd.RDD -import org.apache.spark.sql.Row +import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions.{Attribute, JoinedRow} import org.apache.spark.sql.execution.{BinaryNode, SparkPlan} @@ -30,7 +30,7 @@ import org.apache.spark.sql.execution.{BinaryNode, SparkPlan} case class CartesianProduct(left: SparkPlan, right: SparkPlan) extends BinaryNode { override def output: Seq[Attribute] = left.output ++ right.output - protected override def doExecute(): RDD[Row] = { + protected override def doExecute(): RDD[InternalRow] = { val leftResults = left.execute().map(_.copy()) val rightResults = right.execute().map(_.copy()) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashJoin.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashJoin.scala index 851de1685509..3a4196a90d14 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashJoin.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashJoin.scala @@ -49,11 +49,13 @@ trait HashJoin { @transient protected lazy val streamSideKeyGenerator: () => MutableProjection = newMutableProjection(streamedKeys, streamedPlan.output) - protected def hashJoin(streamIter: Iterator[Row], hashedRelation: HashedRelation): Iterator[Row] = + protected def hashJoin( + streamIter: Iterator[InternalRow], + hashedRelation: HashedRelation): Iterator[InternalRow] = { - new Iterator[Row] { - private[this] var currentStreamedRow: Row = _ - private[this] var currentHashMatches: CompactBuffer[Row] = _ + new Iterator[InternalRow] { + private[this] var currentStreamedRow: InternalRow = _ + private[this] var currentHashMatches: CompactBuffer[InternalRow] = _ private[this] var currentMatchPosition: Int = -1 // Mutable per row objects. @@ -65,7 +67,7 @@ trait HashJoin { (currentMatchPosition != -1 && currentMatchPosition < currentHashMatches.size) || (streamIter.hasNext && fetchNext()) - override final def next(): Row = { + override final def next(): InternalRow = { val ret = buildSide match { case BuildRight => joinRow(currentStreamedRow, currentHashMatches(currentMatchPosition)) case BuildLeft => joinRow(currentHashMatches(currentMatchPosition), currentStreamedRow) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashOuterJoin.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashOuterJoin.scala index 45574392996c..886b5fa0c510 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashOuterJoin.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashOuterJoin.scala @@ -19,41 +19,32 @@ package org.apache.spark.sql.execution.joins import java.util.{HashMap => JavaHashMap} -import org.apache.spark.rdd.RDD - -import scala.collection.JavaConversions._ - import org.apache.spark.annotation.DeveloperApi import org.apache.spark.sql.catalyst.expressions._ -import org.apache.spark.sql.catalyst.plans.physical.{ClusteredDistribution, Partitioning, UnknownPartitioning} +import org.apache.spark.sql.catalyst.plans.physical.{Partitioning, UnknownPartitioning} import org.apache.spark.sql.catalyst.plans.{FullOuter, JoinType, LeftOuter, RightOuter} -import org.apache.spark.sql.execution.{BinaryNode, SparkPlan} +import org.apache.spark.sql.execution.SparkPlan import org.apache.spark.util.collection.CompactBuffer -/** - * :: DeveloperApi :: - * Performs a hash based outer join for two child relations by shuffling the data using - * the join keys. This operator requires loading the associated partition in both side into memory. - */ @DeveloperApi -case class HashOuterJoin( - leftKeys: Seq[Expression], - rightKeys: Seq[Expression], - joinType: JoinType, - condition: Option[Expression], - left: SparkPlan, - right: SparkPlan) extends BinaryNode { - - override def outputPartitioning: Partitioning = joinType match { +trait HashOuterJoin { + self: SparkPlan => + + val leftKeys: Seq[Expression] + val rightKeys: Seq[Expression] + val joinType: JoinType + val condition: Option[Expression] + val left: SparkPlan + val right: SparkPlan + +override def outputPartitioning: Partitioning = joinType match { case LeftOuter => left.outputPartitioning case RightOuter => right.outputPartitioning case FullOuter => UnknownPartitioning(left.outputPartitioning.numPartitions) - case x => throw new Exception(s"HashOuterJoin should not take $x as the JoinType") + case x => + throw new IllegalArgumentException(s"HashOuterJoin should not take $x as the JoinType") } - override def requiredChildDistribution: Seq[ClusteredDistribution] = - ClusteredDistribution(leftKeys) :: ClusteredDistribution(rightKeys) :: Nil - override def output: Seq[Attribute] = { joinType match { case LeftOuter => @@ -63,29 +54,32 @@ case class HashOuterJoin( case FullOuter => left.output.map(_.withNullability(true)) ++ right.output.map(_.withNullability(true)) case x => - throw new Exception(s"HashOuterJoin should not take $x as the JoinType") + throw new IllegalArgumentException(s"HashOuterJoin should not take $x as the JoinType") } } - @transient private[this] lazy val DUMMY_LIST = Seq[Row](null) - @transient private[this] lazy val EMPTY_LIST = Seq.empty[Row] + @transient private[this] lazy val DUMMY_LIST = CompactBuffer[InternalRow](null) + @transient protected[this] lazy val EMPTY_LIST = CompactBuffer[InternalRow]() - @transient private[this] lazy val leftNullRow = new GenericRow(left.output.length) - @transient private[this] lazy val rightNullRow = new GenericRow(right.output.length) + @transient private[this] lazy val leftNullRow = new GenericInternalRow(left.output.length) + @transient private[this] lazy val rightNullRow = new GenericInternalRow(right.output.length) @transient private[this] lazy val boundCondition = - condition.map(newPredicate(_, left.output ++ right.output)).getOrElse((row: Row) => true) + condition.map( + newPredicate(_, left.output ++ right.output)).getOrElse((row: InternalRow) => true) // TODO we need to rewrite all of the iterators with our own implementation instead of the Scala // iterator for performance purpose. - private[this] def leftOuterIterator( - key: Row, joinedRow: JoinedRow, rightIter: Iterable[Row]): Iterator[Row] = { - val ret: Iterable[Row] = { + protected[this] def leftOuterIterator( + key: InternalRow, + joinedRow: JoinedRow, + rightIter: Iterable[InternalRow]): Iterator[InternalRow] = { + val ret: Iterable[InternalRow] = { if (!key.anyNull) { val temp = rightIter.collect { case r if boundCondition(joinedRow.withRight(r)) => joinedRow.copy() } - if (temp.size == 0) { + if (temp.isEmpty) { joinedRow.withRight(rightNullRow).copy :: Nil } else { temp @@ -97,15 +91,17 @@ case class HashOuterJoin( ret.iterator } - private[this] def rightOuterIterator( - key: Row, leftIter: Iterable[Row], joinedRow: JoinedRow): Iterator[Row] = { - - val ret: Iterable[Row] = { + protected[this] def rightOuterIterator( + key: InternalRow, + leftIter: Iterable[InternalRow], + joinedRow: JoinedRow): Iterator[InternalRow] = { + val ret: Iterable[InternalRow] = { if (!key.anyNull) { val temp = leftIter.collect { - case l if boundCondition(joinedRow.withLeft(l)) => joinedRow.copy + case l if boundCondition(joinedRow.withLeft(l)) => + joinedRow.copy() } - if (temp.size == 0) { + if (temp.isEmpty) { joinedRow.withLeft(leftNullRow).copy :: Nil } else { temp @@ -117,15 +113,14 @@ case class HashOuterJoin( ret.iterator } - private[this] def fullOuterIterator( - key: Row, leftIter: Iterable[Row], rightIter: Iterable[Row], - joinedRow: JoinedRow): Iterator[Row] = { - + protected[this] def fullOuterIterator( + key: InternalRow, leftIter: Iterable[InternalRow], rightIter: Iterable[InternalRow], + joinedRow: JoinedRow): Iterator[InternalRow] = { if (!key.anyNull) { // Store the positions of records in right, if one of its associated row satisfy // the join condition. val rightMatchedSet = scala.collection.mutable.Set[Int]() - leftIter.iterator.flatMap[Row] { l => + leftIter.iterator.flatMap[InternalRow] { l => joinedRow.withLeft(l) var matched = false rightIter.zipWithIndex.collect { @@ -156,24 +151,25 @@ case class HashOuterJoin( joinedRow(leftNullRow, r).copy() } } else { - leftIter.iterator.map[Row] { l => + leftIter.iterator.map[InternalRow] { l => joinedRow(l, rightNullRow).copy() - } ++ rightIter.iterator.map[Row] { r => + } ++ rightIter.iterator.map[InternalRow] { r => joinedRow(leftNullRow, r).copy() } } } - private[this] def buildHashTable( - iter: Iterator[Row], keyGenerator: Projection): JavaHashMap[Row, CompactBuffer[Row]] = { - val hashTable = new JavaHashMap[Row, CompactBuffer[Row]]() + protected[this] def buildHashTable( + iter: Iterator[InternalRow], + keyGenerator: Projection): JavaHashMap[InternalRow, CompactBuffer[InternalRow]] = { + val hashTable = new JavaHashMap[InternalRow, CompactBuffer[InternalRow]]() while (iter.hasNext) { val currentRow = iter.next() val rowKey = keyGenerator(currentRow) var existingMatchList = hashTable.get(rowKey) if (existingMatchList == null) { - existingMatchList = new CompactBuffer[Row]() + existingMatchList = new CompactBuffer[InternalRow]() hashTable.put(rowKey, existingMatchList) } @@ -182,42 +178,4 @@ case class HashOuterJoin( hashTable } - - protected override def doExecute(): RDD[Row] = { - val joinedRow = new JoinedRow() - left.execute().zipPartitions(right.execute()) { (leftIter, rightIter) => - // TODO this probably can be replaced by external sort (sort merged join?) - - joinType match { - case LeftOuter => - val rightHashTable = buildHashTable(rightIter, newProjection(rightKeys, right.output)) - val keyGenerator = newProjection(leftKeys, left.output) - leftIter.flatMap( currentRow => { - val rowKey = keyGenerator(currentRow) - joinedRow.withLeft(currentRow) - leftOuterIterator(rowKey, joinedRow, rightHashTable.getOrElse(rowKey, EMPTY_LIST)) - }) - - case RightOuter => - val leftHashTable = buildHashTable(leftIter, newProjection(leftKeys, left.output)) - val keyGenerator = newProjection(rightKeys, right.output) - rightIter.flatMap ( currentRow => { - val rowKey = keyGenerator(currentRow) - joinedRow.withRight(currentRow) - rightOuterIterator(rowKey, leftHashTable.getOrElse(rowKey, EMPTY_LIST), joinedRow) - }) - - case FullOuter => - val leftHashTable = buildHashTable(leftIter, newProjection(leftKeys, left.output)) - val rightHashTable = buildHashTable(rightIter, newProjection(rightKeys, right.output)) - (leftHashTable.keySet ++ rightHashTable.keySet).iterator.flatMap { key => - fullOuterIterator(key, - leftHashTable.getOrElse(key, EMPTY_LIST), - rightHashTable.getOrElse(key, EMPTY_LIST), joinedRow) - } - - case x => throw new Exception(s"HashOuterJoin should not take $x as the JoinType") - } - } - } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashedRelation.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashedRelation.scala index ab84c123e0c0..e18c81797513 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashedRelation.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashedRelation.scala @@ -20,7 +20,7 @@ package org.apache.spark.sql.execution.joins import java.io.{ObjectInput, ObjectOutput, Externalizable} import java.util.{HashMap => JavaHashMap} -import org.apache.spark.sql.catalyst.expressions.{Projection, Row} +import org.apache.spark.sql.catalyst.expressions.{Projection, InternalRow} import org.apache.spark.sql.execution.SparkSqlSerializer import org.apache.spark.util.collection.CompactBuffer @@ -30,7 +30,7 @@ import org.apache.spark.util.collection.CompactBuffer * object. */ private[joins] sealed trait HashedRelation { - def get(key: Row): CompactBuffer[Row] + def get(key: InternalRow): CompactBuffer[InternalRow] // This is a helper method to implement Externalizable, and is used by // GeneralHashedRelation and UniqueKeyHashedRelation @@ -54,12 +54,12 @@ private[joins] sealed trait HashedRelation { * A general [[HashedRelation]] backed by a hash map that maps the key into a sequence of values. */ private[joins] final class GeneralHashedRelation( - private var hashTable: JavaHashMap[Row, CompactBuffer[Row]]) + private var hashTable: JavaHashMap[InternalRow, CompactBuffer[InternalRow]]) extends HashedRelation with Externalizable { def this() = this(null) // Needed for serialization - override def get(key: Row): CompactBuffer[Row] = hashTable.get(key) + override def get(key: InternalRow): CompactBuffer[InternalRow] = hashTable.get(key) override def writeExternal(out: ObjectOutput): Unit = { writeBytes(out, SparkSqlSerializer.serialize(hashTable)) @@ -75,17 +75,18 @@ private[joins] final class GeneralHashedRelation( * A specialized [[HashedRelation]] that maps key into a single value. This implementation * assumes the key is unique. */ -private[joins] final class UniqueKeyHashedRelation(private var hashTable: JavaHashMap[Row, Row]) +private[joins] +final class UniqueKeyHashedRelation(private var hashTable: JavaHashMap[InternalRow, InternalRow]) extends HashedRelation with Externalizable { def this() = this(null) // Needed for serialization - override def get(key: Row): CompactBuffer[Row] = { + override def get(key: InternalRow): CompactBuffer[InternalRow] = { val v = hashTable.get(key) if (v eq null) null else CompactBuffer(v) } - def getValue(key: Row): Row = hashTable.get(key) + def getValue(key: InternalRow): InternalRow = hashTable.get(key) override def writeExternal(out: ObjectOutput): Unit = { writeBytes(out, SparkSqlSerializer.serialize(hashTable)) @@ -103,13 +104,13 @@ private[joins] final class UniqueKeyHashedRelation(private var hashTable: JavaHa private[joins] object HashedRelation { def apply( - input: Iterator[Row], + input: Iterator[InternalRow], keyGenerator: Projection, sizeEstimate: Int = 64): HashedRelation = { // TODO: Use Spark's HashMap implementation. - val hashTable = new JavaHashMap[Row, CompactBuffer[Row]](sizeEstimate) - var currentRow: Row = null + val hashTable = new JavaHashMap[InternalRow, CompactBuffer[InternalRow]](sizeEstimate) + var currentRow: InternalRow = null // Whether the join key is unique. If the key is unique, we can convert the underlying // hash map into one specialized for this. @@ -122,7 +123,7 @@ private[joins] object HashedRelation { if (!rowKey.anyNull) { val existingMatchList = hashTable.get(rowKey) val matchList = if (existingMatchList == null) { - val newMatchList = new CompactBuffer[Row]() + val newMatchList = new CompactBuffer[InternalRow]() hashTable.put(rowKey, newMatchList) newMatchList } else { @@ -134,7 +135,7 @@ private[joins] object HashedRelation { } if (keyIsUnique) { - val uniqHashTable = new JavaHashMap[Row, Row](hashTable.size) + val uniqHashTable = new JavaHashMap[InternalRow, InternalRow](hashTable.size) val iter = hashTable.entrySet().iterator() while (iter.hasNext) { val entry = iter.next() diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/LeftSemiJoinBNL.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/LeftSemiJoinBNL.scala index 036423e6faea..2a6d4d1ab08b 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/LeftSemiJoinBNL.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/LeftSemiJoinBNL.scala @@ -47,7 +47,7 @@ case class LeftSemiJoinBNL( @transient private lazy val boundCondition = newPredicate(condition.getOrElse(Literal(true)), left.output ++ right.output) - protected override def doExecute(): RDD[Row] = { + protected override def doExecute(): RDD[InternalRow] = { val broadcastedRelation = sparkContext.broadcast(broadcast.execute().map(_.copy()).collect().toIndexedSeq) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/LeftSemiJoinHash.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/LeftSemiJoinHash.scala index 8ad27eae80ff..20d74270afb4 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/LeftSemiJoinHash.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/LeftSemiJoinHash.scala @@ -19,7 +19,7 @@ package org.apache.spark.sql.execution.joins import org.apache.spark.annotation.DeveloperApi import org.apache.spark.rdd.RDD -import org.apache.spark.sql.catalyst.expressions.{Attribute, Expression, Row} +import org.apache.spark.sql.catalyst.expressions.{Attribute, Expression, InternalRow} import org.apache.spark.sql.catalyst.plans.physical.ClusteredDistribution import org.apache.spark.sql.execution.{BinaryNode, SparkPlan} @@ -42,10 +42,10 @@ case class LeftSemiJoinHash( override def output: Seq[Attribute] = left.output - protected override def doExecute(): RDD[Row] = { + protected override def doExecute(): RDD[InternalRow] = { buildPlan.execute().zipPartitions(streamedPlan.execute()) { (buildIter, streamIter) => - val hashSet = new java.util.HashSet[Row]() - var currentRow: Row = null + val hashSet = new java.util.HashSet[InternalRow]() + var currentRow: InternalRow = null // Create a Hash set of buildKeys while (buildIter.hasNext) { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/ShuffledHashJoin.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/ShuffledHashJoin.scala index 219525d9d85f..5439e10a60b2 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/ShuffledHashJoin.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/ShuffledHashJoin.scala @@ -19,7 +19,7 @@ package org.apache.spark.sql.execution.joins import org.apache.spark.annotation.DeveloperApi import org.apache.spark.rdd.RDD -import org.apache.spark.sql.Row +import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions.Expression import org.apache.spark.sql.catalyst.plans.physical.{ClusteredDistribution, Partitioning} import org.apache.spark.sql.execution.{BinaryNode, SparkPlan} @@ -43,7 +43,7 @@ case class ShuffledHashJoin( override def requiredChildDistribution: Seq[ClusteredDistribution] = ClusteredDistribution(leftKeys) :: ClusteredDistribution(rightKeys) :: Nil - protected override def doExecute(): RDD[Row] = { + protected override def doExecute(): RDD[InternalRow] = { buildPlan.execute().zipPartitions(streamedPlan.execute()) { (buildIter, streamIter) => val hashed = HashedRelation(buildIter, buildSideKeyGenerator) hashJoin(streamIter, hashed) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/ShuffledHashOuterJoin.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/ShuffledHashOuterJoin.scala new file mode 100644 index 000000000000..cfc9c14aaa36 --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/ShuffledHashOuterJoin.scala @@ -0,0 +1,85 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.joins + +import org.apache.spark.annotation.DeveloperApi +import org.apache.spark.rdd.RDD +import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.catalyst.plans.physical.{Distribution, ClusteredDistribution} +import org.apache.spark.sql.catalyst.plans.{FullOuter, JoinType, LeftOuter, RightOuter} +import org.apache.spark.sql.execution.{BinaryNode, SparkPlan} + +import scala.collection.JavaConversions._ + +/** + * :: DeveloperApi :: + * Performs a hash based outer join for two child relations by shuffling the data using + * the join keys. This operator requires loading the associated partition in both side into memory. + */ +@DeveloperApi +case class ShuffledHashOuterJoin( + leftKeys: Seq[Expression], + rightKeys: Seq[Expression], + joinType: JoinType, + condition: Option[Expression], + left: SparkPlan, + right: SparkPlan) extends BinaryNode with HashOuterJoin { + + override def requiredChildDistribution: Seq[Distribution] = + ClusteredDistribution(leftKeys) :: ClusteredDistribution(rightKeys) :: Nil + + protected override def doExecute(): RDD[InternalRow] = { + val joinedRow = new JoinedRow() + left.execute().zipPartitions(right.execute()) { (leftIter, rightIter) => + // TODO this probably can be replaced by external sort (sort merged join?) + joinType match { + case LeftOuter => + val rightHashTable = buildHashTable(rightIter, newProjection(rightKeys, right.output)) + val keyGenerator = newProjection(leftKeys, left.output) + leftIter.flatMap( currentRow => { + val rowKey = keyGenerator(currentRow) + joinedRow.withLeft(currentRow) + leftOuterIterator(rowKey, joinedRow, rightHashTable.getOrElse(rowKey, EMPTY_LIST)) + }) + + case RightOuter => + val leftHashTable = buildHashTable(leftIter, newProjection(leftKeys, left.output)) + val keyGenerator = newProjection(rightKeys, right.output) + rightIter.flatMap ( currentRow => { + val rowKey = keyGenerator(currentRow) + joinedRow.withRight(currentRow) + rightOuterIterator(rowKey, leftHashTable.getOrElse(rowKey, EMPTY_LIST), joinedRow) + }) + + case FullOuter => + val leftHashTable = buildHashTable(leftIter, newProjection(leftKeys, left.output)) + val rightHashTable = buildHashTable(rightIter, newProjection(rightKeys, right.output)) + (leftHashTable.keySet ++ rightHashTable.keySet).iterator.flatMap { key => + fullOuterIterator(key, + leftHashTable.getOrElse(key, EMPTY_LIST), + rightHashTable.getOrElse(key, EMPTY_LIST), + joinedRow) + } + + case x => + throw new IllegalArgumentException( + s"ShuffledHashOuterJoin should not take $x as the JoinType") + } + } + } +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoin.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoin.scala index 1a39fb4b9660..2abe65a71813 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoin.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoin.scala @@ -21,9 +21,7 @@ import java.util.NoSuchElementException import org.apache.spark.annotation.DeveloperApi import org.apache.spark.rdd.RDD -import org.apache.spark.sql.Row import org.apache.spark.sql.catalyst.expressions._ -import org.apache.spark.sql.catalyst.plans._ import org.apache.spark.sql.catalyst.plans.physical._ import org.apache.spark.sql.execution.{BinaryNode, SparkPlan} import org.apache.spark.util.collection.CompactBuffer @@ -60,29 +58,29 @@ case class SortMergeJoin( private def requiredOrders(keys: Seq[Expression]): Seq[SortOrder] = keys.map(SortOrder(_, Ascending)) - protected override def doExecute(): RDD[Row] = { + protected override def doExecute(): RDD[InternalRow] = { val leftResults = left.execute().map(_.copy()) val rightResults = right.execute().map(_.copy()) leftResults.zipPartitions(rightResults) { (leftIter, rightIter) => - new Iterator[Row] { + new Iterator[InternalRow] { // Mutable per row objects. private[this] val joinRow = new JoinedRow5 - private[this] var leftElement: Row = _ - private[this] var rightElement: Row = _ - private[this] var leftKey: Row = _ - private[this] var rightKey: Row = _ - private[this] var rightMatches: CompactBuffer[Row] = _ + private[this] var leftElement: InternalRow = _ + private[this] var rightElement: InternalRow = _ + private[this] var leftKey: InternalRow = _ + private[this] var rightKey: InternalRow = _ + private[this] var rightMatches: CompactBuffer[InternalRow] = _ private[this] var rightPosition: Int = -1 private[this] var stop: Boolean = false - private[this] var matchKey: Row = _ + private[this] var matchKey: InternalRow = _ // initialize iterator initialize() override final def hasNext: Boolean = nextMatchingPair() - override final def next(): Row = { + override final def next(): InternalRow = { if (hasNext) { // we are using the buffered right rows and run down left iterator val joinedRow = joinRow(leftElement, rightMatches(rightPosition)) @@ -145,7 +143,7 @@ case class SortMergeJoin( fetchLeft() } } - rightMatches = new CompactBuffer[Row]() + rightMatches = new CompactBuffer[InternalRow]() if (stop) { stop = false // iterate the right side to buffer all rows that matches diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/pythonUdfs.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/pythonUDFs.scala similarity index 74% rename from sql/core/src/main/scala/org/apache/spark/sql/execution/pythonUdfs.scala rename to sql/core/src/main/scala/org/apache/spark/sql/execution/pythonUDFs.scala index 65dd7ba020fa..9e1cff06c7ee 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/pythonUdfs.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/pythonUDFs.scala @@ -24,18 +24,19 @@ import scala.collection.JavaConverters._ import net.razorvine.pickle.{Pickler, Unpickler} +import org.apache.spark.{Accumulator, Logging => SparkLogging} import org.apache.spark.annotation.DeveloperApi import org.apache.spark.api.python.{PythonBroadcast, PythonRDD} import org.apache.spark.broadcast.Broadcast import org.apache.spark.rdd.RDD -import org.apache.spark.sql.catalyst.expressions.Row import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.Row import org.apache.spark.sql.catalyst.plans.logical import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan import org.apache.spark.sql.catalyst.rules.Rule -import org.apache.spark.sql.catalyst.util.DateUtils +import org.apache.spark.sql.catalyst.util.DateTimeUtils import org.apache.spark.sql.types._ -import org.apache.spark.{Accumulator, Logging => SparkLogging} +import org.apache.spark.unsafe.types.UTF8String /** * A serialized version of a Python lambda function. Suitable for use in a [[PythonRDD]]. @@ -46,6 +47,7 @@ private[spark] case class PythonUDF( envVars: JMap[String, String], pythonIncludes: JList[String], pythonExec: String, + pythonVer: String, broadcastVars: JList[Broadcast[PythonBroadcast]], accumulator: Accumulator[JList[Array[Byte]]], dataType: DataType, @@ -53,10 +55,10 @@ private[spark] case class PythonUDF( override def toString: String = s"PythonUDF#$name(${children.mkString(",")})" - def nullable: Boolean = true + override def nullable: Boolean = true - override def eval(input: Row): PythonUDF.this.EvaluatedType = { - sys.error("PythonUDFs can not be directly evaluated.") + override def eval(input: InternalRow): Any = { + throw new UnsupportedOperationException("PythonUDFs can not be directly evaluated.") } } @@ -67,46 +69,52 @@ private[spark] case class PythonUDF( * This has the limitation that the input to the Python UDF is not allowed include attributes from * multiple child operators. */ -private[spark] object ExtractPythonUdfs extends Rule[LogicalPlan] { +private[spark] object ExtractPythonUDFs extends Rule[LogicalPlan] { def apply(plan: LogicalPlan): LogicalPlan = plan transform { // Skip EvaluatePython nodes. - case p: EvaluatePython => p + case plan: EvaluatePython => plan - case l: LogicalPlan => + case plan: LogicalPlan if plan.resolved => // Extract any PythonUDFs from the current operator. - val udfs = l.expressions.flatMap(_.collect { case udf: PythonUDF => udf}) + val udfs = plan.expressions.flatMap(_.collect { case udf: PythonUDF => udf }) if (udfs.isEmpty) { // If there aren't any, we are done. - l + plan } else { // Pick the UDF we are going to evaluate (TODO: Support evaluating multiple UDFs at a time) // If there is more than one, we will add another evaluation operator in a subsequent pass. - val udf = udfs.head - - var evaluation: EvaluatePython = null - - // Rewrite the child that has the input required for the UDF - val newChildren = l.children.map { child => - // Check to make sure that the UDF can be evaluated with only the input of this child. - // Other cases are disallowed as they are ambiguous or would require a cartisian product. - if (udf.references.subsetOf(child.outputSet)) { - evaluation = EvaluatePython(udf, child) - evaluation - } else if (udf.references.intersect(child.outputSet).nonEmpty) { - sys.error(s"Invalid PythonUDF $udf, requires attributes from more than one child.") - } else { - child - } + udfs.find(_.resolved) match { + case Some(udf) => + var evaluation: EvaluatePython = null + + // Rewrite the child that has the input required for the UDF + val newChildren = plan.children.map { child => + // Check to make sure that the UDF can be evaluated with only the input of this child. + // Other cases are disallowed as they are ambiguous or would require a cartesian + // product. + if (udf.references.subsetOf(child.outputSet)) { + evaluation = EvaluatePython(udf, child) + evaluation + } else if (udf.references.intersect(child.outputSet).nonEmpty) { + sys.error(s"Invalid PythonUDF $udf, requires attributes from more than one child.") + } else { + child + } + } + + assert(evaluation != null, "Unable to evaluate PythonUDF. Missing input attributes.") + + // Trim away the new UDF value if it was only used for filtering or something. + logical.Project( + plan.output, + plan.transformExpressions { + case p: PythonUDF if p.fastEquals(udf) => evaluation.resultAttribute + }.withNewChildren(newChildren)) + + case None => + // If there is no Python UDF that is resolved, skip this round. + plan } - - assert(evaluation != null, "Unable to evaluate PythonUDF. Missing input attributes.") - - // Trim away the new UDF value if it was only used for filtering or something. - logical.Project( - l.output, - l.transformExpressions { - case p: PythonUDF if p.fastEquals(udf) => evaluation.resultAttribute - }.withNewChildren(newChildren)) } } } @@ -140,7 +148,8 @@ object EvaluatePython { case (ud, udt: UserDefinedType[_]) => toJava(udt.serialize(ud), udt.sqlType) - case (date: Int, DateType) => DateUtils.toJavaDate(date) + case (date: Int, DateType) => DateTimeUtils.toJavaDate(date) + case (t: Long, TimestampType) => DateTimeUtils.toJavaTimestamp(t) case (s: UTF8String, StringType) => s.toString // Pyrolite can handle Timestamp and Decimal @@ -174,15 +183,17 @@ object EvaluatePython { }.toMap case (c, StructType(fields)) if c.getClass.isArray => - new GenericRow(c.asInstanceOf[Array[_]].zip(fields).map { + new GenericInternalRow(c.asInstanceOf[Array[_]].zip(fields).map { case (e, f) => fromJava(e, f.dataType) - }): Row + }) case (c: java.util.Calendar, DateType) => - DateUtils.fromJavaDate(new java.sql.Date(c.getTime().getTime())) + DateTimeUtils.fromJavaDate(new java.sql.Date(c.getTimeInMillis)) case (c: java.util.Calendar, TimestampType) => - new java.sql.Timestamp(c.getTime().getTime()) + c.getTimeInMillis * 10000L + case (t: java.sql.Timestamp, TimestampType) => + DateTimeUtils.fromJavaTimestamp(t) case (_, udt: UserDefinedType[_]) => fromJava(obj, udt.sqlType) @@ -194,8 +205,10 @@ object EvaluatePython { case (c: Long, IntegerType) => c.toInt case (c: Int, LongType) => c.toLong case (c: Double, FloatType) => c.toFloat - case (c: String, StringType) => UTF8String(c) - case (c, StringType) if !c.isInstanceOf[String] => UTF8String(c.toString) + case (c: String, StringType) => UTF8String.fromString(c) + case (c, StringType) => + // If we get here, c is not a string. Call toString on it. + UTF8String.fromString(c.toString) case (c, _) => c } @@ -229,7 +242,7 @@ case class BatchPythonEvaluation(udf: PythonUDF, output: Seq[Attribute], child: def children: Seq[SparkPlan] = child :: Nil - protected override def doExecute(): RDD[Row] = { + protected override def doExecute(): RDD[InternalRow] = { val childResults = child.execute().map(_.copy()) val parent = childResults.mapPartitions { iter => @@ -251,6 +264,7 @@ case class BatchPythonEvaluation(udf: PythonUDF, output: Seq[Attribute], child: udf.pythonIncludes, false, udf.pythonExec, + udf.pythonVer, udf.broadcastVars, udf.accumulator ).mapPartitions { iter => @@ -263,7 +277,7 @@ case class BatchPythonEvaluation(udf: PythonUDF, output: Seq[Attribute], child: val row = new GenericMutableRow(1) iter.map { result => row(0) = EvaluatePython.fromJava(result, udf.dataType) - row: Row + row: InternalRow } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/stat/FrequentItems.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/stat/FrequentItems.scala index 5ae7e107544f..4e2e2c210d5a 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/stat/FrequentItems.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/stat/FrequentItems.scala @@ -20,9 +20,10 @@ package org.apache.spark.sql.execution.stat import scala.collection.mutable.{Map => MutableMap} import org.apache.spark.Logging -import org.apache.spark.sql.{Column, DataFrame, Row} +import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.plans.logical.LocalRelation import org.apache.spark.sql.types.{ArrayType, StructField, StructType} +import org.apache.spark.sql.{Column, DataFrame} private[sql] object FrequentItems extends Logging { @@ -62,7 +63,7 @@ private[sql] object FrequentItems extends Logging { } /** - * Finding frequent items for columns, possibly with false positives. Using the + * Finding frequent items for columns, possibly with false positives. Using the * frequent element count algorithm described in * [[http://dx.doi.org/10.1145/762471.762473, proposed by Karp, Schenker, and Papadimitriou]]. * The `support` should be greater than 1e-4. @@ -75,7 +76,7 @@ private[sql] object FrequentItems extends Logging { * @return A Local DataFrame with the Array of frequent items for each column. */ private[sql] def singlePassFreqItems( - df: DataFrame, + df: DataFrame, cols: Seq[String], support: Double): DataFrame = { require(support >= 1e-4, s"support ($support) must be greater than 1e-4.") @@ -88,8 +89,8 @@ private[sql] object FrequentItems extends Logging { val index = originalSchema.fieldIndex(name) (name, originalSchema.fields(index).dataType) } - - val freqItems = df.select(cols.map(Column(_)):_*).rdd.aggregate(countMaps)( + + val freqItems = df.select(cols.map(Column(_)) : _*).queryExecution.toRdd.aggregate(countMaps)( seqOp = (counts, row) => { var i = 0 while (i < numCols) { @@ -110,7 +111,7 @@ private[sql] object FrequentItems extends Logging { } ) val justItems = freqItems.map(m => m.baseMap.keys.toSeq) - val resultRow = Row(justItems:_*) + val resultRow = InternalRow(justItems : _*) // append frequent Items to the column name for easy debugging val outputCols = colInfo.map { v => StructField(v._1 + "_freqItems", ArrayType(v._2, false)) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/stat/StatFunctions.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/stat/StatFunctions.scala index d22f5fd2d439..00231d65a7d5 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/stat/StatFunctions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/stat/StatFunctions.scala @@ -18,14 +18,15 @@ package org.apache.spark.sql.execution.stat import org.apache.spark.Logging -import org.apache.spark.sql.{Column, DataFrame} +import org.apache.spark.sql.{Row, Column, DataFrame} import org.apache.spark.sql.catalyst.expressions.{GenericMutableRow, Cast} import org.apache.spark.sql.catalyst.plans.logical.LocalRelation import org.apache.spark.sql.functions._ import org.apache.spark.sql.types._ +import org.apache.spark.unsafe.types.UTF8String private[sql] object StatFunctions extends Logging { - + /** Calculate the Pearson Correlation Coefficient for the given columns */ private[sql] def pearsonCorrelation(df: DataFrame, cols: Seq[String]): Double = { val counts = collectStatisticalData(df, cols) @@ -81,7 +82,7 @@ private[sql] object StatFunctions extends Logging { s"with dataType ${data.get.dataType} not supported.") } val columns = cols.map(n => Column(Cast(Column(n).expr, DoubleType))) - df.select(columns: _*).rdd.aggregate(new CovarianceCounter)( + df.select(columns: _*).queryExecution.toRdd.aggregate(new CovarianceCounter)( seqOp = (counter, row) => { counter.add(row.getDouble(0), row.getDouble(1)) }, @@ -109,23 +110,40 @@ private[sql] object StatFunctions extends Logging { logWarning("The maximum limit of 1e6 pairs have been collected, which may not be all of " + "the pairs. Please try reducing the amount of distinct items in your columns.") } + def cleanElement(element: Any): String = { + if (element == null) "null" else element.toString + } // get the distinct values of column 2, so that we can make them the column names - val distinctCol2 = counts.map(_.get(1)).distinct.zipWithIndex.toMap + val distinctCol2: Map[Any, Int] = + counts.map(e => cleanElement(e.get(1))).distinct.zipWithIndex.toMap val columnSize = distinctCol2.size require(columnSize < 1e4, s"The number of distinct values for $col2, can't " + s"exceed 1e4. Currently $columnSize") val table = counts.groupBy(_.get(0)).map { case (col1Item, rows) => val countsRow = new GenericMutableRow(columnSize + 1) - rows.foreach { row => - countsRow.setLong(distinctCol2.get(row.get(1)).get + 1, row.getLong(2)) + rows.foreach { (row: Row) => + // row.get(0) is column 1 + // row.get(1) is column 2 + // row.get(2) is the frequency + val columnIndex = distinctCol2.get(cleanElement(row.get(1))).get + countsRow.setLong(columnIndex + 1, row.getLong(2)) } // the value of col1 is the first value, the rest are the counts - countsRow.setString(0, col1Item.toString) + countsRow.update(0, UTF8String.fromString(cleanElement(col1Item))) countsRow }.toSeq - val headerNames = distinctCol2.map(r => StructField(r._1.toString, LongType)).toSeq + // Back ticks can't exist in DataFrame column names, therefore drop them. To be able to accept + // special keywords and `.`, wrap the column names in ``. + def cleanColumnName(name: String): String = { + name.replace("`", "") + } + // In the map, the column names (._1) are not ordered by the index (._2). This was the bug in + // SPARK-8681. We need to explicitly sort by the column index and assign the column names. + val headerNames = distinctCol2.toSeq.sortBy(_._2).map { r => + StructField(cleanColumnName(r._1.toString), LongType) + } val schema = StructType(StructField(tableName, StringType) +: headerNames) - new DataFrame(df.sqlContext, LocalRelation(schema.toAttributes, table)) + new DataFrame(df.sqlContext, LocalRelation(schema.toAttributes, table)).na.fill(0.0) } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/expressions/Window.scala b/sql/core/src/main/scala/org/apache/spark/sql/expressions/Window.scala new file mode 100644 index 000000000000..e9b60841fc28 --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/expressions/Window.scala @@ -0,0 +1,98 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.expressions + +import org.apache.spark.annotation.Experimental +import org.apache.spark.sql.Column +import org.apache.spark.sql.catalyst.expressions._ + +/** + * :: Experimental :: + * Utility functions for defining window in DataFrames. + * + * {{{ + * // PARTITION BY country ORDER BY date ROWS BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW + * Window.partitionBy("country").orderBy("date").rowsBetween(Long.MinValue, 0) + * + * // PARTITION BY country ORDER BY date ROWS BETWEEN 3 PRECEDING AND 3 FOLLOWING + * Window.partitionBy("country").orderBy("date").rowsBetween(-3, 3) + * }}} + * + * @since 1.4.0 + */ +@Experimental +object Window { + + /** + * Creates a [[WindowSpec]] with the partitioning defined. + * @since 1.4.0 + */ + @scala.annotation.varargs + def partitionBy(colName: String, colNames: String*): WindowSpec = { + spec.partitionBy(colName, colNames : _*) + } + + /** + * Creates a [[WindowSpec]] with the partitioning defined. + * @since 1.4.0 + */ + @scala.annotation.varargs + def partitionBy(cols: Column*): WindowSpec = { + spec.partitionBy(cols : _*) + } + + /** + * Creates a [[WindowSpec]] with the ordering defined. + * @since 1.4.0 + */ + @scala.annotation.varargs + def orderBy(colName: String, colNames: String*): WindowSpec = { + spec.orderBy(colName, colNames : _*) + } + + /** + * Creates a [[WindowSpec]] with the ordering defined. + * @since 1.4.0 + */ + @scala.annotation.varargs + def orderBy(cols: Column*): WindowSpec = { + spec.orderBy(cols : _*) + } + + private def spec: WindowSpec = { + new WindowSpec(Seq.empty, Seq.empty, UnspecifiedFrame) + } + +} + +/** + * :: Experimental :: + * Utility functions for defining window in DataFrames. + * + * {{{ + * // PARTITION BY country ORDER BY date ROWS BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW + * Window.partitionBy("country").orderBy("date").rowsBetween(Long.MinValue, 0) + * + * // PARTITION BY country ORDER BY date ROWS BETWEEN 3 PRECEDING AND 3 FOLLOWING + * Window.partitionBy("country").orderBy("date").rowsBetween(-3, 3) + * }}} + * + * @since 1.4.0 + */ +@Experimental +class Window private() // So we can see Window in JavaDoc. diff --git a/sql/core/src/main/scala/org/apache/spark/sql/expressions/WindowSpec.scala b/sql/core/src/main/scala/org/apache/spark/sql/expressions/WindowSpec.scala new file mode 100644 index 000000000000..c3d224629702 --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/expressions/WindowSpec.scala @@ -0,0 +1,175 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.expressions + +import org.apache.spark.annotation.Experimental +import org.apache.spark.sql.{Column, catalyst} +import org.apache.spark.sql.catalyst.expressions._ + + +/** + * :: Experimental :: + * A window specification that defines the partitioning, ordering, and frame boundaries. + * + * Use the static methods in [[Window]] to create a [[WindowSpec]]. + * + * @since 1.4.0 + */ +@Experimental +class WindowSpec private[sql]( + partitionSpec: Seq[Expression], + orderSpec: Seq[SortOrder], + frame: catalyst.expressions.WindowFrame) { + + /** + * Defines the partitioning columns in a [[WindowSpec]]. + * @since 1.4.0 + */ + @scala.annotation.varargs + def partitionBy(colName: String, colNames: String*): WindowSpec = { + partitionBy((colName +: colNames).map(Column(_)): _*) + } + + /** + * Defines the partitioning columns in a [[WindowSpec]]. + * @since 1.4.0 + */ + @scala.annotation.varargs + def partitionBy(cols: Column*): WindowSpec = { + new WindowSpec(cols.map(_.expr), orderSpec, frame) + } + + /** + * Defines the ordering columns in a [[WindowSpec]]. + * @since 1.4.0 + */ + @scala.annotation.varargs + def orderBy(colName: String, colNames: String*): WindowSpec = { + orderBy((colName +: colNames).map(Column(_)): _*) + } + + /** + * Defines the ordering columns in a [[WindowSpec]]. + * @since 1.4.0 + */ + @scala.annotation.varargs + def orderBy(cols: Column*): WindowSpec = { + val sortOrder: Seq[SortOrder] = cols.map { col => + col.expr match { + case expr: SortOrder => + expr + case expr: Expression => + SortOrder(expr, Ascending) + } + } + new WindowSpec(partitionSpec, sortOrder, frame) + } + + /** + * Defines the frame boundaries, from `start` (inclusive) to `end` (inclusive). + * + * Both `start` and `end` are relative positions from the current row. For example, "0" means + * "current row", while "-1" means the row before the current row, and "5" means the fifth row + * after the current row. + * + * @param start boundary start, inclusive. + * The frame is unbounded if this is the minimum long value. + * @param end boundary end, inclusive. + * The frame is unbounded if this is the maximum long value. + * @since 1.4.0 + */ + def rowsBetween(start: Long, end: Long): WindowSpec = { + between(RowFrame, start, end) + } + + /** + * Defines the frame boundaries, from `start` (inclusive) to `end` (inclusive). + * + * Both `start` and `end` are relative from the current row. For example, "0" means "current row", + * while "-1" means one off before the current row, and "5" means the five off after the + * current row. + * + * @param start boundary start, inclusive. + * The frame is unbounded if this is the minimum long value. + * @param end boundary end, inclusive. + * The frame is unbounded if this is the maximum long value. + * @since 1.4.0 + */ + def rangeBetween(start: Long, end: Long): WindowSpec = { + between(RangeFrame, start, end) + } + + private def between(typ: FrameType, start: Long, end: Long): WindowSpec = { + val boundaryStart = start match { + case 0 => CurrentRow + case Long.MinValue => UnboundedPreceding + case x if x < 0 => ValuePreceding(-start.toInt) + case x if x > 0 => ValueFollowing(start.toInt) + } + + val boundaryEnd = end match { + case 0 => CurrentRow + case Long.MaxValue => UnboundedFollowing + case x if x < 0 => ValuePreceding(-end.toInt) + case x if x > 0 => ValueFollowing(end.toInt) + } + + new WindowSpec( + partitionSpec, + orderSpec, + SpecifiedWindowFrame(typ, boundaryStart, boundaryEnd)) + } + + /** + * Converts this [[WindowSpec]] into a [[Column]] with an aggregate expression. + */ + private[sql] def withAggregate(aggregate: Column): Column = { + val windowExpr = aggregate.expr match { + case Average(child) => WindowExpression( + UnresolvedWindowFunction("avg", child :: Nil), + WindowSpecDefinition(partitionSpec, orderSpec, frame)) + case Sum(child) => WindowExpression( + UnresolvedWindowFunction("sum", child :: Nil), + WindowSpecDefinition(partitionSpec, orderSpec, frame)) + case Count(child) => WindowExpression( + UnresolvedWindowFunction("count", child :: Nil), + WindowSpecDefinition(partitionSpec, orderSpec, frame)) + case First(child) => WindowExpression( + // TODO this is a hack for Hive UDAF first_value + UnresolvedWindowFunction("first_value", child :: Nil), + WindowSpecDefinition(partitionSpec, orderSpec, frame)) + case Last(child) => WindowExpression( + // TODO this is a hack for Hive UDAF last_value + UnresolvedWindowFunction("last_value", child :: Nil), + WindowSpecDefinition(partitionSpec, orderSpec, frame)) + case Min(child) => WindowExpression( + UnresolvedWindowFunction("min", child :: Nil), + WindowSpecDefinition(partitionSpec, orderSpec, frame)) + case Max(child) => WindowExpression( + UnresolvedWindowFunction("max", child :: Nil), + WindowSpecDefinition(partitionSpec, orderSpec, frame)) + case wf: WindowFunction => WindowExpression( + wf, + WindowSpecDefinition(partitionSpec, orderSpec, frame)) + case x => + throw new UnsupportedOperationException(s"$x is not supported in window operation.") + } + new Column(windowExpr) + } + +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala index 6640631cf071..4da9ffc495e1 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala @@ -19,12 +19,13 @@ package org.apache.spark.sql import scala.language.implicitConversions import scala.reflect.runtime.universe.{TypeTag, typeTag} +import scala.util.Try import org.apache.spark.annotation.Experimental import org.apache.spark.sql.catalyst.ScalaReflection import org.apache.spark.sql.catalyst.analysis.{UnresolvedFunction, Star} import org.apache.spark.sql.catalyst.expressions._ -import org.apache.spark.sql.catalyst.expressions.mathfuncs._ +import org.apache.spark.sql.catalyst.plans.logical.BroadcastHint import org.apache.spark.sql.types._ import org.apache.spark.util.Utils @@ -34,9 +35,13 @@ import org.apache.spark.util.Utils * * @groupname udf_funcs UDF functions * @groupname agg_funcs Aggregate functions + * @groupname datetime_funcs Date time functions * @groupname sort_funcs Sorting functions * @groupname normal_funcs Non-aggregate functions * @groupname math_funcs Math functions + * @groupname misc_funcs Misc functions + * @groupname window_funcs Window functions + * @groupname string_funcs String functions * @groupname Ungrouped Support functions for DataFrames. * @since 1.3.0 */ @@ -186,7 +191,7 @@ object functions { */ @scala.annotation.varargs def countDistinct(columnName: String, columnNames: String*): Column = - countDistinct(Column(columnName), columnNames.map(Column.apply) :_*) + countDistinct(Column(columnName), columnNames.map(Column.apply) : _*) /** * Aggregate function: returns the approximate number of distinct items in a group. @@ -320,6 +325,218 @@ object functions { */ def max(columnName: String): Column = max(Column(columnName)) + ////////////////////////////////////////////////////////////////////////////////////////////// + // Window functions + ////////////////////////////////////////////////////////////////////////////////////////////// + + /** + * Window function: returns the value that is `offset` rows before the current row, and + * `null` if there is less than `offset` rows before the current row. For example, + * an `offset` of one will return the previous row at any given point in the window partition. + * + * This is equivalent to the LAG function in SQL. + * + * @group window_funcs + * @since 1.4.0 + */ + def lag(e: Column, offset: Int): Column = { + lag(e, offset, null) + } + + /** + * Window function: returns the value that is `offset` rows before the current row, and + * `null` if there is less than `offset` rows before the current row. For example, + * an `offset` of one will return the previous row at any given point in the window partition. + * + * This is equivalent to the LAG function in SQL. + * + * @group window_funcs + * @since 1.4.0 + */ + def lag(columnName: String, offset: Int): Column = { + lag(columnName, offset, null) + } + + /** + * Window function: returns the value that is `offset` rows before the current row, and + * `defaultValue` if there is less than `offset` rows before the current row. For example, + * an `offset` of one will return the previous row at any given point in the window partition. + * + * This is equivalent to the LAG function in SQL. + * + * @group window_funcs + * @since 1.4.0 + */ + def lag(columnName: String, offset: Int, defaultValue: Any): Column = { + lag(Column(columnName), offset, defaultValue) + } + + /** + * Window function: returns the value that is `offset` rows before the current row, and + * `defaultValue` if there is less than `offset` rows before the current row. For example, + * an `offset` of one will return the previous row at any given point in the window partition. + * + * This is equivalent to the LAG function in SQL. + * + * @group window_funcs + * @since 1.4.0 + */ + def lag(e: Column, offset: Int, defaultValue: Any): Column = { + UnresolvedWindowFunction("lag", e.expr :: Literal(offset) :: Literal(defaultValue) :: Nil) + } + + /** + * Window function: returns the value that is `offset` rows after the current row, and + * `null` if there is less than `offset` rows after the current row. For example, + * an `offset` of one will return the next row at any given point in the window partition. + * + * This is equivalent to the LEAD function in SQL. + * + * @group window_funcs + * @since 1.4.0 + */ + def lead(columnName: String, offset: Int): Column = { + lead(columnName, offset, null) + } + + /** + * Window function: returns the value that is `offset` rows after the current row, and + * `null` if there is less than `offset` rows after the current row. For example, + * an `offset` of one will return the next row at any given point in the window partition. + * + * This is equivalent to the LEAD function in SQL. + * + * @group window_funcs + * @since 1.4.0 + */ + def lead(e: Column, offset: Int): Column = { + lead(e, offset, null) + } + + /** + * Window function: returns the value that is `offset` rows after the current row, and + * `defaultValue` if there is less than `offset` rows after the current row. For example, + * an `offset` of one will return the next row at any given point in the window partition. + * + * This is equivalent to the LEAD function in SQL. + * + * @group window_funcs + * @since 1.4.0 + */ + def lead(columnName: String, offset: Int, defaultValue: Any): Column = { + lead(Column(columnName), offset, defaultValue) + } + + /** + * Window function: returns the value that is `offset` rows after the current row, and + * `defaultValue` if there is less than `offset` rows after the current row. For example, + * an `offset` of one will return the next row at any given point in the window partition. + * + * This is equivalent to the LEAD function in SQL. + * + * @group window_funcs + * @since 1.4.0 + */ + def lead(e: Column, offset: Int, defaultValue: Any): Column = { + UnresolvedWindowFunction("lead", e.expr :: Literal(offset) :: Literal(defaultValue) :: Nil) + } + + /** + * Window function: returns the ntile group id (from 1 to `n` inclusive) in an ordered window + * partition. Fow example, if `n` is 4, the first quarter of the rows will get value 1, the second + * quarter will get 2, the third quarter will get 3, and the last quarter will get 4. + * + * This is equivalent to the NTILE function in SQL. + * + * @group window_funcs + * @since 1.4.0 + */ + def ntile(n: Int): Column = { + UnresolvedWindowFunction("ntile", lit(n).expr :: Nil) + } + + /** + * Window function: returns a sequential number starting at 1 within a window partition. + * + * This is equivalent to the ROW_NUMBER function in SQL. + * + * @group window_funcs + * @since 1.4.0 + */ + def rowNumber(): Column = { + UnresolvedWindowFunction("row_number", Nil) + } + + /** + * Window function: returns the rank of rows within a window partition, without any gaps. + * + * The difference between rank and denseRank is that denseRank leaves no gaps in ranking + * sequence when there are ties. That is, if you were ranking a competition using denseRank + * and had three people tie for second place, you would say that all three were in second + * place and that the next person came in third. + * + * This is equivalent to the DENSE_RANK function in SQL. + * + * @group window_funcs + * @since 1.4.0 + */ + def denseRank(): Column = { + UnresolvedWindowFunction("dense_rank", Nil) + } + + /** + * Window function: returns the rank of rows within a window partition. + * + * The difference between rank and denseRank is that denseRank leaves no gaps in ranking + * sequence when there are ties. That is, if you were ranking a competition using denseRank + * and had three people tie for second place, you would say that all three were in second + * place and that the next person came in third. + * + * This is equivalent to the RANK function in SQL. + * + * @group window_funcs + * @since 1.4.0 + */ + def rank(): Column = { + UnresolvedWindowFunction("rank", Nil) + } + + /** + * Window function: returns the cumulative distribution of values within a window partition, + * i.e. the fraction of rows that are below the current row. + * + * {{{ + * N = total number of rows in the partition + * cumeDist(x) = number of values before (and including) x / N + * }}} + * + * + * This is equivalent to the CUME_DIST function in SQL. + * + * @group window_funcs + * @since 1.4.0 + */ + def cumeDist(): Column = { + UnresolvedWindowFunction("cume_dist", Nil) + } + + /** + * Window function: returns the relative rank (i.e. percentile) of rows within a window partition. + * + * This is computed by: + * {{{ + * (rank of row in its partition - 1) / (number of rows in the partition - 1) + * }}} + * + * This is equivalent to the PERCENT_RANK function in SQL. + * + * @group window_funcs + * @since 1.4.0 + */ + def percentRank(): Column = { + UnresolvedWindowFunction("percent_rank", Nil) + } + ////////////////////////////////////////////////////////////////////////////////////////////// // Non-aggregate functions ////////////////////////////////////////////////////////////////////////////////////////////// @@ -351,6 +568,22 @@ object functions { array((colName +: colNames).map(col) : _*) } + /** + * Marks a DataFrame as small enough for use in broadcast joins. + * + * The following example marks the right DataFrame for broadcast hash join using `joinKey`. + * {{{ + * // left and right are DataFrames + * left.join(broadcast(right), "joinKey") + * }}} + * + * @group normal_funcs + * @since 1.5.0 + */ + def broadcast(df: DataFrame): DataFrame = { + DataFrame(df.sqlContext, BroadcastHint(df.logicalPlan)) + } + /** * Returns the first column that is not null. * {{{ @@ -494,23 +727,32 @@ object functions { /** * Computes the square root of the specified float value. * - * @group normal_funcs + * @group math_funcs * @since 1.3.0 */ def sqrt(e: Column): Column = Sqrt(e.expr) /** - * Creates a new struct column. The input column must be a column in a [[DataFrame]], or - * a derived column expression that is named (i.e. aliased). + * Computes the square root of the specified float value. + * + * @group math_funcs + * @since 1.5.0 + */ + def sqrt(colName: String): Column = sqrt(Column(colName)) + + /** + * Creates a new struct column. + * If the input column is a column in a [[DataFrame]], or a derived column expression + * that is named (i.e. aliased), its name would be remained as the StructField's name, + * otherwise, the newly generated StructField's name would be auto generated as col${index + 1}, + * i.e. col1, col2, col3, ... * * @group normal_funcs * @since 1.4.0 */ @scala.annotation.varargs def struct(cols: Column*): Column = { - require(cols.forall(_.expr.isInstanceOf[NamedExpression]), - s"struct input columns must all be named or aliased ($cols)") - CreateStruct(cols.map(_.expr.asInstanceOf[NamedExpression])) + CreateStruct(cols.map(_.expr)) } /** @@ -668,6 +910,24 @@ object functions { */ def atan2(l: Double, rightName: String): Column = atan2(l, Column(rightName)) + /** + * An expression that returns the string representation of the binary value of the given long + * column. For example, bin("12") returns "1100". + * + * @group math_funcs + * @since 1.5.0 + */ + def bin(e: Column): Column = Bin(e.expr) + + /** + * An expression that returns the string representation of the binary value of the given long + * column. For example, bin("12") returns "1100". + * + * @group math_funcs + * @since 1.5.0 + */ + def bin(columnName: String): Column = bin(Column(columnName)) + /** * Computes the cube-root of the given value. * @@ -732,6 +992,22 @@ object functions { */ def cosh(columnName: String): Column = cosh(Column(columnName)) + /** + * Returns the current date. + * + * @group datetime_funcs + * @since 1.5.0 + */ + def current_date(): Column = CurrentDate() + + /** + * Returns the current timestamp. + * + * @group datetime_funcs + * @since 1.5.0 + */ + def current_timestamp(): Column = CurrentTimestamp() + /** * Computes the exponential of the given value. * @@ -764,6 +1040,22 @@ object functions { */ def expm1(columnName: String): Column = expm1(Column(columnName)) + /** + * Computes the factorial of the given value. + * + * @group math_funcs + * @since 1.5.0 + */ + def factorial(e: Column): Column = Factorial(e.expr) + + /** + * Computes the factorial of the given column. + * + * @group math_funcs + * @since 1.5.0 + */ + def factorial(columnName: String): Column = factorial(Column(columnName)) + /** * Computes the floor of the given value. * @@ -780,6 +1072,40 @@ object functions { */ def floor(columnName: String): Column = floor(Column(columnName)) + /** + * Computes hex value of the given column + * + * @group math_funcs + * @since 1.5.0 + */ + def hex(column: Column): Column = Hex(column.expr) + + /** + * Computes hex value of the given input + * + * @group math_funcs + * @since 1.5.0 + */ + def hex(colName: String): Column = hex(Column(colName)) + + /** + * Inverse of hex. Interprets each pair of characters as a hexadecimal number + * and converts to the byte representation of number. + * + * @group math_funcs + * @since 1.5.0 + */ + def unhex(column: Column): Column = Unhex(column.expr) + + /** + * Inverse of hex. Interprets each pair of characters as a hexadecimal number + * and converts to the byte representation of number. + * + * @group math_funcs + * @since 1.5.0 + */ + def unhex(colName: String): Column = unhex(Column(colName)) + /** * Computes `sqrt(a^2^ + b^2^)` without intermediate overflow or underflow. * @@ -862,7 +1188,23 @@ object functions { def log(columnName: String): Column = log(Column(columnName)) /** - * Computes the logarithm of the given value in Base 10. + * Returns the first argument-base logarithm of the second argument. + * + * @group math_funcs + * @since 1.4.0 + */ + def log(base: Double, a: Column): Column = Logarithm(lit(base).expr, a.expr) + + /** + * Returns the first argument-base logarithm of the second argument. + * + * @group math_funcs + * @since 1.4.0 + */ + def log(base: Double, columnName: String): Column = log(base, Column(columnName)) + + /** + * Computes the logarithm of the given value in base 10. * * @group math_funcs * @since 1.4.0 @@ -870,7 +1212,7 @@ object functions { def log10(e: Column): Column = Log10(e.expr) /** - * Computes the logarithm of the given value in Base 10. + * Computes the logarithm of the given value in base 10. * * @group math_funcs * @since 1.4.0 @@ -893,6 +1235,22 @@ object functions { */ def log1p(columnName: String): Column = log1p(Column(columnName)) + /** + * Computes the logarithm of the given column in base 2. + * + * @group math_funcs + * @since 1.5.0 + */ + def log2(expr: Column): Column = Log2(expr.expr) + + /** + * Computes the logarithm of the given value in base 2. + * + * @group math_funcs + * @since 1.5.0 + */ + def log2(columnName: String): Column = log2(Column(columnName)) + /** * Returns the value of the first argument raised to the power of the second argument. * @@ -975,6 +1333,64 @@ object functions { */ def rint(columnName: String): Column = rint(Column(columnName)) + /** + * Shift the the given value numBits left. If the given value is a long value, this function + * will return a long value else it will return an integer value. + * + * @group math_funcs + * @since 1.5.0 + */ + def shiftLeft(e: Column, numBits: Int): Column = ShiftLeft(e.expr, lit(numBits).expr) + + /** + * Shift the the given value numBits left. If the given value is a long value, this function + * will return a long value else it will return an integer value. + * + * @group math_funcs + * @since 1.5.0 + */ + def shiftLeft(columnName: String, numBits: Int): Column = + shiftLeft(Column(columnName), numBits) + + /** + * Shift the the given value numBits right. If the given value is a long value, it will return + * a long value else it will return an integer value. + * + * @group math_funcs + * @since 1.5.0 + */ + def shiftRight(e: Column, numBits: Int): Column = ShiftRight(e.expr, lit(numBits).expr) + + /** + * Unsigned shift the the given value numBits right. If the given value is a long value, + * it will return a long value else it will return an integer value. + * + * @group math_funcs + * @since 1.5.0 + */ + def shiftRightUnsigned(columnName: String, numBits: Int): Column = + shiftRightUnsigned(Column(columnName), numBits) + + /** + * Unsigned shift the the given value numBits right. If the given value is a long value, + * it will return a long value else it will return an integer value. + * + * @group math_funcs + * @since 1.5.0 + */ + def shiftRightUnsigned(e: Column, numBits: Int): Column = + ShiftRightUnsigned(e.expr, lit(numBits).expr) + + /** + * Shift the the given value numBits right. If the given value is a long value, it will return + * a long value else it will return an integer value. + * + * @group math_funcs + * @since 1.5.0 + */ + def shiftRight(columnName: String, numBits: Int): Column = + shiftRight(Column(columnName), numBits) + /** * Computes the signum of the given value. * @@ -1086,7 +1502,206 @@ object functions { * @since 1.4.0 */ def toRadians(columnName: String): Column = toRadians(Column(columnName)) - + + ////////////////////////////////////////////////////////////////////////////////////////////// + // Misc functions + ////////////////////////////////////////////////////////////////////////////////////////////// + + /** + * Calculates the MD5 digest and returns the value as a 32 character hex string. + * + * @group misc_funcs + * @since 1.5.0 + */ + def md5(e: Column): Column = Md5(e.expr) + + /** + * Calculates the MD5 digest and returns the value as a 32 character hex string. + * + * @group misc_funcs + * @since 1.5.0 + */ + def md5(columnName: String): Column = md5(Column(columnName)) + + /** + * Calculates the SHA-1 digest and returns the value as a 40 character hex string. + * + * @group misc_funcs + * @since 1.5.0 + */ + def sha1(e: Column): Column = Sha1(e.expr) + + /** + * Calculates the SHA-1 digest and returns the value as a 40 character hex string. + * + * @group misc_funcs + * @since 1.5.0 + */ + def sha1(columnName: String): Column = sha1(Column(columnName)) + + /** + * Calculates the SHA-2 family of hash functions and returns the value as a hex string. + * + * @group misc_funcs + * @since 1.5.0 + */ + def sha2(e: Column, numBits: Int): Column = { + require(Seq(0, 224, 256, 384, 512).contains(numBits), + s"numBits $numBits is not in the permitted values (0, 224, 256, 384, 512)") + Sha2(e.expr, lit(numBits).expr) + } + + /** + * Calculates the SHA-2 family of hash functions and returns the value as a hex string. + * + * @group misc_funcs + * @since 1.5.0 + */ + def sha2(columnName: String, numBits: Int): Column = sha2(Column(columnName), numBits) + + /** + * Calculates the cyclic redundancy check value and returns the value as a bigint. + * + * @group misc_funcs + * @since 1.5.0 + */ + def crc32(e: Column): Column = Crc32(e.expr) + + /** + * Calculates the cyclic redundancy check value and returns the value as a bigint. + * + * @group misc_funcs + * @since 1.5.0 + */ + def crc32(columnName: String): Column = crc32(Column(columnName)) + + ////////////////////////////////////////////////////////////////////////////////////////////// + // String functions + ////////////////////////////////////////////////////////////////////////////////////////////// + + /** + * Computes the length of a given string value. + * + * @group string_funcs + * @since 1.5.0 + */ + def strlen(e: Column): Column = StringLength(e.expr) + + /** + * Computes the length of a given string column. + * + * @group string_funcs + * @since 1.5.0 + */ + def strlen(columnName: String): Column = strlen(Column(columnName)) + + /** + * Computes the Levenshtein distance of the two given strings. + * @group string_funcs + * @since 1.5.0 + */ + def levenshtein(l: Column, r: Column): Column = Levenshtein(l.expr, r.expr) + + /** + * Computes the Levenshtein distance of the two given strings. + * @group string_funcs + * @since 1.5.0 + */ + def levenshtein(leftColumnName: String, rightColumnName: String): Column = + levenshtein(Column(leftColumnName), Column(rightColumnName)) + + /** + * Computes the numeric value of the first character of the specified string value. + * + * @group string_funcs + * @since 1.5.0 + */ + def ascii(e: Column): Column = Ascii(e.expr) + + /** + * Computes the numeric value of the first character of the specified string column. + * + * @group string_funcs + * @since 1.5.0 + */ + def ascii(columnName: String): Column = ascii(Column(columnName)) + + /** + * Computes the specified value from binary to a base64 string. + * + * @group string_funcs + * @since 1.5.0 + */ + def base64(e: Column): Column = Base64(e.expr) + + /** + * Computes the specified column from binary to a base64 string. + * + * @group string_funcs + * @since 1.5.0 + */ + def base64(columnName: String): Column = base64(Column(columnName)) + + /** + * Computes the specified value from a base64 string to binary. + * + * @group string_funcs + * @since 1.5.0 + */ + def unbase64(e: Column): Column = UnBase64(e.expr) + + /** + * Computes the specified column from a base64 string to binary. + * + * @group string_funcs + * @since 1.5.0 + */ + def unbase64(columnName: String): Column = unbase64(Column(columnName)) + + /** + * Computes the first argument into a binary from a string using the provided character set + * (one of 'US-ASCII', 'ISO-8859-1', 'UTF-8', 'UTF-16BE', 'UTF-16LE', 'UTF-16'). + * If either argument is null, the result will also be null. + * + * @group string_funcs + * @since 1.5.0 + */ + def encode(value: Column, charset: String): Column = Encode(value.expr, lit(charset).expr) + + /** + * Computes the first argument into a binary from a string using the provided character set + * (one of 'US-ASCII', 'ISO-8859-1', 'UTF-8', 'UTF-16BE', 'UTF-16LE', 'UTF-16'). + * If either argument is null, the result will also be null. + * NOTE: charset represents the string value of the character set, not the column name. + * + * @group string_funcs + * @since 1.5.0 + */ + def encode(columnName: String, charset: String): Column = + encode(Column(columnName), charset) + + /** + * Computes the first argument into a string from a binary using the provided character set + * (one of 'US-ASCII', 'ISO-8859-1', 'UTF-8', 'UTF-16BE', 'UTF-16LE', 'UTF-16'). + * If either argument is null, the result will also be null. + * + * @group string_funcs + * @since 1.5.0 + */ + def decode(value: Column, charset: String): Column = Decode(value.expr, lit(charset).expr) + + /** + * Computes the first argument into a string from a binary using the provided character set + * (one of 'US-ASCII', 'ISO-8859-1', 'UTF-8', 'UTF-16BE', 'UTF-16LE', 'UTF-16'). + * If either argument is null, the result will also be null. + * NOTE: charset represents the string value of the character set, not the column name. + * + * @group string_funcs + * @since 1.5.0 + */ + def decode(columnName: String, charset: String): Column = + decode(Column(columnName), charset) + ////////////////////////////////////////////////////////////////////////////////////////////// ////////////////////////////////////////////////////////////////////////////////////////////// @@ -1097,6 +1712,7 @@ object functions { (0 to 10).map { x => val types = (1 to x).foldRight("RT")((i, s) => {s"A$i, $s"}) val typeTags = (1 to x).map(i => s"A$i: TypeTag").foldLeft("RT: TypeTag")(_ + ", " + _) + val inputTypes = (1 to x).foldRight("Nil")((i, s) => {s"ScalaReflection.schemaFor(typeTag[A$i]).dataType :: $s"}) println(s""" /** * Defines a user-defined function of ${x} arguments as user-defined function (UDF). @@ -1106,14 +1722,15 @@ object functions { * @since 1.3.0 */ def udf[$typeTags](f: Function$x[$types]): UserDefinedFunction = { - UserDefinedFunction(f, ScalaReflection.schemaFor(typeTag[RT]).dataType) + val inputTypes = Try($inputTypes).getOrElse(Nil) + UserDefinedFunction(f, ScalaReflection.schemaFor(typeTag[RT]).dataType, inputTypes) }""") } (0 to 10).map { x => val args = (1 to x).map(i => s"arg$i: Column").mkString(", ") val fTypes = Seq.fill(x + 1)("_").mkString(", ") - val argsInUdf = (1 to x).map(i => s"arg$i.expr").mkString(", ") + val argsInUDF = (1 to x).map(i => s"arg$i.expr").mkString(", ") println(s""" /** * Call a Scala function of ${x} arguments as user-defined function (UDF). This requires @@ -1121,9 +1738,11 @@ object functions { * * @group udf_funcs * @since 1.3.0 + * @deprecated As of 1.5.0, since it's redundant with udf() */ + @deprecated("Use udf", "1.5.0") def callUDF(f: Function$x[$fTypes], returnType: DataType${if (args.length > 0) ", " + args else ""}): Column = { - ScalaUdf(f, returnType, Seq($argsInUdf)) + ScalaUDF(f, returnType, Seq($argsInUDF)) }""") } } @@ -1136,7 +1755,8 @@ object functions { * @since 1.3.0 */ def udf[RT: TypeTag](f: Function0[RT]): UserDefinedFunction = { - UserDefinedFunction(f, ScalaReflection.schemaFor(typeTag[RT]).dataType) + val inputTypes = Try(Nil).getOrElse(Nil) + UserDefinedFunction(f, ScalaReflection.schemaFor(typeTag[RT]).dataType, inputTypes) } /** @@ -1147,7 +1767,8 @@ object functions { * @since 1.3.0 */ def udf[RT: TypeTag, A1: TypeTag](f: Function1[A1, RT]): UserDefinedFunction = { - UserDefinedFunction(f, ScalaReflection.schemaFor(typeTag[RT]).dataType) + val inputTypes = Try(ScalaReflection.schemaFor(typeTag[A1]).dataType :: Nil).getOrElse(Nil) + UserDefinedFunction(f, ScalaReflection.schemaFor(typeTag[RT]).dataType, inputTypes) } /** @@ -1158,7 +1779,8 @@ object functions { * @since 1.3.0 */ def udf[RT: TypeTag, A1: TypeTag, A2: TypeTag](f: Function2[A1, A2, RT]): UserDefinedFunction = { - UserDefinedFunction(f, ScalaReflection.schemaFor(typeTag[RT]).dataType) + val inputTypes = Try(ScalaReflection.schemaFor(typeTag[A1]).dataType :: ScalaReflection.schemaFor(typeTag[A2]).dataType :: Nil).getOrElse(Nil) + UserDefinedFunction(f, ScalaReflection.schemaFor(typeTag[RT]).dataType, inputTypes) } /** @@ -1169,7 +1791,8 @@ object functions { * @since 1.3.0 */ def udf[RT: TypeTag, A1: TypeTag, A2: TypeTag, A3: TypeTag](f: Function3[A1, A2, A3, RT]): UserDefinedFunction = { - UserDefinedFunction(f, ScalaReflection.schemaFor(typeTag[RT]).dataType) + val inputTypes = Try(ScalaReflection.schemaFor(typeTag[A1]).dataType :: ScalaReflection.schemaFor(typeTag[A2]).dataType :: ScalaReflection.schemaFor(typeTag[A3]).dataType :: Nil).getOrElse(Nil) + UserDefinedFunction(f, ScalaReflection.schemaFor(typeTag[RT]).dataType, inputTypes) } /** @@ -1180,7 +1803,8 @@ object functions { * @since 1.3.0 */ def udf[RT: TypeTag, A1: TypeTag, A2: TypeTag, A3: TypeTag, A4: TypeTag](f: Function4[A1, A2, A3, A4, RT]): UserDefinedFunction = { - UserDefinedFunction(f, ScalaReflection.schemaFor(typeTag[RT]).dataType) + val inputTypes = Try(ScalaReflection.schemaFor(typeTag[A1]).dataType :: ScalaReflection.schemaFor(typeTag[A2]).dataType :: ScalaReflection.schemaFor(typeTag[A3]).dataType :: ScalaReflection.schemaFor(typeTag[A4]).dataType :: Nil).getOrElse(Nil) + UserDefinedFunction(f, ScalaReflection.schemaFor(typeTag[RT]).dataType, inputTypes) } /** @@ -1191,7 +1815,8 @@ object functions { * @since 1.3.0 */ def udf[RT: TypeTag, A1: TypeTag, A2: TypeTag, A3: TypeTag, A4: TypeTag, A5: TypeTag](f: Function5[A1, A2, A3, A4, A5, RT]): UserDefinedFunction = { - UserDefinedFunction(f, ScalaReflection.schemaFor(typeTag[RT]).dataType) + val inputTypes = Try(ScalaReflection.schemaFor(typeTag[A1]).dataType :: ScalaReflection.schemaFor(typeTag[A2]).dataType :: ScalaReflection.schemaFor(typeTag[A3]).dataType :: ScalaReflection.schemaFor(typeTag[A4]).dataType :: ScalaReflection.schemaFor(typeTag[A5]).dataType :: Nil).getOrElse(Nil) + UserDefinedFunction(f, ScalaReflection.schemaFor(typeTag[RT]).dataType, inputTypes) } /** @@ -1202,7 +1827,8 @@ object functions { * @since 1.3.0 */ def udf[RT: TypeTag, A1: TypeTag, A2: TypeTag, A3: TypeTag, A4: TypeTag, A5: TypeTag, A6: TypeTag](f: Function6[A1, A2, A3, A4, A5, A6, RT]): UserDefinedFunction = { - UserDefinedFunction(f, ScalaReflection.schemaFor(typeTag[RT]).dataType) + val inputTypes = Try(ScalaReflection.schemaFor(typeTag[A1]).dataType :: ScalaReflection.schemaFor(typeTag[A2]).dataType :: ScalaReflection.schemaFor(typeTag[A3]).dataType :: ScalaReflection.schemaFor(typeTag[A4]).dataType :: ScalaReflection.schemaFor(typeTag[A5]).dataType :: ScalaReflection.schemaFor(typeTag[A6]).dataType :: Nil).getOrElse(Nil) + UserDefinedFunction(f, ScalaReflection.schemaFor(typeTag[RT]).dataType, inputTypes) } /** @@ -1213,7 +1839,8 @@ object functions { * @since 1.3.0 */ def udf[RT: TypeTag, A1: TypeTag, A2: TypeTag, A3: TypeTag, A4: TypeTag, A5: TypeTag, A6: TypeTag, A7: TypeTag](f: Function7[A1, A2, A3, A4, A5, A6, A7, RT]): UserDefinedFunction = { - UserDefinedFunction(f, ScalaReflection.schemaFor(typeTag[RT]).dataType) + val inputTypes = Try(ScalaReflection.schemaFor(typeTag[A1]).dataType :: ScalaReflection.schemaFor(typeTag[A2]).dataType :: ScalaReflection.schemaFor(typeTag[A3]).dataType :: ScalaReflection.schemaFor(typeTag[A4]).dataType :: ScalaReflection.schemaFor(typeTag[A5]).dataType :: ScalaReflection.schemaFor(typeTag[A6]).dataType :: ScalaReflection.schemaFor(typeTag[A7]).dataType :: Nil).getOrElse(Nil) + UserDefinedFunction(f, ScalaReflection.schemaFor(typeTag[RT]).dataType, inputTypes) } /** @@ -1224,7 +1851,8 @@ object functions { * @since 1.3.0 */ def udf[RT: TypeTag, A1: TypeTag, A2: TypeTag, A3: TypeTag, A4: TypeTag, A5: TypeTag, A6: TypeTag, A7: TypeTag, A8: TypeTag](f: Function8[A1, A2, A3, A4, A5, A6, A7, A8, RT]): UserDefinedFunction = { - UserDefinedFunction(f, ScalaReflection.schemaFor(typeTag[RT]).dataType) + val inputTypes = Try(ScalaReflection.schemaFor(typeTag[A1]).dataType :: ScalaReflection.schemaFor(typeTag[A2]).dataType :: ScalaReflection.schemaFor(typeTag[A3]).dataType :: ScalaReflection.schemaFor(typeTag[A4]).dataType :: ScalaReflection.schemaFor(typeTag[A5]).dataType :: ScalaReflection.schemaFor(typeTag[A6]).dataType :: ScalaReflection.schemaFor(typeTag[A7]).dataType :: ScalaReflection.schemaFor(typeTag[A8]).dataType :: Nil).getOrElse(Nil) + UserDefinedFunction(f, ScalaReflection.schemaFor(typeTag[RT]).dataType, inputTypes) } /** @@ -1235,7 +1863,8 @@ object functions { * @since 1.3.0 */ def udf[RT: TypeTag, A1: TypeTag, A2: TypeTag, A3: TypeTag, A4: TypeTag, A5: TypeTag, A6: TypeTag, A7: TypeTag, A8: TypeTag, A9: TypeTag](f: Function9[A1, A2, A3, A4, A5, A6, A7, A8, A9, RT]): UserDefinedFunction = { - UserDefinedFunction(f, ScalaReflection.schemaFor(typeTag[RT]).dataType) + val inputTypes = Try(ScalaReflection.schemaFor(typeTag[A1]).dataType :: ScalaReflection.schemaFor(typeTag[A2]).dataType :: ScalaReflection.schemaFor(typeTag[A3]).dataType :: ScalaReflection.schemaFor(typeTag[A4]).dataType :: ScalaReflection.schemaFor(typeTag[A5]).dataType :: ScalaReflection.schemaFor(typeTag[A6]).dataType :: ScalaReflection.schemaFor(typeTag[A7]).dataType :: ScalaReflection.schemaFor(typeTag[A8]).dataType :: ScalaReflection.schemaFor(typeTag[A9]).dataType :: Nil).getOrElse(Nil) + UserDefinedFunction(f, ScalaReflection.schemaFor(typeTag[RT]).dataType, inputTypes) } /** @@ -1246,7 +1875,8 @@ object functions { * @since 1.3.0 */ def udf[RT: TypeTag, A1: TypeTag, A2: TypeTag, A3: TypeTag, A4: TypeTag, A5: TypeTag, A6: TypeTag, A7: TypeTag, A8: TypeTag, A9: TypeTag, A10: TypeTag](f: Function10[A1, A2, A3, A4, A5, A6, A7, A8, A9, A10, RT]): UserDefinedFunction = { - UserDefinedFunction(f, ScalaReflection.schemaFor(typeTag[RT]).dataType) + val inputTypes = Try(ScalaReflection.schemaFor(typeTag[A1]).dataType :: ScalaReflection.schemaFor(typeTag[A2]).dataType :: ScalaReflection.schemaFor(typeTag[A3]).dataType :: ScalaReflection.schemaFor(typeTag[A4]).dataType :: ScalaReflection.schemaFor(typeTag[A5]).dataType :: ScalaReflection.schemaFor(typeTag[A6]).dataType :: ScalaReflection.schemaFor(typeTag[A7]).dataType :: ScalaReflection.schemaFor(typeTag[A8]).dataType :: ScalaReflection.schemaFor(typeTag[A9]).dataType :: ScalaReflection.schemaFor(typeTag[A10]).dataType :: Nil).getOrElse(Nil) + UserDefinedFunction(f, ScalaReflection.schemaFor(typeTag[RT]).dataType, inputTypes) } ////////////////////////////////////////////////////////////////////////////////////////////////// @@ -1257,9 +1887,11 @@ object functions { * * @group udf_funcs * @since 1.3.0 + * @deprecated As of 1.5.0, since it's redundant with udf() */ + @deprecated("Use udf", "1.5.0") def callUDF(f: Function0[_], returnType: DataType): Column = { - ScalaUdf(f, returnType, Seq()) + ScalaUDF(f, returnType, Seq()) } /** @@ -1268,9 +1900,11 @@ object functions { * * @group udf_funcs * @since 1.3.0 + * @deprecated As of 1.5.0, since it's redundant with udf() */ + @deprecated("Use udf", "1.5.0") def callUDF(f: Function1[_, _], returnType: DataType, arg1: Column): Column = { - ScalaUdf(f, returnType, Seq(arg1.expr)) + ScalaUDF(f, returnType, Seq(arg1.expr)) } /** @@ -1279,9 +1913,11 @@ object functions { * * @group udf_funcs * @since 1.3.0 + * @deprecated As of 1.5.0, since it's redundant with udf() */ + @deprecated("Use udf", "1.5.0") def callUDF(f: Function2[_, _, _], returnType: DataType, arg1: Column, arg2: Column): Column = { - ScalaUdf(f, returnType, Seq(arg1.expr, arg2.expr)) + ScalaUDF(f, returnType, Seq(arg1.expr, arg2.expr)) } /** @@ -1290,9 +1926,11 @@ object functions { * * @group udf_funcs * @since 1.3.0 + * @deprecated As of 1.5.0, since it's redundant with udf() */ + @deprecated("Use udf", "1.5.0") def callUDF(f: Function3[_, _, _, _], returnType: DataType, arg1: Column, arg2: Column, arg3: Column): Column = { - ScalaUdf(f, returnType, Seq(arg1.expr, arg2.expr, arg3.expr)) + ScalaUDF(f, returnType, Seq(arg1.expr, arg2.expr, arg3.expr)) } /** @@ -1301,9 +1939,11 @@ object functions { * * @group udf_funcs * @since 1.3.0 + * @deprecated As of 1.5.0, since it's redundant with udf() */ + @deprecated("Use udf", "1.5.0") def callUDF(f: Function4[_, _, _, _, _], returnType: DataType, arg1: Column, arg2: Column, arg3: Column, arg4: Column): Column = { - ScalaUdf(f, returnType, Seq(arg1.expr, arg2.expr, arg3.expr, arg4.expr)) + ScalaUDF(f, returnType, Seq(arg1.expr, arg2.expr, arg3.expr, arg4.expr)) } /** @@ -1312,9 +1952,11 @@ object functions { * * @group udf_funcs * @since 1.3.0 + * @deprecated As of 1.5.0, since it's redundant with udf() */ + @deprecated("Use udf", "1.5.0") def callUDF(f: Function5[_, _, _, _, _, _], returnType: DataType, arg1: Column, arg2: Column, arg3: Column, arg4: Column, arg5: Column): Column = { - ScalaUdf(f, returnType, Seq(arg1.expr, arg2.expr, arg3.expr, arg4.expr, arg5.expr)) + ScalaUDF(f, returnType, Seq(arg1.expr, arg2.expr, arg3.expr, arg4.expr, arg5.expr)) } /** @@ -1323,9 +1965,11 @@ object functions { * * @group udf_funcs * @since 1.3.0 + * @deprecated As of 1.5.0, since it's redundant with udf() */ + @deprecated("Use udf", "1.5.0") def callUDF(f: Function6[_, _, _, _, _, _, _], returnType: DataType, arg1: Column, arg2: Column, arg3: Column, arg4: Column, arg5: Column, arg6: Column): Column = { - ScalaUdf(f, returnType, Seq(arg1.expr, arg2.expr, arg3.expr, arg4.expr, arg5.expr, arg6.expr)) + ScalaUDF(f, returnType, Seq(arg1.expr, arg2.expr, arg3.expr, arg4.expr, arg5.expr, arg6.expr)) } /** @@ -1334,9 +1978,11 @@ object functions { * * @group udf_funcs * @since 1.3.0 + * @deprecated As of 1.5.0, since it's redundant with udf() */ + @deprecated("Use udf", "1.5.0") def callUDF(f: Function7[_, _, _, _, _, _, _, _], returnType: DataType, arg1: Column, arg2: Column, arg3: Column, arg4: Column, arg5: Column, arg6: Column, arg7: Column): Column = { - ScalaUdf(f, returnType, Seq(arg1.expr, arg2.expr, arg3.expr, arg4.expr, arg5.expr, arg6.expr, arg7.expr)) + ScalaUDF(f, returnType, Seq(arg1.expr, arg2.expr, arg3.expr, arg4.expr, arg5.expr, arg6.expr, arg7.expr)) } /** @@ -1345,9 +1991,11 @@ object functions { * * @group udf_funcs * @since 1.3.0 + * @deprecated As of 1.5.0, since it's redundant with udf() */ + @deprecated("Use udf", "1.5.0") def callUDF(f: Function8[_, _, _, _, _, _, _, _, _], returnType: DataType, arg1: Column, arg2: Column, arg3: Column, arg4: Column, arg5: Column, arg6: Column, arg7: Column, arg8: Column): Column = { - ScalaUdf(f, returnType, Seq(arg1.expr, arg2.expr, arg3.expr, arg4.expr, arg5.expr, arg6.expr, arg7.expr, arg8.expr)) + ScalaUDF(f, returnType, Seq(arg1.expr, arg2.expr, arg3.expr, arg4.expr, arg5.expr, arg6.expr, arg7.expr, arg8.expr)) } /** @@ -1356,9 +2004,11 @@ object functions { * * @group udf_funcs * @since 1.3.0 + * @deprecated As of 1.5.0, since it's redundant with udf() */ + @deprecated("Use udf", "1.5.0") def callUDF(f: Function9[_, _, _, _, _, _, _, _, _, _], returnType: DataType, arg1: Column, arg2: Column, arg3: Column, arg4: Column, arg5: Column, arg6: Column, arg7: Column, arg8: Column, arg9: Column): Column = { - ScalaUdf(f, returnType, Seq(arg1.expr, arg2.expr, arg3.expr, arg4.expr, arg5.expr, arg6.expr, arg7.expr, arg8.expr, arg9.expr)) + ScalaUDF(f, returnType, Seq(arg1.expr, arg2.expr, arg3.expr, arg4.expr, arg5.expr, arg6.expr, arg7.expr, arg8.expr, arg9.expr)) } /** @@ -1367,9 +2017,11 @@ object functions { * * @group udf_funcs * @since 1.3.0 + * @deprecated As of 1.5.0, since it's redundant with udf() */ + @deprecated("Use udf", "1.5.0") def callUDF(f: Function10[_, _, _, _, _, _, _, _, _, _, _], returnType: DataType, arg1: Column, arg2: Column, arg3: Column, arg4: Column, arg5: Column, arg6: Column, arg7: Column, arg8: Column, arg9: Column, arg10: Column): Column = { - ScalaUdf(f, returnType, Seq(arg1.expr, arg2.expr, arg3.expr, arg4.expr, arg5.expr, arg6.expr, arg7.expr, arg8.expr, arg9.expr, arg10.expr)) + ScalaUDF(f, returnType, Seq(arg1.expr, arg2.expr, arg3.expr, arg4.expr, arg5.expr, arg6.expr, arg7.expr, arg8.expr, arg9.expr, arg10.expr)) } // scalastyle:on @@ -1382,15 +2034,44 @@ object functions { * * val df = Seq(("id1", 1), ("id2", 4), ("id3", 5)).toDF("id", "value") * val sqlContext = df.sqlContext - * sqlContext.udf.register("simpleUdf", (v: Int) => v * v) - * df.select($"id", callUdf("simpleUdf", $"value")) + * sqlContext.udf.register("simpleUDF", (v: Int) => v * v) + * df.select($"id", callUDF("simpleUDF", $"value")) + * }}} + * + * @group udf_funcs + * @since 1.5.0 + */ + def callUDF(udfName: String, cols: Column*): Column = { + UnresolvedFunction(udfName, cols.map(_.expr)) + } + + /** + * Call an user-defined function. + * Example: + * {{{ + * import org.apache.spark.sql._ + * + * val df = Seq(("id1", 1), ("id2", 4), ("id3", 5)).toDF("id", "value") + * val sqlContext = df.sqlContext + * sqlContext.udf.register("simpleUDF", (v: Int) => v * v) + * df.select($"id", callUdf("simpleUDF", $"value")) * }}} * * @group udf_funcs * @since 1.4.0 + * @deprecated As of 1.5.0, since it was not coherent to have two functions callUdf and callUDF */ + @deprecated("Use callUDF", "1.5.0") def callUdf(udfName: String, cols: Column*): Column = { - UnresolvedFunction(udfName, cols.map(_.expr)) + // Note: we avoid using closures here because on file systems that are case-insensitive, the + // compiled class file for the closure here will conflict with the one in callUDF (upper case). + val exprs = new Array[Expression](cols.size) + var i = 0 + while (i < cols.size) { + exprs(i) = cols(i).expr + i += 1 + } + UnresolvedFunction(udfName, exprs) } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/DriverQuirks.scala b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/DriverQuirks.scala deleted file mode 100644 index 0feabc4282f4..000000000000 --- a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/DriverQuirks.scala +++ /dev/null @@ -1,99 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.sql.jdbc - -import org.apache.spark.sql.types._ - -import java.sql.Types - - -/** - * Encapsulates workarounds for the extensions, quirks, and bugs in various - * databases. Lots of databases define types that aren't explicitly supported - * by the JDBC spec. Some JDBC drivers also report inaccurate - * information---for instance, BIT(n>1) being reported as a BIT type is quite - * common, even though BIT in JDBC is meant for single-bit values. Also, there - * does not appear to be a standard name for an unbounded string or binary - * type; we use BLOB and CLOB by default but override with database-specific - * alternatives when these are absent or do not behave correctly. - * - * Currently, the only thing DriverQuirks does is handle type mapping. - * `getCatalystType` is used when reading from a JDBC table and `getJDBCType` - * is used when writing to a JDBC table. If `getCatalystType` returns `null`, - * the default type handling is used for the given JDBC type. Similarly, - * if `getJDBCType` returns `(null, None)`, the default type handling is used - * for the given Catalyst type. - */ -private[sql] abstract class DriverQuirks { - def getCatalystType(sqlType: Int, typeName: String, size: Int, md: MetadataBuilder): DataType - def getJDBCType(dt: DataType): (String, Option[Int]) -} - -private[sql] object DriverQuirks { - /** - * Fetch the DriverQuirks class corresponding to a given database url. - */ - def get(url: String): DriverQuirks = { - if (url.startsWith("jdbc:mysql")) { - new MySQLQuirks() - } else if (url.startsWith("jdbc:postgresql")) { - new PostgresQuirks() - } else { - new NoQuirks() - } - } -} - -private[sql] class NoQuirks extends DriverQuirks { - def getCatalystType(sqlType: Int, typeName: String, size: Int, md: MetadataBuilder): DataType = - null - def getJDBCType(dt: DataType): (String, Option[Int]) = (null, None) -} - -private[sql] class PostgresQuirks extends DriverQuirks { - def getCatalystType(sqlType: Int, typeName: String, size: Int, md: MetadataBuilder): DataType = { - if (sqlType == Types.BIT && typeName.equals("bit") && size != 1) { - BinaryType - } else if (sqlType == Types.OTHER && typeName.equals("cidr")) { - StringType - } else if (sqlType == Types.OTHER && typeName.equals("inet")) { - StringType - } else null - } - - def getJDBCType(dt: DataType): (String, Option[Int]) = dt match { - case StringType => ("TEXT", Some(java.sql.Types.CHAR)) - case BinaryType => ("BYTEA", Some(java.sql.Types.BINARY)) - case BooleanType => ("BOOLEAN", Some(java.sql.Types.BOOLEAN)) - case _ => (null, None) - } -} - -private[sql] class MySQLQuirks extends DriverQuirks { - def getCatalystType(sqlType: Int, typeName: String, size: Int, md: MetadataBuilder): DataType = { - if (sqlType == Types.VARBINARY && typeName.equals("BIT") && size != 1) { - // This could instead be a BinaryType if we'd rather return bit-vectors of up to 64 bits as - // byte arrays instead of longs. - md.putLong("binarylong", 1) - LongType - } else if (sqlType == Types.BIT && typeName.equals("TINYINT")) { - BooleanType - } else null - } - def getJDBCType(dt: DataType): (String, Option[Int]) = (null, None) -} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/JDBCRDD.scala b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/JDBCRDD.scala index 40483d3ec770..30c5f4ca3e1b 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/JDBCRDD.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/JDBCRDD.scala @@ -22,27 +22,42 @@ import java.util.Properties import org.apache.commons.lang3.StringUtils -import org.apache.spark.{Logging, Partition, SparkContext, TaskContext} import org.apache.spark.rdd.RDD -import org.apache.spark.sql.catalyst.expressions.{Row, SpecificMutableRow} -import org.apache.spark.sql.catalyst.util.DateUtils -import org.apache.spark.sql.types._ +import org.apache.spark.sql.catalyst.expressions.{InternalRow, SpecificMutableRow} +import org.apache.spark.sql.catalyst.util.DateTimeUtils import org.apache.spark.sql.sources._ +import org.apache.spark.sql.types._ +import org.apache.spark.unsafe.types.UTF8String +import org.apache.spark.{Logging, Partition, SparkContext, TaskContext} + +/** + * Data corresponding to one partition of a JDBCRDD. + */ +private[sql] case class JDBCPartition(whereClause: String, idx: Int) extends Partition { + override def index: Int = idx +} + private[sql] object JDBCRDD extends Logging { + /** * Maps a JDBC type to a Catalyst type. This function is called only when - * the DriverQuirks class corresponding to your database driver returns null. + * the JdbcDialect class corresponding to your database driver returns null. * * @param sqlType - A field of java.sql.Types * @return The Catalyst type corresponding to sqlType. */ - private def getCatalystType(sqlType: Int, precision: Int, scale: Int): DataType = { + private def getCatalystType( + sqlType: Int, + precision: Int, + scale: Int, + signed: Boolean): DataType = { val answer = sqlType match { + // scalastyle:off case java.sql.Types.ARRAY => null - case java.sql.Types.BIGINT => LongType + case java.sql.Types.BIGINT => if (signed) { LongType } else { DecimalType(20,0) } case java.sql.Types.BINARY => BinaryType - case java.sql.Types.BIT => BooleanType // Per JDBC; Quirks handles quirky drivers. + case java.sql.Types.BIT => BooleanType // @see JdbcDialect for quirks case java.sql.Types.BLOB => BinaryType case java.sql.Types.BOOLEAN => BooleanType case java.sql.Types.CHAR => StringType @@ -55,7 +70,7 @@ private[sql] object JDBCRDD extends Logging { case java.sql.Types.DISTINCT => null case java.sql.Types.DOUBLE => DoubleType case java.sql.Types.FLOAT => FloatType - case java.sql.Types.INTEGER => IntegerType + case java.sql.Types.INTEGER => if (signed) { IntegerType } else { LongType } case java.sql.Types.JAVA_OBJECT => null case java.sql.Types.LONGNVARCHAR => StringType case java.sql.Types.LONGVARBINARY => BinaryType @@ -79,7 +94,8 @@ private[sql] object JDBCRDD extends Logging { case java.sql.Types.TINYINT => IntegerType case java.sql.Types.VARBINARY => BinaryType case java.sql.Types.VARCHAR => StringType - case _ => null + case _ => null + // scalastyle:on } if (answer == null) throw new SQLException("Unsupported type " + sqlType) @@ -99,7 +115,7 @@ private[sql] object JDBCRDD extends Logging { * @throws SQLException if the table contains an unsupported type. */ def resolveTable(url: String, table: String, properties: Properties): StructType = { - val quirks = DriverQuirks.get(url) + val dialect = JdbcDialects.get(url) val conn: Connection = DriverManager.getConnection(url, properties) try { val rs = conn.prepareStatement(s"SELECT * FROM $table WHERE 1=0").executeQuery() @@ -114,10 +130,12 @@ private[sql] object JDBCRDD extends Logging { val typeName = rsmd.getColumnTypeName(i + 1) val fieldSize = rsmd.getPrecision(i + 1) val fieldScale = rsmd.getScale(i + 1) + val isSigned = rsmd.isSigned(i + 1) val nullable = rsmd.isNullable(i + 1) != ResultSetMetaData.columnNoNulls val metadata = new MetadataBuilder().putString("name", columnName) - var columnType = quirks.getCatalystType(dataType, typeName, fieldSize, metadata) - if (columnType == null) columnType = getCatalystType(dataType, fieldSize, fieldScale) + val columnType = + dialect.getCatalystType(dataType, typeName, fieldSize, metadata).getOrElse( + getCatalystType(dataType, fieldSize, fieldScale, isSigned)) fields(i) = StructField(columnName, columnType, nullable, metadata.build()) i = i + 1 } @@ -168,6 +186,7 @@ private[sql] object JDBCRDD extends Logging { DriverManager.getConnection(url, properties) } } + /** * Build and return JDBCRDD from the given information. * @@ -192,19 +211,18 @@ private[sql] object JDBCRDD extends Logging { fqTable: String, requiredColumns: Array[String], filters: Array[Filter], - parts: Array[Partition]): RDD[Row] = { - - val prunedSchema = pruneSchema(schema, requiredColumns) - - return new - JDBCRDD( - sc, - getConnector(driver, url, properties), - prunedSchema, - fqTable, - requiredColumns, - filters, - parts) + parts: Array[Partition]): RDD[InternalRow] = { + val dialect = JdbcDialects.get(url) + val quotedColumns = requiredColumns.map(colName => dialect.quoteIdentifier(colName)) + new JDBCRDD( + sc, + getConnector(driver, url, properties), + pruneSchema(schema, requiredColumns), + fqTable, + quotedColumns, + filters, + parts, + properties) } } @@ -220,8 +238,9 @@ private[sql] class JDBCRDD( fqTable: String, columns: Array[String], filters: Array[Filter], - partitions: Array[Partition]) - extends RDD[Row](sc, Nil) { + partitions: Array[Partition], + properties: Properties) + extends RDD[InternalRow](sc, Nil) { /** * Retrieve the list of partitions corresponding to this RDD. @@ -246,7 +265,7 @@ private[sql] class JDBCRDD( } private def escapeSql(value: String): String = - if (value == null) null else StringUtils.replace(value, "'", "''") + if (value == null) null else StringUtils.replace(value, "'", "''") /** * Turns a single Filter into a String representing a SQL expression. @@ -288,13 +307,13 @@ private[sql] class JDBCRDD( // Each JDBC-to-Catalyst conversion corresponds to a tag defined here so that // we don't have to potentially poke around in the Metadata once for every - // row. + // row. // Is there a better way to do this? I'd rather be using a type that // contains only the tags I define. abstract class JDBCConversion case object BooleanConversion extends JDBCConversion case object DateConversion extends JDBCConversion - case object DecimalConversion extends JDBCConversion + case class DecimalConversion(precisionInfo: Option[(Int, Int)]) extends JDBCConversion case object DoubleConversion extends JDBCConversion case object FloatConversion extends JDBCConversion case object IntegerConversion extends JDBCConversion @@ -309,19 +328,19 @@ private[sql] class JDBCRDD( */ def getConversions(schema: StructType): Array[JDBCConversion] = { schema.fields.map(sf => sf.dataType match { - case BooleanType => BooleanConversion - case DateType => DateConversion - case DecimalType.Unlimited => DecimalConversion - case DecimalType.Fixed(d) => DecimalConversion - case DoubleType => DoubleConversion - case FloatType => FloatConversion - case IntegerType => IntegerConversion - case LongType => + case BooleanType => BooleanConversion + case DateType => DateConversion + case DecimalType.Unlimited => DecimalConversion(None) + case DecimalType.Fixed(d) => DecimalConversion(Some(d)) + case DoubleType => DoubleConversion + case FloatType => FloatConversion + case IntegerType => IntegerConversion + case LongType => if (sf.metadata.contains("binarylong")) BinaryLongConversion else LongConversion - case StringType => StringConversion - case TimestampType => TimestampConversion - case BinaryType => BinaryConversion - case _ => throw new IllegalArgumentException(s"Unsupported field $sf") + case StringType => StringConversion + case TimestampType => TimestampConversion + case BinaryType => BinaryConversion + case _ => throw new IllegalArgumentException(s"Unsupported field $sf") }).toArray } @@ -329,12 +348,12 @@ private[sql] class JDBCRDD( /** * Runs the SQL query against the JDBC driver. */ - override def compute(thePart: Partition, context: TaskContext): Iterator[Row] = new Iterator[Row] - { + override def compute(thePart: Partition, context: TaskContext): Iterator[InternalRow] = + new Iterator[InternalRow] { var closed = false var finished = false var gotNext = false - var nextValue: Row = null + var nextValue: InternalRow = null context.addTaskCompletionListener{ context => close() } val part = thePart.asInstanceOf[JDBCPartition] @@ -349,41 +368,64 @@ private[sql] class JDBCRDD( val sqlText = s"SELECT $columnList FROM $fqTable $myWhereClause" val stmt = conn.prepareStatement(sqlText, ResultSet.TYPE_FORWARD_ONLY, ResultSet.CONCUR_READ_ONLY) + val fetchSize = properties.getProperty("fetchSize", "0").toInt + stmt.setFetchSize(fetchSize) val rs = stmt.executeQuery() val conversions = getConversions(schema) val mutableRow = new SpecificMutableRow(schema.fields.map(x => x.dataType)) - def getNext(): Row = { + def getNext(): InternalRow = { if (rs.next()) { var i = 0 while (i < conversions.length) { val pos = i + 1 conversions(i) match { - case BooleanConversion => mutableRow.setBoolean(i, rs.getBoolean(pos)) - case DateConversion => - // DateUtils.fromJavaDate does not handle null value, so we need to check it. + case BooleanConversion => mutableRow.setBoolean(i, rs.getBoolean(pos)) + case DateConversion => + // DateTimeUtils.fromJavaDate does not handle null value, so we need to check it. val dateVal = rs.getDate(pos) if (dateVal != null) { - mutableRow.update(i, DateUtils.fromJavaDate(dateVal)) + mutableRow.setInt(i, DateTimeUtils.fromJavaDate(dateVal)) } else { mutableRow.update(i, null) } - case DecimalConversion => + // When connecting with Oracle DB through JDBC, the precision and scale of BigDecimal + // object returned by ResultSet.getBigDecimal is not correctly matched to the table + // schema reported by ResultSetMetaData.getPrecision and ResultSetMetaData.getScale. + // If inserting values like 19999 into a column with NUMBER(12, 2) type, you get through + // a BigDecimal object with scale as 0. But the dataframe schema has correct type as + // DecimalType(12, 2). Thus, after saving the dataframe into parquet file and then + // retrieve it, you will get wrong result 199.99. + // So it is needed to set precision and scale for Decimal based on JDBC metadata. + case DecimalConversion(Some((p, s))) => + val decimalVal = rs.getBigDecimal(pos) + if (decimalVal == null) { + mutableRow.update(i, null) + } else { + mutableRow.update(i, Decimal(decimalVal, p, s)) + } + case DecimalConversion(None) => val decimalVal = rs.getBigDecimal(pos) if (decimalVal == null) { mutableRow.update(i, null) } else { mutableRow.update(i, Decimal(decimalVal)) } - case DoubleConversion => mutableRow.setDouble(i, rs.getDouble(pos)) - case FloatConversion => mutableRow.setFloat(i, rs.getFloat(pos)) - case IntegerConversion => mutableRow.setInt(i, rs.getInt(pos)) - case LongConversion => mutableRow.setLong(i, rs.getLong(pos)) + case DoubleConversion => mutableRow.setDouble(i, rs.getDouble(pos)) + case FloatConversion => mutableRow.setFloat(i, rs.getFloat(pos)) + case IntegerConversion => mutableRow.setInt(i, rs.getInt(pos)) + case LongConversion => mutableRow.setLong(i, rs.getLong(pos)) // TODO(davies): use getBytes for better performance, if the encoding is UTF-8 - case StringConversion => mutableRow.setString(i, rs.getString(pos)) - case TimestampConversion => mutableRow.update(i, rs.getTimestamp(pos)) - case BinaryConversion => mutableRow.update(i, rs.getBytes(pos)) + case StringConversion => mutableRow.update(i, UTF8String.fromString(rs.getString(pos))) + case TimestampConversion => + val t = rs.getTimestamp(pos) + if (t != null) { + mutableRow.setLong(i, DateTimeUtils.fromJavaTimestamp(t)) + } else { + mutableRow.update(i, null) + } + case BinaryConversion => mutableRow.update(i, rs.getBytes(pos)) case BinaryLongConversion => { val bytes = rs.getBytes(pos) var ans = 0L @@ -401,7 +443,7 @@ private[sql] class JDBCRDD( mutableRow } else { finished = true - null.asInstanceOf[Row] + null.asInstanceOf[InternalRow] } } @@ -444,7 +486,7 @@ private[sql] class JDBCRDD( !finished } - override def next(): Row = { + override def next(): InternalRow = { if (!hasNext) { throw new NoSuchElementException("End of stream") } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/JDBCRelation.scala b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/JDBCRelation.scala index 93e82549f213..4d3aac464c53 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/JDBCRelation.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/JDBCRelation.scala @@ -17,26 +17,15 @@ package org.apache.spark.sql.jdbc -import java.sql.DriverManager import java.util.Properties import scala.collection.mutable.ArrayBuffer import org.apache.spark.Partition import org.apache.spark.rdd.RDD -import org.apache.spark.sql.DataFrame -import org.apache.spark.sql.SQLContext -import org.apache.spark.sql.catalyst.expressions.Row import org.apache.spark.sql.sources._ import org.apache.spark.sql.types.StructType -import org.apache.spark.util.Utils - -/** - * Data corresponding to one partition of a JDBCRDD. - */ -private[sql] case class JDBCPartition(whereClause: String, idx: Int) extends Partition { - override def index: Int = idx -} +import org.apache.spark.sql.{DataFrame, Row, SQLContext, SaveMode} /** * Instructions on how to partition the table among workers. @@ -64,7 +53,7 @@ private[sql] object JDBCRelation { if (numPartitions == 1) return Array[Partition](JDBCPartition(null, 0)) // Overflow and silliness can happen if you subtract then divide. // Here we get a little roundoff, but that's (hopefully) OK. - val stride: Long = (partitioning.upperBound / numPartitions + val stride: Long = (partitioning.upperBound / numPartitions - partitioning.lowerBound / numPartitions) var i: Int = 0 var currentValue: Long = partitioning.lowerBound @@ -148,10 +137,12 @@ private[sql] case class JDBCRelation( table, requiredColumns, filters, - parts) + parts).map(_.asInstanceOf[Row]) } - + override def insert(data: DataFrame, overwrite: Boolean): Unit = { - data.insertIntoJDBC(url, table, overwrite, properties) - } + data.write + .mode(if (overwrite) SaveMode.Overwrite else SaveMode.Append) + .jdbc(url, table, properties) + } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/JdbcDialects.scala b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/JdbcDialects.scala new file mode 100644 index 000000000000..8849fc2f1f0e --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/JdbcDialects.scala @@ -0,0 +1,224 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.jdbc + +import java.sql.Types + +import org.apache.spark.sql.types._ +import org.apache.spark.annotation.DeveloperApi + +/** + * :: DeveloperApi :: + * A database type definition coupled with the jdbc type needed to send null + * values to the database. + * @param databaseTypeDefinition The database type definition + * @param jdbcNullType The jdbc type (as defined in java.sql.Types) used to + * send a null value to the database. + */ +@DeveloperApi +case class JdbcType(databaseTypeDefinition : String, jdbcNullType : Int) + +/** + * :: DeveloperApi :: + * Encapsulates everything (extensions, workarounds, quirks) to handle the + * SQL dialect of a certain database or jdbc driver. + * Lots of databases define types that aren't explicitly supported + * by the JDBC spec. Some JDBC drivers also report inaccurate + * information---for instance, BIT(n>1) being reported as a BIT type is quite + * common, even though BIT in JDBC is meant for single-bit values. Also, there + * does not appear to be a standard name for an unbounded string or binary + * type; we use BLOB and CLOB by default but override with database-specific + * alternatives when these are absent or do not behave correctly. + * + * Currently, the only thing done by the dialect is type mapping. + * `getCatalystType` is used when reading from a JDBC table and `getJDBCType` + * is used when writing to a JDBC table. If `getCatalystType` returns `null`, + * the default type handling is used for the given JDBC type. Similarly, + * if `getJDBCType` returns `(null, None)`, the default type handling is used + * for the given Catalyst type. + */ +@DeveloperApi +abstract class JdbcDialect { + /** + * Check if this dialect instance can handle a certain jdbc url. + * @param url the jdbc url. + * @return True if the dialect can be applied on the given jdbc url. + * @throws NullPointerException if the url is null. + */ + def canHandle(url : String): Boolean + + /** + * Get the custom datatype mapping for the given jdbc meta information. + * @param sqlType The sql type (see java.sql.Types) + * @param typeName The sql type name (e.g. "BIGINT UNSIGNED") + * @param size The size of the type. + * @param md Result metadata associated with this type. + * @return The actual DataType (subclasses of [[org.apache.spark.sql.types.DataType]]) + * or null if the default type mapping should be used. + */ + def getCatalystType( + sqlType: Int, typeName: String, size: Int, md: MetadataBuilder): Option[DataType] = None + + /** + * Retrieve the jdbc / sql type for a given datatype. + * @param dt The datatype (e.g. [[org.apache.spark.sql.types.StringType]]) + * @return The new JdbcType if there is an override for this DataType + */ + def getJDBCType(dt: DataType): Option[JdbcType] = None + + /** + * Quotes the identifier. This is used to put quotes around the identifier in case the column + * name is a reserved keyword, or in case it contains characters that require quotes (e.g. space). + */ + def quoteIdentifier(colName: String): String = { + s""""$colName"""" + } +} + +/** + * :: DeveloperApi :: + * Registry of dialects that apply to every new jdbc [[org.apache.spark.sql.DataFrame]]. + * + * If multiple matching dialects are registered then all matching ones will be + * tried in reverse order. A user-added dialect will thus be applied first, + * overwriting the defaults. + * + * Note that all new dialects are applied to new jdbc DataFrames only. Make + * sure to register your dialects first. + */ +@DeveloperApi +object JdbcDialects { + + private var dialects = List[JdbcDialect]() + + /** + * Register a dialect for use on all new matching jdbc [[org.apache.spark.sql.DataFrame]]. + * Readding an existing dialect will cause a move-to-front. + * @param dialect The new dialect. + */ + def registerDialect(dialect: JdbcDialect) : Unit = { + dialects = dialect :: dialects.filterNot(_ == dialect) + } + + /** + * Unregister a dialect. Does nothing if the dialect is not registered. + * @param dialect The jdbc dialect. + */ + def unregisterDialect(dialect : JdbcDialect) : Unit = { + dialects = dialects.filterNot(_ == dialect) + } + + registerDialect(MySQLDialect) + registerDialect(PostgresDialect) + + /** + * Fetch the JdbcDialect class corresponding to a given database url. + */ + private[sql] def get(url: String): JdbcDialect = { + val matchingDialects = dialects.filter(_.canHandle(url)) + matchingDialects.length match { + case 0 => NoopDialect + case 1 => matchingDialects.head + case _ => new AggregatedDialect(matchingDialects) + } + } +} + +/** + * :: DeveloperApi :: + * AggregatedDialect can unify multiple dialects into one virtual Dialect. + * Dialects are tried in order, and the first dialect that does not return a + * neutral element will will. + * @param dialects List of dialects. + */ +@DeveloperApi +class AggregatedDialect(dialects: List[JdbcDialect]) extends JdbcDialect { + + require(dialects.nonEmpty) + + override def canHandle(url : String): Boolean = + dialects.map(_.canHandle(url)).reduce(_ && _) + + override def getCatalystType( + sqlType: Int, typeName: String, size: Int, md: MetadataBuilder): Option[DataType] = { + dialects.flatMap(_.getCatalystType(sqlType, typeName, size, md)).headOption + } + + override def getJDBCType(dt: DataType): Option[JdbcType] = { + dialects.flatMap(_.getJDBCType(dt)).headOption + } +} + +/** + * :: DeveloperApi :: + * NOOP dialect object, always returning the neutral element. + */ +@DeveloperApi +case object NoopDialect extends JdbcDialect { + override def canHandle(url : String): Boolean = true +} + +/** + * :: DeveloperApi :: + * Default postgres dialect, mapping bit/cidr/inet on read and string/binary/boolean on write. + */ +@DeveloperApi +case object PostgresDialect extends JdbcDialect { + override def canHandle(url: String): Boolean = url.startsWith("jdbc:postgresql") + override def getCatalystType( + sqlType: Int, typeName: String, size: Int, md: MetadataBuilder): Option[DataType] = { + if (sqlType == Types.BIT && typeName.equals("bit") && size != 1) { + Some(BinaryType) + } else if (sqlType == Types.OTHER && typeName.equals("cidr")) { + Some(StringType) + } else if (sqlType == Types.OTHER && typeName.equals("inet")) { + Some(StringType) + } else None + } + + override def getJDBCType(dt: DataType): Option[JdbcType] = dt match { + case StringType => Some(JdbcType("TEXT", java.sql.Types.CHAR)) + case BinaryType => Some(JdbcType("BYTEA", java.sql.Types.BINARY)) + case BooleanType => Some(JdbcType("BOOLEAN", java.sql.Types.BOOLEAN)) + case _ => None + } +} + +/** + * :: DeveloperApi :: + * Default mysql dialect to read bit/bitsets correctly. + */ +@DeveloperApi +case object MySQLDialect extends JdbcDialect { + override def canHandle(url : String): Boolean = url.startsWith("jdbc:mysql") + override def getCatalystType( + sqlType: Int, typeName: String, size: Int, md: MetadataBuilder): Option[DataType] = { + if (sqlType == Types.VARBINARY && typeName.equals("BIT") && size != 1) { + // This could instead be a BinaryType if we'd rather return bit-vectors of up to 64 bits as + // byte arrays instead of longs. + md.putLong("binarylong", 1) + Some(LongType) + } else if (sqlType == Types.BIT && typeName.equals("TINYINT")) { + Some(BooleanType) + } else None + } + + override def quoteIdentifier(colName: String): String = { + s"`$colName`" + } +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/JdbcUtils.scala b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/JdbcUtils.scala new file mode 100644 index 000000000000..cc918c237192 --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/JdbcUtils.scala @@ -0,0 +1,52 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.jdbc + +import java.sql.{Connection, DriverManager} +import java.util.Properties + +import scala.util.Try + +/** + * Util functions for JDBC tables. + */ +private[sql] object JdbcUtils { + + /** + * Establishes a JDBC connection. + */ + def createConnection(url: String, connectionProperties: Properties): Connection = { + DriverManager.getConnection(url, connectionProperties) + } + + /** + * Returns true if the table already exists in the JDBC database. + */ + def tableExists(conn: Connection, table: String): Boolean = { + // Somewhat hacky, but there isn't a good way to identify whether a table exists for all + // SQL database systems, considering "table" could also include the database name. + Try(conn.prepareStatement(s"SELECT 1 FROM $table LIMIT 1").executeQuery().next()).isSuccess + } + + /** + * Drops a table from the JDBC database. + */ + def dropTable(conn: Connection, table: String): Unit = { + conn.prepareStatement(s"DROP TABLE $table").executeUpdate() + } +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/jdbc.scala b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/jdbc.scala index c099881a0122..dd8aaf647489 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/jdbc.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/jdbc.scala @@ -129,25 +129,26 @@ package object jdbc { */ def schemaString(df: DataFrame, url: String): String = { val sb = new StringBuilder() - val quirks = DriverQuirks.get(url) + val dialect = JdbcDialects.get(url) df.schema.fields foreach { field => { val name = field.name - var typ: String = quirks.getJDBCType(field.dataType)._1 - if (typ == null) typ = field.dataType match { - case IntegerType => "INTEGER" - case LongType => "BIGINT" - case DoubleType => "DOUBLE PRECISION" - case FloatType => "REAL" - case ShortType => "INTEGER" - case ByteType => "BYTE" - case BooleanType => "BIT(1)" - case StringType => "TEXT" - case BinaryType => "BLOB" - case TimestampType => "TIMESTAMP" - case DateType => "DATE" - case DecimalType.Unlimited => "DECIMAL(40,20)" - case _ => throw new IllegalArgumentException(s"Don't know how to save $field to JDBC") - } + val typ: String = + dialect.getJDBCType(field.dataType).map(_.databaseTypeDefinition).getOrElse( + field.dataType match { + case IntegerType => "INTEGER" + case LongType => "BIGINT" + case DoubleType => "DOUBLE PRECISION" + case FloatType => "REAL" + case ShortType => "INTEGER" + case ByteType => "BYTE" + case BooleanType => "BIT(1)" + case StringType => "TEXT" + case BinaryType => "BLOB" + case TimestampType => "TIMESTAMP" + case DateType => "DATE" + case DecimalType.Unlimited => "DECIMAL(40,20)" + case _ => throw new IllegalArgumentException(s"Don't know how to save $field to JDBC") + }) val nullable = if (field.nullable) "" else "NOT NULL" sb.append(s", $name $typ $nullable") }} @@ -162,10 +163,9 @@ package object jdbc { url: String, table: String, properties: Properties = new Properties()) { - val quirks = DriverQuirks.get(url) - var nullTypes: Array[Int] = df.schema.fields.map(field => { - var nullType: Option[Int] = quirks.getJDBCType(field.dataType)._2 - if (nullType.isEmpty) { + val dialect = JdbcDialects.get(url) + val nullTypes: Array[Int] = df.schema.fields.map { field => + dialect.getJDBCType(field.dataType).map(_.jdbcNullType).getOrElse( field.dataType match { case IntegerType => java.sql.Types.INTEGER case LongType => java.sql.Types.BIGINT @@ -181,9 +181,8 @@ package object jdbc { case DecimalType.Unlimited => java.sql.Types.DECIMAL case _ => throw new IllegalArgumentException( s"Can't translate null value for field $field") - } - } else nullType.get - }).toArray + }) + } val rddSchema = df.schema df.foreachPartition { iterator => @@ -241,10 +240,10 @@ package object jdbc { } } } - + def getDriverClassName(url: String): String = DriverManager.getDriver(url) match { case wrapper: DriverWrapper => wrapper.wrapped.getClass.getCanonicalName - case driver => driver.getClass.getCanonicalName + case driver => driver.getClass.getCanonicalName } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/json/InferSchema.scala b/sql/core/src/main/scala/org/apache/spark/sql/json/InferSchema.scala index 9c58b8e4bb16..afe2c6c11ac6 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/json/InferSchema.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/json/InferSchema.scala @@ -43,7 +43,7 @@ private[sql] object InferSchema { } // perform schema inference on each row and merge afterwards - schemaData.mapPartitions { iter => + val rootType = schemaData.mapPartitions { iter => val factory = new JsonFactory() iter.map { row => try { @@ -55,8 +55,13 @@ private[sql] object InferSchema { StructType(Seq(StructField(columnNameOfCorruptRecords, StringType))) } } - }.treeAggregate[DataType](StructType(Seq()))(compatibleRootType, compatibleRootType) match { - case st: StructType => nullTypeToStringType(st) + }.treeAggregate[DataType](StructType(Seq()))(compatibleRootType, compatibleRootType) + + canonicalizeType(rootType) match { + case Some(st: StructType) => st + case _ => + // canonicalizeType erases all empty structs, including the only one we want to keep + StructType(Seq()) } } @@ -116,22 +121,35 @@ private[sql] object InferSchema { } } - private def nullTypeToStringType(struct: StructType): StructType = { - val fields = struct.fields.map { - case StructField(fieldName, dataType, nullable, _) => - val newType = dataType match { - case NullType => StringType - case ArrayType(NullType, containsNull) => ArrayType(StringType, containsNull) - case ArrayType(struct: StructType, containsNull) => - ArrayType(nullTypeToStringType(struct), containsNull) - case struct: StructType =>nullTypeToStringType(struct) - case other: DataType => other - } + /** + * Convert NullType to StringType and remove StructTypes with no fields + */ + private def canonicalizeType: DataType => Option[DataType] = { + case at@ArrayType(elementType, _) => + for { + canonicalType <- canonicalizeType(elementType) + } yield { + at.copy(canonicalType) + } - StructField(fieldName, newType, nullable) - } + case StructType(fields) => + val canonicalFields = for { + field <- fields + if field.name.nonEmpty + canonicalType <- canonicalizeType(field.dataType) + } yield { + field.copy(dataType = canonicalType) + } + + if (canonicalFields.nonEmpty) { + Some(StructType(canonicalFields)) + } else { + // per SPARK-8093: empty structs should be deleted + None + } - StructType(fields) + case NullType => Some(StringType) + case other => Some(other) } /** @@ -147,7 +165,7 @@ private[sql] object InferSchema { * Returns the most general data type for two given data types. */ private[json] def compatibleType(t1: DataType, t2: DataType): DataType = { - HiveTypeCoercion.findTightestCommonType(t1, t2).getOrElse { + HiveTypeCoercion.findTightestCommonTypeOfTwo(t1, t2).getOrElse { // t1 or t2 is a StructType, ArrayType, or an unexpected type. (t1, t2) match { case (other: DataType, NullType) => other diff --git a/sql/core/src/main/scala/org/apache/spark/sql/json/JSONRelation.scala b/sql/core/src/main/scala/org/apache/spark/sql/json/JSONRelation.scala index c772cd1f53e5..69bf13e1e5a6 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/json/JSONRelation.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/json/JSONRelation.scala @@ -22,10 +22,10 @@ import java.io.IOException import org.apache.hadoop.fs.Path import org.apache.spark.rdd.RDD -import org.apache.spark.sql.catalyst.expressions.{Expression, Attribute, Row} +import org.apache.spark.sql.catalyst.expressions.{Attribute, Expression} import org.apache.spark.sql.sources._ -import org.apache.spark.sql.types.{StructField, StructType} -import org.apache.spark.sql.{DataFrame, SQLContext, SaveMode} +import org.apache.spark.sql.types.StructType +import org.apache.spark.sql.{DataFrame, Row, SQLContext, SaveMode} private[sql] class DefaultSource @@ -154,12 +154,12 @@ private[sql] class JSONRelation( JacksonParser( baseRDD(), schema, - sqlContext.conf.columnNameOfCorruptRecord) + sqlContext.conf.columnNameOfCorruptRecord).map(_.asInstanceOf[Row]) } else { JsonRDD.jsonStringToRow( baseRDD(), schema, - sqlContext.conf.columnNameOfCorruptRecord) + sqlContext.conf.columnNameOfCorruptRecord).map(_.asInstanceOf[Row]) } } @@ -168,12 +168,12 @@ private[sql] class JSONRelation( JacksonParser( baseRDD(), StructType.fromAttributes(requiredColumns), - sqlContext.conf.columnNameOfCorruptRecord) + sqlContext.conf.columnNameOfCorruptRecord).map(_.asInstanceOf[Row]) } else { JsonRDD.jsonStringToRow( baseRDD(), StructType.fromAttributes(requiredColumns), - sqlContext.conf.columnNameOfCorruptRecord) + sqlContext.conf.columnNameOfCorruptRecord).map(_.asInstanceOf[Row]) } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/json/JacksonGenerator.scala b/sql/core/src/main/scala/org/apache/spark/sql/json/JacksonGenerator.scala index 80bf74aa0260..1e6b1198d245 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/json/JacksonGenerator.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/json/JacksonGenerator.scala @@ -21,7 +21,7 @@ import scala.collection.Map import com.fasterxml.jackson.core._ -import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.Row import org.apache.spark.sql.types._ private[sql] object JacksonGenerator { @@ -33,7 +33,7 @@ private[sql] object JacksonGenerator { */ def apply(rowSchema: StructType, gen: JsonGenerator)(row: Row): Unit = { def valWriter: (DataType, Any) => Unit = { - case (_, null) | (NullType, _) => gen.writeNull() + case (_, null) | (NullType, _) => gen.writeNull() case (StringType, v: String) => gen.writeString(v) case (TimestampType, v: java.sql.Timestamp) => gen.writeString(v.toString) case (IntegerType, v: Int) => gen.writeNumber(v) @@ -48,16 +48,16 @@ private[sql] object JacksonGenerator { case (DateType, v) => gen.writeString(v.toString) case (udt: UserDefinedType[_], v) => valWriter(udt.sqlType, udt.serialize(v)) - case (ArrayType(ty, _), v: Seq[_] ) => + case (ArrayType(ty, _), v: Seq[_]) => gen.writeStartArray() - v.foreach(valWriter(ty,_)) + v.foreach(valWriter(ty, _)) gen.writeEndArray() - case (MapType(kv,vv, _), v: Map[_,_]) => + case (MapType(kv, vv, _), v: Map[_, _]) => gen.writeStartObject() v.foreach { p => gen.writeFieldName(p._1.toString) - valWriter(vv,p._2) + valWriter(vv, p._2) } gen.writeEndObject() diff --git a/sql/core/src/main/scala/org/apache/spark/sql/json/JacksonParser.scala b/sql/core/src/main/scala/org/apache/spark/sql/json/JacksonParser.scala index 81611513582a..6222addc9aa3 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/json/JacksonParser.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/json/JacksonParser.scala @@ -18,7 +18,6 @@ package org.apache.spark.sql.json import java.io.ByteArrayOutputStream -import java.sql.Timestamp import scala.collection.Map @@ -26,15 +25,17 @@ import com.fasterxml.jackson.core._ import org.apache.spark.rdd.RDD import org.apache.spark.sql.catalyst.expressions._ -import org.apache.spark.sql.catalyst.util.DateUtils +import org.apache.spark.sql.catalyst.util.DateTimeUtils import org.apache.spark.sql.json.JacksonUtils.nextUntil import org.apache.spark.sql.types._ +import org.apache.spark.unsafe.types.UTF8String + private[sql] object JacksonParser { def apply( json: RDD[String], schema: StructType, - columnNameOfCorruptRecords: String): RDD[Row] = { + columnNameOfCorruptRecords: String): RDD[InternalRow] = { parseJson(json, schema, columnNameOfCorruptRecords) } @@ -55,27 +56,27 @@ private[sql] object JacksonParser { convertField(factory, parser, schema) case (VALUE_STRING, StringType) => - UTF8String(parser.getText) + UTF8String.fromString(parser.getText) case (VALUE_STRING, _) if parser.getTextLength < 1 => // guard the non string type null case (VALUE_STRING, DateType) => - DateUtils.millisToDays(DateUtils.stringToTime(parser.getText).getTime) + DateTimeUtils.millisToDays(DateTimeUtils.stringToTime(parser.getText).getTime) case (VALUE_STRING, TimestampType) => - new Timestamp(DateUtils.stringToTime(parser.getText).getTime) + DateTimeUtils.stringToTime(parser.getText).getTime * 10000L case (VALUE_NUMBER_INT, TimestampType) => - new Timestamp(parser.getLongValue) + parser.getLongValue * 10000L case (_, StringType) => val writer = new ByteArrayOutputStream() val generator = factory.createGenerator(writer, JsonEncoding.UTF8) generator.copyCurrentStructure(parser) generator.close() - UTF8String(writer.toByteArray) + UTF8String.fromBytes(writer.toByteArray) case (VALUE_NUMBER_INT | VALUE_NUMBER_FLOAT, FloatType) => parser.getFloatValue @@ -129,7 +130,10 @@ private[sql] object JacksonParser { * * Fields in the json that are not defined in the requested schema will be dropped. */ - private def convertObject(factory: JsonFactory, parser: JsonParser, schema: StructType): Row = { + private def convertObject( + factory: JsonFactory, + parser: JsonParser, + schema: StructType): InternalRow = { val row = new GenericMutableRow(schema.length) while (nextUntil(parser, JsonToken.END_OBJECT)) { schema.getFieldIndex(parser.getCurrentName) match { @@ -150,10 +154,11 @@ private[sql] object JacksonParser { private def convertMap( factory: JsonFactory, parser: JsonParser, - valueType: DataType): Map[String, Any] = { - val builder = Map.newBuilder[String, Any] + valueType: DataType): Map[UTF8String, Any] = { + val builder = Map.newBuilder[UTF8String, Any] while (nextUntil(parser, JsonToken.END_OBJECT)) { - builder += parser.getCurrentName -> convertField(factory, parser, valueType) + builder += + UTF8String.fromString(parser.getCurrentName) -> convertField(factory, parser, valueType) } builder.result() @@ -174,14 +179,14 @@ private[sql] object JacksonParser { private def parseJson( json: RDD[String], schema: StructType, - columnNameOfCorruptRecords: String): RDD[Row] = { + columnNameOfCorruptRecords: String): RDD[InternalRow] = { - def failedRecord(record: String): Seq[Row] = { + def failedRecord(record: String): Seq[InternalRow] = { // create a row even if no corrupt record column is present val row = new GenericMutableRow(schema.length) for (corruptIndex <- schema.getFieldIndex(columnNameOfCorruptRecords)) { require(schema(corruptIndex).dataType == StringType) - row.update(corruptIndex, record) + row.update(corruptIndex, UTF8String.fromString(record)) } Seq(row) @@ -200,7 +205,7 @@ private[sql] object JacksonParser { // convertField wrap an object into a single value array when necessary. convertField(factory, parser, ArrayType(schema)) match { case null => failedRecord(record) - case list: Seq[Row @unchecked] => list + case list: Seq[InternalRow @unchecked] => list case _ => sys.error( s"Failed to parse record $record. Please make sure that each line of the file " + diff --git a/sql/core/src/main/scala/org/apache/spark/sql/json/JsonRDD.scala b/sql/core/src/main/scala/org/apache/spark/sql/json/JsonRDD.scala index 4c32710a17bc..73d9520d6f53 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/json/JsonRDD.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/json/JsonRDD.scala @@ -17,28 +17,28 @@ package org.apache.spark.sql.json -import java.sql.Timestamp - import scala.collection.Map -import scala.collection.convert.Wrappers.{JMapWrapper, JListWrapper} +import scala.collection.convert.Wrappers.{JListWrapper, JMapWrapper} -import com.fasterxml.jackson.core.{JsonGenerator, JsonProcessingException} +import com.fasterxml.jackson.core.JsonProcessingException import com.fasterxml.jackson.databind.ObjectMapper +import org.apache.spark.Logging import org.apache.spark.rdd.RDD +import org.apache.spark.sql.catalyst.ScalaReflection import org.apache.spark.sql.catalyst.analysis.HiveTypeCoercion import org.apache.spark.sql.catalyst.expressions._ -import org.apache.spark.sql.catalyst.ScalaReflection -import org.apache.spark.sql.catalyst.util.DateUtils +import org.apache.spark.sql.catalyst.util.DateTimeUtils import org.apache.spark.sql.types._ -import org.apache.spark.Logging +import org.apache.spark.unsafe.types.UTF8String + private[sql] object JsonRDD extends Logging { private[sql] def jsonStringToRow( json: RDD[String], schema: StructType, - columnNameOfCorruptRecords: String): RDD[Row] = { + columnNameOfCorruptRecords: String): RDD[InternalRow] = { parseJson(json, columnNameOfCorruptRecords).map(parsed => asRow(parsed, schema)) } @@ -141,7 +141,7 @@ private[sql] object JsonRDD extends Logging { case ArrayType(NullType, containsNull) => ArrayType(StringType, containsNull) case ArrayType(struct: StructType, containsNull) => ArrayType(nullTypeToStringType(struct), containsNull) - case struct: StructType =>nullTypeToStringType(struct) + case struct: StructType => nullTypeToStringType(struct) case other: DataType => other } StructField(fieldName, newType, nullable) @@ -155,7 +155,7 @@ private[sql] object JsonRDD extends Logging { * Returns the most general data type for two given data types. */ private[json] def compatibleType(t1: DataType, t2: DataType): DataType = { - HiveTypeCoercion.findTightestCommonType(t1, t2) match { + HiveTypeCoercion.findTightestCommonTypeOfTwo(t1, t2) match { case Some(commonType) => commonType case None => // t1 or t2 is a StructType, ArrayType, or an unexpected type. @@ -216,7 +216,7 @@ private[sql] object JsonRDD extends Logging { case map: Map[_, _] => StructType(Nil) // We have an array of arrays. If those element arrays do not have the same // element types, we will return ArrayType[StringType]. - case seq: Seq[_] => typeOfArray(seq) + case seq: Seq[_] => typeOfArray(seq) case value => typeOfPrimitiveValue(value) } }.reduce((type1: DataType, type2: DataType) => compatibleType(type1, type2)) @@ -318,7 +318,8 @@ private[sql] object JsonRDD extends Logging { parsed } catch { - case e: JsonProcessingException => Map(columnNameOfCorruptRecords -> record) :: Nil + case e: JsonProcessingException => + Map(columnNameOfCorruptRecords -> UTF8String.fromString(record)) :: Nil } } }) @@ -392,25 +393,25 @@ private[sql] object JsonRDD extends Logging { value match { // only support string as date case value: java.lang.String => - DateUtils.millisToDays(DateUtils.stringToTime(value).getTime) - case value: java.sql.Date => DateUtils.fromJavaDate(value) + DateTimeUtils.millisToDays(DateTimeUtils.stringToTime(value).getTime) + case value: java.sql.Date => DateTimeUtils.fromJavaDate(value) } } - private def toTimestamp(value: Any): Timestamp = { + private def toTimestamp(value: Any): Long = { value match { - case value: java.lang.Integer => new Timestamp(value.asInstanceOf[Int].toLong) - case value: java.lang.Long => new Timestamp(value) - case value: java.lang.String => toTimestamp(DateUtils.stringToTime(value).getTime) + case value: java.lang.Integer => value.asInstanceOf[Int].toLong * 10000L + case value: java.lang.Long => value * 10000L + case value: java.lang.String => DateTimeUtils.stringToTime(value).getTime * 10000L } } - private[json] def enforceCorrectType(value: Any, desiredType: DataType): Any ={ + private[json] def enforceCorrectType(value: Any, desiredType: DataType): Any = { if (value == null) { null } else { desiredType match { - case StringType => UTF8String(toString(value)) + case StringType => UTF8String.fromString(toString(value)) case _ if value == null || value == "" => null // guard the non string type case IntegerType => value.asInstanceOf[IntegerType.InternalType] case LongType => toLong(value) @@ -422,7 +423,10 @@ private[sql] object JsonRDD extends Logging { value.asInstanceOf[Seq[Any]].map(enforceCorrectType(_, elementType)) case MapType(StringType, valueType, _) => val map = value.asInstanceOf[Map[String, Any]] - map.mapValues(enforceCorrectType(_, valueType)).map(identity) + map.map { + case (k, v) => + (UTF8String.fromString(k), enforceCorrectType(v, valueType)) + }.map(identity) case struct: StructType => asRow(value.asInstanceOf[Map[String, Any]], struct) case DateType => toDate(value) case TimestampType => toTimestamp(value) @@ -430,7 +434,7 @@ private[sql] object JsonRDD extends Logging { } } - private def asRow(json: Map[String,Any], schema: StructType): Row = { + private def asRow(json: Map[String, Any], schema: StructType): InternalRow = { // TODO: Reuse the row instead of creating a new one for every record. val row = new GenericMutableRow(schema.fields.length) schema.fields.zipWithIndex.foreach { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/package.scala b/sql/core/src/main/scala/org/apache/spark/sql/package.scala index 3f97a11ceb97..a9c600b139b1 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/package.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/package.scala @@ -44,7 +44,8 @@ package object sql { /** * Type alias for [[DataFrame]]. Kept here for backward source compatibility for Scala. + * @deprecated As of 1.3.0, replaced by `DataFrame`. */ - @deprecated("1.3.0", "use DataFrame") + @deprecated("use DataFrame", "1.3.0") type SchemaRDD = DataFrame } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/parquet/CatalystSchemaConverter.scala b/sql/core/src/main/scala/org/apache/spark/sql/parquet/CatalystSchemaConverter.scala new file mode 100644 index 000000000000..4ab274ec17a0 --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/parquet/CatalystSchemaConverter.scala @@ -0,0 +1,565 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.parquet + +import scala.collection.JavaConversions._ + +import org.apache.hadoop.conf.Configuration +import org.apache.parquet.schema.OriginalType._ +import org.apache.parquet.schema.PrimitiveType.PrimitiveTypeName._ +import org.apache.parquet.schema.Type.Repetition._ +import org.apache.parquet.schema._ + +import org.apache.spark.sql.types._ +import org.apache.spark.sql.{AnalysisException, SQLConf} + +/** + * This converter class is used to convert Parquet [[MessageType]] to Spark SQL [[StructType]] and + * vice versa. + * + * Parquet format backwards-compatibility rules are respected when converting Parquet + * [[MessageType]] schemas. + * + * @see https://github.com/apache/parquet-format/blob/master/LogicalTypes.md + * + * @constructor + * @param assumeBinaryIsString Whether unannotated BINARY fields should be assumed to be Spark SQL + * [[StringType]] fields when converting Parquet a [[MessageType]] to Spark SQL + * [[StructType]]. + * @param assumeInt96IsTimestamp Whether unannotated INT96 fields should be assumed to be Spark SQL + * [[TimestampType]] fields when converting Parquet a [[MessageType]] to Spark SQL + * [[StructType]]. Note that Spark SQL [[TimestampType]] is similar to Hive timestamp, which + * has optional nanosecond precision, but different from `TIME_MILLS` and `TIMESTAMP_MILLIS` + * described in Parquet format spec. + * @param followParquetFormatSpec Whether to generate standard DECIMAL, LIST, and MAP structure when + * converting Spark SQL [[StructType]] to Parquet [[MessageType]]. For Spark 1.4.x and + * prior versions, Spark SQL only supports decimals with a max precision of 18 digits, and + * uses non-standard LIST and MAP structure. Note that the current Parquet format spec is + * backwards-compatible with these settings. If this argument is set to `false`, we fallback + * to old style non-standard behaviors. + */ +private[parquet] class CatalystSchemaConverter( + private val assumeBinaryIsString: Boolean, + private val assumeInt96IsTimestamp: Boolean, + private val followParquetFormatSpec: Boolean) { + + // Only used when constructing converter for converting Spark SQL schema to Parquet schema, in + // which case `assumeInt96IsTimestamp` and `assumeBinaryIsString` are irrelevant. + def this() = this( + assumeBinaryIsString = SQLConf.PARQUET_BINARY_AS_STRING.defaultValue.get, + assumeInt96IsTimestamp = SQLConf.PARQUET_INT96_AS_TIMESTAMP.defaultValue.get, + followParquetFormatSpec = SQLConf.PARQUET_FOLLOW_PARQUET_FORMAT_SPEC.defaultValue.get) + + def this(conf: SQLConf) = this( + assumeBinaryIsString = conf.isParquetBinaryAsString, + assumeInt96IsTimestamp = conf.isParquetINT96AsTimestamp, + followParquetFormatSpec = conf.followParquetFormatSpec) + + def this(conf: Configuration) = this( + assumeBinaryIsString = + conf.getBoolean( + SQLConf.PARQUET_BINARY_AS_STRING.key, + SQLConf.PARQUET_BINARY_AS_STRING.defaultValue.get), + assumeInt96IsTimestamp = + conf.getBoolean( + SQLConf.PARQUET_INT96_AS_TIMESTAMP.key, + SQLConf.PARQUET_INT96_AS_TIMESTAMP.defaultValue.get), + followParquetFormatSpec = + conf.getBoolean( + SQLConf.PARQUET_FOLLOW_PARQUET_FORMAT_SPEC.key, + SQLConf.PARQUET_FOLLOW_PARQUET_FORMAT_SPEC.defaultValue.get)) + + /** + * Converts Parquet [[MessageType]] `parquetSchema` to a Spark SQL [[StructType]]. + */ + def convert(parquetSchema: MessageType): StructType = convert(parquetSchema.asGroupType()) + + private def convert(parquetSchema: GroupType): StructType = { + val fields = parquetSchema.getFields.map { field => + field.getRepetition match { + case OPTIONAL => + StructField(field.getName, convertField(field), nullable = true) + + case REQUIRED => + StructField(field.getName, convertField(field), nullable = false) + + case REPEATED => + throw new AnalysisException( + s"REPEATED not supported outside LIST or MAP. Type: $field") + } + } + + StructType(fields) + } + + /** + * Converts a Parquet [[Type]] to a Spark SQL [[DataType]]. + */ + def convertField(parquetType: Type): DataType = parquetType match { + case t: PrimitiveType => convertPrimitiveField(t) + case t: GroupType => convertGroupField(t.asGroupType()) + } + + private def convertPrimitiveField(field: PrimitiveType): DataType = { + val typeName = field.getPrimitiveTypeName + val originalType = field.getOriginalType + + def typeString = + if (originalType == null) s"$typeName" else s"$typeName ($originalType)" + + def typeNotImplemented() = + throw new AnalysisException(s"Parquet type not yet supported: $typeString") + + def illegalType() = + throw new AnalysisException(s"Illegal Parquet type: $typeString") + + // When maxPrecision = -1, we skip precision range check, and always respect the precision + // specified in field.getDecimalMetadata. This is useful when interpreting decimal types stored + // as binaries with variable lengths. + def makeDecimalType(maxPrecision: Int = -1): DecimalType = { + val precision = field.getDecimalMetadata.getPrecision + val scale = field.getDecimalMetadata.getScale + + CatalystSchemaConverter.analysisRequire( + maxPrecision == -1 || 1 <= precision && precision <= maxPrecision, + s"Invalid decimal precision: $typeName cannot store $precision digits (max $maxPrecision)") + + DecimalType(precision, scale) + } + + typeName match { + case BOOLEAN => BooleanType + + case FLOAT => FloatType + + case DOUBLE => DoubleType + + case INT32 => + originalType match { + case INT_8 => ByteType + case INT_16 => ShortType + case INT_32 | null => IntegerType + case DATE => DateType + case DECIMAL => makeDecimalType(maxPrecisionForBytes(4)) + case TIME_MILLIS => typeNotImplemented() + case _ => illegalType() + } + + case INT64 => + originalType match { + case INT_64 | null => LongType + case DECIMAL => makeDecimalType(maxPrecisionForBytes(8)) + case TIMESTAMP_MILLIS => typeNotImplemented() + case _ => illegalType() + } + + case INT96 => + CatalystSchemaConverter.analysisRequire( + assumeInt96IsTimestamp, + "INT96 is not supported unless it's interpreted as timestamp. " + + s"Please try to set ${SQLConf.PARQUET_INT96_AS_TIMESTAMP.key} to true.") + TimestampType + + case BINARY => + originalType match { + case UTF8 | ENUM => StringType + case null if assumeBinaryIsString => StringType + case null => BinaryType + case DECIMAL => makeDecimalType() + case _ => illegalType() + } + + case FIXED_LEN_BYTE_ARRAY => + originalType match { + case DECIMAL => makeDecimalType(maxPrecisionForBytes(field.getTypeLength)) + case INTERVAL => typeNotImplemented() + case _ => illegalType() + } + + case _ => illegalType() + } + } + + private def convertGroupField(field: GroupType): DataType = { + Option(field.getOriginalType).fold(convert(field): DataType) { + // A Parquet list is represented as a 3-level structure: + // + // group (LIST) { + // repeated group list { + // element; + // } + // } + // + // However, according to the most recent Parquet format spec (not released yet up until + // writing), some 2-level structures are also recognized for backwards-compatibility. Thus, + // we need to check whether the 2nd level or the 3rd level refers to list element type. + // + // See: https://github.com/apache/parquet-format/blob/master/LogicalTypes.md#lists + case LIST => + CatalystSchemaConverter.analysisRequire( + field.getFieldCount == 1, s"Invalid list type $field") + + val repeatedType = field.getType(0) + CatalystSchemaConverter.analysisRequire( + repeatedType.isRepetition(REPEATED), s"Invalid list type $field") + + if (isElementType(repeatedType, field.getName)) { + ArrayType(convertField(repeatedType), containsNull = false) + } else { + val elementType = repeatedType.asGroupType().getType(0) + val optional = elementType.isRepetition(OPTIONAL) + ArrayType(convertField(elementType), containsNull = optional) + } + + // scalastyle:off + // `MAP_KEY_VALUE` is for backwards-compatibility + // See: https://github.com/apache/parquet-format/blob/master/LogicalTypes.md#backward-compatibility-rules-1 + // scalastyle:on + case MAP | MAP_KEY_VALUE => + CatalystSchemaConverter.analysisRequire( + field.getFieldCount == 1 && !field.getType(0).isPrimitive, + s"Invalid map type: $field") + + val keyValueType = field.getType(0).asGroupType() + CatalystSchemaConverter.analysisRequire( + keyValueType.isRepetition(REPEATED) && keyValueType.getFieldCount == 2, + s"Invalid map type: $field") + + val keyType = keyValueType.getType(0) + CatalystSchemaConverter.analysisRequire( + keyType.isPrimitive, + s"Map key type is expected to be a primitive type, but found: $keyType") + + val valueType = keyValueType.getType(1) + val valueOptional = valueType.isRepetition(OPTIONAL) + MapType( + convertField(keyType), + convertField(valueType), + valueContainsNull = valueOptional) + + case _ => + throw new AnalysisException(s"Unrecognized Parquet type: $field") + } + } + + // scalastyle:off + // Here we implement Parquet LIST backwards-compatibility rules. + // See: https://github.com/apache/parquet-format/blob/master/LogicalTypes.md#backward-compatibility-rules + // scalastyle:on + private def isElementType(repeatedType: Type, parentName: String): Boolean = { + { + // For legacy 2-level list types with primitive element type, e.g.: + // + // // List (nullable list, non-null elements) + // optional group my_list (LIST) { + // repeated int32 element; + // } + // + repeatedType.isPrimitive + } || { + // For legacy 2-level list types whose element type is a group type with 2 or more fields, + // e.g.: + // + // // List> (nullable list, non-null elements) + // optional group my_list (LIST) { + // repeated group element { + // required binary str (UTF8); + // required int32 num; + // }; + // } + // + repeatedType.asGroupType().getFieldCount > 1 + } || { + // For legacy 2-level list types generated by parquet-avro (Parquet version < 1.6.0), e.g.: + // + // // List> (nullable list, non-null elements) + // optional group my_list (LIST) { + // repeated group array { + // required binary str (UTF8); + // }; + // } + // + repeatedType.getName == "array" + } || { + // For Parquet data generated by parquet-thrift, e.g.: + // + // // List> (nullable list, non-null elements) + // optional group my_list (LIST) { + // repeated group my_list_tuple { + // required binary str (UTF8); + // }; + // } + // + repeatedType.getName == s"${parentName}_tuple" + } + } + + /** + * Converts a Spark SQL [[StructType]] to a Parquet [[MessageType]]. + */ + def convert(catalystSchema: StructType): MessageType = { + Types.buildMessage().addFields(catalystSchema.map(convertField): _*).named("root") + } + + /** + * Converts a Spark SQL [[StructField]] to a Parquet [[Type]]. + */ + def convertField(field: StructField): Type = { + convertField(field, if (field.nullable) OPTIONAL else REQUIRED) + } + + private def convertField(field: StructField, repetition: Type.Repetition): Type = { + CatalystSchemaConverter.checkFieldName(field.name) + + field.dataType match { + // =================== + // Simple atomic types + // =================== + + case BooleanType => + Types.primitive(BOOLEAN, repetition).named(field.name) + + case ByteType => + Types.primitive(INT32, repetition).as(INT_8).named(field.name) + + case ShortType => + Types.primitive(INT32, repetition).as(INT_16).named(field.name) + + case IntegerType => + Types.primitive(INT32, repetition).named(field.name) + + case LongType => + Types.primitive(INT64, repetition).named(field.name) + + case FloatType => + Types.primitive(FLOAT, repetition).named(field.name) + + case DoubleType => + Types.primitive(DOUBLE, repetition).named(field.name) + + case StringType => + Types.primitive(BINARY, repetition).as(UTF8).named(field.name) + + case DateType => + Types.primitive(INT32, repetition).as(DATE).named(field.name) + + // NOTE: !! This timestamp type is not specified in Parquet format spec !! + // However, Impala and older versions of Spark SQL use INT96 to store timestamps with + // nanosecond precision (not TIME_MILLIS or TIMESTAMP_MILLIS described in the spec). + case TimestampType => + Types.primitive(INT96, repetition).named(field.name) + + case BinaryType => + Types.primitive(BINARY, repetition).named(field.name) + + // ===================================== + // Decimals (for Spark version <= 1.4.x) + // ===================================== + + // Spark 1.4.x and prior versions only support decimals with a maximum precision of 18 and + // always store decimals in fixed-length byte arrays. + case DecimalType.Fixed(precision, scale) + if precision <= maxPrecisionForBytes(8) && !followParquetFormatSpec => + Types + .primitive(FIXED_LEN_BYTE_ARRAY, repetition) + .as(DECIMAL) + .precision(precision) + .scale(scale) + .length(minBytesForPrecision(precision)) + .named(field.name) + + case dec @ DecimalType() if !followParquetFormatSpec => + throw new AnalysisException( + s"Data type $dec is not supported. " + + s"When ${SQLConf.PARQUET_FOLLOW_PARQUET_FORMAT_SPEC.key} is set to false," + + "decimal precision and scale must be specified, " + + "and precision must be less than or equal to 18.") + + // ===================================== + // Decimals (follow Parquet format spec) + // ===================================== + + // Uses INT32 for 1 <= precision <= 9 + case DecimalType.Fixed(precision, scale) + if precision <= maxPrecisionForBytes(4) && followParquetFormatSpec => + Types + .primitive(INT32, repetition) + .as(DECIMAL) + .precision(precision) + .scale(scale) + .named(field.name) + + // Uses INT64 for 1 <= precision <= 18 + case DecimalType.Fixed(precision, scale) + if precision <= maxPrecisionForBytes(8) && followParquetFormatSpec => + Types + .primitive(INT64, repetition) + .as(DECIMAL) + .precision(precision) + .scale(scale) + .named(field.name) + + // Uses FIXED_LEN_BYTE_ARRAY for all other precisions + case DecimalType.Fixed(precision, scale) if followParquetFormatSpec => + Types + .primitive(FIXED_LEN_BYTE_ARRAY, repetition) + .as(DECIMAL) + .precision(precision) + .scale(scale) + .length(minBytesForPrecision(precision)) + .named(field.name) + + case dec @ DecimalType.Unlimited if followParquetFormatSpec => + throw new AnalysisException( + s"Data type $dec is not supported. Decimal precision and scale must be specified.") + + // =================================================== + // ArrayType and MapType (for Spark versions <= 1.4.x) + // =================================================== + + // Spark 1.4.x and prior versions convert ArrayType with nullable elements into a 3-level + // LIST structure. This behavior mimics parquet-hive (1.6.0rc3). Note that this case is + // covered by the backwards-compatibility rules implemented in `isElementType()`. + case ArrayType(elementType, nullable @ true) if !followParquetFormatSpec => + // group (LIST) { + // optional group bag { + // repeated element; + // } + // } + ConversionPatterns.listType( + repetition, + field.name, + Types + .buildGroup(REPEATED) + .addField(convertField(StructField("element", elementType, nullable))) + .named(CatalystConverter.ARRAY_CONTAINS_NULL_BAG_SCHEMA_NAME)) + + // Spark 1.4.x and prior versions convert ArrayType with non-nullable elements into a 2-level + // LIST structure. This behavior mimics parquet-avro (1.6.0rc3). Note that this case is + // covered by the backwards-compatibility rules implemented in `isElementType()`. + case ArrayType(elementType, nullable @ false) if !followParquetFormatSpec => + // group (LIST) { + // repeated element; + // } + ConversionPatterns.listType( + repetition, + field.name, + convertField(StructField("element", elementType, nullable), REPEATED)) + + // Spark 1.4.x and prior versions convert MapType into a 3-level group annotated by + // MAP_KEY_VALUE. This is covered by `convertGroupField(field: GroupType): DataType`. + case MapType(keyType, valueType, valueContainsNull) if !followParquetFormatSpec => + // group (MAP) { + // repeated group map (MAP_KEY_VALUE) { + // required key; + // value; + // } + // } + ConversionPatterns.mapType( + repetition, + field.name, + convertField(StructField("key", keyType, nullable = false)), + convertField(StructField("value", valueType, valueContainsNull))) + + // ================================================== + // ArrayType and MapType (follow Parquet format spec) + // ================================================== + + case ArrayType(elementType, containsNull) if followParquetFormatSpec => + // group (LIST) { + // repeated group list { + // element; + // } + // } + Types + .buildGroup(repetition).as(LIST) + .addField( + Types.repeatedGroup() + .addField(convertField(StructField("element", elementType, containsNull))) + .named("list")) + .named(field.name) + + case MapType(keyType, valueType, valueContainsNull) => + // group (MAP) { + // repeated group key_value { + // required key; + // value; + // } + // } + Types + .buildGroup(repetition).as(MAP) + .addField( + Types + .repeatedGroup() + .addField(convertField(StructField("key", keyType, nullable = false))) + .addField(convertField(StructField("value", valueType, valueContainsNull))) + .named("key_value")) + .named(field.name) + + // =========== + // Other types + // =========== + + case StructType(fields) => + fields.foldLeft(Types.buildGroup(repetition)) { (builder, field) => + builder.addField(convertField(field)) + }.named(field.name) + + case udt: UserDefinedType[_] => + convertField(field.copy(dataType = udt.sqlType)) + + case _ => + throw new AnalysisException(s"Unsupported data type $field.dataType") + } + } + + // Max precision of a decimal value stored in `numBytes` bytes + private def maxPrecisionForBytes(numBytes: Int): Int = { + Math.round( // convert double to long + Math.floor(Math.log10( // number of base-10 digits + Math.pow(2, 8 * numBytes - 1) - 1))) // max value stored in numBytes + .asInstanceOf[Int] + } + + // Min byte counts needed to store decimals with various precisions + private val minBytesForPrecision: Array[Int] = Array.tabulate(38) { precision => + var numBytes = 1 + while (math.pow(2.0, 8 * numBytes - 1) < math.pow(10.0, precision)) { + numBytes += 1 + } + numBytes + } +} + + +private[parquet] object CatalystSchemaConverter { + def checkFieldName(name: String): Unit = { + // ,;{}()\n\t= and space are special characters in Parquet schema + analysisRequire( + !name.matches(".*[ ,;{}()\n\t=].*"), + s"""Attribute name "$name" contains invalid character(s) among " ,;{}()\\n\\t=". + |Please use alias to rename it. + """.stripMargin.split("\n").mkString(" ")) + } + + def analysisRequire(f: => Boolean, message: String): Unit = { + if (!f) { + throw new AnalysisException(message) + } + } +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/parquet/DirectParquetOutputCommitter.scala b/sql/core/src/main/scala/org/apache/spark/sql/parquet/DirectParquetOutputCommitter.scala index f5ce2718bec4..1551afd7b7bf 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/parquet/DirectParquetOutputCommitter.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/parquet/DirectParquetOutputCommitter.scala @@ -17,19 +17,35 @@ package org.apache.spark.sql.parquet +import org.apache.hadoop.conf.Configuration import org.apache.hadoop.fs.Path -import org.apache.hadoop.mapreduce.{JobContext, TaskAttemptContext} import org.apache.hadoop.mapreduce.lib.output.FileOutputCommitter +import org.apache.hadoop.mapreduce.{JobContext, TaskAttemptContext} +import org.apache.parquet.Log +import org.apache.parquet.hadoop.util.ContextUtil +import org.apache.parquet.hadoop.{ParquetFileReader, ParquetFileWriter, ParquetOutputCommitter, ParquetOutputFormat} -import parquet.Log -import parquet.hadoop.util.ContextUtil -import parquet.hadoop.{ParquetFileReader, ParquetFileWriter, ParquetOutputCommitter, ParquetOutputFormat} - +/** + * An output committer for writing Parquet files. In stead of writing to the `_temporary` folder + * like what [[ParquetOutputCommitter]] does, this output committer writes data directly to the + * destination folder. This can be useful for data stored in S3, where directory operations are + * relatively expensive. + * + * To enable this output committer, users may set the "spark.sql.parquet.output.committer.class" + * property via Hadoop [[Configuration]]. Not that this property overrides + * "spark.sql.sources.outputCommitterClass". + * + * *NOTE* + * + * NEVER use [[DirectParquetOutputCommitter]] when appending data, because currently there's + * no safe way undo a failed appending job (that's why both `abortTask()` and `abortJob()` are + * left * empty). + */ private[parquet] class DirectParquetOutputCommitter(outputPath: Path, context: TaskAttemptContext) extends ParquetOutputCommitter(outputPath, context) { val LOG = Log.getLog(classOf[ParquetOutputCommitter]) - override def getWorkPath(): Path = outputPath + override def getWorkPath: Path = outputPath override def abortTask(taskContext: TaskAttemptContext): Unit = {} override def commitTask(taskContext: TaskAttemptContext): Unit = {} override def needsTaskCommit(taskContext: TaskAttemptContext): Boolean = true @@ -46,13 +62,11 @@ private[parquet] class DirectParquetOutputCommitter(outputPath: Path, context: T val footers = ParquetFileReader.readAllFootersInParallel(configuration, outputStatus) try { ParquetFileWriter.writeMetadataFile(configuration, outputPath, footers) - } catch { - case e: Exception => { - LOG.warn("could not write summary file for " + outputPath, e) - val metadataPath = new Path(outputPath, ParquetFileWriter.PARQUET_METADATA_FILE) - if (fileSystem.exists(metadataPath)) { - fileSystem.delete(metadataPath, true) - } + } catch { case e: Exception => + LOG.warn("could not write summary file for " + outputPath, e) + val metadataPath = new Path(outputPath, ParquetFileWriter.PARQUET_METADATA_FILE) + if (fileSystem.exists(metadataPath)) { + fileSystem.delete(metadataPath, true) } } } catch { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetConverter.scala b/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetConverter.scala index 36cb5e03bbca..ae7cbf0624dc 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetConverter.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetConverter.scala @@ -17,20 +17,20 @@ package org.apache.spark.sql.parquet -import java.sql.Timestamp -import java.util.{TimeZone, Calendar} +import java.nio.ByteOrder -import scala.collection.mutable.{Buffer, ArrayBuffer, HashMap} +import scala.collection.mutable.{ArrayBuffer, Buffer, HashMap} -import jodd.datetime.JDateTime -import parquet.column.Dictionary -import parquet.io.api.{PrimitiveConverter, GroupConverter, Binary, Converter} -import parquet.schema.MessageType +import org.apache.parquet.Preconditions +import org.apache.parquet.column.Dictionary +import org.apache.parquet.io.api.{Binary, Converter, GroupConverter, PrimitiveConverter} +import org.apache.parquet.schema.MessageType import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.catalyst.util.DateTimeUtils import org.apache.spark.sql.parquet.CatalystConverter.FieldType import org.apache.spark.sql.types._ -import org.apache.spark.sql.parquet.timestamp.NanoTime +import org.apache.spark.unsafe.types.UTF8String /** * Collection of converters of Parquet types (group and primitive types) that @@ -77,7 +77,7 @@ private[sql] object CatalystConverter { // TODO: consider using Array[T] for arrays to avoid boxing of primitive types type ArrayScalaType[T] = Seq[T] - type StructScalaType[T] = Row + type StructScalaType[T] = InternalRow type MapScalaType[K, V] = Map[K, V] protected[parquet] def createConverter( @@ -221,7 +221,7 @@ private[parquet] abstract class CatalystConverter extends GroupConverter { updateField(fieldIndex, value.getBytes) protected[parquet] def updateString(fieldIndex: Int, value: Array[Byte]): Unit = - updateField(fieldIndex, UTF8String(value)) + updateField(fieldIndex, UTF8String.fromBytes(value)) protected[parquet] def updateTimestamp(fieldIndex: Int, value: Binary): Unit = updateField(fieldIndex, readTimestamp(value)) @@ -238,13 +238,15 @@ private[parquet] abstract class CatalystConverter extends GroupConverter { * * @return */ - def getCurrentRecord: Row = throw new UnsupportedOperationException + def getCurrentRecord: InternalRow = throw new UnsupportedOperationException /** * Read a decimal value from a Parquet Binary into "dest". Only supports decimals that fit in * a long (i.e. precision <= 18) + * + * Returned value is needed by CatalystConverter, which doesn't reuse the Decimal object. */ - protected[parquet] def readDecimal(dest: Decimal, value: Binary, ctype: DecimalType): Unit = { + protected[parquet] def readDecimal(dest: Decimal, value: Binary, ctype: DecimalType): Decimal = { val precision = ctype.precisionInfo.get.precision val scale = ctype.precisionInfo.get.scale val bytes = value.getBytes @@ -264,14 +266,19 @@ private[parquet] abstract class CatalystConverter extends GroupConverter { /** * Read a Timestamp value from a Parquet Int96Value */ - protected[parquet] def readTimestamp(value: Binary): Timestamp = { - CatalystTimestampConverter.convertToTimestamp(value) + protected[parquet] def readTimestamp(value: Binary): Long = { + Preconditions.checkArgument(value.length() == 12, "Must be 12 bytes") + val buf = value.toByteBuffer + buf.order(ByteOrder.LITTLE_ENDIAN) + val timeOfDayNanos = buf.getLong + val julianDay = buf.getInt + DateTimeUtils.fromJulianDay(julianDay, timeOfDayNanos) } } /** * A `parquet.io.api.GroupConverter` that is able to convert a Parquet record - * to a [[org.apache.spark.sql.catalyst.expressions.Row]] object. + * to a [[org.apache.spark.sql.catalyst.expressions.InternalRow]] object. * * @param schema The corresponding Catalyst schema in the form of a list of attributes. */ @@ -280,7 +287,7 @@ private[parquet] class CatalystGroupConverter( protected[parquet] val index: Int, protected[parquet] val parent: CatalystConverter, protected[parquet] var current: ArrayBuffer[Any], - protected[parquet] var buffer: ArrayBuffer[Row]) + protected[parquet] var buffer: ArrayBuffer[InternalRow]) extends CatalystConverter { def this(schema: Array[FieldType], index: Int, parent: CatalystConverter) = @@ -289,7 +296,7 @@ private[parquet] class CatalystGroupConverter( index, parent, current = null, - buffer = new ArrayBuffer[Row]( + buffer = new ArrayBuffer[InternalRow]( CatalystArrayConverter.INITIAL_ARRAY_SIZE)) /** @@ -305,13 +312,13 @@ private[parquet] class CatalystGroupConverter( override val size = schema.size - override def getCurrentRecord: Row = { + override def getCurrentRecord: InternalRow = { assert(isRootConverter, "getCurrentRecord should only be called in root group converter!") // TODO: use iterators if possible // Note: this will ever only be called in the root converter when the record has been // fully processed. Therefore it will be difficult to use mutable rows instead, since // any non-root converter never would be sure when it would be safe to re-use the buffer. - new GenericRow(current.toArray) + new GenericInternalRow(current.toArray) } override def getConverter(fieldIndex: Int): Converter = converters(fieldIndex) @@ -335,15 +342,15 @@ private[parquet] class CatalystGroupConverter( override def end(): Unit = { if (!isRootConverter) { assert(current != null) // there should be no empty groups - buffer.append(new GenericRow(current.toArray)) - parent.updateField(index, new GenericRow(buffer.toArray.asInstanceOf[Array[Any]])) + buffer.append(new GenericInternalRow(current.toArray)) + parent.updateField(index, new GenericInternalRow(buffer.toArray.asInstanceOf[Array[Any]])) } } } /** * A `parquet.io.api.GroupConverter` that is able to convert a Parquet record - * to a [[org.apache.spark.sql.catalyst.expressions.Row]] object. Note that his + * to a [[org.apache.spark.sql.catalyst.expressions.InternalRow]] object. Note that his * converter is optimized for rows of primitive types (non-nested records). */ private[parquet] class CatalystPrimitiveRowConverter( @@ -369,7 +376,7 @@ private[parquet] class CatalystPrimitiveRowConverter( override val parent = null // Should be only called in root group converter! - override def getCurrentRecord: Row = current + override def getCurrentRecord: InternalRow = current override def getConverter(fieldIndex: Int): Converter = converters(fieldIndex) @@ -399,7 +406,7 @@ private[parquet] class CatalystPrimitiveRowConverter( current.setInt(fieldIndex, value) override protected[parquet] def updateDate(fieldIndex: Int, value: Int): Unit = - current.update(fieldIndex, value) + current.setInt(fieldIndex, value) override protected[parquet] def updateLong(fieldIndex: Int, value: Long): Unit = current.setLong(fieldIndex, value) @@ -420,10 +427,10 @@ private[parquet] class CatalystPrimitiveRowConverter( current.update(fieldIndex, value.getBytes) override protected[parquet] def updateString(fieldIndex: Int, value: Array[Byte]): Unit = - current.update(fieldIndex, UTF8String(value)) + current.update(fieldIndex, UTF8String.fromBytes(value)) override protected[parquet] def updateTimestamp(fieldIndex: Int, value: Binary): Unit = - current.update(fieldIndex, readTimestamp(value)) + current.setLong(fieldIndex, readTimestamp(value)) override protected[parquet] def updateDecimal( fieldIndex: Int, value: Binary, ctype: DecimalType): Unit = { @@ -480,7 +487,7 @@ private[parquet] class CatalystPrimitiveStringConverter(parent: CatalystConverte override def hasDictionarySupport: Boolean = true - override def setDictionary(dictionary: Dictionary):Unit = + override def setDictionary(dictionary: Dictionary): Unit = dict = Array.tabulate(dictionary.getMaxId + 1) { dictionary.decodeToBinary(_).getBytes } override def addValueFromDictionary(dictionaryId: Int): Unit = @@ -494,73 +501,6 @@ private[parquet] object CatalystArrayConverter { val INITIAL_ARRAY_SIZE = 20 } -private[parquet] object CatalystTimestampConverter { - // TODO most part of this comes from Hive-0.14 - // Hive code might have some issues, so we need to keep an eye on it. - // Also we use NanoTime and Int96Values from parquet-examples. - // We utilize jodd to convert between NanoTime and Timestamp - val parquetTsCalendar = new ThreadLocal[Calendar] - def getCalendar: Calendar = { - // this is a cache for the calendar instance. - if (parquetTsCalendar.get == null) { - parquetTsCalendar.set(Calendar.getInstance(TimeZone.getTimeZone("GMT"))) - } - parquetTsCalendar.get - } - val NANOS_PER_SECOND: Long = 1000000000 - val SECONDS_PER_MINUTE: Long = 60 - val MINUTES_PER_HOUR: Long = 60 - val NANOS_PER_MILLI: Long = 1000000 - - def convertToTimestamp(value: Binary): Timestamp = { - val nt = NanoTime.fromBinary(value) - val timeOfDayNanos = nt.getTimeOfDayNanos - val julianDay = nt.getJulianDay - val jDateTime = new JDateTime(julianDay.toDouble) - val calendar = getCalendar - calendar.set(Calendar.YEAR, jDateTime.getYear) - calendar.set(Calendar.MONTH, jDateTime.getMonth - 1) - calendar.set(Calendar.DAY_OF_MONTH, jDateTime.getDay) - - // written in command style - var remainder = timeOfDayNanos - calendar.set( - Calendar.HOUR_OF_DAY, - (remainder / (NANOS_PER_SECOND * SECONDS_PER_MINUTE * MINUTES_PER_HOUR)).toInt) - remainder = remainder % (NANOS_PER_SECOND * SECONDS_PER_MINUTE * MINUTES_PER_HOUR) - calendar.set( - Calendar.MINUTE, (remainder / (NANOS_PER_SECOND * SECONDS_PER_MINUTE)).toInt) - remainder = remainder % (NANOS_PER_SECOND * SECONDS_PER_MINUTE) - calendar.set(Calendar.SECOND, (remainder / NANOS_PER_SECOND).toInt) - val nanos = remainder % NANOS_PER_SECOND - val ts = new Timestamp(calendar.getTimeInMillis) - ts.setNanos(nanos.toInt) - ts - } - - def convertFromTimestamp(ts: Timestamp): Binary = { - val calendar = getCalendar - calendar.setTime(ts) - val jDateTime = new JDateTime(calendar.get(Calendar.YEAR), - calendar.get(Calendar.MONTH) + 1, calendar.get(Calendar.DAY_OF_MONTH)) - // Hive-0.14 didn't set hour before get day number, while the day number should - // has something to do with hour, since julian day number grows at 12h GMT - // here we just follow what hive does. - val julianDay = jDateTime.getJulianDayNumber - - val hour = calendar.get(Calendar.HOUR_OF_DAY) - val minute = calendar.get(Calendar.MINUTE) - val second = calendar.get(Calendar.SECOND) - val nanos = ts.getNanos - // Hive-0.14 would use hours directly, that might be wrong, since the day starts - // from 12h in Julian. here we just follow what hive does. - val nanosOfDay = nanos + second * NANOS_PER_SECOND + - minute * NANOS_PER_SECOND * SECONDS_PER_MINUTE + - hour * NANOS_PER_SECOND * SECONDS_PER_MINUTE * MINUTES_PER_HOUR - NanoTime(julianDay, nanosOfDay).toBinary - } -} - /** * A `parquet.io.api.GroupConverter` that converts a single-element groups that * match the characteristics of an array (see @@ -591,8 +531,8 @@ private[parquet] class CatalystArrayConverter( CatalystConverter.ARRAY_ELEMENTS_SCHEMA_NAME, elementType, false), - fieldIndex=0, - parent=this) + fieldIndex = 0, + parent = this) override def getConverter(fieldIndex: Int): Converter = converter @@ -601,7 +541,7 @@ private[parquet] class CatalystArrayConverter( override protected[parquet] def updateField(fieldIndex: Int, value: Any): Unit = { // fieldIndex is ignored (assumed to be zero but not checked) - if(value == null) { + if (value == null) { throw new IllegalArgumentException("Null values inside Parquet arrays are not supported!") } buffer += value @@ -654,8 +594,8 @@ private[parquet] class CatalystNativeArrayConverter( CatalystConverter.ARRAY_ELEMENTS_SCHEMA_NAME, elementType, false), - fieldIndex=0, - parent=this) + fieldIndex = 0, + parent = this) override def getConverter(fieldIndex: Int): Converter = converter @@ -716,7 +656,7 @@ private[parquet] class CatalystNativeArrayConverter( override protected[parquet] def updateString(fieldIndex: Int, value: Array[Byte]): Unit = { checkGrowBuffer() - buffer(elements) = UTF8String(value).asInstanceOf[NativeType] + buffer(elements) = UTF8String.fromBytes(value).asInstanceOf[NativeType] elements += 1 } @@ -848,7 +788,7 @@ private[parquet] class CatalystStructConverter( // here we need to make sure to use StructScalaType // Note: we need to actually make a copy of the array since we // may be in a nested field - parent.updateField(index, new GenericRow(current.toArray)) + parent.updateField(index, new GenericInternalRow(current.toArray)) } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetFilters.scala b/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetFilters.scala index f0f4e7d147e7..d57b789f5c1c 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetFilters.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetFilters.scala @@ -17,20 +17,23 @@ package org.apache.spark.sql.parquet +import java.io.Serializable import java.nio.ByteBuffer import com.google.common.io.BaseEncoding import org.apache.hadoop.conf.Configuration -import parquet.filter2.compat.FilterCompat -import parquet.filter2.compat.FilterCompat._ -import parquet.filter2.predicate.FilterApi._ -import parquet.filter2.predicate.{FilterApi, FilterPredicate} -import parquet.io.api.Binary +import org.apache.parquet.filter2.compat.FilterCompat +import org.apache.parquet.filter2.compat.FilterCompat._ +import org.apache.parquet.filter2.predicate.FilterApi._ +import org.apache.parquet.filter2.predicate.{FilterApi, FilterPredicate, Statistics} +import org.apache.parquet.filter2.predicate.UserDefinedPredicate +import org.apache.parquet.io.api.Binary import org.apache.spark.SparkEnv import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.sources import org.apache.spark.sql.types._ +import org.apache.spark.unsafe.types.UTF8String private[sql] object ParquetFilters { val PARQUET_FILTER_DATA = "org.apache.spark.sql.parquet.row.filter" @@ -41,6 +44,18 @@ private[sql] object ParquetFilters { }.reduceOption(FilterApi.and).map(FilterCompat.get) } + case class SetInFilter[T <: Comparable[T]]( + valueSet: Set[T]) extends UserDefinedPredicate[T] with Serializable { + + override def keep(value: T): Boolean = { + value != null && valueSet.contains(value) + } + + override def canDrop(statistics: Statistics[T]): Boolean = false + + override def inverseCanDrop(statistics: Statistics[T]): Boolean = false + } + private val makeEq: PartialFunction[DataType, (String, Any) => FilterPredicate] = { case BooleanType => (n: String, v: Any) => FilterApi.eq(booleanColumn(n), v.asInstanceOf[java.lang.Boolean]) @@ -153,6 +168,29 @@ private[sql] object ParquetFilters { FilterApi.gtEq(binaryColumn(n), Binary.fromByteArray(v.asInstanceOf[Array[Byte]])) } + private val makeInSet: PartialFunction[DataType, (String, Set[Any]) => FilterPredicate] = { + case IntegerType => + (n: String, v: Set[Any]) => + FilterApi.userDefined(intColumn(n), SetInFilter(v.asInstanceOf[Set[java.lang.Integer]])) + case LongType => + (n: String, v: Set[Any]) => + FilterApi.userDefined(longColumn(n), SetInFilter(v.asInstanceOf[Set[java.lang.Long]])) + case FloatType => + (n: String, v: Set[Any]) => + FilterApi.userDefined(floatColumn(n), SetInFilter(v.asInstanceOf[Set[java.lang.Float]])) + case DoubleType => + (n: String, v: Set[Any]) => + FilterApi.userDefined(doubleColumn(n), SetInFilter(v.asInstanceOf[Set[java.lang.Double]])) + case StringType => + (n: String, v: Set[Any]) => + FilterApi.userDefined(binaryColumn(n), + SetInFilter(v.map(e => Binary.fromByteArray(e.asInstanceOf[UTF8String].getBytes)))) + case BinaryType => + (n: String, v: Set[Any]) => + FilterApi.userDefined(binaryColumn(n), + SetInFilter(v.map(e => Binary.fromByteArray(e.asInstanceOf[Array[Byte]])))) + } + /** * Converts data sources filters to Parquet filter predicates. */ @@ -284,6 +322,9 @@ private[sql] object ParquetFilters { case Not(pred) => createFilter(pred).map(FilterApi.not) + case InSet(NamedExpression(name, dataType), valueSet) => + makeInSet.lift(dataType).map(_(name, valueSet)) + case _ => None } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetRelation.scala b/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetRelation.scala index fcb9513ab66f..704cf56f3826 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetRelation.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetRelation.scala @@ -18,20 +18,21 @@ package org.apache.spark.sql.parquet import java.io.IOException -import java.util.logging.Level +import java.util.logging.{Level, Logger => JLogger} import org.apache.hadoop.conf.Configuration import org.apache.hadoop.fs.Path import org.apache.hadoop.fs.permission.FsAction -import org.apache.spark.sql.types.{StructType, DataType} -import parquet.hadoop.{ParquetOutputCommitter, ParquetOutputFormat} -import parquet.hadoop.metadata.CompressionCodecName -import parquet.schema.MessageType +import org.apache.parquet.hadoop.metadata.CompressionCodecName +import org.apache.parquet.hadoop.{ParquetOutputCommitter, ParquetOutputFormat, ParquetRecordReader} +import org.apache.parquet.schema.MessageType +import org.apache.parquet.{Log => ParquetLog} -import org.apache.spark.sql.{DataFrame, SQLContext} import org.apache.spark.sql.catalyst.analysis.{MultiInstanceRelation, UnresolvedException} -import org.apache.spark.sql.catalyst.expressions.{AttributeMap, Attribute} +import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeMap} import org.apache.spark.sql.catalyst.plans.logical.{LeafNode, LogicalPlan, Statistics} +import org.apache.spark.sql.types.StructType +import org.apache.spark.sql.{DataFrame, SQLContext} /** * Relation that consists of data stored in a Parquet columnar format. @@ -94,40 +95,44 @@ private[sql] case class ParquetRelation( private[sql] object ParquetRelation { def enableLogForwarding() { - // Note: the parquet.Log class has a static initializer that - // sets the java.util.logging Logger for "parquet". This + // Note: the org.apache.parquet.Log class has a static initializer that + // sets the java.util.logging Logger for "org.apache.parquet". This // checks first to see if there's any handlers already set // and if not it creates them. If this method executes prior // to that class being loaded then: // 1) there's no handlers installed so there's none to // remove. But when it IS finally loaded the desired affect // of removing them is circumvented. - // 2) The parquet.Log static initializer calls setUseParentHanders(false) + // 2) The parquet.Log static initializer calls setUseParentHandlers(false) // undoing the attempt to override the logging here. // // Therefore we need to force the class to be loaded. // This should really be resolved by Parquet. - Class.forName(classOf[parquet.Log].getName) + Class.forName(classOf[ParquetLog].getName) // Note: Logger.getLogger("parquet") has a default logger // that appends to Console which needs to be cleared. - val parquetLogger = java.util.logging.Logger.getLogger("parquet") + val parquetLogger = JLogger.getLogger(classOf[ParquetLog].getPackage.getName) parquetLogger.getHandlers.foreach(parquetLogger.removeHandler) - // TODO(witgo): Need to set the log level ? - // if(parquetLogger.getLevel != null) parquetLogger.setLevel(null) - if (!parquetLogger.getUseParentHandlers) parquetLogger.setUseParentHandlers(true) + parquetLogger.setUseParentHandlers(true) - // Disables WARN log message in ParquetOutputCommitter. + // Disables a WARN log message in ParquetOutputCommitter. We first ensure that + // ParquetOutputCommitter is loaded and the static LOG field gets initialized. // See https://issues.apache.org/jira/browse/SPARK-5968 for details Class.forName(classOf[ParquetOutputCommitter].getName) - java.util.logging.Logger.getLogger(classOf[ParquetOutputCommitter].getName).setLevel(Level.OFF) + JLogger.getLogger(classOf[ParquetOutputCommitter].getName).setLevel(Level.OFF) + + // Similar as above, disables a unnecessary WARN log message in ParquetRecordReader. + // See https://issues.apache.org/jira/browse/PARQUET-220 for details + Class.forName(classOf[ParquetRecordReader[_]].getName) + JLogger.getLogger(classOf[ParquetRecordReader[_]].getName).setLevel(Level.OFF) } // The element type for the RDDs that this relation maps to. type RowType = org.apache.spark.sql.catalyst.expressions.GenericMutableRow // The compression type - type CompressionType = parquet.hadoop.metadata.CompressionCodecName + type CompressionType = org.apache.parquet.hadoop.metadata.CompressionCodecName // The parquet compression short names val shortParquetCompressionCodecNames = Map( diff --git a/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTableOperations.scala b/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTableOperations.scala index 90950f924a05..b30fc171c0af 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTableOperations.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTableOperations.scala @@ -33,28 +33,29 @@ import org.apache.hadoop.fs.{BlockLocation, FileStatus, Path} import org.apache.hadoop.mapreduce._ import org.apache.hadoop.mapreduce.lib.input.{FileInputFormat => NewFileInputFormat} import org.apache.hadoop.mapreduce.lib.output.{FileOutputCommitter, FileOutputFormat => NewFileOutputFormat} -import parquet.hadoop._ -import parquet.hadoop.api.ReadSupport.ReadContext -import parquet.hadoop.api.{InitContext, ReadSupport} -import parquet.hadoop.metadata.GlobalMetaData -import parquet.hadoop.util.ContextUtil -import parquet.io.ParquetDecodingException -import parquet.schema.MessageType +import org.apache.parquet.hadoop._ +import org.apache.parquet.hadoop.api.ReadSupport.ReadContext +import org.apache.parquet.hadoop.api.{InitContext, ReadSupport} +import org.apache.parquet.hadoop.metadata.GlobalMetaData +import org.apache.parquet.hadoop.util.ContextUtil +import org.apache.parquet.io.ParquetDecodingException +import org.apache.parquet.schema.MessageType import org.apache.spark.annotation.DeveloperApi import org.apache.spark.mapred.SparkHadoopMapRedUtil import org.apache.spark.mapreduce.SparkHadoopMapReduceUtil import org.apache.spark.rdd.RDD import org.apache.spark.sql.SQLConf -import org.apache.spark.sql.catalyst.expressions.{Attribute, Expression, Row, _} +import org.apache.spark.sql.catalyst.expressions.{Attribute, Expression, InternalRow, _} import org.apache.spark.sql.execution.{LeafNode, SparkPlan, UnaryNode} import org.apache.spark.sql.types.StructType -import org.apache.spark.{Logging, SerializableWritable, TaskContext} +import org.apache.spark.{Logging, TaskContext} +import org.apache.spark.util.SerializableConfiguration /** * :: DeveloperApi :: * Parquet table scan operator. Imports the file that backs the given - * [[org.apache.spark.sql.parquet.ParquetRelation]] as a ``RDD[Row]``. + * [[org.apache.spark.sql.parquet.ParquetRelation]] as a ``RDD[InternalRow]``. */ private[sql] case class ParquetTableScan( attributes: Seq[Attribute], @@ -77,8 +78,8 @@ private[sql] case class ParquetTableScan( } }.toArray - protected override def doExecute(): RDD[Row] = { - import parquet.filter2.compat.FilterCompat.FilterPredicateCompat + protected override def doExecute(): RDD[InternalRow] = { + import org.apache.parquet.filter2.compat.FilterCompat.FilterPredicateCompat val sc = sqlContext.sparkContext val job = new Job(sc.hadoopConfiguration) @@ -113,16 +114,19 @@ private[sql] case class ParquetTableScan( .foreach(ParquetInputFormat.setFilterPredicate(conf, _)) // Tell FilteringParquetRowInputFormat whether it's okay to cache Parquet and FS metadata - conf.set( - SQLConf.PARQUET_CACHE_METADATA, - sqlContext.getConf(SQLConf.PARQUET_CACHE_METADATA, "true")) + conf.setBoolean( + SQLConf.PARQUET_CACHE_METADATA.key, + sqlContext.getConf(SQLConf.PARQUET_CACHE_METADATA, true)) + + // Use task side metadata in parquet + conf.setBoolean(ParquetInputFormat.TASK_SIDE_METADATA, true) val baseRDD = new org.apache.spark.rdd.NewHadoopRDD( sc, classOf[FilteringParquetRowInputFormat], classOf[Void], - classOf[Row], + classOf[InternalRow], conf) if (requestedPartitionOrdinals.nonEmpty) { @@ -136,7 +140,7 @@ private[sql] case class ParquetTableScan( baseRDD.mapPartitionsWithInputSplit { case (split, iter) => val partValue = "([^=]+)=([^=]+)".r val partValues = - split.asInstanceOf[parquet.hadoop.ParquetInputSplit] + split.asInstanceOf[org.apache.parquet.hadoop.ParquetInputSplit] .getPath .toString .split("/") @@ -151,9 +155,9 @@ private[sql] case class ParquetTableScan( .map(a => Cast(Literal(partValues(a.name)), a.dataType).eval(EmptyRow)) if (primitiveRow) { - new Iterator[Row] { + new Iterator[InternalRow] { def hasNext: Boolean = iter.hasNext - def next(): Row = { + def next(): InternalRow = { // We are using CatalystPrimitiveRowConverter and it returns a SpecificMutableRow. val row = iter.next()._2.asInstanceOf[SpecificMutableRow] @@ -170,12 +174,12 @@ private[sql] case class ParquetTableScan( } else { // Create a mutable row since we need to fill in values from partition columns. val mutableRow = new GenericMutableRow(outputSize) - new Iterator[Row] { + new Iterator[InternalRow] { def hasNext: Boolean = iter.hasNext - def next(): Row = { + def next(): InternalRow = { // We are using CatalystGroupConverter and it returns a GenericRow. // Since GenericRow is not mutable, we just cast it to a Row. - val row = iter.next()._2.asInstanceOf[Row] + val row = iter.next()._2.asInstanceOf[InternalRow] var i = 0 while (i < row.size) { @@ -255,7 +259,7 @@ private[sql] case class InsertIntoParquetTable( /** * Inserts all rows into the Parquet file. */ - protected override def doExecute(): RDD[Row] = { + protected override def doExecute(): RDD[InternalRow] = { // TODO: currently we do not check whether the "schema"s are compatible // That means if one first creates a table and then INSERTs data with // and incompatible schema the execution will fail. It would be nice @@ -318,15 +322,15 @@ private[sql] case class InsertIntoParquetTable( * @param conf A [[org.apache.hadoop.conf.Configuration]]. */ private def saveAsHadoopFile( - rdd: RDD[Row], + rdd: RDD[InternalRow], path: String, conf: Configuration) { val job = new Job(conf) val keyType = classOf[Void] job.setOutputKeyClass(keyType) - job.setOutputValueClass(classOf[Row]) + job.setOutputValueClass(classOf[InternalRow]) NewFileOutputFormat.setOutputPath(job, new Path(path)) - val wrappedConf = new SerializableWritable(job.getConfiguration) + val wrappedConf = new SerializableConfiguration(job.getConfiguration) val formatter = new SimpleDateFormat("yyyyMMddHHmm") val jobtrackerID = formatter.format(new Date()) val stageId = sqlContext.sparkContext.newRddId() @@ -339,7 +343,7 @@ private[sql] case class InsertIntoParquetTable( .findMaxTaskId(NewFileOutputFormat.getOutputPath(job).toString, job.getConfiguration) + 1 } - def writeShard(context: TaskContext, iter: Iterator[Row]): Int = { + def writeShard(context: TaskContext, iter: Iterator[InternalRow]): Int = { /* "reduce task" */ val attemptId = newTaskAttemptID(jobtrackerID, stageId, isMap = false, context.partitionId, context.attemptNumber) @@ -378,7 +382,7 @@ private[sql] case class InsertIntoParquetTable( * to imported ones. */ private[parquet] class AppendingParquetOutputFormat(offset: Int) - extends parquet.hadoop.ParquetOutputFormat[Row] { + extends org.apache.parquet.hadoop.ParquetOutputFormat[InternalRow] { // override to accept existing directories as valid output directory override def checkOutputSpecs(job: JobContext): Unit = {} var committer: OutputCommitter = null @@ -431,210 +435,26 @@ private[parquet] class AppendingParquetOutputFormat(offset: Int) * RecordFilter we want to use. */ private[parquet] class FilteringParquetRowInputFormat - extends parquet.hadoop.ParquetInputFormat[Row] with Logging { + extends org.apache.parquet.hadoop.ParquetInputFormat[InternalRow] with Logging { private var fileStatuses = Map.empty[Path, FileStatus] override def createRecordReader( inputSplit: InputSplit, - taskAttemptContext: TaskAttemptContext): RecordReader[Void, Row] = { + taskAttemptContext: TaskAttemptContext): RecordReader[Void, InternalRow] = { - import parquet.filter2.compat.FilterCompat.NoOpFilter + import org.apache.parquet.filter2.compat.FilterCompat.NoOpFilter - val readSupport: ReadSupport[Row] = new RowReadSupport() + val readSupport: ReadSupport[InternalRow] = new RowReadSupport() val filter = ParquetInputFormat.getFilter(ContextUtil.getConfiguration(taskAttemptContext)) if (!filter.isInstanceOf[NoOpFilter]) { - new ParquetRecordReader[Row]( + new ParquetRecordReader[InternalRow]( readSupport, filter) } else { - new ParquetRecordReader[Row](readSupport) - } - } - - // This is only a temporary solution sicne we need to use fileStatuses in - // both getClientSideSplits and getTaskSideSplits. It can be removed once we get rid of these - // two methods. - override def getSplits(jobContext: JobContext): JList[InputSplit] = { - // First set fileStatuses. - val statuses = listStatus(jobContext) - fileStatuses = statuses.map(file => file.getPath -> file).toMap - - super.getSplits(jobContext) - } - - // TODO Remove this method and related code once PARQUET-16 is fixed - // This method together with the `getFooters` method and the `fileStatuses` field are just used - // to mimic this PR: https://github.com/apache/incubator-parquet-mr/pull/17 - override def getSplits( - configuration: Configuration, - footers: JList[Footer]): JList[ParquetInputSplit] = { - - // Use task side strategy by default - val taskSideMetaData = configuration.getBoolean(ParquetInputFormat.TASK_SIDE_METADATA, true) - val maxSplitSize: JLong = configuration.getLong("mapred.max.split.size", Long.MaxValue) - val minSplitSize: JLong = - Math.max(getFormatMinSplitSize, configuration.getLong("mapred.min.split.size", 0L)) - if (maxSplitSize < 0 || minSplitSize < 0) { - throw new ParquetDecodingException( - s"maxSplitSize or minSplitSie should not be negative: maxSplitSize = $maxSplitSize;" + - s" minSplitSize = $minSplitSize") - } - - // Uses strict type checking by default - val getGlobalMetaData = - classOf[ParquetFileWriter].getDeclaredMethod("getGlobalMetaData", classOf[JList[Footer]]) - getGlobalMetaData.setAccessible(true) - var globalMetaData = getGlobalMetaData.invoke(null, footers).asInstanceOf[GlobalMetaData] - - if (globalMetaData == null) { - val splits = mutable.ArrayBuffer.empty[ParquetInputSplit] - return splits - } - - val metadata = configuration.get(RowWriteSupport.SPARK_ROW_SCHEMA) - val mergedMetadata = globalMetaData - .getKeyValueMetaData - .updated(RowReadSupport.SPARK_METADATA_KEY, setAsJavaSet(Set(metadata))) - - globalMetaData = new GlobalMetaData(globalMetaData.getSchema, - mergedMetadata, globalMetaData.getCreatedBy) - - val readContext = getReadSupport(configuration).init( - new InitContext(configuration, - globalMetaData.getKeyValueMetaData, - globalMetaData.getSchema)) - - if (taskSideMetaData){ - logInfo("Using Task Side Metadata Split Strategy") - getTaskSideSplits(configuration, - footers, - maxSplitSize, - minSplitSize, - readContext) - } else { - logInfo("Using Client Side Metadata Split Strategy") - getClientSideSplits(configuration, - footers, - maxSplitSize, - minSplitSize, - readContext) + new ParquetRecordReader[InternalRow](readSupport) } - - } - - def getClientSideSplits( - configuration: Configuration, - footers: JList[Footer], - maxSplitSize: JLong, - minSplitSize: JLong, - readContext: ReadContext): JList[ParquetInputSplit] = { - - import parquet.filter2.compat.FilterCompat.Filter - import parquet.filter2.compat.RowGroupFilter - - import org.apache.spark.sql.parquet.FilteringParquetRowInputFormat.blockLocationCache - - val cacheMetadata = configuration.getBoolean(SQLConf.PARQUET_CACHE_METADATA, true) - - val splits = mutable.ArrayBuffer.empty[ParquetInputSplit] - val filter: Filter = ParquetInputFormat.getFilter(configuration) - var rowGroupsDropped: Long = 0 - var totalRowGroups: Long = 0 - - // Ugly hack, stuck with it until PR: - // https://github.com/apache/incubator-parquet-mr/pull/17 - // is resolved - val generateSplits = - Class.forName("parquet.hadoop.ClientSideMetadataSplitStrategy") - .getDeclaredMethods.find(_.getName == "generateSplits").getOrElse( - sys.error(s"Failed to reflectively invoke ClientSideMetadataSplitStrategy.generateSplits")) - generateSplits.setAccessible(true) - - for (footer <- footers) { - val fs = footer.getFile.getFileSystem(configuration) - val file = footer.getFile - val status = fileStatuses.getOrElse(file, fs.getFileStatus(file)) - val parquetMetaData = footer.getParquetMetadata - val blocks = parquetMetaData.getBlocks - totalRowGroups = totalRowGroups + blocks.size - val filteredBlocks = RowGroupFilter.filterRowGroups( - filter, - blocks, - parquetMetaData.getFileMetaData.getSchema) - rowGroupsDropped = rowGroupsDropped + (blocks.size - filteredBlocks.size) - - if (!filteredBlocks.isEmpty){ - var blockLocations: Array[BlockLocation] = null - if (!cacheMetadata) { - blockLocations = fs.getFileBlockLocations(status, 0, status.getLen) - } else { - blockLocations = blockLocationCache.get(status, new Callable[Array[BlockLocation]] { - def call(): Array[BlockLocation] = fs.getFileBlockLocations(status, 0, status.getLen) - }) - } - splits.addAll( - generateSplits.invoke( - null, - filteredBlocks, - blockLocations, - status, - readContext.getRequestedSchema.toString, - readContext.getReadSupportMetadata, - minSplitSize, - maxSplitSize).asInstanceOf[JList[ParquetInputSplit]]) - } - } - - if (rowGroupsDropped > 0 && totalRowGroups > 0){ - val percentDropped = ((rowGroupsDropped/totalRowGroups.toDouble) * 100).toInt - logInfo(s"Dropping $rowGroupsDropped row groups that do not pass filter predicate " - + s"($percentDropped %) !") - } - else { - logInfo("There were no row groups that could be dropped due to filter predicates") - } - splits - - } - - def getTaskSideSplits( - configuration: Configuration, - footers: JList[Footer], - maxSplitSize: JLong, - minSplitSize: JLong, - readContext: ReadContext): JList[ParquetInputSplit] = { - - val splits = mutable.ArrayBuffer.empty[ParquetInputSplit] - - // Ugly hack, stuck with it until PR: - // https://github.com/apache/incubator-parquet-mr/pull/17 - // is resolved - val generateSplits = - Class.forName("parquet.hadoop.TaskSideMetadataSplitStrategy") - .getDeclaredMethods.find(_.getName == "generateTaskSideMDSplits").getOrElse( - sys.error( - s"Failed to reflectively invoke TaskSideMetadataSplitStrategy.generateTaskSideMDSplits")) - generateSplits.setAccessible(true) - - for (footer <- footers) { - val file = footer.getFile - val fs = file.getFileSystem(configuration) - val status = fileStatuses.getOrElse(file, fs.getFileStatus(file)) - val blockLocations = fs.getFileBlockLocations(status, 0, status.getLen) - splits.addAll( - generateSplits.invoke( - null, - blockLocations, - status, - readContext.getRequestedSchema.toString, - readContext.getReadSupportMetadata, - minSplitSize, - maxSplitSize).asInstanceOf[JList[ParquetInputSplit]]) - } - - splits } } @@ -664,7 +484,7 @@ private[parquet] object FileSystemHelper { s"ParquetTableOperations: path $path does not exist or is not a directory") } fs.globStatus(path) - .flatMap { status => if(status.isDir) fs.listStatus(status.getPath) else List(status) } + .flatMap { status => if (status.isDir) fs.listStatus(status.getPath) else List(status) } .map(_.getPath) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTableSupport.scala b/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTableSupport.scala index c45c431438ef..df2a96dfeb61 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTableSupport.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTableSupport.scala @@ -17,19 +17,22 @@ package org.apache.spark.sql.parquet +import java.nio.{ByteOrder, ByteBuffer} import java.util.{HashMap => JHashMap} import org.apache.hadoop.conf.Configuration -import parquet.column.ParquetProperties -import parquet.hadoop.ParquetOutputFormat -import parquet.hadoop.api.ReadSupport.ReadContext -import parquet.hadoop.api.{ReadSupport, WriteSupport} -import parquet.io.api._ -import parquet.schema.MessageType +import org.apache.parquet.column.ParquetProperties +import org.apache.parquet.hadoop.ParquetOutputFormat +import org.apache.parquet.hadoop.api.ReadSupport.ReadContext +import org.apache.parquet.hadoop.api.{ReadSupport, WriteSupport} +import org.apache.parquet.io.api._ +import org.apache.parquet.schema.MessageType import org.apache.spark.Logging -import org.apache.spark.sql.catalyst.expressions.{Attribute, Row} +import org.apache.spark.sql.catalyst.expressions.{Attribute, InternalRow} +import org.apache.spark.sql.catalyst.util.DateTimeUtils import org.apache.spark.sql.types._ +import org.apache.spark.unsafe.types.UTF8String /** * A `parquet.io.api.RecordMaterializer` for Rows. @@ -37,12 +40,12 @@ import org.apache.spark.sql.types._ *@param root The root group converter for the record. */ private[parquet] class RowRecordMaterializer(root: CatalystConverter) - extends RecordMaterializer[Row] { + extends RecordMaterializer[InternalRow] { def this(parquetSchema: MessageType, attributes: Seq[Attribute]) = this(CatalystConverter.createRootConverter(parquetSchema, attributes)) - override def getCurrentRecord: Row = root.getCurrentRecord + override def getCurrentRecord: InternalRow = root.getCurrentRecord override def getRootConverter: GroupConverter = root.asInstanceOf[GroupConverter] } @@ -50,13 +53,13 @@ private[parquet] class RowRecordMaterializer(root: CatalystConverter) /** * A `parquet.hadoop.api.ReadSupport` for Row objects. */ -private[parquet] class RowReadSupport extends ReadSupport[Row] with Logging { +private[parquet] class RowReadSupport extends ReadSupport[InternalRow] with Logging { override def prepareForRead( conf: Configuration, stringMap: java.util.Map[String, String], fileSchema: MessageType, - readContext: ReadContext): RecordMaterializer[Row] = { + readContext: ReadContext): RecordMaterializer[InternalRow] = { log.debug(s"preparing for read with Parquet file schema $fileSchema") // Note: this very much imitates AvroParquet val parquetSchema = readContext.getRequestedSchema @@ -83,8 +86,7 @@ private[parquet] class RowReadSupport extends ReadSupport[Row] with Logging { // TODO: Why it can be null? if (schema == null) { log.debug("falling back to Parquet read schema") - schema = ParquetTypesConverter.convertToAttributes( - parquetSchema, false, true) + schema = ParquetTypesConverter.convertToAttributes(parquetSchema, false, true) } log.debug(s"list of attributes that will be read: $schema") new RowRecordMaterializer(parquetSchema, schema) @@ -102,8 +104,7 @@ private[parquet] class RowReadSupport extends ReadSupport[Row] with Logging { // If the parquet file is thrift derived, there is a good chance that // it will have the thrift class in metadata. val isThriftDerived = keyValueMetaData.keySet().contains("thrift.class") - parquetSchema = ParquetTypesConverter - .convertFromAttributes(requestedAttributes, isThriftDerived) + parquetSchema = ParquetTypesConverter.convertFromAttributes(requestedAttributes) metadata.put( RowReadSupport.SPARK_ROW_REQUESTED_SCHEMA, ParquetTypesConverter.convertToString(requestedAttributes)) @@ -129,9 +130,9 @@ private[parquet] object RowReadSupport { } /** - * A `parquet.hadoop.api.WriteSupport` for Row ojects. + * A `parquet.hadoop.api.WriteSupport` for Row objects. */ -private[parquet] class RowWriteSupport extends WriteSupport[Row] with Logging { +private[parquet] class RowWriteSupport extends WriteSupport[InternalRow] with Logging { private[parquet] var writer: RecordConsumer = null private[parquet] var attributes: Array[Attribute] = null @@ -155,7 +156,7 @@ private[parquet] class RowWriteSupport extends WriteSupport[Row] with Logging { log.debug(s"preparing for write with schema $attributes") } - override def write(record: Row): Unit = { + override def write(record: InternalRow): Unit = { val attributesSize = attributes.size if (attributesSize > record.size) { throw new IndexOutOfBoundsException( @@ -197,19 +198,18 @@ private[parquet] class RowWriteSupport extends WriteSupport[Row] with Logging { private[parquet] def writePrimitive(schema: DataType, value: Any): Unit = { if (value != null) { schema match { + case BooleanType => writer.addBoolean(value.asInstanceOf[Boolean]) + case ByteType => writer.addInteger(value.asInstanceOf[Byte]) + case ShortType => writer.addInteger(value.asInstanceOf[Short]) + case IntegerType | DateType => writer.addInteger(value.asInstanceOf[Int]) + case LongType => writer.addLong(value.asInstanceOf[Long]) + case TimestampType => writeTimestamp(value.asInstanceOf[Long]) + case FloatType => writer.addFloat(value.asInstanceOf[Float]) + case DoubleType => writer.addDouble(value.asInstanceOf[Double]) case StringType => writer.addBinary( Binary.fromByteArray(value.asInstanceOf[UTF8String].getBytes)) case BinaryType => writer.addBinary( Binary.fromByteArray(value.asInstanceOf[Array[Byte]])) - case IntegerType => writer.addInteger(value.asInstanceOf[Int]) - case ShortType => writer.addInteger(value.asInstanceOf[Short]) - case LongType => writer.addLong(value.asInstanceOf[Long]) - case TimestampType => writeTimestamp(value.asInstanceOf[java.sql.Timestamp]) - case ByteType => writer.addInteger(value.asInstanceOf[Byte]) - case DoubleType => writer.addDouble(value.asInstanceOf[Double]) - case FloatType => writer.addFloat(value.asInstanceOf[Float]) - case BooleanType => writer.addBoolean(value.asInstanceOf[Boolean]) - case DateType => writer.addInteger(value.asInstanceOf[Int]) case d: DecimalType => if (d.precisionInfo == None || d.precisionInfo.get.precision > 18) { sys.error(s"Unsupported datatype $d, cannot write to consumer") @@ -296,7 +296,7 @@ private[parquet] class RowWriteSupport extends WriteSupport[Row] with Logging { } // Scratch array used to write decimals as fixed-length binary - private val scratchBytes = new Array[Byte](8) + private[this] val scratchBytes = new Array[Byte](8) private[parquet] def writeDecimal(decimal: Decimal, precision: Int): Unit = { val numBytes = ParquetTypesConverter.BYTES_FOR_PRECISION(precision) @@ -311,15 +311,22 @@ private[parquet] class RowWriteSupport extends WriteSupport[Row] with Logging { writer.addBinary(Binary.fromByteArray(scratchBytes, 0, numBytes)) } - private[parquet] def writeTimestamp(ts: java.sql.Timestamp): Unit = { - val binaryNanoTime = CatalystTimestampConverter.convertFromTimestamp(ts) - writer.addBinary(binaryNanoTime) + // array used to write Timestamp as Int96 (fixed-length binary) + private[this] val int96buf = new Array[Byte](12) + + private[parquet] def writeTimestamp(ts: Long): Unit = { + val (julianDay, timeOfDayNanos) = DateTimeUtils.toJulianDay(ts) + val buf = ByteBuffer.wrap(int96buf) + buf.order(ByteOrder.LITTLE_ENDIAN) + buf.putLong(timeOfDayNanos) + buf.putInt(julianDay) + writer.addBinary(Binary.fromByteArray(int96buf)) } } // Optimized for non-nested rows private[parquet] class MutableRowWriteSupport extends RowWriteSupport { - override def write(record: Row): Unit = { + override def write(record: InternalRow): Unit = { val attributesSize = attributes.size if (attributesSize > record.size) { throw new IndexOutOfBoundsException( @@ -342,22 +349,21 @@ private[parquet] class MutableRowWriteSupport extends RowWriteSupport { private def consumeType( ctype: DataType, - record: Row, + record: InternalRow, index: Int): Unit = { ctype match { + case BooleanType => writer.addBoolean(record.getBoolean(index)) + case ByteType => writer.addInteger(record.getByte(index)) + case ShortType => writer.addInteger(record.getShort(index)) + case IntegerType | DateType => writer.addInteger(record.getInt(index)) + case LongType => writer.addLong(record.getLong(index)) + case TimestampType => writeTimestamp(record.getLong(index)) + case FloatType => writer.addFloat(record.getFloat(index)) + case DoubleType => writer.addDouble(record.getDouble(index)) case StringType => writer.addBinary( Binary.fromByteArray(record(index).asInstanceOf[UTF8String].getBytes)) case BinaryType => writer.addBinary( Binary.fromByteArray(record(index).asInstanceOf[Array[Byte]])) - case IntegerType => writer.addInteger(record.getInt(index)) - case ShortType => writer.addInteger(record.getShort(index)) - case LongType => writer.addLong(record.getLong(index)) - case ByteType => writer.addInteger(record.getByte(index)) - case DoubleType => writer.addDouble(record.getDouble(index)) - case FloatType => writer.addFloat(record.getFloat(index)) - case BooleanType => writer.addBoolean(record.getBoolean(index)) - case DateType => writer.addInteger(record.getInt(index)) - case TimestampType => writeTimestamp(record(index).asInstanceOf[java.sql.Timestamp]) case d: DecimalType => if (d.precisionInfo == None || d.precisionInfo.get.precision > 18) { sys.error(s"Unsupported datatype $d, cannot write to consumer") diff --git a/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTypes.scala b/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTypes.scala index 1dc819b5d7b9..e748bd7857bd 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTypes.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTypes.scala @@ -19,224 +19,27 @@ package org.apache.spark.sql.parquet import java.io.IOException -import scala.collection.mutable.ArrayBuffer +import scala.collection.JavaConversions._ import scala.util.Try import org.apache.hadoop.conf.Configuration import org.apache.hadoop.fs.{FileSystem, Path} import org.apache.hadoop.mapreduce.Job -import parquet.format.converter.ParquetMetadataConverter -import parquet.hadoop.metadata.{FileMetaData, ParquetMetadata} -import parquet.hadoop.util.ContextUtil -import parquet.hadoop.{Footer, ParquetFileReader, ParquetFileWriter} -import parquet.schema.PrimitiveType.{PrimitiveTypeName => ParquetPrimitiveTypeName} -import parquet.schema.Type.Repetition -import parquet.schema.{ConversionPatterns, DecimalMetadata, GroupType => ParquetGroupType, MessageType, OriginalType => ParquetOriginalType, PrimitiveType => ParquetPrimitiveType, Type => ParquetType, Types => ParquetTypes} - -import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeReference} +import org.apache.parquet.format.converter.ParquetMetadataConverter +import org.apache.parquet.hadoop.metadata.{FileMetaData, ParquetMetadata} +import org.apache.parquet.hadoop.util.ContextUtil +import org.apache.parquet.hadoop.{Footer, ParquetFileReader, ParquetFileWriter} +import org.apache.parquet.schema.MessageType + +import org.apache.spark.Logging +import org.apache.spark.sql.catalyst.expressions.Attribute import org.apache.spark.sql.types._ -import org.apache.spark.{Logging, SparkException} - -// Implicits -import scala.collection.JavaConversions._ -/** A class representing Parquet info fields we care about, for passing back to Parquet */ -private[parquet] case class ParquetTypeInfo( - primitiveType: ParquetPrimitiveTypeName, - originalType: Option[ParquetOriginalType] = None, - decimalMetadata: Option[DecimalMetadata] = None, - length: Option[Int] = None) private[parquet] object ParquetTypesConverter extends Logging { def isPrimitiveType(ctype: DataType): Boolean = ctype match { - case _: NumericType | BooleanType | StringType | BinaryType => true - case _: DataType => false - } - - def toPrimitiveDataType( - parquetType: ParquetPrimitiveType, - binaryAsString: Boolean, - int96AsTimestamp: Boolean): DataType = { - val originalType = parquetType.getOriginalType - val decimalInfo = parquetType.getDecimalMetadata - parquetType.getPrimitiveTypeName match { - case ParquetPrimitiveTypeName.BINARY - if (originalType == ParquetOriginalType.UTF8 || binaryAsString) => StringType - case ParquetPrimitiveTypeName.BINARY => BinaryType - case ParquetPrimitiveTypeName.BOOLEAN => BooleanType - case ParquetPrimitiveTypeName.DOUBLE => DoubleType - case ParquetPrimitiveTypeName.FLOAT => FloatType - case ParquetPrimitiveTypeName.INT32 - if originalType == ParquetOriginalType.DATE => DateType - case ParquetPrimitiveTypeName.INT32 => IntegerType - case ParquetPrimitiveTypeName.INT64 => LongType - case ParquetPrimitiveTypeName.INT96 if int96AsTimestamp => TimestampType - case ParquetPrimitiveTypeName.INT96 => - // TODO: add BigInteger type? TODO(andre) use DecimalType instead???? - sys.error("Potential loss of precision: cannot convert INT96") - case ParquetPrimitiveTypeName.FIXED_LEN_BYTE_ARRAY - if (originalType == ParquetOriginalType.DECIMAL && decimalInfo.getPrecision <= 18) => - // TODO: for now, our reader only supports decimals that fit in a Long - DecimalType(decimalInfo.getPrecision, decimalInfo.getScale) - case _ => sys.error( - s"Unsupported parquet datatype $parquetType") - } - } - - /** - * Converts a given Parquet `Type` into the corresponding - * [[org.apache.spark.sql.types.DataType]]. - * - * We apply the following conversion rules: - *

      - *
    • Primitive types are converter to the corresponding primitive type.
    • - *
    • Group types that have a single field that is itself a group, which has repetition - * level `REPEATED`, are treated as follows:
        - *
      • If the nested group has name `values`, the surrounding group is converted - * into an [[ArrayType]] with the corresponding field type (primitive or - * complex) as element type.
      • - *
      • If the nested group has name `map` and two fields (named `key` and `value`), - * the surrounding group is converted into a [[MapType]] - * with the corresponding key and value (value possibly complex) types. - * Note that we currently assume map values are not nullable.
      • - *
      • Other group types are converted into a [[StructType]] with the corresponding - * field types.
    • - *
    - * Note that fields are determined to be `nullable` if and only if their Parquet repetition - * level is not `REQUIRED`. - * - * @param parquetType The type to convert. - * @return The corresponding Catalyst type. - */ - def toDataType(parquetType: ParquetType, - isBinaryAsString: Boolean, - isInt96AsTimestamp: Boolean): DataType = { - def correspondsToMap(groupType: ParquetGroupType): Boolean = { - if (groupType.getFieldCount != 1 || groupType.getFields.apply(0).isPrimitive) { - false - } else { - // This mostly follows the convention in ``parquet.schema.ConversionPatterns`` - val keyValueGroup = groupType.getFields.apply(0).asGroupType() - keyValueGroup.getRepetition == Repetition.REPEATED && - keyValueGroup.getName == CatalystConverter.MAP_SCHEMA_NAME && - keyValueGroup.getFieldCount == 2 && - keyValueGroup.getFields.apply(0).getName == CatalystConverter.MAP_KEY_SCHEMA_NAME && - keyValueGroup.getFields.apply(1).getName == CatalystConverter.MAP_VALUE_SCHEMA_NAME - } - } - - def correspondsToArray(groupType: ParquetGroupType): Boolean = { - groupType.getFieldCount == 1 && - groupType.getFieldName(0) == CatalystConverter.ARRAY_ELEMENTS_SCHEMA_NAME && - groupType.getFields.apply(0).getRepetition == Repetition.REPEATED - } - - if (parquetType.isPrimitive) { - toPrimitiveDataType(parquetType.asPrimitiveType, isBinaryAsString, isInt96AsTimestamp) - } else { - val groupType = parquetType.asGroupType() - parquetType.getOriginalType match { - // if the schema was constructed programmatically there may be hints how to convert - // it inside the metadata via the OriginalType field - case ParquetOriginalType.LIST => { // TODO: check enums! - assert(groupType.getFieldCount == 1) - val field = groupType.getFields.apply(0) - if (field.getName == CatalystConverter.ARRAY_CONTAINS_NULL_BAG_SCHEMA_NAME) { - val bag = field.asGroupType() - assert(bag.getFieldCount == 1) - ArrayType( - toDataType(bag.getFields.apply(0), isBinaryAsString, isInt96AsTimestamp), - containsNull = true) - } else { - ArrayType( - toDataType(field, isBinaryAsString, isInt96AsTimestamp), containsNull = false) - } - } - case ParquetOriginalType.MAP => { - assert( - !groupType.getFields.apply(0).isPrimitive, - "Parquet Map type malformatted: expected nested group for map!") - val keyValueGroup = groupType.getFields.apply(0).asGroupType() - assert( - keyValueGroup.getFieldCount == 2, - "Parquet Map type malformatted: nested group should have 2 (key, value) fields!") - assert(keyValueGroup.getFields.apply(0).getRepetition == Repetition.REQUIRED) - - val keyType = - toDataType(keyValueGroup.getFields.apply(0), isBinaryAsString, isInt96AsTimestamp) - val valueType = - toDataType(keyValueGroup.getFields.apply(1), isBinaryAsString, isInt96AsTimestamp) - MapType(keyType, valueType, - keyValueGroup.getFields.apply(1).getRepetition != Repetition.REQUIRED) - } - case _ => { - // Note: the order of these checks is important! - if (correspondsToMap(groupType)) { // MapType - val keyValueGroup = groupType.getFields.apply(0).asGroupType() - assert(keyValueGroup.getFields.apply(0).getRepetition == Repetition.REQUIRED) - - val keyType = - toDataType(keyValueGroup.getFields.apply(0), isBinaryAsString, isInt96AsTimestamp) - val valueType = - toDataType(keyValueGroup.getFields.apply(1), isBinaryAsString, isInt96AsTimestamp) - MapType(keyType, valueType, - keyValueGroup.getFields.apply(1).getRepetition != Repetition.REQUIRED) - } else if (correspondsToArray(groupType)) { // ArrayType - val field = groupType.getFields.apply(0) - if (field.getName == CatalystConverter.ARRAY_CONTAINS_NULL_BAG_SCHEMA_NAME) { - val bag = field.asGroupType() - assert(bag.getFieldCount == 1) - ArrayType( - toDataType(bag.getFields.apply(0), isBinaryAsString, isInt96AsTimestamp), - containsNull = true) - } else { - ArrayType( - toDataType(field, isBinaryAsString, isInt96AsTimestamp), containsNull = false) - } - } else { // everything else: StructType - val fields = groupType - .getFields - .map(ptype => new StructField( - ptype.getName, - toDataType(ptype, isBinaryAsString, isInt96AsTimestamp), - ptype.getRepetition != Repetition.REQUIRED)) - StructType(fields) - } - } - } - } - } - - /** - * For a given Catalyst [[org.apache.spark.sql.types.DataType]] return - * the name of the corresponding Parquet primitive type or None if the given type - * is not primitive. - * - * @param ctype The type to convert - * @return The name of the corresponding Parquet type properties - */ - def fromPrimitiveDataType(ctype: DataType): Option[ParquetTypeInfo] = ctype match { - case StringType => Some(ParquetTypeInfo( - ParquetPrimitiveTypeName.BINARY, Some(ParquetOriginalType.UTF8))) - case BinaryType => Some(ParquetTypeInfo(ParquetPrimitiveTypeName.BINARY)) - case BooleanType => Some(ParquetTypeInfo(ParquetPrimitiveTypeName.BOOLEAN)) - case DoubleType => Some(ParquetTypeInfo(ParquetPrimitiveTypeName.DOUBLE)) - case FloatType => Some(ParquetTypeInfo(ParquetPrimitiveTypeName.FLOAT)) - case IntegerType => Some(ParquetTypeInfo(ParquetPrimitiveTypeName.INT32)) - // There is no type for Byte or Short so we promote them to INT32. - case ShortType => Some(ParquetTypeInfo(ParquetPrimitiveTypeName.INT32)) - case ByteType => Some(ParquetTypeInfo(ParquetPrimitiveTypeName.INT32)) - case DateType => Some(ParquetTypeInfo( - ParquetPrimitiveTypeName.INT32, Some(ParquetOriginalType.DATE))) - case LongType => Some(ParquetTypeInfo(ParquetPrimitiveTypeName.INT64)) - case TimestampType => Some(ParquetTypeInfo(ParquetPrimitiveTypeName.INT96)) - case DecimalType.Fixed(precision, scale) if precision <= 18 => - // TODO: for now, our writer only supports decimals that fit in a Long - Some(ParquetTypeInfo(ParquetPrimitiveTypeName.FIXED_LEN_BYTE_ARRAY, - Some(ParquetOriginalType.DECIMAL), - Some(new DecimalMetadata(precision, scale)), - Some(BYTES_FOR_PRECISION(precision)))) - case _ => None + case _: NumericType | BooleanType | DateType | TimestampType | StringType | BinaryType => true + case _ => false } /** @@ -250,154 +53,18 @@ private[parquet] object ParquetTypesConverter extends Logging { length } - /** - * Converts a given Catalyst [[org.apache.spark.sql.types.DataType]] into - * the corresponding Parquet `Type`. - * - * The conversion follows the rules below: - *
      - *
    • Primitive types are converted into Parquet's primitive types.
    • - *
    • [[org.apache.spark.sql.types.StructType]]s are converted - * into Parquet's `GroupType` with the corresponding field types.
    • - *
    • [[org.apache.spark.sql.types.ArrayType]]s are converted - * into a 2-level nested group, where the outer group has the inner - * group as sole field. The inner group has name `values` and - * repetition level `REPEATED` and has the element type of - * the array as schema. We use Parquet's `ConversionPatterns` for this - * purpose.
    • - *
    • [[org.apache.spark.sql.types.MapType]]s are converted - * into a nested (2-level) Parquet `GroupType` with two fields: a key - * type and a value type. The nested group has repetition level - * `REPEATED` and name `map`. We use Parquet's `ConversionPatterns` - * for this purpose
    • - *
    - * Parquet's repetition level is generally set according to the following rule: - *
      - *
    • If the call to `fromDataType` is recursive inside an enclosing `ArrayType` or - * `MapType`, then the repetition level is set to `REPEATED`.
    • - *
    • Otherwise, if the attribute whose type is converted is `nullable`, the Parquet - * type gets repetition level `OPTIONAL` and otherwise `REQUIRED`.
    • - *
    - * - *@param ctype The type to convert - * @param name The name of the [[org.apache.spark.sql.catalyst.expressions.Attribute]] - * whose type is converted - * @param nullable When true indicates that the attribute is nullable - * @param inArray When true indicates that this is a nested attribute inside an array. - * @return The corresponding Parquet type. - */ - def fromDataType( - ctype: DataType, - name: String, - nullable: Boolean = true, - inArray: Boolean = false, - toThriftSchemaNames: Boolean = false): ParquetType = { - val repetition = - if (inArray) { - Repetition.REPEATED - } else { - if (nullable) Repetition.OPTIONAL else Repetition.REQUIRED - } - val arraySchemaName = if (toThriftSchemaNames) { - name + CatalystConverter.THRIFT_ARRAY_ELEMENTS_SCHEMA_NAME_SUFFIX - } else { - CatalystConverter.ARRAY_ELEMENTS_SCHEMA_NAME - } - val typeInfo = fromPrimitiveDataType(ctype) - typeInfo.map { - case ParquetTypeInfo(primitiveType, originalType, decimalMetadata, length) => - val builder = ParquetTypes.primitive(primitiveType, repetition).as(originalType.orNull) - for (len <- length) { - builder.length(len) - } - for (metadata <- decimalMetadata) { - builder.precision(metadata.getPrecision).scale(metadata.getScale) - } - builder.named(name) - }.getOrElse { - ctype match { - case udt: UserDefinedType[_] => { - fromDataType(udt.sqlType, name, nullable, inArray, toThriftSchemaNames) - } - case ArrayType(elementType, false) => { - val parquetElementType = fromDataType( - elementType, - arraySchemaName, - nullable = false, - inArray = true, - toThriftSchemaNames) - ConversionPatterns.listType(repetition, name, parquetElementType) - } - case ArrayType(elementType, true) => { - val parquetElementType = fromDataType( - elementType, - arraySchemaName, - nullable = true, - inArray = false, - toThriftSchemaNames) - ConversionPatterns.listType( - repetition, - name, - new ParquetGroupType( - Repetition.REPEATED, - CatalystConverter.ARRAY_CONTAINS_NULL_BAG_SCHEMA_NAME, - parquetElementType)) - } - case StructType(structFields) => { - val fields = structFields.map { - field => fromDataType(field.dataType, field.name, field.nullable, - inArray = false, toThriftSchemaNames) - } - new ParquetGroupType(repetition, name, fields.toSeq) - } - case MapType(keyType, valueType, valueContainsNull) => { - val parquetKeyType = - fromDataType( - keyType, - CatalystConverter.MAP_KEY_SCHEMA_NAME, - nullable = false, - inArray = false, - toThriftSchemaNames) - val parquetValueType = - fromDataType( - valueType, - CatalystConverter.MAP_VALUE_SCHEMA_NAME, - nullable = valueContainsNull, - inArray = false, - toThriftSchemaNames) - ConversionPatterns.mapType( - repetition, - name, - parquetKeyType, - parquetValueType) - } - case _ => sys.error(s"Unsupported datatype $ctype") - } - } - } - - def convertToAttributes(parquetSchema: ParquetType, - isBinaryAsString: Boolean, - isInt96AsTimestamp: Boolean): Seq[Attribute] = { - parquetSchema - .asGroupType() - .getFields - .map( - field => - new AttributeReference( - field.getName, - toDataType(field, isBinaryAsString, isInt96AsTimestamp), - field.getRepetition != Repetition.REQUIRED)()) + def convertToAttributes( + parquetSchema: MessageType, + isBinaryAsString: Boolean, + isInt96AsTimestamp: Boolean): Seq[Attribute] = { + val converter = new CatalystSchemaConverter( + isBinaryAsString, isInt96AsTimestamp, followParquetFormatSpec = false) + converter.convert(parquetSchema).toAttributes } - def convertFromAttributes(attributes: Seq[Attribute], - toThriftSchemaNames: Boolean = false): MessageType = { - checkSpecialCharacters(attributes) - val fields = attributes.map( - attribute => - fromDataType(attribute.dataType, attribute.name, attribute.nullable, - toThriftSchemaNames = toThriftSchemaNames)) - new MessageType("root", fields) + def convertFromAttributes(attributes: Seq[Attribute]): MessageType = { + val converter = new CatalystSchemaConverter() + converter.convert(StructType.fromAttributes(attributes)) } def convertFromString(string: String): Seq[Attribute] = { @@ -407,20 +74,8 @@ private[parquet] object ParquetTypesConverter extends Logging { } } - private def checkSpecialCharacters(schema: Seq[Attribute]) = { - // ,;{}()\n\t= and space character are special characters in Parquet schema - schema.map(_.name).foreach { name => - if (name.matches(".*[ ,;{}()\n\t=].*")) { - sys.error( - s"""Attribute name "$name" contains invalid character(s) among " ,;{}()\n\t=". - |Please use alias to rename it. - """.stripMargin.split("\n").mkString(" ")) - } - } - } - def convertToString(schema: Seq[Attribute]): String = { - checkSpecialCharacters(schema) + schema.map(_.name).foreach(CatalystSchemaConverter.checkFieldName) StructType.fromAttributes(schema).json } @@ -452,8 +107,7 @@ private[parquet] object ParquetTypesConverter extends Logging { ParquetTypesConverter.convertToString(attributes)) // TODO: add extra data, e.g., table name, date, etc.? - val parquetSchema: MessageType = - ParquetTypesConverter.convertFromAttributes(attributes) + val parquetSchema: MessageType = ParquetTypesConverter.convertFromAttributes(attributes) val metaData: FileMetaData = new FileMetaData( parquetSchema, extraMetadata, @@ -489,7 +143,7 @@ private[parquet] object ParquetTypesConverter extends Logging { val children = fs .globStatus(path) - .flatMap { status => if(status.isDir) fs.listStatus(status.getPath) else List(status) } + .flatMap { status => if (status.isDir) fs.listStatus(status.getPath) else List(status) } .filterNot { status => val name = status.getPath.getName (name(0) == '.' || name(0) == '_') && name != ParquetFileWriter.PARQUET_METADATA_FILE diff --git a/sql/core/src/main/scala/org/apache/spark/sql/parquet/newParquet.scala b/sql/core/src/main/scala/org/apache/spark/sql/parquet/newParquet.scala index 946062f6ea64..5ac3e9a44e6f 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/parquet/newParquet.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/parquet/newParquet.scala @@ -17,28 +17,31 @@ package org.apache.spark.sql.parquet +import java.net.URI import java.util.{List => JList} import scala.collection.JavaConversions._ import scala.util.Try import com.google.common.base.Objects -import org.apache.hadoop.fs.{FileStatus, FileSystem, Path} +import org.apache.hadoop.fs.{FileStatus, Path} import org.apache.hadoop.io.Writable import org.apache.hadoop.mapreduce._ import org.apache.hadoop.mapreduce.lib.input.FileInputFormat -import parquet.filter2.predicate.FilterApi -import parquet.format.converter.ParquetMetadataConverter -import parquet.hadoop._ -import parquet.hadoop.metadata.CompressionCodecName -import parquet.hadoop.util.ContextUtil +import org.apache.parquet.filter2.predicate.FilterApi +import org.apache.parquet.hadoop._ +import org.apache.parquet.hadoop.metadata.CompressionCodecName +import org.apache.parquet.hadoop.util.ContextUtil +import org.apache.spark.broadcast.Broadcast import org.apache.spark.deploy.SparkHadoopUtil +import org.apache.spark.rdd.RDD import org.apache.spark.rdd.RDD._ -import org.apache.spark.rdd.{NewHadoopPartition, NewHadoopRDD, RDD} +import org.apache.spark.sql._ +import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.sources._ import org.apache.spark.sql.types.{DataType, StructType} -import org.apache.spark.sql.{Row, SQLConf, SQLContext} +import org.apache.spark.util.{SerializableConfiguration, Utils} import org.apache.spark.{Logging, Partition => SparkPartition, SparkException} private[sql] class DefaultSource extends HadoopFsRelationProvider { @@ -48,8 +51,7 @@ private[sql] class DefaultSource extends HadoopFsRelationProvider { schema: Option[StructType], partitionColumns: Option[StructType], parameters: Map[String, String]): HadoopFsRelation = { - val partitionSpec = partitionColumns.map(PartitionSpec(_, Seq.empty)) - new ParquetRelation2(paths, schema, partitionSpec, parameters)(sqlContext) + new ParquetRelation2(paths, schema, None, partitionColumns, parameters)(sqlContext) } } @@ -57,51 +59,22 @@ private[sql] class DefaultSource extends HadoopFsRelationProvider { private[sql] class ParquetOutputWriter(path: String, context: TaskAttemptContext) extends OutputWriter { - private val recordWriter: RecordWriter[Void, Row] = { - val conf = context.getConfiguration + private val recordWriter: RecordWriter[Void, InternalRow] = { val outputFormat = { - // When appending new Parquet files to an existing Parquet file directory, to avoid - // overwriting existing data files, we need to find out the max task ID encoded in these data - // file names. - // TODO Make this snippet a utility function for other data source developers - val maxExistingTaskId = { - // Note that `path` may point to a temporary location. Here we retrieve the real - // destination path from the configuration - val outputPath = new Path(conf.get("spark.sql.sources.output.path")) - val fs = outputPath.getFileSystem(conf) - - if (fs.exists(outputPath)) { - // Pattern used to match task ID in part file names, e.g.: - // - // part-r-00001.gz.parquet - // ^~~~~ - val partFilePattern = """part-.-(\d{1,}).*""".r - - fs.listStatus(outputPath).map(_.getPath.getName).map { - case partFilePattern(id) => id.toInt - case name if name.startsWith("_") => 0 - case name if name.startsWith(".") => 0 - case name => sys.error( - s"Trying to write Parquet files to directory $outputPath, " + - s"but found items with illegal name '$name'.") - }.reduceOption(_ max _).getOrElse(0) - } else { - 0 - } - } - - new ParquetOutputFormat[Row]() { + new ParquetOutputFormat[InternalRow]() { // Here we override `getDefaultWorkFile` for two reasons: // - // 1. To allow appending. We need to generate output file name based on the max available - // task ID computed above. + // 1. To allow appending. We need to generate unique output file names to avoid + // overwriting existing files (either exist before the write job, or are just written + // by other tasks within the same write job). // // 2. To allow dynamic partitioning. Default `getDefaultWorkFile` uses // `FileOutputCommitter.getWorkPath()`, which points to the base directory of all // partitions in the case of dynamic partitioning. override def getDefaultWorkFile(context: TaskAttemptContext, extension: String): Path = { - val split = context.getTaskAttemptID.getTaskID.getId + maxExistingTaskId + 1 - new Path(path, f"part-r-$split%05d$extension") + val uniqueWriteJobId = context.getConfiguration.get("spark.sql.sources.writeJobUUID") + val split = context.getTaskAttemptID.getTaskID.getId + new Path(path, f"part-r-$split%05d-$uniqueWriteJobId$extension") } } } @@ -109,7 +82,7 @@ private[sql] class ParquetOutputWriter(path: String, context: TaskAttemptContext outputFormat.getRecordWriter(context) } - override def write(row: Row): Unit = recordWriter.write(null, row) + override def write(row: Row): Unit = recordWriter.write(null, row.asInstanceOf[InternalRow]) override def close(): Unit = recordWriter.close(context) } @@ -117,15 +90,34 @@ private[sql] class ParquetOutputWriter(path: String, context: TaskAttemptContext private[sql] class ParquetRelation2( override val paths: Array[String], private val maybeDataSchema: Option[StructType], + // This is for metastore conversion. private val maybePartitionSpec: Option[PartitionSpec], + override val userDefinedPartitionColumns: Option[StructType], parameters: Map[String, String])( val sqlContext: SQLContext) extends HadoopFsRelation(maybePartitionSpec) with Logging { + private[sql] def this( + paths: Array[String], + maybeDataSchema: Option[StructType], + maybePartitionSpec: Option[PartitionSpec], + parameters: Map[String, String])( + sqlContext: SQLContext) = { + this( + paths, + maybeDataSchema, + maybePartitionSpec, + maybePartitionSpec.map(_.partitionColumns), + parameters)(sqlContext) + } + // Should we merge schemas from all Parquet part-files? private val shouldMergeSchemas = - parameters.getOrElse(ParquetRelation2.MERGE_SCHEMA, "true").toBoolean + parameters + .get(ParquetRelation2.MERGE_SCHEMA) + .map(_.toBoolean) + .getOrElse(sqlContext.conf.getConf(SQLConf.PARQUET_SCHEMA_MERGING_ENABLED)) private val maybeMetastoreSchema = parameters .get(ParquetRelation2.METASTORE_SCHEMA) @@ -137,7 +129,7 @@ private[sql] class ParquetRelation2( meta } - override def equals(other: scala.Any): Boolean = other match { + override def equals(other: Any): Boolean = other match { case that: ParquetRelation2 => val schemaEquality = if (shouldMergeSchemas) { this.shouldMergeSchemas == that.shouldMergeSchemas @@ -160,7 +152,7 @@ private[sql] class ParquetRelation2( Boolean.box(shouldMergeSchemas), paths.toSet, maybeDataSchema, - maybePartitionSpec) + partitionColumns) } else { Objects.hashCode( Boolean.box(shouldMergeSchemas), @@ -168,15 +160,15 @@ private[sql] class ParquetRelation2( dataSchema, schema, maybeDataSchema, - maybePartitionSpec) + partitionColumns) } } - override def dataSchema: StructType = metadataCache.dataSchema + override def dataSchema: StructType = maybeDataSchema.getOrElse(metadataCache.dataSchema) override private[sql] def refresh(): Unit = { - metadataCache.refresh() super.refresh() + metadataCache.refresh() } // Parquet data source always uses Catalyst internal representations. @@ -184,23 +176,33 @@ private[sql] class ParquetRelation2( override def sizeInBytes: Long = metadataCache.dataStatuses.map(_.getLen).sum - override def userDefinedPartitionColumns: Option[StructType] = - maybePartitionSpec.map(_.partitionColumns) - override def prepareJobForWrite(job: Job): OutputWriterFactory = { val conf = ContextUtil.getConfiguration(job) val committerClass = conf.getClass( - "spark.sql.parquet.output.committer.class", + SQLConf.PARQUET_OUTPUT_COMMITTER_CLASS.key, classOf[ParquetOutputCommitter], classOf[ParquetOutputCommitter]) + if (conf.get(SQLConf.PARQUET_OUTPUT_COMMITTER_CLASS.key) == null) { + logInfo("Using default output committer for Parquet: " + + classOf[ParquetOutputCommitter].getCanonicalName) + } else { + logInfo("Using user defined output committer for Parquet: " + committerClass.getCanonicalName) + } + conf.setClass( - "mapred.output.committer.class", + SQLConf.OUTPUT_COMMITTER_CLASS.key, committerClass, classOf[ParquetOutputCommitter]) + // We're not really using `ParquetOutputFormat[Row]` for writing data here, because we override + // it in `ParquetOutputWriter` to support appending and dynamic partitioning. The reason why + // we set it here is to setup the output committer class to `ParquetOutputCommitter`, which is + // bundled with `ParquetOutputFormat[Row]`. + job.setOutputFormatClass(classOf[ParquetOutputFormat[Row]]) + // TODO There's no need to use two kinds of WriteSupport // We should unify them. `SpecificMutableRow` can process both atomic (primitive) types and // complex types. @@ -234,95 +236,82 @@ private[sql] class ParquetRelation2( override def buildScan( requiredColumns: Array[String], filters: Array[Filter], - inputPaths: Array[String]): RDD[Row] = { - - val job = new Job(SparkHadoopUtil.get.conf) - val conf = ContextUtil.getConfiguration(job) - - ParquetInputFormat.setReadSupportClass(job, classOf[RowReadSupport]) - - if (inputPaths.nonEmpty) { - FileInputFormat.setInputPaths(job, inputPaths.map(new Path(_)): _*) - } - - // Try to push down filters when filter push-down is enabled. - if (sqlContext.conf.parquetFilterPushDown) { - filters - // Collects all converted Parquet filter predicates. Notice that not all predicates can be - // converted (`ParquetFilters.createFilter` returns an `Option`). That's why a `flatMap` - // is used here. - .flatMap(ParquetFilters.createFilter(dataSchema, _)) - .reduceOption(FilterApi.and) - .foreach(ParquetInputFormat.setFilterPredicate(conf, _)) - } - - conf.set(RowReadSupport.SPARK_ROW_REQUESTED_SCHEMA, { - val requestedSchema = StructType(requiredColumns.map(dataSchema(_))) - ParquetTypesConverter.convertToString(requestedSchema.toAttributes) - }) - - conf.set( - RowWriteSupport.SPARK_ROW_SCHEMA, - ParquetTypesConverter.convertToString(dataSchema.toAttributes)) + inputFiles: Array[FileStatus], + broadcastedConf: Broadcast[SerializableConfiguration]): RDD[Row] = { + val useMetadataCache = sqlContext.getConf(SQLConf.PARQUET_CACHE_METADATA) + val parquetFilterPushDown = sqlContext.conf.parquetFilterPushDown + // Create the function to set variable Parquet confs at both driver and executor side. + val initLocalJobFuncOpt = + ParquetRelation2.initializeLocalJobFunc( + requiredColumns, + filters, + dataSchema, + useMetadataCache, + parquetFilterPushDown) _ + // Create the function to set input paths at the driver side. + val setInputPaths = ParquetRelation2.initializeDriverSideJobFunc(inputFiles) _ + + val footers = inputFiles.map(f => metadataCache.footers(f.getPath)) + + Utils.withDummyCallSite(sqlContext.sparkContext) { + // TODO Stop using `FilteringParquetRowInputFormat` and overriding `getPartition`. + // After upgrading to Parquet 1.6.0, we should be able to stop caching `FileStatus` objects + // and footers. Especially when a global arbitrative schema (either from metastore or data + // source DDL) is available. + new SqlNewHadoopRDD( + sc = sqlContext.sparkContext, + broadcastedConf = broadcastedConf, + initDriverSideJobFuncOpt = Some(setInputPaths), + initLocalJobFuncOpt = Some(initLocalJobFuncOpt), + inputFormatClass = classOf[FilteringParquetRowInputFormat], + keyClass = classOf[Void], + valueClass = classOf[InternalRow]) { + + val cacheMetadata = useMetadataCache + + @transient val cachedStatuses = inputFiles.map { f => + // In order to encode the authority of a Path containing special characters such as '/' + // (which does happen in some S3N credentials), we need to use the string returned by the + // URI of the path to create a new Path. + val pathWithEscapedAuthority = escapePathUserInfo(f.getPath) + new FileStatus( + f.getLen, f.isDir, f.getReplication, f.getBlockSize, f.getModificationTime, + f.getAccessTime, f.getPermission, f.getOwner, f.getGroup, pathWithEscapedAuthority) + }.toSeq + + @transient val cachedFooters = footers.map { f => + // In order to encode the authority of a Path containing special characters such as /, + // we need to use the string returned by the URI of the path to create a new Path. + new Footer(escapePathUserInfo(f.getFile), f.getParquetMetadata) + }.toSeq + + private def escapePathUserInfo(path: Path): Path = { + val uri = path.toUri + new Path(new URI( + uri.getScheme, uri.getRawUserInfo, uri.getHost, uri.getPort, uri.getPath, + uri.getQuery, uri.getFragment)) + } - // Tell FilteringParquetRowInputFormat whether it's okay to cache Parquet and FS metadata - val useMetadataCache = sqlContext.getConf(SQLConf.PARQUET_CACHE_METADATA, "true").toBoolean - conf.set(SQLConf.PARQUET_CACHE_METADATA, useMetadataCache.toString) - - val inputFileStatuses = - metadataCache.dataStatuses.filter(f => inputPaths.contains(f.getPath.toString)) - - val footers = inputFileStatuses.map(metadataCache.footers) - - // TODO Stop using `FilteringParquetRowInputFormat` and overriding `getPartition`. - // After upgrading to Parquet 1.6.0, we should be able to stop caching `FileStatus` objects and - // footers. Especially when a global arbitrative schema (either from metastore or data source - // DDL) is available. - new NewHadoopRDD( - sqlContext.sparkContext, - classOf[FilteringParquetRowInputFormat], - classOf[Void], - classOf[Row], - conf) { - - val cacheMetadata = useMetadataCache - - @transient val cachedStatuses = inputFileStatuses.map { f => - // In order to encode the authority of a Path containing special characters such as /, - // we need to use the string returned by the URI of the path to create a new Path. - val pathWithAuthority = new Path(f.getPath.toUri.toString) - - new FileStatus( - f.getLen, f.isDir, f.getReplication, f.getBlockSize, f.getModificationTime, - f.getAccessTime, f.getPermission, f.getOwner, f.getGroup, pathWithAuthority) - }.toSeq - - @transient val cachedFooters = footers.map { f => - // In order to encode the authority of a Path containing special characters such as /, - // we need to use the string returned by the URI of the path to create a new Path. - new Footer(new Path(f.getFile.toUri.toString), f.getParquetMetadata) - }.toSeq - - // Overridden so we can inject our own cached files statuses. - override def getPartitions: Array[SparkPartition] = { - val inputFormat = if (cacheMetadata) { - new FilteringParquetRowInputFormat { - override def listStatus(jobContext: JobContext): JList[FileStatus] = cachedStatuses - - override def getFooters(jobContext: JobContext): JList[Footer] = cachedFooters + // Overridden so we can inject our own cached files statuses. + override def getPartitions: Array[SparkPartition] = { + val inputFormat = if (cacheMetadata) { + new FilteringParquetRowInputFormat { + override def listStatus(jobContext: JobContext): JList[FileStatus] = cachedStatuses + override def getFooters(jobContext: JobContext): JList[Footer] = cachedFooters + } + } else { + new FilteringParquetRowInputFormat } - } else { - new FilteringParquetRowInputFormat - } - val jobContext = newJobContext(getConf, jobId) - val rawSplits = inputFormat.getSplits(jobContext) + val jobContext = newJobContext(getConf(isDriverSide = true), jobId) + val rawSplits = inputFormat.getSplits(jobContext) - Array.tabulate[SparkPartition](rawSplits.size) { i => - new NewHadoopPartition(id, i, rawSplits(i).asInstanceOf[InputSplit with Writable]) + Array.tabulate[SparkPartition](rawSplits.size) { i => + new SqlNewHadoopPartition(id, i, rawSplits(i).asInstanceOf[InputSplit with Writable]) + } } - } - }.values + }.values.map(_.asInstanceOf[Row]) + } } private class MetadataCache { @@ -333,14 +322,14 @@ private[sql] class ParquetRelation2( private var commonMetadataStatuses: Array[FileStatus] = _ // Parquet footer cache. - var footers: Map[FileStatus, Footer] = _ + var footers: Map[Path, Footer] = _ // `FileStatus` objects of all data files (Parquet part-files). var dataStatuses: Array[FileStatus] = _ // Schema of the actual Parquet files, without partition columns discovered from partition // directory paths. - var dataSchema: StructType = _ + var dataSchema: StructType = null // Schema of the whole table, including partition columns. var schema: StructType = _ @@ -349,49 +338,49 @@ private[sql] class ParquetRelation2( * Refreshes `FileStatus`es, footers, partition spec, and table schema. */ def refresh(): Unit = { - // Support either reading a collection of raw Parquet part-files, or a collection of folders - // containing Parquet files (e.g. partitioned Parquet table). - val baseStatuses = paths.distinct.flatMap { p => - val path = new Path(p) - val fs = path.getFileSystem(sqlContext.sparkContext.hadoopConfiguration) - val qualified = path.makeQualified(fs.getUri, fs.getWorkingDirectory) - Try(fs.getFileStatus(qualified)).toOption - } - assert(baseStatuses.forall(!_.isDir) || baseStatuses.forall(_.isDir)) - // Lists `FileStatus`es of all leaf nodes (files) under all base directories. - val leaves = baseStatuses.flatMap { f => - val fs = FileSystem.get(f.getPath.toUri, SparkHadoopUtil.get.conf) - SparkHadoopUtil.get.listLeafStatuses(fs, f.getPath).filter { f => - isSummaryFile(f.getPath) || - !(f.getPath.getName.startsWith("_") || f.getPath.getName.startsWith(".")) - } - } + val leaves = cachedLeafStatuses().filter { f => + isSummaryFile(f.getPath) || + !(f.getPath.getName.startsWith("_") || f.getPath.getName.startsWith(".")) + }.toArray dataStatuses = leaves.filterNot(f => isSummaryFile(f.getPath)) metadataStatuses = leaves.filter(_.getPath.getName == ParquetFileWriter.PARQUET_METADATA_FILE) commonMetadataStatuses = leaves.filter(_.getPath.getName == ParquetFileWriter.PARQUET_COMMON_METADATA_FILE) - footers = (dataStatuses ++ metadataStatuses ++ commonMetadataStatuses).par.map { f => - val parquetMetadata = ParquetFileReader.readFooter( - SparkHadoopUtil.get.conf, f, ParquetMetadataConverter.NO_FILTER) - f -> new Footer(f.getPath, parquetMetadata) - }.seq.toMap + footers = { + val conf = SparkHadoopUtil.get.conf + val taskSideMetaData = conf.getBoolean(ParquetInputFormat.TASK_SIDE_METADATA, true) + val rawFooters = if (shouldMergeSchemas) { + ParquetFileReader.readAllFootersInParallel( + conf, seqAsJavaList(leaves), taskSideMetaData) + } else { + ParquetFileReader.readAllFootersInParallelUsingSummaryFiles( + conf, seqAsJavaList(leaves), taskSideMetaData) + } - dataSchema = { - val dataSchema0 = - maybeDataSchema + rawFooters.map(footer => footer.getFile -> footer).toMap + } + + // If we already get the schema, don't need to re-compute it since the schema merging is + // time-consuming. + if (dataSchema == null) { + dataSchema = { + val dataSchema0 = maybeDataSchema .orElse(readSchema()) .orElse(maybeMetastoreSchema) - .getOrElse(sys.error("Failed to get the schema.")) - - // If this Parquet relation is converted from a Hive Metastore table, must reconcile case - // case insensitivity issue and possible schema mismatch (probably caused by schema - // evolution). - maybeMetastoreSchema - .map(ParquetRelation2.mergeMetastoreParquetSchema(_, dataSchema0)) - .getOrElse(dataSchema0) + .getOrElse(throw new AnalysisException( + s"Failed to discover schema of Parquet file(s) in the following location(s):\n" + + paths.mkString("\n\t"))) + + // If this Parquet relation is converted from a Hive Metastore table, must reconcile case + // case insensitivity issue and possible schema mismatch (probably caused by schema + // evolution). + maybeMetastoreSchema + .map(ParquetRelation2.mergeMetastoreParquetSchema(_, dataSchema0)) + .getOrElse(dataSchema0) + } } } @@ -444,7 +433,7 @@ private[sql] class ParquetRelation2( "No schema defined, " + s"and no Parquet data file or summary file found under ${paths.mkString(", ")}.") - ParquetRelation2.readSchema(filesToTouch.map(footers.apply), sqlContext) + ParquetRelation2.readSchema(filesToTouch.map(f => footers.apply(f.getPath)), sqlContext) } } } @@ -457,6 +446,49 @@ private[sql] object ParquetRelation2 extends Logging { // internally. private[sql] val METASTORE_SCHEMA = "metastoreSchema" + /** This closure sets various Parquet configurations at both driver side and executor side. */ + private[parquet] def initializeLocalJobFunc( + requiredColumns: Array[String], + filters: Array[Filter], + dataSchema: StructType, + useMetadataCache: Boolean, + parquetFilterPushDown: Boolean)(job: Job): Unit = { + val conf = job.getConfiguration + conf.set(ParquetInputFormat.READ_SUPPORT_CLASS, classOf[RowReadSupport].getName()) + + // Try to push down filters when filter push-down is enabled. + if (parquetFilterPushDown) { + filters + // Collects all converted Parquet filter predicates. Notice that not all predicates can be + // converted (`ParquetFilters.createFilter` returns an `Option`). That's why a `flatMap` + // is used here. + .flatMap(ParquetFilters.createFilter(dataSchema, _)) + .reduceOption(FilterApi.and) + .foreach(ParquetInputFormat.setFilterPredicate(conf, _)) + } + + conf.set(RowReadSupport.SPARK_ROW_REQUESTED_SCHEMA, { + val requestedSchema = StructType(requiredColumns.map(dataSchema(_))) + ParquetTypesConverter.convertToString(requestedSchema.toAttributes) + }) + + conf.set( + RowWriteSupport.SPARK_ROW_SCHEMA, + ParquetTypesConverter.convertToString(dataSchema.toAttributes)) + + // Tell FilteringParquetRowInputFormat whether it's okay to cache Parquet and FS metadata + conf.setBoolean(SQLConf.PARQUET_CACHE_METADATA.key, useMetadataCache) + } + + /** This closure sets input paths at the driver side. */ + private[parquet] def initializeDriverSideJobFunc( + inputFiles: Array[FileStatus])(job: Job): Unit = { + // We side the input paths at the driver side. + if (inputFiles.nonEmpty) { + FileInputFormat.setInputPaths(job, inputFiles.map(_.getPath): _*) + } + } + private[parquet] def readSchema( footers: Seq[Footer], sqlContext: SQLContext): Option[StructType] = { footers.map { footer => diff --git a/sql/core/src/main/scala/org/apache/spark/sql/parquet/timestamp/NanoTime.scala b/sql/core/src/main/scala/org/apache/spark/sql/parquet/timestamp/NanoTime.scala deleted file mode 100644 index 70bcca7526aa..000000000000 --- a/sql/core/src/main/scala/org/apache/spark/sql/parquet/timestamp/NanoTime.scala +++ /dev/null @@ -1,69 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.sql.parquet.timestamp - -import java.nio.{ByteBuffer, ByteOrder} - -import parquet.Preconditions -import parquet.io.api.{Binary, RecordConsumer} - -private[parquet] class NanoTime extends Serializable { - private var julianDay = 0 - private var timeOfDayNanos = 0L - - def set(julianDay: Int, timeOfDayNanos: Long): this.type = { - this.julianDay = julianDay - this.timeOfDayNanos = timeOfDayNanos - this - } - - def getJulianDay: Int = julianDay - - def getTimeOfDayNanos: Long = timeOfDayNanos - - def toBinary: Binary = { - val buf = ByteBuffer.allocate(12) - buf.order(ByteOrder.LITTLE_ENDIAN) - buf.putLong(timeOfDayNanos) - buf.putInt(julianDay) - buf.flip() - Binary.fromByteBuffer(buf) - } - - def writeValue(recordConsumer: RecordConsumer): Unit = { - recordConsumer.addBinary(toBinary) - } - - override def toString: String = - "NanoTime{julianDay=" + julianDay + ", timeOfDayNanos=" + timeOfDayNanos + "}" -} - -private[sql] object NanoTime { - def fromBinary(bytes: Binary): NanoTime = { - Preconditions.checkArgument(bytes.length() == 12, "Must be 12 bytes") - val buf = bytes.toByteBuffer - buf.order(ByteOrder.LITTLE_ENDIAN) - val timeOfDayNanos = buf.getLong - val julianDay = buf.getInt - new NanoTime().set(julianDay, timeOfDayNanos) - } - - def apply(julianDay: Int, timeOfDayNanos: Long): NanoTime = { - new NanoTime().set(julianDay, timeOfDayNanos) - } -} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/sources/DataSourceStrategy.scala b/sql/core/src/main/scala/org/apache/spark/sql/sources/DataSourceStrategy.scala index e6324b20b306..66f7ba90140b 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/sources/DataSourceStrategy.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/sources/DataSourceStrategy.scala @@ -17,56 +17,55 @@ package org.apache.spark.sql.sources -import org.apache.hadoop.fs.Path - -import org.apache.spark.Logging +import org.apache.spark.{Logging, TaskContext} import org.apache.spark.deploy.SparkHadoopUtil -import org.apache.spark.rdd.{UnionRDD, RDD} -import org.apache.spark.sql.Row +import org.apache.spark.rdd.{MapPartitionsRDD, RDD, UnionRDD} +import org.apache.spark.sql._ import org.apache.spark.sql.catalyst.expressions import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.planning.PhysicalOperation import org.apache.spark.sql.catalyst.plans.logical import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan -import org.apache.spark.sql.execution.SparkPlan -import org.apache.spark.sql.types.{StructType, UTF8String, StringType} -import org.apache.spark.sql._ +import org.apache.spark.sql.types.{StringType, StructType} +import org.apache.spark.sql.{SaveMode, Strategy, execution, sources} +import org.apache.spark.util.{SerializableConfiguration, Utils} +import org.apache.spark.unsafe.types.UTF8String /** * A Strategy for planning scans over data sources defined using the sources API. */ private[sql] object DataSourceStrategy extends Strategy with Logging { def apply(plan: LogicalPlan): Seq[execution.SparkPlan] = plan match { - case PhysicalOperation(projectList, filters, l @ LogicalRelation(t: CatalystScan)) => + case PhysicalOperation(projects, filters, l @ LogicalRelation(t: CatalystScan)) => pruneFilterProjectRaw( l, - projectList, + projects, filters, - (a, f) => t.buildScan(a, f)) :: Nil + (a, f) => toCatalystRDD(l, a, t.buildScan(a, f))) :: Nil - case PhysicalOperation(projectList, filters, l @ LogicalRelation(t: PrunedFilteredScan)) => + case PhysicalOperation(projects, filters, l @ LogicalRelation(t: PrunedFilteredScan)) => pruneFilterProject( l, - projectList, + projects, filters, - (a, f) => t.buildScan(a, f)) :: Nil + (a, f) => toCatalystRDD(l, a, t.buildScan(a.map(_.name).toArray, f))) :: Nil - case PhysicalOperation(projectList, filters, l @ LogicalRelation(t: PrunedScan)) => + case PhysicalOperation(projects, filters, l @ LogicalRelation(t: PrunedScan)) => pruneFilterProject( l, - projectList, + projects, filters, - (a, _) => t.buildScan(a)) :: Nil + (a, _) => toCatalystRDD(l, a, t.buildScan(a.map(_.name).toArray))) :: Nil - // Scanning partitioned FSBasedRelation - case PhysicalOperation(projectList, filters, l @ LogicalRelation(t: HadoopFsRelation)) + // Scanning partitioned HadoopFsRelation + case PhysicalOperation(projects, filters, l @ LogicalRelation(t: HadoopFsRelation)) if t.partitionSpec.partitionColumns.nonEmpty => val selectedPartitions = prunePartitions(filters, t.partitionSpec).toArray logInfo { val total = t.partitionSpec.partitions.length val selected = selectedPartitions.length - val percentPruned = (1 - total.toDouble / selected.toDouble) * 100 + val percentPruned = (1 - selected.toDouble / total.toDouble) * 100 s"Selected $selected partitions out of $total, pruned $percentPruned% partitions." } @@ -81,40 +80,36 @@ private[sql] object DataSourceStrategy extends Strategy with Logging { buildPartitionedTableScan( l, - projectList, + projects, pushedFilters, t.partitionSpec.partitionColumns, selectedPartitions) :: Nil - // Scanning non-partitioned FSBasedRelation - case PhysicalOperation(projectList, filters, l @ LogicalRelation(t: HadoopFsRelation)) => - val inputPaths = t.paths.map(new Path(_)).flatMap { path => - val fs = path.getFileSystem(t.sqlContext.sparkContext.hadoopConfiguration) - val qualifiedPath = path.makeQualified(fs.getUri, fs.getWorkingDirectory) - SparkHadoopUtil.get.listLeafStatuses(fs, qualifiedPath).map(_.getPath).filterNot { path => - val name = path.getName - name.startsWith("_") || name.startsWith(".") - }.map(fs.makeQualified(_).toString) - } - + // Scanning non-partitioned HadoopFsRelation + case PhysicalOperation(projects, filters, l @ LogicalRelation(t: HadoopFsRelation)) => + // See buildPartitionedTableScan for the reason that we need to create a shard + // broadcast HadoopConf. + val sharedHadoopConf = SparkHadoopUtil.get.conf + val confBroadcast = + t.sqlContext.sparkContext.broadcast(new SerializableConfiguration(sharedHadoopConf)) pruneFilterProject( l, - projectList, + projects, filters, - (a, f) => t.buildScan(a, f, inputPaths)) :: Nil + (a, f) => + toCatalystRDD(l, a, t.buildScan(a.map(_.name).toArray, f, t.paths, confBroadcast))) :: Nil case l @ LogicalRelation(t: TableScan) => - createPhysicalRDD(l.relation, l.output, t.buildScan()) :: Nil + execution.PhysicalRDD(l.output, toCatalystRDD(l, t.buildScan())) :: Nil case i @ logical.InsertIntoTable( l @ LogicalRelation(t: InsertableRelation), part, query, overwrite, false) if part.isEmpty => execution.ExecutedCommand(InsertIntoDataSource(l, query, overwrite)) :: Nil case i @ logical.InsertIntoTable( - l @ LogicalRelation(t: HadoopFsRelation), part, query, overwrite, false) if part.isEmpty => + l @ LogicalRelation(t: HadoopFsRelation), part, query, overwrite, false) => val mode = if (overwrite) SaveMode.Overwrite else SaveMode.Append - execution.ExecutedCommand( - InsertIntoHadoopFsRelation(t, query, Array.empty[String], mode)) :: Nil + execution.ExecutedCommand(InsertIntoHadoopFsRelation(t, query, mode)) :: Nil case _ => Nil } @@ -125,21 +120,16 @@ private[sql] object DataSourceStrategy extends Strategy with Logging { filters: Seq[Expression], partitionColumns: StructType, partitions: Array[Partition]) = { - val output = projections.map(_.toAttribute) val relation = logicalRelation.relation.asInstanceOf[HadoopFsRelation] + // Because we are creating one RDD per partition, we need to have a shared HadoopConf. + // Otherwise, the cost of broadcasting HadoopConf in every RDD will be high. + val sharedHadoopConf = SparkHadoopUtil.get.conf + val confBroadcast = + relation.sqlContext.sparkContext.broadcast(new SerializableConfiguration(sharedHadoopConf)) + // Builds RDD[Row]s for each selected partition. val perPartitionRows = partitions.map { case Partition(partitionValues, dir) => - // Paths to all data files within this partition - val dataFilePaths = { - val dirPath = new Path(dir) - val fs = dirPath.getFileSystem(SparkHadoopUtil.get.conf) - fs.listStatus(dirPath).map(_.getPath).filterNot { path => - val name = path.getName - name.startsWith("_") || name.startsWith(".") - }.map(fs.makeQualified(_).toString) - } - // The table scan operator (PhysicalRDD) which retrieves required columns from data files. // Notice that the schema of data files, represented by `relation.dataSchema`, may contain // some partition column(s). @@ -148,22 +138,23 @@ private[sql] object DataSourceStrategy extends Strategy with Logging { logicalRelation, projections, filters, - (requiredColumns, filters) => { + (columns: Seq[Attribute], filters) => { val partitionColNames = partitionColumns.fieldNames // Don't scan any partition columns to save I/O. Here we are being optimistic and // assuming partition columns data stored in data files are always consistent with those // partition values encoded in partition directory paths. - val nonPartitionColumns = requiredColumns.filterNot(partitionColNames.contains) - val dataRows = relation.buildScan(nonPartitionColumns, filters, dataFilePaths) + val needed = columns.filterNot(a => partitionColNames.contains(a.name)) + val dataRows = + relation.buildScan(needed.map(_.name).toArray, filters, Array(dir), confBroadcast) // Merges data values with partition values. mergeWithPartitionValues( relation.schema, - requiredColumns, + columns.map(_.name).toArray, partitionColNames, partitionValues, - dataRows) + toCatalystRDD(logicalRelation, needed, dataRows)) }) scan.execute() @@ -176,15 +167,15 @@ private[sql] object DataSourceStrategy extends Strategy with Logging { new UnionRDD(relation.sqlContext.sparkContext, perPartitionRows) } - createPhysicalRDD(logicalRelation.relation, output, unionedRows) + execution.PhysicalRDD(projections.map(_.toAttribute), unionedRows) } private def mergeWithPartitionValues( schema: StructType, requiredColumns: Array[String], partitionColumns: Array[String], - partitionValues: Row, - dataRows: RDD[Row]): RDD[Row] = { + partitionValues: InternalRow, + dataRows: RDD[InternalRow]): RDD[InternalRow] = { val nonPartitionColumns = requiredColumns.filterNot(partitionColumns.contains) // If output columns contain any partition column(s), we need to merge scanned data @@ -195,19 +186,22 @@ private[sql] object DataSourceStrategy extends Strategy with Logging { val i = partitionColumns.indexOf(name) if (i != -1) { // If yes, gets column value from partition values. - (mutableRow: MutableRow, dataRow: expressions.Row, ordinal: Int) => { + (mutableRow: MutableRow, dataRow: InternalRow, ordinal: Int) => { mutableRow(ordinal) = partitionValues(i) } } else { // Otherwise, inherits the value from scanned data. val i = nonPartitionColumns.indexOf(name) - (mutableRow: MutableRow, dataRow: expressions.Row, ordinal: Int) => { + (mutableRow: MutableRow, dataRow: InternalRow, ordinal: Int) => { mutableRow(ordinal) = dataRow(i) } } } - dataRows.mapPartitions { iterator => + // Since we know for sure that this closure is serializable, we can avoid the overhead + // of cleaning a closure for each RDD by creating our own MapPartitionsRDD. Functionally + // this is equivalent to calling `dataRows.mapPartitions(mapPartitionsFunc)` (SPARK-7718). + val mapPartitionsFunc = (_: TaskContext, _: Int, iterator: Iterator[InternalRow]) => { val dataTypes = requiredColumns.map(schema(_).dataType) val mutableRow = new SpecificMutableRow(dataTypes) iterator.map { dataRow => @@ -216,9 +210,17 @@ private[sql] object DataSourceStrategy extends Strategy with Logging { mergers(i)(mutableRow, dataRow, i) i += 1 } - mutableRow.asInstanceOf[expressions.Row] + mutableRow.asInstanceOf[InternalRow] } } + + // This is an internal RDD whose call site the user should not be concerned with + // Since we create many of these (one per partition), the time spent on computing + // the call site may add up. + Utils.withDummyCallSite(dataRows.sparkContext) { + new MapPartitionsRDD(dataRows, mapPartitionsFunc, preservesPartitioning = false) + } + } else { dataRows } @@ -254,26 +256,26 @@ private[sql] object DataSourceStrategy extends Strategy with Logging { // Based on Public API. protected def pruneFilterProject( relation: LogicalRelation, - projectList: Seq[NamedExpression], + projects: Seq[NamedExpression], filterPredicates: Seq[Expression], - scanBuilder: (Array[String], Array[Filter]) => RDD[Row]) = { + scanBuilder: (Seq[Attribute], Array[Filter]) => RDD[InternalRow]) = { pruneFilterProjectRaw( relation, - projectList, + projects, filterPredicates, (requestedColumns, pushedFilters) => { - scanBuilder(requestedColumns.map(_.name).toArray, selectFilters(pushedFilters).toArray) + scanBuilder(requestedColumns, selectFilters(pushedFilters).toArray) }) } // Based on Catalyst expressions. protected def pruneFilterProjectRaw( relation: LogicalRelation, - projectList: Seq[NamedExpression], + projects: Seq[NamedExpression], filterPredicates: Seq[Expression], - scanBuilder: (Seq[Attribute], Seq[Expression]) => RDD[Row]) = { + scanBuilder: (Seq[Attribute], Seq[Expression]) => RDD[InternalRow]) = { - val projectSet = AttributeSet(projectList.flatMap(_.references)) + val projectSet = AttributeSet(projects.flatMap(_.references)) val filterSet = AttributeSet(filterPredicates.flatMap(_.references)) val filterCondition = filterPredicates.reduceLeftOption(expressions.And) @@ -281,38 +283,47 @@ private[sql] object DataSourceStrategy extends Strategy with Logging { case a: AttributeReference => relation.attributeMap(a) // Match original case of attributes. }} - if (projectList.map(_.toAttribute) == projectList && - projectSet.size == projectList.size && + if (projects.map(_.toAttribute) == projects && + projectSet.size == projects.size && filterSet.subsetOf(projectSet)) { // When it is possible to just use column pruning to get the right projection and // when the columns of this projection are enough to evaluate all filter conditions, // just do a scan followed by a filter, with no extra project. val requestedColumns = - projectList.asInstanceOf[Seq[Attribute]] // Safe due to if above. + projects.asInstanceOf[Seq[Attribute]] // Safe due to if above. .map(relation.attributeMap) // Match original case of attributes. - val scan = createPhysicalRDD(relation.relation, projectList.map(_.toAttribute), - scanBuilder(requestedColumns, pushedFilters)) + val scan = execution.PhysicalRDD(projects.map(_.toAttribute), + scanBuilder(requestedColumns, pushedFilters)) filterCondition.map(execution.Filter(_, scan)).getOrElse(scan) } else { val requestedColumns = (projectSet ++ filterSet).map(relation.attributeMap).toSeq - val scan = createPhysicalRDD(relation.relation, requestedColumns, + val scan = execution.PhysicalRDD(requestedColumns, scanBuilder(requestedColumns, pushedFilters)) - execution.Project(projectList, filterCondition.map(execution.Filter(_, scan)).getOrElse(scan)) + execution.Project(projects, filterCondition.map(execution.Filter(_, scan)).getOrElse(scan)) } } - private[this] def createPhysicalRDD( - relation: BaseRelation, + /** + * Convert RDD of Row into RDD of InternalRow with objects in catalyst types + */ + private[this] def toCatalystRDD( + relation: LogicalRelation, output: Seq[Attribute], - rdd: RDD[Row]): SparkPlan = { - val converted = if (relation.needConversion) { - execution.RDDConversions.rowToRowRdd(rdd, relation.schema) + rdd: RDD[Row]): RDD[InternalRow] = { + if (relation.relation.needConversion) { + execution.RDDConversions.rowToRowRdd(rdd, output.map(_.dataType)) } else { - rdd + rdd.map(_.asInstanceOf[InternalRow]) } - execution.PhysicalRDD(output, converted) + } + + /** + * Convert RDD of Row into RDD of InternalRow with objects in catalyst types + */ + private[this] def toCatalystRDD(relation: LogicalRelation, rdd: RDD[Row]): RDD[InternalRow] = { + toCatalystRDD(relation, relation.output, rdd) } /** diff --git a/sql/core/src/main/scala/org/apache/spark/sql/sources/PartitioningUtils.scala b/sql/core/src/main/scala/org/apache/spark/sql/sources/PartitioningUtils.scala index d1f0cdab55f6..8b2a45d8e970 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/sources/PartitioningUtils.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/sources/PartitioningUtils.scala @@ -17,23 +17,26 @@ package org.apache.spark.sql.sources -import java.lang.{Double => JDouble, Float => JFloat, Long => JLong} +import java.lang.{Double => JDouble, Float => JFloat, Integer => JInteger, Long => JLong} import java.math.{BigDecimal => JBigDecimal} import scala.collection.mutable.ArrayBuffer import scala.util.Try -import com.google.common.cache.{CacheBuilder, Cache} -import org.apache.hadoop.fs.{FileStatus, Path} - -import org.apache.spark.sql.Row +import org.apache.hadoop.fs.Path +import org.apache.hadoop.util.Shell +import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions.{Cast, Literal} import org.apache.spark.sql.types._ -private[sql] case class Partition(values: Row, path: String) +private[sql] case class Partition(values: InternalRow, path: String) private[sql] case class PartitionSpec(partitionColumns: StructType, partitions: Seq[Partition]) +private[sql] object PartitionSpec { + val emptySpec = PartitionSpec(StructType(Seq.empty[StructField]), Seq.empty[Partition]) +} + private[sql] object PartitioningUtils { // This duplicates default value of Hive `ConfVars.DEFAULTPARTITIONNAME`, since sql/core doesn't // depend on Hive. @@ -68,21 +71,39 @@ private[sql] object PartitioningUtils { */ private[sql] def parsePartitions( paths: Seq[Path], - defaultPartitionName: String): PartitionSpec = { - val partitionValues = resolvePartitions(paths.map(parsePartition(_, defaultPartitionName))) - val fields = { - val (PartitionValues(columnNames, literals)) = partitionValues.head - columnNames.zip(literals).map { case (name, Literal(_, dataType)) => - StructField(name, dataType, nullable = true) - } + defaultPartitionName: String, + typeInference: Boolean): PartitionSpec = { + // First, we need to parse every partition's path and see if we can find partition values. + val pathsWithPartitionValues = paths.flatMap { path => + parsePartition(path, defaultPartitionName, typeInference).map(path -> _) } - val partitions = partitionValues.zip(paths).map { - case (PartitionValues(_, literals), path) => - Partition(Row(literals.map(_.value): _*), path.toString) - } + if (pathsWithPartitionValues.isEmpty) { + // This dataset is not partitioned. + PartitionSpec.emptySpec + } else { + // This dataset is partitioned. We need to check whether all partitions have the same + // partition columns and resolve potential type conflicts. + val resolvedPartitionValues = resolvePartitions(pathsWithPartitionValues) - PartitionSpec(StructType(fields), partitions) + // Creates the StructType which represents the partition columns. + val fields = { + val PartitionValues(columnNames, literals) = resolvedPartitionValues.head + columnNames.zip(literals).map { case (name, Literal(_, dataType)) => + // We always assume partition columns are nullable since we've no idea whether null values + // will be appended in the future. + StructField(name, dataType, nullable = true) + } + } + + // Finally, we create `Partition`s based on paths and resolved partition values. + val partitions = resolvedPartitionValues.zip(pathsWithPartitionValues).map { + case (PartitionValues(_, literals), (path, _)) => + Partition(InternalRow.fromSeq(literals.map(_.value)), path.toString) + } + + PartitionSpec(StructType(fields), partitions) + } } /** @@ -103,26 +124,38 @@ private[sql] object PartitioningUtils { */ private[sql] def parsePartition( path: Path, - defaultPartitionName: String): PartitionValues = { + defaultPartitionName: String, + typeInference: Boolean): Option[PartitionValues] = { val columns = ArrayBuffer.empty[(String, Literal)] // Old Hadoop versions don't have `Path.isRoot` var finished = path.getParent == null var chopped = path while (!finished) { - val maybeColumn = parsePartitionColumn(chopped.getName, defaultPartitionName) + // Sometimes (e.g., when speculative task is enabled), temporary directories may be left + // uncleaned. Here we simply ignore them. + if (chopped.getName.toLowerCase == "_temporary") { + return None + } + + val maybeColumn = parsePartitionColumn(chopped.getName, defaultPartitionName, typeInference) maybeColumn.foreach(columns += _) chopped = chopped.getParent finished = maybeColumn.isEmpty || chopped.getParent == null } - val (columnNames, values) = columns.reverse.unzip - PartitionValues(columnNames, values) + if (columns.isEmpty) { + None + } else { + val (columnNames, values) = columns.reverse.unzip + Some(PartitionValues(columnNames, values)) + } } private def parsePartitionColumn( columnSpec: String, - defaultPartitionName: String): Option[(String, Literal)] = { + defaultPartitionName: String, + typeInference: Boolean): Option[(String, Literal)] = { val equalSignIndex = columnSpec.indexOf('=') if (equalSignIndex == -1) { None @@ -133,7 +166,7 @@ private[sql] object PartitioningUtils { val rawColumnValue = columnSpec.drop(equalSignIndex + 1) assert(rawColumnValue.nonEmpty, s"Empty partition column value in '$columnSpec'") - val literal = inferPartitionColumnValue(rawColumnValue, defaultPartitionName) + val literal = inferPartitionColumnValue(rawColumnValue, defaultPartitionName, typeInference) Some(columnName -> literal) } } @@ -144,50 +177,93 @@ private[sql] object PartitioningUtils { * {{{ * NullType -> * IntegerType -> LongType -> - * FloatType -> DoubleType -> DecimalType.Unlimited -> + * DoubleType -> DecimalType.Unlimited -> * StringType * }}} */ - private[sql] def resolvePartitions(values: Seq[PartitionValues]): Seq[PartitionValues] = { - // Column names of all partitions must match - val distinctPartitionsColNames = values.map(_.columnNames).distinct - assert(distinctPartitionsColNames.size == 1, { - val list = distinctPartitionsColNames.mkString("\t", "\n", "") - s"Conflicting partition column names detected:\n$list" - }) + private[sql] def resolvePartitions( + pathsWithPartitionValues: Seq[(Path, PartitionValues)]): Seq[PartitionValues] = { + if (pathsWithPartitionValues.isEmpty) { + Seq.empty + } else { + val distinctPartColNames = pathsWithPartitionValues.map(_._2.columnNames).distinct + assert( + distinctPartColNames.size == 1, + listConflictingPartitionColumns(pathsWithPartitionValues)) + + // Resolves possible type conflicts for each column + val values = pathsWithPartitionValues.map(_._2) + val columnCount = values.head.columnNames.size + val resolvedValues = (0 until columnCount).map { i => + resolveTypeConflicts(values.map(_.literals(i))) + } - // Resolves possible type conflicts for each column - val columnCount = values.head.columnNames.size - val resolvedValues = (0 until columnCount).map { i => - resolveTypeConflicts(values.map(_.literals(i))) + // Fills resolved literals back to each partition + values.zipWithIndex.map { case (d, index) => + d.copy(literals = resolvedValues.map(_(index))) + } } + } + + private[sql] def listConflictingPartitionColumns( + pathWithPartitionValues: Seq[(Path, PartitionValues)]): String = { + val distinctPartColNames = pathWithPartitionValues.map(_._2.columnNames).distinct + + def groupByKey[K, V](seq: Seq[(K, V)]): Map[K, Iterable[V]] = + seq.groupBy { case (key, _) => key }.mapValues(_.map { case (_, value) => value }) + + val partColNamesToPaths = groupByKey(pathWithPartitionValues.map { + case (path, partValues) => partValues.columnNames -> path + }) - // Fills resolved literals back to each partition - values.zipWithIndex.map { case (d, index) => - d.copy(literals = resolvedValues.map(_(index))) + val distinctPartColLists = distinctPartColNames.map(_.mkString(", ")).zipWithIndex.map { + case (names, index) => + s"Partition column name list #$index: $names" } + + // Lists out those non-leaf partition directories that also contain files + val suspiciousPaths = distinctPartColNames.sortBy(_.length).flatMap(partColNamesToPaths) + + s"Conflicting partition column names detected:\n" + + distinctPartColLists.mkString("\n\t", "\n\t", "\n\n") + + "For partitioned table directories, data files should only live in leaf directories.\n" + + "And directories at the same level should have the same partition column name.\n" + + "Please check the following directories for unexpected files or " + + "inconsistent partition column names:\n" + + suspiciousPaths.map("\t" + _).mkString("\n", "\n", "") } /** - * Converts a string to a `Literal` with automatic type inference. Currently only supports - * [[IntegerType]], [[LongType]], [[FloatType]], [[DoubleType]], [[DecimalType.Unlimited]], and + * Converts a string to a [[Literal]] with automatic type inference. Currently only supports + * [[IntegerType]], [[LongType]], [[DoubleType]], [[DecimalType.Unlimited]], and * [[StringType]]. */ private[sql] def inferPartitionColumnValue( raw: String, - defaultPartitionName: String): Literal = { - // First tries integral types - Try(Literal.create(Integer.parseInt(raw), IntegerType)) - .orElse(Try(Literal.create(JLong.parseLong(raw), LongType))) - // Then falls back to fractional types - .orElse(Try(Literal.create(JFloat.parseFloat(raw), FloatType))) - .orElse(Try(Literal.create(JDouble.parseDouble(raw), DoubleType))) - .orElse(Try(Literal.create(new JBigDecimal(raw), DecimalType.Unlimited))) - // Then falls back to string - .getOrElse { - if (raw == defaultPartitionName) Literal.create(null, NullType) - else Literal.create(raw, StringType) + defaultPartitionName: String, + typeInference: Boolean): Literal = { + if (typeInference) { + // First tries integral types + Try(Literal.create(Integer.parseInt(raw), IntegerType)) + .orElse(Try(Literal.create(JLong.parseLong(raw), LongType))) + // Then falls back to fractional types + .orElse(Try(Literal.create(JDouble.parseDouble(raw), DoubleType))) + .orElse(Try(Literal.create(new JBigDecimal(raw), DecimalType.Unlimited))) + // Then falls back to string + .getOrElse { + if (raw == defaultPartitionName) { + Literal.create(null, NullType) + } else { + Literal.create(unescapePathName(raw), StringType) + } + } + } else { + if (raw == defaultPartitionName) { + Literal.create(null, NullType) + } else { + Literal.create(unescapePathName(raw), StringType) } + } } private val upCastingOrder: Seq[DataType] = @@ -208,4 +284,77 @@ private[sql] object PartitioningUtils { Literal.create(Cast(l, desiredType).eval(), desiredType) } } + + ////////////////////////////////////////////////////////////////////////////////////////////////// + // The following string escaping code is mainly copied from Hive (o.a.h.h.common.FileUtils). + ////////////////////////////////////////////////////////////////////////////////////////////////// + + val charToEscape = { + val bitSet = new java.util.BitSet(128) + + /** + * ASCII 01-1F are HTTP control characters that need to be escaped. + * \u000A and \u000D are \n and \r, respectively. + */ + val clist = Array( + '\u0001', '\u0002', '\u0003', '\u0004', '\u0005', '\u0006', '\u0007', '\u0008', '\u0009', + '\n', '\u000B', '\u000C', '\r', '\u000E', '\u000F', '\u0010', '\u0011', '\u0012', '\u0013', + '\u0014', '\u0015', '\u0016', '\u0017', '\u0018', '\u0019', '\u001A', '\u001B', '\u001C', + '\u001D', '\u001E', '\u001F', '"', '#', '%', '\'', '*', '/', ':', '=', '?', '\\', '\u007F', + '{', '[', ']', '^') + + clist.foreach(bitSet.set(_)) + + if (Shell.WINDOWS) { + Array(' ', '<', '>', '|').foreach(bitSet.set(_)) + } + + bitSet + } + + def needsEscaping(c: Char): Boolean = { + c >= 0 && c < charToEscape.size() && charToEscape.get(c) + } + + def escapePathName(path: String): String = { + val builder = new StringBuilder() + path.foreach { c => + if (needsEscaping(c)) { + builder.append('%') + builder.append(f"${c.asInstanceOf[Int]}%02x") + } else { + builder.append(c) + } + } + + builder.toString() + } + + def unescapePathName(path: String): String = { + val sb = new StringBuilder + var i = 0 + + while (i < path.length) { + val c = path.charAt(i) + if (c == '%' && i + 2 < path.length) { + val code: Int = try { + Integer.valueOf(path.substring(i + 1, i + 3), 16) + } catch { case e: Exception => + -1: Integer + } + if (code >= 0) { + sb.append(code.asInstanceOf[Char]) + i += 3 + } else { + sb.append(c) + i += 1 + } + } else { + sb.append(c) + i += 1 + } + } + + sb.toString() + } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/sources/SqlNewHadoopRDD.scala b/sql/core/src/main/scala/org/apache/spark/sql/sources/SqlNewHadoopRDD.scala new file mode 100644 index 000000000000..2bdc34102125 --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/sources/SqlNewHadoopRDD.scala @@ -0,0 +1,264 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.sources + +import java.text.SimpleDateFormat +import java.util.Date + +import org.apache.hadoop.conf.{Configurable, Configuration} +import org.apache.hadoop.io.Writable +import org.apache.hadoop.mapreduce._ +import org.apache.hadoop.mapreduce.lib.input.{CombineFileSplit, FileSplit} +import org.apache.spark.broadcast.Broadcast + +import org.apache.spark.{Partition => SparkPartition, _} +import org.apache.spark.annotation.DeveloperApi +import org.apache.spark.deploy.SparkHadoopUtil +import org.apache.spark.executor.DataReadMethod +import org.apache.spark.mapreduce.SparkHadoopMapReduceUtil +import org.apache.spark.rdd.{RDD, HadoopRDD} +import org.apache.spark.rdd.NewHadoopRDD.NewHadoopMapPartitionsWithSplitRDD +import org.apache.spark.storage.StorageLevel +import org.apache.spark.util.{SerializableConfiguration, Utils} + +import scala.reflect.ClassTag + +private[spark] class SqlNewHadoopPartition( + rddId: Int, + val index: Int, + @transient rawSplit: InputSplit with Writable) + extends SparkPartition { + + val serializableHadoopSplit = new SerializableWritable(rawSplit) + + override def hashCode(): Int = 41 * (41 + rddId) + index +} + +/** + * An RDD that provides core functionality for reading data stored in Hadoop (e.g., files in HDFS, + * sources in HBase, or S3), using the new MapReduce API (`org.apache.hadoop.mapreduce`). + * It is based on [[org.apache.spark.rdd.NewHadoopRDD]]. It has three additions. + * 1. A shared broadcast Hadoop Configuration. + * 2. An optional closure `initDriverSideJobFuncOpt` that set configurations at the driver side + * to the shared Hadoop Configuration. + * 3. An optional closure `initLocalJobFuncOpt` that set configurations at both the driver side + * and the executor side to the shared Hadoop Configuration. + * + * Note: This is RDD is basically a cloned version of [[org.apache.spark.rdd.NewHadoopRDD]] with + * changes based on [[org.apache.spark.rdd.HadoopRDD]]. In future, this functionality will be + * folded into core. + */ +private[sql] class SqlNewHadoopRDD[K, V]( + @transient sc : SparkContext, + broadcastedConf: Broadcast[SerializableConfiguration], + @transient initDriverSideJobFuncOpt: Option[Job => Unit], + initLocalJobFuncOpt: Option[Job => Unit], + inputFormatClass: Class[_ <: InputFormat[K, V]], + keyClass: Class[K], + valueClass: Class[V]) + extends RDD[(K, V)](sc, Nil) + with SparkHadoopMapReduceUtil + with Logging { + + protected def getJob(): Job = { + val conf: Configuration = broadcastedConf.value.value + // "new Job" will make a copy of the conf. Then, it is + // safe to mutate conf properties with initLocalJobFuncOpt + // and initDriverSideJobFuncOpt. + val newJob = new Job(conf) + initLocalJobFuncOpt.map(f => f(newJob)) + newJob + } + + def getConf(isDriverSide: Boolean): Configuration = { + val job = getJob() + if (isDriverSide) { + initDriverSideJobFuncOpt.map(f => f(job)) + } + job.getConfiguration + } + + private val jobTrackerId: String = { + val formatter = new SimpleDateFormat("yyyyMMddHHmm") + formatter.format(new Date()) + } + + @transient protected val jobId = new JobID(jobTrackerId, id) + + override def getPartitions: Array[SparkPartition] = { + val conf = getConf(isDriverSide = true) + val inputFormat = inputFormatClass.newInstance + inputFormat match { + case configurable: Configurable => + configurable.setConf(conf) + case _ => + } + val jobContext = newJobContext(conf, jobId) + val rawSplits = inputFormat.getSplits(jobContext).toArray + val result = new Array[SparkPartition](rawSplits.size) + for (i <- 0 until rawSplits.size) { + result(i) = + new SqlNewHadoopPartition(id, i, rawSplits(i).asInstanceOf[InputSplit with Writable]) + } + result + } + + override def compute( + theSplit: SparkPartition, + context: TaskContext): InterruptibleIterator[(K, V)] = { + val iter = new Iterator[(K, V)] { + val split = theSplit.asInstanceOf[SqlNewHadoopPartition] + logInfo("Input split: " + split.serializableHadoopSplit) + val conf = getConf(isDriverSide = false) + + val inputMetrics = context.taskMetrics + .getInputMetricsForReadMethod(DataReadMethod.Hadoop) + + // Find a function that will return the FileSystem bytes read by this thread. Do this before + // creating RecordReader, because RecordReader's constructor might read some bytes + val bytesReadCallback = inputMetrics.bytesReadCallback.orElse { + split.serializableHadoopSplit.value match { + case _: FileSplit | _: CombineFileSplit => + SparkHadoopUtil.get.getFSBytesReadOnThreadCallback() + case _ => None + } + } + inputMetrics.setBytesReadCallback(bytesReadCallback) + + val attemptId = newTaskAttemptID(jobTrackerId, id, isMap = true, split.index, 0) + val hadoopAttemptContext = newTaskAttemptContext(conf, attemptId) + val format = inputFormatClass.newInstance + format match { + case configurable: Configurable => + configurable.setConf(conf) + case _ => + } + val reader = format.createRecordReader( + split.serializableHadoopSplit.value, hadoopAttemptContext) + reader.initialize(split.serializableHadoopSplit.value, hadoopAttemptContext) + + // Register an on-task-completion callback to close the input stream. + context.addTaskCompletionListener(context => close()) + var havePair = false + var finished = false + var recordsSinceMetricsUpdate = 0 + + override def hasNext: Boolean = { + if (!finished && !havePair) { + finished = !reader.nextKeyValue + havePair = !finished + } + !finished + } + + override def next(): (K, V) = { + if (!hasNext) { + throw new java.util.NoSuchElementException("End of stream") + } + havePair = false + if (!finished) { + inputMetrics.incRecordsRead(1) + } + (reader.getCurrentKey, reader.getCurrentValue) + } + + private def close() { + try { + reader.close() + if (bytesReadCallback.isDefined) { + inputMetrics.updateBytesRead() + } else if (split.serializableHadoopSplit.value.isInstanceOf[FileSplit] || + split.serializableHadoopSplit.value.isInstanceOf[CombineFileSplit]) { + // If we can't get the bytes read from the FS stats, fall back to the split size, + // which may be inaccurate. + try { + inputMetrics.incBytesRead(split.serializableHadoopSplit.value.getLength) + } catch { + case e: java.io.IOException => + logWarning("Unable to get input size to set InputMetrics for task", e) + } + } + } catch { + case e: Exception => { + if (!Utils.inShutdown()) { + logWarning("Exception in RecordReader.close()", e) + } + } + } + } + } + new InterruptibleIterator(context, iter) + } + + /** Maps over a partition, providing the InputSplit that was used as the base of the partition. */ + @DeveloperApi + def mapPartitionsWithInputSplit[U: ClassTag]( + f: (InputSplit, Iterator[(K, V)]) => Iterator[U], + preservesPartitioning: Boolean = false): RDD[U] = { + new NewHadoopMapPartitionsWithSplitRDD(this, f, preservesPartitioning) + } + + override def getPreferredLocations(hsplit: SparkPartition): Seq[String] = { + val split = hsplit.asInstanceOf[SqlNewHadoopPartition].serializableHadoopSplit.value + val locs = HadoopRDD.SPLIT_INFO_REFLECTIONS match { + case Some(c) => + try { + val infos = c.newGetLocationInfo.invoke(split).asInstanceOf[Array[AnyRef]] + Some(HadoopRDD.convertSplitLocationInfo(infos)) + } catch { + case e : Exception => + logDebug("Failed to use InputSplit#getLocationInfo.", e) + None + } + case None => None + } + locs.getOrElse(split.getLocations.filter(_ != "localhost")) + } + + override def persist(storageLevel: StorageLevel): this.type = { + if (storageLevel.deserialized) { + logWarning("Caching NewHadoopRDDs as deserialized objects usually leads to undesired" + + " behavior because Hadoop's RecordReader reuses the same Writable object for all records." + + " Use a map transformation to make copies of the records.") + } + super.persist(storageLevel) + } +} + +private[spark] object SqlNewHadoopRDD { + /** + * Analogous to [[org.apache.spark.rdd.MapPartitionsRDD]], but passes in an InputSplit to + * the given function rather than the index of the partition. + */ + private[spark] class NewHadoopMapPartitionsWithSplitRDD[U: ClassTag, T: ClassTag]( + prev: RDD[T], + f: (InputSplit, Iterator[T]) => Iterator[U], + preservesPartitioning: Boolean = false) + extends RDD[U](prev) { + + override val partitioner = if (preservesPartitioning) firstParent[T].partitioner else None + + override def getPartitions: Array[SparkPartition] = firstParent[T].partitions + + override def compute(split: SparkPartition, context: TaskContext): Iterator[U] = { + val partition = split.asInstanceOf[SqlNewHadoopPartition] + val inputSplit = partition.serializableHadoopSplit.value + f(inputSplit, firstParent[T].iterator(split, context)) + } + } +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/sources/commands.scala b/sql/core/src/main/scala/org/apache/spark/sql/sources/commands.scala index a09bb08de736..7214eb0b4169 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/sources/commands.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/sources/commands.scala @@ -17,25 +17,26 @@ package org.apache.spark.sql.sources -import java.util.Date +import java.util.{Date, UUID} import scala.collection.mutable import org.apache.hadoop.fs.Path import org.apache.hadoop.mapreduce._ -import org.apache.hadoop.mapreduce.lib.output.{FileOutputCommitter, FileOutputFormat} -import org.apache.hadoop.util.Shell -import parquet.hadoop.util.ContextUtil +import org.apache.hadoop.mapreduce.lib.output.{FileOutputCommitter => MapReduceFileOutputCommitter, FileOutputFormat} import org.apache.spark._ import org.apache.spark.mapred.SparkHadoopMapRedUtil import org.apache.spark.mapreduce.SparkHadoopMapReduceUtil import org.apache.spark.sql.catalyst.CatalystTypeConverters +import org.apache.spark.sql.catalyst.analysis.UnresolvedAttribute import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.codegen.GenerateProjection -import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan +import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, Project} import org.apache.spark.sql.execution.RunnableCommand -import org.apache.spark.sql.{DataFrame, SQLContext, SaveMode} +import org.apache.spark.sql.types.StructType +import org.apache.spark.sql.{DataFrame, Row, SQLConf, SQLContext, SaveMode} +import org.apache.spark.util.SerializableConfiguration private[sql] case class InsertIntoDataSource( logicalRelation: LogicalRelation, @@ -47,8 +48,7 @@ private[sql] case class InsertIntoDataSource( val relation = logicalRelation.relation.asInstanceOf[InsertableRelation] val data = DataFrame(sqlContext, query) // Apply the schema of the existing table to the new data. - val df = sqlContext.createDataFrame( - data.queryExecution.toRdd, logicalRelation.schema, needsConversion = false) + val df = sqlContext.internalCreateDataFrame(data.queryExecution.toRdd, logicalRelation.schema) relation.insert(df, overwrite) // Invalidate the cache. @@ -58,10 +58,31 @@ private[sql] case class InsertIntoDataSource( } } +/** + * A command for writing data to a [[HadoopFsRelation]]. Supports both overwriting and appending. + * Writing to dynamic partitions is also supported. Each [[InsertIntoHadoopFsRelation]] issues a + * single write job, and owns a UUID that identifies this job. Each concrete implementation of + * [[HadoopFsRelation]] should use this UUID together with task id to generate unique file path for + * each task output file. This UUID is passed to executor side via a property named + * `spark.sql.sources.writeJobUUID`. + * + * Different writer containers, [[DefaultWriterContainer]] and [[DynamicPartitionWriterContainer]] + * are used to write to normal tables and tables with dynamic partitions. + * + * Basic work flow of this command is: + * + * 1. Driver side setup, including output committer initialization and data source specific + * preparation work for the write job to be issued. + * 2. Issues a write job consists of one or more executor side tasks, each of which writes all + * rows within an RDD partition. + * 3. If no exception is thrown in a task, commits that task, otherwise aborts that task; If any + * exception is thrown during task commitment, also aborts that task. + * 4. If all tasks are committed, commit the job, otherwise aborts the job; If any exception is + * thrown during job commitment, also aborts the job. + */ private[sql] case class InsertIntoHadoopFsRelation( @transient relation: HadoopFsRelation, @transient query: LogicalPlan, - partitionColumns: Array[String], mode: SaveMode) extends RunnableCommand { @@ -75,7 +96,8 @@ private[sql] case class InsertIntoHadoopFsRelation( val fs = outputPath.getFileSystem(hadoopConf) val qualifiedOutputPath = outputPath.makeQualified(fs.getUri, fs.getWorkingDirectory) - val doInsertion = (mode, fs.exists(qualifiedOutputPath)) match { + val pathExists = fs.exists(qualifiedOutputPath) + val doInsertion = (mode, pathExists) match { case (SaveMode.ErrorIfExists, true) => sys.error(s"path $qualifiedOutputPath already exists.") case (SaveMode.Overwrite, true) => @@ -86,28 +108,40 @@ private[sql] case class InsertIntoHadoopFsRelation( case (SaveMode.Ignore, exists) => !exists } + // If we are appending data to an existing dir. + val isAppend = (pathExists) && (mode == SaveMode.Append) if (doInsertion) { val job = new Job(hadoopConf) job.setOutputKeyClass(classOf[Void]) - job.setOutputValueClass(classOf[Row]) + job.setOutputValueClass(classOf[InternalRow]) FileOutputFormat.setOutputPath(job, qualifiedOutputPath) - val df = sqlContext.createDataFrame( - DataFrame(sqlContext, query).queryExecution.toRdd, - relation.schema, - needsConversion = false) + // We create a DataFrame by applying the schema of relation to the data to make sure. + // We are writing data based on the expected schema, + val df = { + // For partitioned relation r, r.schema's column ordering can be different from the column + // ordering of data.logicalPlan (partition columns are all moved after data column). We + // need a Project to adjust the ordering, so that inside InsertIntoHadoopFsRelation, we can + // safely apply the schema of r.schema to the data. + val project = Project( + relation.schema.map(field => new UnresolvedAttribute(Seq(field.name))), query) + + sqlContext.internalCreateDataFrame( + DataFrame(sqlContext, project).queryExecution.toRdd, relation.schema) + } + val partitionColumns = relation.partitionColumns.fieldNames if (partitionColumns.isEmpty) { - insert(new DefaultWriterContainer(relation, job), df) + insert(new DefaultWriterContainer(relation, job, isAppend), df) } else { val writerContainer = new DynamicPartitionWriterContainer( - relation, job, partitionColumns, PartitioningUtils.DEFAULT_PARTITION_NAME) + relation, job, partitionColumns, PartitioningUtils.DEFAULT_PARTITION_NAME, isAppend) insertWithDynamicPartitions(sqlContext, writerContainer, df, partitionColumns) } } - Seq.empty[Row] + Seq.empty[InternalRow] } private def insert(writerContainer: BaseWriterContainer, df: DataFrame): Unit = { @@ -115,9 +149,12 @@ private[sql] case class InsertIntoHadoopFsRelation( val needsConversion = relation.needConversion val dataSchema = relation.dataSchema + // This call shouldn't be put into the `try` block below because it only initializes and + // prepares the job, any exception thrown from here shouldn't cause abortJob() to be called. + writerContainer.driverSideSetup() + try { - writerContainer.driverSideSetup() - df.sqlContext.sparkContext.runJob(df.queryExecution.executedPlan.execute(), writeRows _) + df.sqlContext.sparkContext.runJob(df.queryExecution.toRdd, writeRows _) writerContainer.commitJob() relation.refresh() } catch { case cause: Throwable => @@ -126,22 +163,21 @@ private[sql] case class InsertIntoHadoopFsRelation( throw new SparkException("Job aborted.", cause) } - def writeRows(taskContext: TaskContext, iterator: Iterator[Row]): Unit = { - writerContainer.executorSideSetup(taskContext) - + def writeRows(taskContext: TaskContext, iterator: Iterator[InternalRow]): Unit = { + // If anything below fails, we should abort the task. try { - if (needsConversion) { - val converter = CatalystTypeConverters.createToScalaConverter(dataSchema) - while (iterator.hasNext) { - val row = converter(iterator.next()).asInstanceOf[Row] - writerContainer.outputWriterForRow(row).write(row) - } + writerContainer.executorSideSetup(taskContext) + + val converter = if (needsConversion) { + CatalystTypeConverters.createToScalaConverter(dataSchema).asInstanceOf[InternalRow => Row] } else { - while (iterator.hasNext) { - val row = iterator.next() - writerContainer.outputWriterForRow(row).write(row) - } + r: InternalRow => r.asInstanceOf[Row] + } + while (iterator.hasNext) { + val row = converter(iterator.next()) + writerContainer.outputWriterForRow(row).write(row) } + writerContainer.commitTask() } catch { case cause: Throwable => logError("Aborting task.", cause) @@ -179,9 +215,12 @@ private[sql] case class InsertIntoHadoopFsRelation( val (partitionOutput, dataOutput) = output.partition(a => partitionColumns.contains(a.name)) val codegenEnabled = df.sqlContext.conf.codegenEnabled + // This call shouldn't be put into the `try` block below because it only initializes and + // prepares the job, any exception thrown from here shouldn't cause abortJob() to be called. + writerContainer.driverSideSetup() + try { - writerContainer.driverSideSetup() - df.sqlContext.sparkContext.runJob(df.queryExecution.executedPlan.execute(), writeRows _) + df.sqlContext.sparkContext.runJob(df.queryExecution.toRdd, writeRows _) writerContainer.commitJob() relation.refresh() } catch { case cause: Throwable => @@ -190,31 +229,37 @@ private[sql] case class InsertIntoHadoopFsRelation( throw new SparkException("Job aborted.", cause) } - def writeRows(taskContext: TaskContext, iterator: Iterator[Row]): Unit = { - writerContainer.executorSideSetup(taskContext) + def writeRows(taskContext: TaskContext, iterator: Iterator[InternalRow]): Unit = { + // If anything below fails, we should abort the task. + try { + writerContainer.executorSideSetup(taskContext) - val partitionProj = newProjection(codegenEnabled, partitionOutput, output) - val dataProj = newProjection(codegenEnabled, dataOutput, output) + val partitionProj = newProjection(codegenEnabled, partitionOutput, output) + val dataProj = newProjection(codegenEnabled, dataOutput, output) - if (needsConversion) { - val converter = CatalystTypeConverters.createToScalaConverter(dataSchema) - while (iterator.hasNext) { - val row = iterator.next() - val partitionPart = partitionProj(row) - val dataPart = dataProj(row) - val convertedDataPart = converter(dataPart).asInstanceOf[Row] - writerContainer.outputWriterForRow(partitionPart).write(convertedDataPart) + val dataConverter: InternalRow => Row = if (needsConversion) { + CatalystTypeConverters.createToScalaConverter(dataSchema).asInstanceOf[InternalRow => Row] + } else { + r: InternalRow => r.asInstanceOf[Row] } - } else { + val partitionSchema = StructType.fromAttributes(partitionOutput) + val partConverter: InternalRow => Row = + CatalystTypeConverters.createToScalaConverter(partitionSchema) + .asInstanceOf[InternalRow => Row] + while (iterator.hasNext) { val row = iterator.next() - val partitionPart = partitionProj(row) - val dataPart = dataProj(row) + val partitionPart = partConverter(partitionProj(row)) + val dataPart = dataConverter(dataProj(row)) writerContainer.outputWriterForRow(partitionPart).write(dataPart) } - } - writerContainer.commitTask() + writerContainer.commitTask() + } catch { case cause: Throwable => + logError("Aborting task.", cause) + writerContainer.abortTask() + throw new SparkException("Task failed while writing rows.", cause) + } } } @@ -235,12 +280,20 @@ private[sql] case class InsertIntoHadoopFsRelation( private[sql] abstract class BaseWriterContainer( @transient val relation: HadoopFsRelation, - @transient job: Job) + @transient job: Job, + isAppend: Boolean) extends SparkHadoopMapReduceUtil with Logging with Serializable { - protected val serializableConf = new SerializableWritable(ContextUtil.getConfiguration(job)) + protected val serializableConf = new SerializableConfiguration(job.getConfiguration) + + // This UUID is used to avoid output file name collision between different appending write jobs. + // These jobs may belong to different SparkContext instances. Concrete data source implementations + // may use this UUID to generate unique file names (e.g., `part-r--.parquet`). + // The reason why this ID is used to identify a job rather than a single task output file is + // that, speculative tasks must generate the same output file name as the original task. + private val uniqueWriteJobId = UUID.randomUUID() // This is only used on driver side. @transient private val jobContext: JobContext = job @@ -268,8 +321,22 @@ private[sql] abstract class BaseWriterContainer( def driverSideSetup(): Unit = { setupIDs(0, 0, 0) setupConf() - taskAttemptContext = newTaskAttemptContext(serializableConf.value, taskAttemptId) + + // This UUID is sent to executor side together with the serialized `Configuration` object within + // the `Job` instance. `OutputWriters` on the executor side should use this UUID to generate + // unique task output files. + job.getConfiguration.set("spark.sql.sources.writeJobUUID", uniqueWriteJobId.toString) + + // Order of the following two lines is important. For Hadoop 1, TaskAttemptContext constructor + // clones the Configuration object passed in. If we initialize the TaskAttemptContext first, + // configurations made in prepareJobForWrite(job) are not populated into the TaskAttemptContext. + // + // Also, the `prepareJobForWrite` call must happen before initializing output format and output + // committer, since their initialization involve the job configuration, which can be potentially + // decorated in `prepareJobForWrite`. outputWriterFactory = relation.prepareJobForWrite(job) + taskAttemptContext = newTaskAttemptContext(serializableConf.value, taskAttemptId) + outputFormatClass = job.getOutputFormatClass outputCommitter = newOutputCommitter(taskAttemptContext) outputCommitter.setupJob(jobContext) @@ -287,24 +354,56 @@ private[sql] abstract class BaseWriterContainer( protected def getWorkPath: String = { outputCommitter match { // FileOutputCommitter writes to a temporary location returned by `getWorkPath`. - case f: FileOutputCommitter => f.getWorkPath.toString + case f: MapReduceFileOutputCommitter => f.getWorkPath.toString case _ => outputPath } } private def newOutputCommitter(context: TaskAttemptContext): OutputCommitter = { - val committerClass = context.getConfiguration.getClass( - "mapred.output.committer.class", null, classOf[OutputCommitter]) - - Option(committerClass).map { clazz => - val ctor = clazz.getDeclaredConstructor(classOf[Path], classOf[TaskAttemptContext]) - ctor.newInstance(new Path(outputPath), context) - }.getOrElse { - outputFormatClass.newInstance().getOutputCommitter(context) + val defaultOutputCommitter = outputFormatClass.newInstance().getOutputCommitter(context) + + if (isAppend) { + // If we are appending data to an existing dir, we will only use the output committer + // associated with the file output format since it is not safe to use a custom + // committer for appending. For example, in S3, direct parquet output committer may + // leave partial data in the destination dir when the the appending job fails. + logInfo( + s"Using output committer class ${defaultOutputCommitter.getClass.getCanonicalName} " + + "for appending.") + defaultOutputCommitter + } else { + val committerClass = context.getConfiguration.getClass( + SQLConf.OUTPUT_COMMITTER_CLASS.key, null, classOf[OutputCommitter]) + + Option(committerClass).map { clazz => + logInfo(s"Using user defined output committer class ${clazz.getCanonicalName}") + + // Every output format based on org.apache.hadoop.mapreduce.lib.output.OutputFormat + // has an associated output committer. To override this output committer, + // we will first try to use the output committer set in SQLConf.OUTPUT_COMMITTER_CLASS. + // If a data source needs to override the output committer, it needs to set the + // output committer in prepareForWrite method. + if (classOf[MapReduceFileOutputCommitter].isAssignableFrom(clazz)) { + // The specified output committer is a FileOutputCommitter. + // So, we will use the FileOutputCommitter-specified constructor. + val ctor = clazz.getDeclaredConstructor(classOf[Path], classOf[TaskAttemptContext]) + ctor.newInstance(new Path(outputPath), context) + } else { + // The specified output committer is just a OutputCommitter. + // So, we will use the no-argument constructor. + val ctor = clazz.getDeclaredConstructor() + ctor.newInstance() + } + }.getOrElse { + // If output committer class is not set, we will use the one associated with the + // file output format. + logInfo( + s"Using output committer class ${defaultOutputCommitter.getClass.getCanonicalName}") + defaultOutputCommitter + } } } - private def setupIDs(jobId: Int, splitId: Int, attemptId: Int): Unit = { this.jobId = SparkHadoopWriter.createJobID(new Date, jobId) this.taskId = new TaskID(this.jobId, true, splitId) @@ -330,7 +429,9 @@ private[sql] abstract class BaseWriterContainer( } def abortTask(): Unit = { - outputCommitter.abortTask(taskAttemptContext) + if (outputCommitter != null) { + outputCommitter.abortTask(taskAttemptContext) + } logError(s"Task attempt $taskAttemptId aborted.") } @@ -340,15 +441,18 @@ private[sql] abstract class BaseWriterContainer( } def abortJob(): Unit = { - outputCommitter.abortJob(jobContext, JobStatus.State.FAILED) + if (outputCommitter != null) { + outputCommitter.abortJob(jobContext, JobStatus.State.FAILED) + } logError(s"Job $jobId aborted.") } } private[sql] class DefaultWriterContainer( @transient relation: HadoopFsRelation, - @transient job: Job) - extends BaseWriterContainer(relation, job) { + @transient job: Job, + isAppend: Boolean) + extends BaseWriterContainer(relation, job, isAppend) { @transient private var writer: OutputWriter = _ @@ -360,13 +464,26 @@ private[sql] class DefaultWriterContainer( override def outputWriterForRow(row: Row): OutputWriter = writer override def commitTask(): Unit = { - writer.close() - super.commitTask() + try { + assert(writer != null, "OutputWriter instance should have been initialized") + writer.close() + super.commitTask() + } catch { case cause: Throwable => + // This exception will be handled in `InsertIntoHadoopFsRelation.insert$writeRows`, and will + // cause `abortTask()` to be invoked. + throw new RuntimeException("Failed to commit task", cause) + } } override def abortTask(): Unit = { - writer.close() - super.abortTask() + try { + // It's possible that the task fails before `writer` gets initialized + if (writer != null) { + writer.close() + } + } finally { + super.abortTask() + } } } @@ -374,8 +491,9 @@ private[sql] class DynamicPartitionWriterContainer( @transient relation: HadoopFsRelation, @transient job: Job, partitionColumns: Array[String], - defaultPartitionName: String) - extends BaseWriterContainer(relation, job) { + defaultPartitionName: String, + isAppend: Boolean) + extends BaseWriterContainer(relation, job, isAppend) { // All output writers are created on executor side. @transient protected var outputWriters: mutable.Map[String, OutputWriter] = _ @@ -390,7 +508,7 @@ private[sql] class DynamicPartitionWriterContainer( val valueString = if (string == null || string.isEmpty) { defaultPartitionName } else { - DynamicPartitionWriterContainer.escapePathName(string) + PartitioningUtils.escapePathName(string) } s"/$col=$valueString" }.mkString.stripPrefix(Path.SEPARATOR) @@ -404,60 +522,27 @@ private[sql] class DynamicPartitionWriterContainer( }) } - override def commitTask(): Unit = { - outputWriters.values.foreach(_.close()) - super.commitTask() - } - - override def abortTask(): Unit = { - outputWriters.values.foreach(_.close()) - super.abortTask() - } -} - -private[sql] object DynamicPartitionWriterContainer { - ////////////////////////////////////////////////////////////////////////////////////////////////// - // The following string escaping code is mainly copied from Hive (o.a.h.h.common.FileUtils). - ////////////////////////////////////////////////////////////////////////////////////////////////// - - val charToEscape = { - val bitSet = new java.util.BitSet(128) - - /** - * ASCII 01-1F are HTTP control characters that need to be escaped. - * \u000A and \u000D are \n and \r, respectively. - */ - val clist = Array( - '\u0001', '\u0002', '\u0003', '\u0004', '\u0005', '\u0006', '\u0007', '\u0008', '\u0009', - '\n', '\u000B', '\u000C', '\r', '\u000E', '\u000F', '\u0010', '\u0011', '\u0012', '\u0013', - '\u0014', '\u0015', '\u0016', '\u0017', '\u0018', '\u0019', '\u001A', '\u001B', '\u001C', - '\u001D', '\u001E', '\u001F', '"', '#', '%', '\'', '*', '/', ':', '=', '?', '\\', '\u007F', - '{', '[', ']', '^') - - clist.foreach(bitSet.set(_)) - - if (Shell.WINDOWS) { - Array(' ', '<', '>', '|').foreach(bitSet.set(_)) + private def clearOutputWriters(): Unit = { + if (outputWriters.nonEmpty) { + outputWriters.values.foreach(_.close()) + outputWriters.clear() } - - bitSet } - def needsEscaping(c: Char): Boolean = { - c >= 0 && c < charToEscape.size() && charToEscape.get(c) + override def commitTask(): Unit = { + try { + clearOutputWriters() + super.commitTask() + } catch { case cause: Throwable => + throw new RuntimeException("Failed to commit task", cause) + } } - def escapePathName(path: String): String = { - val builder = new StringBuilder() - path.foreach { c => - if (DynamicPartitionWriterContainer.needsEscaping(c)) { - builder.append('%') - builder.append(f"${c.asInstanceOf[Int]}%02x") - } else { - builder.append(c) - } + override def abortTask(): Unit = { + try { + clearOutputWriters() + } finally { + super.abortTask() } - - builder.toString() } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/sources/ddl.scala b/sql/core/src/main/scala/org/apache/spark/sql/sources/ddl.scala index 37a569db311e..b7095c8ead79 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/sources/ddl.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/sources/ddl.scala @@ -26,7 +26,7 @@ import org.apache.spark.Logging import org.apache.spark.deploy.SparkHadoopUtil import org.apache.spark.sql.catalyst.AbstractSparkSQLParser import org.apache.spark.sql.catalyst.analysis.UnresolvedRelation -import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeReference, Row} +import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeReference, InternalRow} import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.execution.RunnableCommand import org.apache.spark.sql.types._ @@ -130,7 +130,7 @@ private[sql] class DDLParser( } } - protected lazy val tableCols: Parser[Seq[StructField]] = "(" ~> repsep(column, ",") <~ ")" + protected lazy val tableCols: Parser[Seq[StructField]] = "(" ~> repsep(column, ",") <~ ")" /* * describe [extended] table avroTable @@ -138,7 +138,7 @@ private[sql] class DDLParser( */ protected lazy val describeTable: Parser[LogicalPlan] = (DESCRIBE ~> opt(EXTENDED)) ~ (ident <~ ".").? ~ ident ^^ { - case e ~ db ~ tbl => + case e ~ db ~ tbl => val tblIdentifier = db match { case Some(dbName) => Seq(dbName, tbl) @@ -166,12 +166,16 @@ private[sql] class DDLParser( } ) - protected lazy val optionName: Parser[String] = "[_a-zA-Z][a-zA-Z0-9]*".r ^^ { + protected lazy val optionPart: Parser[String] = "[_a-zA-Z][_a-zA-Z0-9]*".r ^^ { case name => name } + protected lazy val optionName: Parser[String] = repsep(optionPart, ".") ^^ { + case parts => parts.mkString(".") + } + protected lazy val pair: Parser[(String, String)] = - optionName ~ stringLit ^^ { case k ~ v => (k,v) } + optionName ~ stringLit ^^ { case k ~ v => (k, v) } protected lazy val column: Parser[StructField] = ident ~ dataType ~ (COMMENT ~> stringLit).? ^^ { case columnName ~ typ ~ cm => @@ -188,18 +192,20 @@ private[sql] class DDLParser( private[sql] object ResolvedDataSource { private val builtinSources = Map( - "jdbc" -> classOf[org.apache.spark.sql.jdbc.DefaultSource], - "json" -> classOf[org.apache.spark.sql.json.DefaultSource], - "parquet" -> classOf[org.apache.spark.sql.parquet.DefaultSource] + "jdbc" -> "org.apache.spark.sql.jdbc.DefaultSource", + "json" -> "org.apache.spark.sql.json.DefaultSource", + "parquet" -> "org.apache.spark.sql.parquet.DefaultSource", + "orc" -> "org.apache.spark.sql.hive.orc.DefaultSource" ) /** Given a provider name, look up the data source class definition. */ def lookupDataSource(provider: String): Class[_] = { + val loader = Utils.getContextOrSparkClassLoader + if (builtinSources.contains(provider)) { - return builtinSources(provider) + return loader.loadClass(builtinSources(provider)) } - val loader = Utils.getContextOrSparkClassLoader try { loader.loadClass(provider) } catch { @@ -208,7 +214,11 @@ private[sql] object ResolvedDataSource { loader.loadClass(provider + ".DefaultSource") } catch { case cnf: java.lang.ClassNotFoundException => - sys.error(s"Failed to load class for data source: $provider") + if (provider.startsWith("org.apache.spark.sql.hive.orc")) { + sys.error("The ORC data source must be used with Hive support enabled.") + } else { + sys.error(s"Failed to load class for data source: $provider") + } } } } @@ -233,18 +243,19 @@ private[sql] object ResolvedDataSource { Some(partitionColumnsSchema(schema, partitionColumns)) } - val caseInsensitiveOptions= new CaseInsensitiveMap(options) + val caseInsensitiveOptions = new CaseInsensitiveMap(options) val paths = { val patternPath = new Path(caseInsensitiveOptions("path")) SparkHadoopUtil.get.globPath(patternPath).map(_.toString).toArray } - val dataSchema = StructType(schema.filterNot(f => partitionColumns.contains(f.name))) + val dataSchema = + StructType(schema.filterNot(f => partitionColumns.contains(f.name))).asNullable dataSource.createRelation( sqlContext, paths, - Some(schema), + Some(dataSchema), maybePartitionsSchema, caseInsensitiveOptions) case dataSource: org.apache.spark.sql.sources.RelationProvider => @@ -314,11 +325,14 @@ private[sql] object ResolvedDataSource { Some(dataSchema.asNullable), Some(partitionColumnsSchema(data.schema, partitionColumns)), caseInsensitiveOptions) + + // For partitioned relation r, r.schema's column ordering can be different from the column + // ordering of data.logicalPlan (partition columns are all moved after data column). This + // will be adjusted within InsertIntoHadoopFsRelation. sqlContext.executePlan( InsertIntoHadoopFsRelation( r, data.logicalPlan, - partitionColumns.toArray, mode)).toRdd r case _ => @@ -394,7 +408,7 @@ private[sql] case class CreateTempTableUsing( provider: String, options: Map[String, String]) extends RunnableCommand { - def run(sqlContext: SQLContext): Seq[Row] = { + def run(sqlContext: SQLContext): Seq[InternalRow] = { val resolved = ResolvedDataSource( sqlContext, userSpecifiedSchema, Array.empty[String], provider, options) sqlContext.registerDataFrameAsTable( @@ -411,7 +425,7 @@ private[sql] case class CreateTempTableUsingAsSelect( options: Map[String, String], query: LogicalPlan) extends RunnableCommand { - override def run(sqlContext: SQLContext): Seq[Row] = { + override def run(sqlContext: SQLContext): Seq[InternalRow] = { val df = DataFrame(sqlContext, query) val resolved = ResolvedDataSource(sqlContext, provider, partitionColumns, mode, options, df) sqlContext.registerDataFrameAsTable( @@ -424,7 +438,7 @@ private[sql] case class CreateTempTableUsingAsSelect( private[sql] case class RefreshTable(databaseName: String, tableName: String) extends RunnableCommand { - override def run(sqlContext: SQLContext): Seq[Row] = { + override def run(sqlContext: SQLContext): Seq[InternalRow] = { // Refresh the given table's metadata first. sqlContext.catalog.refreshTable(databaseName, tableName) @@ -443,7 +457,7 @@ private[sql] case class RefreshTable(databaseName: String, tableName: String) sqlContext.cacheManager.cacheQuery(df, Some(tableName)) } - Seq.empty[Row] + Seq.empty[InternalRow] } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/sources/interfaces.scala b/sql/core/src/main/scala/org/apache/spark/sql/sources/interfaces.scala index 274ab4485217..0b875304f9b0 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/sources/interfaces.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/sources/interfaces.scala @@ -17,19 +17,22 @@ package org.apache.spark.sql.sources +import scala.collection.mutable import scala.util.Try import org.apache.hadoop.conf.Configuration -import org.apache.hadoop.fs.{FileStatus, Path} +import org.apache.hadoop.fs.{FileStatus, FileSystem, Path} import org.apache.hadoop.mapreduce.{Job, TaskAttemptContext} import org.apache.spark.annotation.{DeveloperApi, Experimental} -import org.apache.spark.deploy.SparkHadoopUtil +import org.apache.spark.broadcast.Broadcast import org.apache.spark.rdd.RDD -import org.apache.spark.sql._ +import org.apache.spark.sql.execution.RDDConversions +import org.apache.spark.sql.{DataFrame, Row, SaveMode, SQLContext} import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.codegen.GenerateMutableProjection -import org.apache.spark.sql.types.{StructField, StructType} +import org.apache.spark.sql.types.StructType +import org.apache.spark.util.SerializableConfiguration /** * ::DeveloperApi:: @@ -91,7 +94,7 @@ trait SchemaRelationProvider { } /** - * ::DeveloperApi:: + * ::Experimental:: * Implemented by objects that produce relations for a specific kind of data source * with a given schema and partitioned columns. When Spark SQL is given a DDL operation with a * USING clause specified (to specify the implemented [[HadoopFsRelationProvider]]), a user defined @@ -113,16 +116,19 @@ trait SchemaRelationProvider { * * @since 1.4.0 */ +@Experimental trait HadoopFsRelationProvider { /** * Returns a new base relation with the given parameters, a user defined schema, and a list of * partition columns. Note: the parameters' keywords are case insensitive and this insensitivity * is enforced by the Map that is passed to the function. + * + * @param dataSchema Schema of data columns (i.e., columns that are not partition columns). */ def createRelation( sqlContext: SQLContext, paths: Array[String], - schema: Option[StructType], + dataSchema: Option[StructType], partitionColumns: Option[StructType], parameters: Map[String, String]): HadoopFsRelation } @@ -190,6 +196,8 @@ abstract class BaseRelation { * java.lang.String -> UTF8String * java.lang.Decimal -> Decimal * + * If `needConversion` is `false`, buildScan() should return an [[RDD]] of [[InternalRow]] + * * Note: The internal representation is not stable across releases and thus data sources outside * of Spark SQL should leave this as true. * @@ -368,18 +376,88 @@ abstract class HadoopFsRelation private[sql](maybePartitionSpec: Option[Partitio private var _partitionSpec: PartitionSpec = _ + private class FileStatusCache { + var leafFiles = mutable.Map.empty[Path, FileStatus] + + var leafDirToChildrenFiles = mutable.Map.empty[Path, Array[FileStatus]] + + def refresh(): Unit = { + // We don't filter files/directories whose name start with "_" except "_temporary" here, as + // specific data sources may take advantages over them (e.g. Parquet _metadata and + // _common_metadata files). "_temporary" directories are explicitly ignored since failed + // tasks/jobs may leave partial/corrupted data files there. + def listLeafFilesAndDirs(fs: FileSystem, status: FileStatus): Set[FileStatus] = { + if (status.getPath.getName.toLowerCase == "_temporary") { + Set.empty + } else { + val (dirs, files) = fs.listStatus(status.getPath).partition(_.isDir) + val leafDirs = if (dirs.isEmpty) Set(status) else Set.empty[FileStatus] + files.toSet ++ leafDirs ++ dirs.flatMap(dir => listLeafFilesAndDirs(fs, dir)) + } + } + + leafFiles.clear() + + val statuses = paths.flatMap { path => + val hdfsPath = new Path(path) + val fs = hdfsPath.getFileSystem(hadoopConf) + val qualified = hdfsPath.makeQualified(fs.getUri, fs.getWorkingDirectory) + Try(fs.getFileStatus(qualified)).toOption.toArray.flatMap(listLeafFilesAndDirs(fs, _)) + }.filterNot { status => + // SPARK-8037: Ignores files like ".DS_Store" and other hidden files/directories + status.getPath.getName.startsWith(".") + } + + val files = statuses.filterNot(_.isDir) + leafFiles ++= files.map(f => f.getPath -> f).toMap + leafDirToChildrenFiles ++= files.groupBy(_.getPath.getParent) + } + } + + private lazy val fileStatusCache = { + val cache = new FileStatusCache + cache.refresh() + cache + } + + protected def cachedLeafStatuses(): Set[FileStatus] = { + fileStatusCache.leafFiles.values.toSet + } + final private[sql] def partitionSpec: PartitionSpec = { if (_partitionSpec == null) { _partitionSpec = maybePartitionSpec - .map(spec => spec.copy(partitionColumns = spec.partitionColumns.asNullable)) - .orElse(userDefinedPartitionColumns.map(PartitionSpec(_, Array.empty[Partition]))) + .flatMap { + case spec if spec.partitions.nonEmpty => + Some(spec.copy(partitionColumns = spec.partitionColumns.asNullable)) + case _ => + None + } + .orElse { + // We only know the partition columns and their data types. We need to discover + // partition values. + userDefinedPartitionColumns.map { partitionSchema => + val spec = discoverPartitions() + val partitionColumnTypes = spec.partitionColumns.map(_.dataType) + val castedPartitions = spec.partitions.map { case p @ Partition(values, path) => + val literals = values.toSeq.zip(partitionColumnTypes).map { + case (value, dataType) => Literal.create(value, dataType) + } + val castedValues = partitionSchema.zip(literals).map { case (field, literal) => + Cast(literal, field.dataType).eval() + } + p.copy(values = InternalRow.fromSeq(castedValues)) + } + PartitionSpec(partitionSchema, castedPartitions) + } + } .getOrElse { - if (sqlContext.conf.partitionDiscoveryEnabled()) { - discoverPartitions() - } else { - PartitionSpec(StructType(Nil), Array.empty[Partition]) + if (sqlContext.conf.partitionDiscoveryEnabled()) { + discoverPartitions() + } else { + PartitionSpec(StructType(Nil), Array.empty[Partition]) + } } - } } _partitionSpec } @@ -409,26 +487,18 @@ abstract class HadoopFsRelation private[sql](maybePartitionSpec: Option[Partitio def userDefinedPartitionColumns: Option[StructType] = None private[sql] def refresh(): Unit = { + fileStatusCache.refresh() if (sqlContext.conf.partitionDiscoveryEnabled()) { _partitionSpec = discoverPartitions() } } private def discoverPartitions(): PartitionSpec = { - val basePaths = paths.map(new Path(_)) - val leafDirs = basePaths.flatMap { path => - val fs = path.getFileSystem(hadoopConf) - Try(fs.getFileStatus(path.makeQualified(fs.getUri, fs.getWorkingDirectory))) - .filter(_.isDir) - .map(SparkHadoopUtil.get.listLeafDirStatuses(fs, _)) - .getOrElse(Seq.empty[FileStatus]) - }.map(_.getPath) - - if (leafDirs.nonEmpty) { - PartitioningUtils.parsePartitions(leafDirs, PartitioningUtils.DEFAULT_PARTITION_NAME) - } else { - PartitionSpec(StructType(Array.empty[StructField]), Array.empty[Partition]) - } + val typeInference = sqlContext.conf.partitionColumnTypeInferenceEnabled() + // We use leaf dirs containing data files to discover the schema. + val leafDirs = fileStatusCache.leafDirToChildrenFiles.keys.toSeq + PartitioningUtils.parsePartitions(leafDirs, PartitioningUtils.DEFAULT_PARTITION_NAME, + typeInference) } /** @@ -439,11 +509,33 @@ abstract class HadoopFsRelation private[sql](maybePartitionSpec: Option[Partitio */ override lazy val schema: StructType = { val dataSchemaColumnNames = dataSchema.map(_.name.toLowerCase).toSet - StructType(dataSchema ++ partitionSpec.partitionColumns.filterNot { column => + StructType(dataSchema ++ partitionColumns.filterNot { column => dataSchemaColumnNames.contains(column.name.toLowerCase) }) } + private[sources] final def buildScan( + requiredColumns: Array[String], + filters: Array[Filter], + inputPaths: Array[String], + broadcastedConf: Broadcast[SerializableConfiguration]): RDD[Row] = { + val inputStatuses = inputPaths.flatMap { input => + val path = new Path(input) + + // First assumes `input` is a directory path, and tries to get all files contained in it. + fileStatusCache.leafDirToChildrenFiles.getOrElse( + path, + // Otherwise, `input` might be a file path + fileStatusCache.leafFiles.get(path).toArray + ).filter { status => + val name = status.getPath.getName + !name.startsWith("_") && !name.startsWith(".") + } + } + + buildScan(requiredColumns, filters, inputStatuses, broadcastedConf) + } + /** * Specifies schema of actual data files. For partitioned relations, if one or more partitioned * columns are contained in the data files, they should also appear in `dataSchema`. @@ -457,13 +549,13 @@ abstract class HadoopFsRelation private[sql](maybePartitionSpec: Option[Partitio * this relation. For partitioned relations, this method is called for each selected partition, * and builds an `RDD[Row]` containing all rows within that single partition. * - * @param inputPaths For a non-partitioned relation, it contains paths of all data files in the + * @param inputFiles For a non-partitioned relation, it contains paths of all data files in the * relation. For a partitioned relation, it contains paths of all data files in a single * selected partition. * * @since 1.4.0 */ - def buildScan(inputPaths: Array[String]): RDD[Row] = { + def buildScan(inputFiles: Array[FileStatus]): RDD[Row] = { throw new UnsupportedOperationException( "At least one buildScan() method should be overridden to read the relation.") } @@ -474,31 +566,38 @@ abstract class HadoopFsRelation private[sql](maybePartitionSpec: Option[Partitio * and builds an `RDD[Row]` containing all rows within that single partition. * * @param requiredColumns Required columns. - * @param inputPaths For a non-partitioned relation, it contains paths of all data files in the + * @param inputFiles For a non-partitioned relation, it contains paths of all data files in the * relation. For a partitioned relation, it contains paths of all data files in a single * selected partition. * * @since 1.4.0 */ - def buildScan(requiredColumns: Array[String], inputPaths: Array[String]): RDD[Row] = { + def buildScan(requiredColumns: Array[String], inputFiles: Array[FileStatus]): RDD[Row] = { // Yeah, to workaround serialization... val dataSchema = this.dataSchema val codegenEnabled = this.codegenEnabled + val needConversion = this.needConversion val requiredOutput = requiredColumns.map { col => val field = dataSchema(col) BoundReference(dataSchema.fieldIndex(col), field.dataType, field.nullable) }.toSeq - buildScan(inputPaths).mapPartitions { rows => + val rdd = buildScan(inputFiles) + val converted = + if (needConversion) { + RDDConversions.rowToRowRdd(rdd, dataSchema.fields.map(_.dataType)) + } else { + rdd.map(_.asInstanceOf[InternalRow]) + } + converted.mapPartitions { rows => val buildProjection = if (codegenEnabled) { GenerateMutableProjection.generate(requiredOutput, dataSchema.toAttributes) } else { () => new InterpretedMutableProjection(requiredOutput, dataSchema.toAttributes) } - val mutableProjection = buildProjection() - rows.map(mutableProjection) + rows.map(r => mutableProjection(r).asInstanceOf[Row]) } } @@ -512,7 +611,7 @@ abstract class HadoopFsRelation private[sql](maybePartitionSpec: Option[Partitio * of all `filters`. The pushed down filters are currently purely an optimization as they * will all be evaluated again. This means it is safe to use them with methods that produce * false positives such as filtering partitions based on a bloom filter. - * @param inputPaths For a non-partitioned relation, it contains paths of all data files in the + * @param inputFiles For a non-partitioned relation, it contains paths of all data files in the * relation. For a partitioned relation, it contains paths of all data files in a single * selected partition. * @@ -521,13 +620,42 @@ abstract class HadoopFsRelation private[sql](maybePartitionSpec: Option[Partitio def buildScan( requiredColumns: Array[String], filters: Array[Filter], - inputPaths: Array[String]): RDD[Row] = { - buildScan(requiredColumns, inputPaths) + inputFiles: Array[FileStatus]): RDD[Row] = { + buildScan(requiredColumns, inputFiles) + } + + /** + * For a non-partitioned relation, this method builds an `RDD[Row]` containing all rows within + * this relation. For partitioned relations, this method is called for each selected partition, + * and builds an `RDD[Row]` containing all rows within that single partition. + * + * Note: This interface is subject to change in future. + * + * @param requiredColumns Required columns. + * @param filters Candidate filters to be pushed down. The actual filter should be the conjunction + * of all `filters`. The pushed down filters are currently purely an optimization as they + * will all be evaluated again. This means it is safe to use them with methods that produce + * false positives such as filtering partitions based on a bloom filter. + * @param inputFiles For a non-partitioned relation, it contains paths of all data files in the + * relation. For a partitioned relation, it contains paths of all data files in a single + * selected partition. + * @param broadcastedConf A shared broadcast Hadoop Configuration, which can be used to reduce the + * overhead of broadcasting the Configuration for every Hadoop RDD. + * + * @since 1.4.0 + */ + private[sql] def buildScan( + requiredColumns: Array[String], + filters: Array[Filter], + inputFiles: Array[FileStatus], + broadcastedConf: Broadcast[SerializableConfiguration]): RDD[Row] = { + buildScan(requiredColumns, filters, inputFiles) } /** * Prepares a write job and returns an [[OutputWriterFactory]]. Client side job preparation can - * be put here. For example, user defined output committer can be configured here. + * be put here. For example, user defined output committer can be configured here + * by setting the output committer class in the conf of spark.sql.sources.outputCommitterClass. * * Note that the only side effect expected here is mutating `job` via its setters. Especially, * Spark SQL caches [[BaseRelation]] instances for performance, mutating relation internal states diff --git a/sql/core/src/main/scala/org/apache/spark/sql/sources/rules.scala b/sql/core/src/main/scala/org/apache/spark/sql/sources/rules.scala index 1eacdde7413f..a3fd7f13b3db 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/sources/rules.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/sources/rules.scala @@ -35,9 +35,9 @@ private[sql] object PreInsertCastAndRename extends Rule[LogicalPlan] { // Wait until children are resolved. case p: LogicalPlan if !p.childrenResolved => p - // We are inserting into an InsertableRelation. + // We are inserting into an InsertableRelation or HadoopFsRelation. case i @ InsertIntoTable( - l @ LogicalRelation(r: InsertableRelation), partition, child, overwrite, ifNotExists) => { + l @ LogicalRelation(_: InsertableRelation | _: HadoopFsRelation), _, child, _, _) => { // First, make sure the data to be inserted have the same number of fields with the // schema of the relation. if (l.output.size != child.output.size) { @@ -101,8 +101,20 @@ private[sql] case class PreWriteCheck(catalog: Catalog) extends (LogicalPlan => } } - case logical.InsertIntoTable(LogicalRelation(_: InsertableRelation), _, _, _, _) => // OK - case logical.InsertIntoTable(LogicalRelation(_: HadoopFsRelation), _, _, _, _) => // OK + case logical.InsertIntoTable(LogicalRelation(r: HadoopFsRelation), part, _, _, _) => + // We need to make sure the partition columns specified by users do match partition + // columns of the relation. + val existingPartitionColumns = r.partitionColumns.fieldNames.toSet + val specifiedPartitionColumns = part.keySet + if (existingPartitionColumns != specifiedPartitionColumns) { + failAnalysis(s"Specified partition columns " + + s"(${specifiedPartitionColumns.mkString(", ")}) " + + s"do not match the partition columns of the table. Please use " + + s"(${existingPartitionColumns.mkString(", ")}) as the partition columns.") + } else { + // OK + } + case logical.InsertIntoTable(l: LogicalRelation, _, _, _, _) => // The relation in l is not an InsertableRelation. failAnalysis(s"$l does not allow insertion.") diff --git a/sql/core/src/main/scala/org/apache/spark/sql/test/README.md b/sql/core/src/main/scala/org/apache/spark/sql/test/README.md new file mode 100644 index 000000000000..d867f181b972 --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/test/README.md @@ -0,0 +1,7 @@ +README +====== + +Please do not add any class in this place unless it is used by `sql/console` or Python tests. +If you need to create any classes or traits that will be used by tests from both `sql/core` and +`sql/hive`, you can add them in the `src/test` of `sql/core` (tests of `sql/hive` +depend on the test jar of `sql/core`). diff --git a/sql/core/src/main/scala/org/apache/spark/sql/test/TestSQLContext.scala b/sql/core/src/main/scala/org/apache/spark/sql/test/TestSQLContext.scala index 356a6100d2cf..9fa394525d65 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/test/TestSQLContext.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/test/TestSQLContext.scala @@ -38,7 +38,7 @@ class LocalSQLContext protected[sql] class SQLSession extends super.SQLSession { protected[sql] override lazy val conf: SQLConf = new SQLConf { /** Fewer partitions to speed up testing. */ - override def numShufflePartitions: Int = this.getConf(SQLConf.SHUFFLE_PARTITIONS, "5").toInt + override def numShufflePartitions: Int = this.getConf(SQLConf.SHUFFLE_PARTITIONS, 5) } } diff --git a/sql/core/src/test/java/test/org/apache/spark/sql/JavaApplySchemaSuite.java b/sql/core/src/test/java/test/org/apache/spark/sql/JavaApplySchemaSuite.java index c344a9b095c5..fcb8f5499cf8 100644 --- a/sql/core/src/test/java/test/org/apache/spark/sql/JavaApplySchemaSuite.java +++ b/sql/core/src/test/java/test/org/apache/spark/sql/JavaApplySchemaSuite.java @@ -187,14 +187,14 @@ public void applySchemaToJSON() { null, "this is another simple string.")); - DataFrame df1 = sqlContext.jsonRDD(jsonRDD); + DataFrame df1 = sqlContext.read().json(jsonRDD); StructType actualSchema1 = df1.schema(); Assert.assertEquals(expectedSchema, actualSchema1); df1.registerTempTable("jsonTable1"); List actual1 = sqlContext.sql("select * from jsonTable1").collectAsList(); Assert.assertEquals(expectedResult, actual1); - DataFrame df2 = sqlContext.jsonRDD(jsonRDD, expectedSchema); + DataFrame df2 = sqlContext.read().schema(expectedSchema).json(jsonRDD); StructType actualSchema2 = df2.schema(); Assert.assertEquals(expectedSchema, actualSchema2); df2.registerTempTable("jsonTable2"); diff --git a/sql/core/src/test/java/test/org/apache/spark/sql/sources/JavaSaveLoadSuite.java b/sql/core/src/test/java/test/org/apache/spark/sql/sources/JavaSaveLoadSuite.java index b76f7d421f64..2706e01bd28a 100644 --- a/sql/core/src/test/java/test/org/apache/spark/sql/sources/JavaSaveLoadSuite.java +++ b/sql/core/src/test/java/test/org/apache/spark/sql/sources/JavaSaveLoadSuite.java @@ -67,7 +67,7 @@ public void setUp() throws IOException { jsonObjects.add("{\"a\":" + i + ", \"b\":\"str" + i + "\"}"); } JavaRDD rdd = sc.parallelize(jsonObjects); - df = sqlContext.jsonRDD(rdd); + df = sqlContext.read().json(rdd); df.registerTempTable("jsonTable"); } @@ -75,10 +75,8 @@ public void setUp() throws IOException { public void saveAndLoad() { Map options = new HashMap(); options.put("path", path.toString()); - df.save("org.apache.spark.sql.json", SaveMode.ErrorIfExists, options); - - DataFrame loadedDF = sqlContext.load("org.apache.spark.sql.json", options); - + df.write().mode(SaveMode.ErrorIfExists).format("json").options(options).save(); + DataFrame loadedDF = sqlContext.read().format("json").options(options).load(); checkAnswer(loadedDF, df.collectAsList()); } @@ -86,12 +84,12 @@ public void saveAndLoad() { public void saveAndLoadWithSchema() { Map options = new HashMap(); options.put("path", path.toString()); - df.save("org.apache.spark.sql.json", SaveMode.ErrorIfExists, options); + df.write().format("json").mode(SaveMode.ErrorIfExists).options(options).save(); List fields = new ArrayList(); fields.add(DataTypes.createStructField("b", DataTypes.StringType, true)); StructType schema = DataTypes.createStructType(fields); - DataFrame loadedDF = sqlContext.load("org.apache.spark.sql.json", schema, options); + DataFrame loadedDF = sqlContext.read().format("json").schema(schema).options(options).load(); checkAnswer(loadedDF, sqlContext.sql("SELECT b FROM jsonTable").collectAsList()); } diff --git a/sql/core/src/test/resources/log4j.properties b/sql/core/src/test/resources/log4j.properties index 28e90b9520b2..12fb128149d3 100644 --- a/sql/core/src/test/resources/log4j.properties +++ b/sql/core/src/test/resources/log4j.properties @@ -36,11 +36,11 @@ log4j.appender.FA.layout.ConversionPattern=%d{HH:mm:ss.SSS} %t %p %c{1}: %m%n log4j.appender.FA.Threshold = INFO # Some packages are noisy for no good reason. -log4j.additivity.parquet.hadoop.ParquetRecordReader=false -log4j.logger.parquet.hadoop.ParquetRecordReader=OFF +log4j.additivity.org.apache.parquet.hadoop.ParquetRecordReader=false +log4j.logger.org.apache.parquet.hadoop.ParquetRecordReader=OFF -log4j.additivity.parquet.hadoop.ParquetOutputCommitter=false -log4j.logger.parquet.hadoop.ParquetOutputCommitter=OFF +log4j.additivity.org.apache.parquet.hadoop.ParquetOutputCommitter=false +log4j.logger.org.apache.parquet.hadoop.ParquetOutputCommitter=OFF log4j.additivity.org.apache.hadoop.hive.serde2.lazy.LazyStruct=false log4j.logger.org.apache.hadoop.hive.serde2.lazy.LazyStruct=OFF @@ -52,5 +52,5 @@ log4j.additivity.hive.ql.metadata.Hive=false log4j.logger.hive.ql.metadata.Hive=OFF # Parquet related logging -log4j.logger.parquet.hadoop=WARN +log4j.logger.org.apache.parquet.hadoop=WARN log4j.logger.org.apache.spark.sql.parquet=INFO diff --git a/sql/core/src/test/scala/org/apache/spark/sql/CachedTableSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/CachedTableSuite.scala index 0772e5e18742..eb3e91332206 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/CachedTableSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/CachedTableSuite.scala @@ -25,17 +25,19 @@ import org.scalatest.concurrent.Eventually._ import org.apache.spark.Accumulators import org.apache.spark.sql.TestData._ import org.apache.spark.sql.columnar._ -import org.apache.spark.sql.test.TestSQLContext._ -import org.apache.spark.sql.test.TestSQLContext.implicits._ -import org.apache.spark.storage.{RDDBlockId, StorageLevel} +import org.apache.spark.storage.{StorageLevel, RDDBlockId} case class BigData(s: String) class CachedTableSuite extends QueryTest { TestData // Load test tables. + private lazy val ctx = org.apache.spark.sql.test.TestSQLContext + import ctx.implicits._ + import ctx.sql + def rddIdOf(tableName: String): Int = { - val executedPlan = table(tableName).queryExecution.executedPlan + val executedPlan = ctx.table(tableName).queryExecution.executedPlan executedPlan.collect { case InMemoryColumnarTableScan(_, _, relation) => relation.cachedColumnBuffers.id @@ -45,151 +47,151 @@ class CachedTableSuite extends QueryTest { } def isMaterialized(rddId: Int): Boolean = { - sparkContext.env.blockManager.get(RDDBlockId(rddId, 0)).nonEmpty + ctx.sparkContext.env.blockManager.get(RDDBlockId(rddId, 0)).nonEmpty } test("cache temp table") { testData.select('key).registerTempTable("tempTable") assertCached(sql("SELECT COUNT(*) FROM tempTable"), 0) - cacheTable("tempTable") + ctx.cacheTable("tempTable") assertCached(sql("SELECT COUNT(*) FROM tempTable")) - uncacheTable("tempTable") + ctx.uncacheTable("tempTable") } test("unpersist an uncached table will not raise exception") { - assert(None == cacheManager.lookupCachedData(testData)) + assert(None == ctx.cacheManager.lookupCachedData(testData)) testData.unpersist(blocking = true) - assert(None == cacheManager.lookupCachedData(testData)) + assert(None == ctx.cacheManager.lookupCachedData(testData)) testData.unpersist(blocking = false) - assert(None == cacheManager.lookupCachedData(testData)) + assert(None == ctx.cacheManager.lookupCachedData(testData)) testData.persist() - assert(None != cacheManager.lookupCachedData(testData)) + assert(None != ctx.cacheManager.lookupCachedData(testData)) testData.unpersist(blocking = true) - assert(None == cacheManager.lookupCachedData(testData)) + assert(None == ctx.cacheManager.lookupCachedData(testData)) testData.unpersist(blocking = false) - assert(None == cacheManager.lookupCachedData(testData)) + assert(None == ctx.cacheManager.lookupCachedData(testData)) } test("cache table as select") { sql("CACHE TABLE tempTable AS SELECT key FROM testData") assertCached(sql("SELECT COUNT(*) FROM tempTable")) - uncacheTable("tempTable") + ctx.uncacheTable("tempTable") } test("uncaching temp table") { testData.select('key).registerTempTable("tempTable1") testData.select('key).registerTempTable("tempTable2") - cacheTable("tempTable1") + ctx.cacheTable("tempTable1") assertCached(sql("SELECT COUNT(*) FROM tempTable1")) assertCached(sql("SELECT COUNT(*) FROM tempTable2")) // Is this valid? - uncacheTable("tempTable2") + ctx.uncacheTable("tempTable2") // Should this be cached? assertCached(sql("SELECT COUNT(*) FROM tempTable1"), 0) } test("too big for memory") { - val data = "*" * 10000 - sparkContext.parallelize(1 to 200000, 1).map(_ => BigData(data)).toDF() + val data = "*" * 1000 + ctx.sparkContext.parallelize(1 to 200000, 1).map(_ => BigData(data)).toDF() .registerTempTable("bigData") - table("bigData").persist(StorageLevel.MEMORY_AND_DISK) - assert(table("bigData").count() === 200000L) - table("bigData").unpersist(blocking = true) + ctx.table("bigData").persist(StorageLevel.MEMORY_AND_DISK) + assert(ctx.table("bigData").count() === 200000L) + ctx.table("bigData").unpersist(blocking = true) } test("calling .cache() should use in-memory columnar caching") { - table("testData").cache() - assertCached(table("testData")) - table("testData").unpersist(blocking = true) + ctx.table("testData").cache() + assertCached(ctx.table("testData")) + ctx.table("testData").unpersist(blocking = true) } test("calling .unpersist() should drop in-memory columnar cache") { - table("testData").cache() - table("testData").count() - table("testData").unpersist(blocking = true) - assertCached(table("testData"), 0) + ctx.table("testData").cache() + ctx.table("testData").count() + ctx.table("testData").unpersist(blocking = true) + assertCached(ctx.table("testData"), 0) } test("isCached") { - cacheTable("testData") + ctx.cacheTable("testData") - assertCached(table("testData")) - assert(table("testData").queryExecution.withCachedData match { + assertCached(ctx.table("testData")) + assert(ctx.table("testData").queryExecution.withCachedData match { case _: InMemoryRelation => true case _ => false }) - uncacheTable("testData") - assert(!isCached("testData")) - assert(table("testData").queryExecution.withCachedData match { + ctx.uncacheTable("testData") + assert(!ctx.isCached("testData")) + assert(ctx.table("testData").queryExecution.withCachedData match { case _: InMemoryRelation => false case _ => true }) } test("SPARK-1669: cacheTable should be idempotent") { - assume(!table("testData").logicalPlan.isInstanceOf[InMemoryRelation]) + assume(!ctx.table("testData").logicalPlan.isInstanceOf[InMemoryRelation]) - cacheTable("testData") - assertCached(table("testData")) + ctx.cacheTable("testData") + assertCached(ctx.table("testData")) assertResult(1, "InMemoryRelation not found, testData should have been cached") { - table("testData").queryExecution.withCachedData.collect { + ctx.table("testData").queryExecution.withCachedData.collect { case r: InMemoryRelation => r }.size } - cacheTable("testData") + ctx.cacheTable("testData") assertResult(0, "Double InMemoryRelations found, cacheTable() is not idempotent") { - table("testData").queryExecution.withCachedData.collect { + ctx.table("testData").queryExecution.withCachedData.collect { case r @ InMemoryRelation(_, _, _, _, _: InMemoryColumnarTableScan, _) => r }.size } - uncacheTable("testData") + ctx.uncacheTable("testData") } test("read from cached table and uncache") { - cacheTable("testData") - checkAnswer(table("testData"), testData.collect().toSeq) - assertCached(table("testData")) + ctx.cacheTable("testData") + checkAnswer(ctx.table("testData"), testData.collect().toSeq) + assertCached(ctx.table("testData")) - uncacheTable("testData") - checkAnswer(table("testData"), testData.collect().toSeq) - assertCached(table("testData"), 0) + ctx.uncacheTable("testData") + checkAnswer(ctx.table("testData"), testData.collect().toSeq) + assertCached(ctx.table("testData"), 0) } test("correct error on uncache of non-cached table") { intercept[IllegalArgumentException] { - uncacheTable("testData") + ctx.uncacheTable("testData") } } test("SELECT star from cached table") { sql("SELECT * FROM testData").registerTempTable("selectStar") - cacheTable("selectStar") + ctx.cacheTable("selectStar") checkAnswer( sql("SELECT * FROM selectStar WHERE key = 1"), Seq(Row(1, "1"))) - uncacheTable("selectStar") + ctx.uncacheTable("selectStar") } test("Self-join cached") { val unCachedAnswer = sql("SELECT * FROM testData a JOIN testData b ON a.key = b.key").collect() - cacheTable("testData") + ctx.cacheTable("testData") checkAnswer( sql("SELECT * FROM testData a JOIN testData b ON a.key = b.key"), unCachedAnswer.toSeq) - uncacheTable("testData") + ctx.uncacheTable("testData") } test("'CACHE TABLE' and 'UNCACHE TABLE' SQL statement") { sql("CACHE TABLE testData") - assertCached(table("testData")) + assertCached(ctx.table("testData")) val rddId = rddIdOf("testData") assert( @@ -197,7 +199,7 @@ class CachedTableSuite extends QueryTest { "Eagerly cached in-memory table should have already been materialized") sql("UNCACHE TABLE testData") - assert(!isCached("testData"), "Table 'testData' should not be cached") + assert(!ctx.isCached("testData"), "Table 'testData' should not be cached") eventually(timeout(10 seconds)) { assert(!isMaterialized(rddId), "Uncached in-memory table should have been unpersisted") @@ -206,14 +208,14 @@ class CachedTableSuite extends QueryTest { test("CACHE TABLE tableName AS SELECT * FROM anotherTable") { sql("CACHE TABLE testCacheTable AS SELECT * FROM testData") - assertCached(table("testCacheTable")) + assertCached(ctx.table("testCacheTable")) val rddId = rddIdOf("testCacheTable") assert( isMaterialized(rddId), "Eagerly cached in-memory table should have already been materialized") - uncacheTable("testCacheTable") + ctx.uncacheTable("testCacheTable") eventually(timeout(10 seconds)) { assert(!isMaterialized(rddId), "Uncached in-memory table should have been unpersisted") } @@ -221,14 +223,14 @@ class CachedTableSuite extends QueryTest { test("CACHE TABLE tableName AS SELECT ...") { sql("CACHE TABLE testCacheTable AS SELECT key FROM testData LIMIT 10") - assertCached(table("testCacheTable")) + assertCached(ctx.table("testCacheTable")) val rddId = rddIdOf("testCacheTable") assert( isMaterialized(rddId), "Eagerly cached in-memory table should have already been materialized") - uncacheTable("testCacheTable") + ctx.uncacheTable("testCacheTable") eventually(timeout(10 seconds)) { assert(!isMaterialized(rddId), "Uncached in-memory table should have been unpersisted") } @@ -236,7 +238,7 @@ class CachedTableSuite extends QueryTest { test("CACHE LAZY TABLE tableName") { sql("CACHE LAZY TABLE testData") - assertCached(table("testData")) + assertCached(ctx.table("testData")) val rddId = rddIdOf("testData") assert( @@ -248,7 +250,7 @@ class CachedTableSuite extends QueryTest { isMaterialized(rddId), "Lazily cached in-memory table should have been materialized") - uncacheTable("testData") + ctx.uncacheTable("testData") eventually(timeout(10 seconds)) { assert(!isMaterialized(rddId), "Uncached in-memory table should have been unpersisted") } @@ -256,7 +258,7 @@ class CachedTableSuite extends QueryTest { test("InMemoryRelation statistics") { sql("CACHE TABLE testData") - table("testData").queryExecution.withCachedData.collect { + ctx.table("testData").queryExecution.withCachedData.collect { case cached: InMemoryRelation => val actualSizeInBytes = (1 to 100).map(i => INT.defaultSize + i.toString.length + 4).sum assert(cached.statistics.sizeInBytes === actualSizeInBytes) @@ -265,38 +267,38 @@ class CachedTableSuite extends QueryTest { test("Drops temporary table") { testData.select('key).registerTempTable("t1") - table("t1") - dropTempTable("t1") - assert(intercept[RuntimeException](table("t1")).getMessage.startsWith("Table Not Found")) + ctx.table("t1") + ctx.dropTempTable("t1") + assert(intercept[RuntimeException](ctx.table("t1")).getMessage.startsWith("Table Not Found")) } test("Drops cached temporary table") { testData.select('key).registerTempTable("t1") testData.select('key).registerTempTable("t2") - cacheTable("t1") + ctx.cacheTable("t1") - assert(isCached("t1")) - assert(isCached("t2")) + assert(ctx.isCached("t1")) + assert(ctx.isCached("t2")) - dropTempTable("t1") - assert(intercept[RuntimeException](table("t1")).getMessage.startsWith("Table Not Found")) - assert(!isCached("t2")) + ctx.dropTempTable("t1") + assert(intercept[RuntimeException](ctx.table("t1")).getMessage.startsWith("Table Not Found")) + assert(!ctx.isCached("t2")) } test("Clear all cache") { sql("SELECT key FROM testData LIMIT 10").registerTempTable("t1") sql("SELECT key FROM testData LIMIT 5").registerTempTable("t2") - cacheTable("t1") - cacheTable("t2") - clearCache() - assert(cacheManager.isEmpty) + ctx.cacheTable("t1") + ctx.cacheTable("t2") + ctx.clearCache() + assert(ctx.cacheManager.isEmpty) sql("SELECT key FROM testData LIMIT 10").registerTempTable("t1") sql("SELECT key FROM testData LIMIT 5").registerTempTable("t2") - cacheTable("t1") - cacheTable("t2") + ctx.cacheTable("t1") + ctx.cacheTable("t2") sql("Clear CACHE") - assert(cacheManager.isEmpty) + assert(ctx.cacheManager.isEmpty) } test("Clear accumulators when uncacheTable to prevent memory leaking") { @@ -305,8 +307,8 @@ class CachedTableSuite extends QueryTest { Accumulators.synchronized { val accsSize = Accumulators.originals.size - cacheTable("t1") - cacheTable("t2") + ctx.cacheTable("t1") + ctx.cacheTable("t2") assert((accsSize + 2) == Accumulators.originals.size) } @@ -317,8 +319,8 @@ class CachedTableSuite extends QueryTest { Accumulators.synchronized { val accsSize = Accumulators.originals.size - uncacheTable("t1") - uncacheTable("t2") + ctx.uncacheTable("t1") + ctx.uncacheTable("t2") assert((accsSize - 2) == Accumulators.originals.size) } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/ColumnExpressionSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/ColumnExpressionSuite.scala index 9bdf201b3be7..88bb743ab0bc 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/ColumnExpressionSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/ColumnExpressionSuite.scala @@ -19,23 +19,31 @@ package org.apache.spark.sql import org.scalatest.Matchers._ +import org.apache.spark.sql.execution.Project import org.apache.spark.sql.functions._ -import org.apache.spark.sql.test.TestSQLContext -import org.apache.spark.sql.test.TestSQLContext.implicits._ import org.apache.spark.sql.types._ class ColumnExpressionSuite extends QueryTest { import org.apache.spark.sql.TestData._ + private lazy val ctx = org.apache.spark.sql.test.TestSQLContext + import ctx.implicits._ + + test("alias") { + val df = Seq((1, Seq(1, 2, 3))).toDF("a", "intList") + assert(df.select(df("a").as("b")).columns.head === "b") + assert(df.select(df("a").alias("b")).columns.head === "b") + } + test("single explode") { - val df = Seq((1, Seq(1,2,3))).toDF("a", "intList") + val df = Seq((1, Seq(1, 2, 3))).toDF("a", "intList") checkAnswer( df.select(explode('intList)), Row(1) :: Row(2) :: Row(3) :: Nil) } test("explode and other columns") { - val df = Seq((1, Seq(1,2,3))).toDF("a", "intList") + val df = Seq((1, Seq(1, 2, 3))).toDF("a", "intList") checkAnswer( df.select($"a", explode('intList)), @@ -45,13 +53,13 @@ class ColumnExpressionSuite extends QueryTest { checkAnswer( df.select($"*", explode('intList)), - Row(1, Seq(1,2,3), 1) :: - Row(1, Seq(1,2,3), 2) :: - Row(1, Seq(1,2,3), 3) :: Nil) + Row(1, Seq(1, 2, 3), 1) :: + Row(1, Seq(1, 2, 3), 2) :: + Row(1, Seq(1, 2, 3), 3) :: Nil) } test("aliased explode") { - val df = Seq((1, Seq(1,2,3))).toDF("a", "intList") + val df = Seq((1, Seq(1, 2, 3))).toDF("a", "intList") checkAnswer( df.select(explode('intList).as('int)).select('int), @@ -79,7 +87,7 @@ class ColumnExpressionSuite extends QueryTest { } test("self join explode") { - val df = Seq((1, Seq(1,2,3))).toDF("a", "intList") + val df = Seq((1, Seq(1, 2, 3))).toDF("a", "intList") val exploded = df.select(explode('intList).as('i)) checkAnswer( @@ -177,12 +185,20 @@ class ColumnExpressionSuite extends QueryTest { checkAnswer( nullStrings.toDF.where($"s".isNull), nullStrings.collect().toSeq.filter(r => r.getString(1) eq null)) + + checkAnswer( + ctx.sql("select isnull(null), isnull(1)"), + Row(true, false)) } test("isNotNull") { checkAnswer( nullStrings.toDF.where($"s".isNotNull), nullStrings.collect().toSeq.filter(r => r.getString(1) ne null)) + + checkAnswer( + ctx.sql("select isnotnull(null), isnotnull('a')"), + Row(false, true)) } test("===") { @@ -206,7 +222,7 @@ class ColumnExpressionSuite extends QueryTest { } test("!==") { - val nullData = TestSQLContext.createDataFrame(TestSQLContext.sparkContext.parallelize( + val nullData = ctx.createDataFrame(ctx.sparkContext.parallelize( Row(1, 1) :: Row(1, 2) :: Row(1, null) :: @@ -267,7 +283,7 @@ class ColumnExpressionSuite extends QueryTest { } test("between") { - val testData = TestSQLContext.sparkContext.parallelize( + val testData = ctx.sparkContext.parallelize( (0, 1, 2) :: (1, 2, 3) :: (2, 1, 0) :: @@ -280,7 +296,23 @@ class ColumnExpressionSuite extends QueryTest { checkAnswer(testData.filter($"a".between($"b", $"c")), expectAnswer) } - val booleanData = TestSQLContext.createDataFrame(TestSQLContext.sparkContext.parallelize( + test("in") { + val df = Seq((1, "x"), (2, "y"), (3, "z")).toDF("a", "b") + checkAnswer(df.filter($"a".in(1, 2)), + df.collect().toSeq.filter(r => r.getInt(0) == 1 || r.getInt(0) == 2)) + checkAnswer(df.filter($"a".in(3, 2)), + df.collect().toSeq.filter(r => r.getInt(0) == 3 || r.getInt(0) == 2)) + checkAnswer(df.filter($"a".in(3, 1)), + df.collect().toSeq.filter(r => r.getInt(0) == 3 || r.getInt(0) == 1)) + checkAnswer(df.filter($"b".in("y", "x")), + df.collect().toSeq.filter(r => r.getString(1) == "y" || r.getString(1) == "x")) + checkAnswer(df.filter($"b".in("z", "x")), + df.collect().toSeq.filter(r => r.getString(1) == "z" || r.getString(1) == "x")) + checkAnswer(df.filter($"b".in("z", "y")), + df.collect().toSeq.filter(r => r.getString(1) == "z" || r.getString(1) == "y")) + } + + val booleanData = ctx.createDataFrame(ctx.sparkContext.parallelize( Row(false, false) :: Row(false, true) :: Row(true, false) :: @@ -353,23 +385,6 @@ class ColumnExpressionSuite extends QueryTest { ) } - test("abs") { - checkAnswer( - testData.select(abs('key)).orderBy('key.asc), - (1 to 100).map(n => Row(n)) - ) - - checkAnswer( - negativeData.select(abs('key)).orderBy('key.desc), - (1 to 100).map(n => Row(n)) - ) - - checkAnswer( - testData.select(abs(lit(null))), - (1 to 100).map(_ => Row(null)) - ) - } - test("upper") { checkAnswer( lowerCaseData.select(upper('l)), @@ -385,6 +400,10 @@ class ColumnExpressionSuite extends QueryTest { testData.select(upper(lit(null))), (1 to 100).map(n => Row(null)) ) + + checkAnswer( + ctx.sql("SELECT upper('aB'), ucase('cDe')"), + Row("AB", "CDE")) } test("lower") { @@ -402,11 +421,15 @@ class ColumnExpressionSuite extends QueryTest { testData.select(lower(lit(null))), (1 to 100).map(n => Row(null)) ) + + checkAnswer( + ctx.sql("SELECT lower('aB'), lcase('cDe')"), + Row("ab", "cde")) } test("monotonicallyIncreasingId") { // Make sure we have 2 partitions, each with 2 records. - val df = TestSQLContext.sparkContext.parallelize(1 to 2, 2).mapPartitions { iter => + val df = ctx.sparkContext.parallelize(1 to 2, 2).mapPartitions { iter => Iterator(Tuple1(1), Tuple1(2)) }.toDF("a") checkAnswer( @@ -416,7 +439,7 @@ class ColumnExpressionSuite extends QueryTest { } test("sparkPartitionId") { - val df = TestSQLContext.sparkContext.parallelize(1 to 1, 1).map(i => (i, i)).toDF("a", "b") + val df = ctx.sparkContext.parallelize(1 to 1, 1).map(i => (i, i)).toDF("a", "b") checkAnswer( df.select(sparkPartitionId()), Row(0) @@ -446,13 +469,51 @@ class ColumnExpressionSuite extends QueryTest { } test("rand") { - val randCol = testData.select('key, rand(5L).as("rand")) + val randCol = testData.select($"key", rand(5L).as("rand")) randCol.columns.length should be (2) val rows = randCol.collect() rows.foreach { row => assert(row.getDouble(1) <= 1.0) assert(row.getDouble(1) >= 0.0) } + + def checkNumProjects(df: DataFrame, expectedNumProjects: Int): Unit = { + val projects = df.queryExecution.executedPlan.collect { + case project: Project => project + } + assert(projects.size === expectedNumProjects) + } + + // We first create a plan with two Projects. + // Project [rand + 1 AS rand1, rand - 1 AS rand2] + // Project [key, (Rand 5 + 1) AS rand] + // LogicalRDD [key, value] + // Because Rand function is not deterministic, the column rand is not deterministic. + // So, in the optimizer, we will not collapse Project [rand + 1 AS rand1, rand - 1 AS rand2] + // and Project [key, Rand 5 AS rand]. The final plan still has two Projects. + val dfWithTwoProjects = + testData + .select($"key", (rand(5L) + 1).as("rand")) + .select(($"rand" + 1).as("rand1"), ($"rand" - 1).as("rand2")) + checkNumProjects(dfWithTwoProjects, 2) + + // Now, we add one more project rand1 - rand2 on top of the query plan. + // Since rand1 and rand2 are deterministic (they basically apply +/- to the generated + // rand value), we can collapse rand1 - rand2 to the Project generating rand1 and rand2. + // So, the plan will be optimized from ... + // Project [(rand1 - rand2) AS (rand1 - rand2)] + // Project [rand + 1 AS rand1, rand - 1 AS rand2] + // Project [key, (Rand 5 + 1) AS rand] + // LogicalRDD [key, value] + // to ... + // Project [((rand + 1 AS rand1) - (rand - 1 AS rand2)) AS (rand1 - rand2)] + // Project [key, Rand 5 AS rand] + // LogicalRDD [key, value] + val dfWithThreeProjects = dfWithTwoProjects.select($"rand1" - $"rand2") + checkNumProjects(dfWithThreeProjects, 2) + dfWithThreeProjects.collect().foreach { row => + assert(row.getDouble(0) === 2.0 +- 0.0001) + } } test("randn") { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala index 35a574f35474..b26d3ab253a1 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala @@ -19,13 +19,14 @@ package org.apache.spark.sql import org.apache.spark.sql.TestData._ import org.apache.spark.sql.functions._ -import org.apache.spark.sql.test.TestSQLContext -import org.apache.spark.sql.test.TestSQLContext.implicits._ import org.apache.spark.sql.types.DecimalType class DataFrameAggregateSuite extends QueryTest { + private lazy val ctx = org.apache.spark.sql.test.TestSQLContext + import ctx.implicits._ + test("groupBy") { checkAnswer( testData2.groupBy("a").agg(sum($"b")), @@ -67,12 +68,12 @@ class DataFrameAggregateSuite extends QueryTest { Seq(Row(1, 3), Row(2, 3), Row(3, 3)) ) - TestSQLContext.conf.setConf("spark.sql.retainGroupColumns", "false") + ctx.conf.setConf(SQLConf.DATAFRAME_RETAIN_GROUP_COLUMNS, false) checkAnswer( testData2.groupBy("a").agg(sum($"b")), Seq(Row(3), Row(3), Row(3)) ) - TestSQLContext.conf.setConf("spark.sql.retainGroupColumns", "true") + ctx.conf.setConf(SQLConf.DATAFRAME_RETAIN_GROUP_COLUMNS, true) } test("agg without groups") { @@ -148,12 +149,12 @@ class DataFrameAggregateSuite extends QueryTest { test("null count") { checkAnswer( testData3.groupBy('a).agg(count('b)), - Seq(Row(1,0), Row(2, 1)) + Seq(Row(1, 0), Row(2, 1)) ) checkAnswer( testData3.groupBy('a).agg(count('a + 'b)), - Seq(Row(1,0), Row(2, 1)) + Seq(Row(1, 0), Row(2, 1)) ) checkAnswer( diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameDateSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameDateSuite.scala new file mode 100644 index 000000000000..a4719a38de1d --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameDateSuite.scala @@ -0,0 +1,56 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql + +import java.sql.{Date, Timestamp} + +class DataFrameDateTimeSuite extends QueryTest { + + private lazy val ctx = org.apache.spark.sql.test.TestSQLContext + import ctx.implicits._ + + test("timestamp comparison with date strings") { + val df = Seq( + (1, Timestamp.valueOf("2015-01-01 00:00:00")), + (2, Timestamp.valueOf("2014-01-01 00:00:00"))).toDF("i", "t") + + checkAnswer( + df.select("t").filter($"t" <= "2014-06-01"), + Row(Timestamp.valueOf("2014-01-01 00:00:00")) :: Nil) + + + checkAnswer( + df.select("t").filter($"t" >= "2014-06-01"), + Row(Timestamp.valueOf("2015-01-01 00:00:00")) :: Nil) + } + + test("date comparison with date strings") { + val df = Seq( + (1, Date.valueOf("2015-01-01")), + (2, Date.valueOf("2014-01-01"))).toDF("i", "t") + + checkAnswer( + df.select("t").filter($"t" <= "2014-06-01"), + Row(Date.valueOf("2014-01-01")) :: Nil) + + + checkAnswer( + df.select("t").filter($"t" >= "2015"), + Row(Date.valueOf("2015-01-01")) :: Nil) + } +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala index b1e0faa310b6..afba28515e03 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala @@ -19,7 +19,6 @@ package org.apache.spark.sql import org.apache.spark.sql.TestData._ import org.apache.spark.sql.functions._ -import org.apache.spark.sql.test.TestSQLContext.implicits._ import org.apache.spark.sql.types._ /** @@ -27,6 +26,9 @@ import org.apache.spark.sql.types._ */ class DataFrameFunctionsSuite extends QueryTest { + private lazy val ctx = org.apache.spark.sql.test.TestSQLContext + import ctx.implicits._ + test("array with column name") { val df = Seq((0, 1)).toDF("a", "b") val row = df.select(array("a", "b")).first() @@ -77,10 +79,53 @@ class DataFrameFunctionsSuite extends QueryTest { assert(row.getAs[Row](0) === Row(2, "str")) } - test("struct: must use named column expression") { - intercept[IllegalArgumentException] { - struct(col("a") * 2) - } + test("struct with column expression to be automatically named") { + val df = Seq((1, "str")).toDF("a", "b") + val result = df.select(struct((col("a") * 2), col("b"))) + + val expectedType = StructType(Seq( + StructField("col1", IntegerType, nullable = false), + StructField("b", StringType) + )) + assert(result.first.schema(0).dataType === expectedType) + checkAnswer(result, Row(Row(2, "str"))) + } + + test("struct with literal columns") { + val df = Seq((1, "str1"), (2, "str2")).toDF("a", "b") + val result = df.select(struct((col("a") * 2), lit(5.0))) + + val expectedType = StructType(Seq( + StructField("col1", IntegerType, nullable = false), + StructField("col2", DoubleType, nullable = false) + )) + + assert(result.first.schema(0).dataType === expectedType) + checkAnswer(result, Seq(Row(Row(2, 5.0)), Row(Row(4, 5.0)))) + } + + test("struct with all literal columns") { + val df = Seq((1, "str1"), (2, "str2")).toDF("a", "b") + val result = df.select(struct(lit("v"), lit(5.0))) + + val expectedType = StructType(Seq( + StructField("col1", StringType, nullable = false), + StructField("col2", DoubleType, nullable = false) + )) + + assert(result.first.schema(0).dataType === expectedType) + checkAnswer(result, Seq(Row(Row("v", 5.0)), Row(Row("v", 5.0)))) + } + + test("constant functions") { + checkAnswer( + ctx.sql("SELECT E()"), + Row(scala.math.E) + ) + checkAnswer( + ctx.sql("SELECT PI()"), + Row(scala.math.Pi) + ) } test("bitwiseNOT") { @@ -88,4 +133,144 @@ class DataFrameFunctionsSuite extends QueryTest { testData2.select(bitwiseNOT($"a")), testData2.collect().toSeq.map(r => Row(~r.getInt(0)))) } + + test("bin") { + val df = Seq[(Integer, Integer)]((12, null)).toDF("a", "b") + checkAnswer( + df.select(bin("a"), bin("b")), + Row("1100", null)) + checkAnswer( + df.selectExpr("bin(a)", "bin(b)"), + Row("1100", null)) + } + + test("if function") { + val df = Seq((1, 2)).toDF("a", "b") + checkAnswer( + df.selectExpr("if(a = 1, 'one', 'not_one')", "if(b = 1, 'one', 'not_one')"), + Row("one", "not_one")) + } + + test("nvl function") { + checkAnswer( + ctx.sql("SELECT nvl(null, 'x'), nvl('y', 'x'), nvl(null, null)"), + Row("x", "y", null)) + } + + test("misc md5 function") { + val df = Seq(("ABC", Array[Byte](1, 2, 3, 4, 5, 6))).toDF("a", "b") + checkAnswer( + df.select(md5($"a"), md5("b")), + Row("902fbdd2b1df0c4f70b4a5d23525e932", "6ac1e56bc78f031059be7be854522c4c")) + + checkAnswer( + df.selectExpr("md5(a)", "md5(b)"), + Row("902fbdd2b1df0c4f70b4a5d23525e932", "6ac1e56bc78f031059be7be854522c4c")) + } + + test("misc sha1 function") { + val df = Seq(("ABC", "ABC".getBytes)).toDF("a", "b") + checkAnswer( + df.select(sha1($"a"), sha1("b")), + Row("3c01bdbb26f358bab27f267924aa2c9a03fcfdb8", "3c01bdbb26f358bab27f267924aa2c9a03fcfdb8")) + + val dfEmpty = Seq(("", "".getBytes)).toDF("a", "b") + checkAnswer( + dfEmpty.selectExpr("sha1(a)", "sha1(b)"), + Row("da39a3ee5e6b4b0d3255bfef95601890afd80709", "da39a3ee5e6b4b0d3255bfef95601890afd80709")) + } + + test("misc sha2 function") { + val df = Seq(("ABC", Array[Byte](1, 2, 3, 4, 5, 6))).toDF("a", "b") + checkAnswer( + df.select(sha2($"a", 256), sha2("b", 256)), + Row("b5d4045c3f466fa91fe2cc6abe79232a1a57cdf104f7a26e716e0a1e2789df78", + "7192385c3c0605de55bb9476ce1d90748190ecb32a8eed7f5207b30cf6a1fe89")) + + checkAnswer( + df.selectExpr("sha2(a, 256)", "sha2(b, 256)"), + Row("b5d4045c3f466fa91fe2cc6abe79232a1a57cdf104f7a26e716e0a1e2789df78", + "7192385c3c0605de55bb9476ce1d90748190ecb32a8eed7f5207b30cf6a1fe89")) + + intercept[IllegalArgumentException] { + df.select(sha2($"a", 1024)) + } + } + + test("misc crc32 function") { + val df = Seq(("ABC", Array[Byte](1, 2, 3, 4, 5, 6))).toDF("a", "b") + checkAnswer( + df.select(crc32($"a"), crc32("b")), + Row(2743272264L, 2180413220L)) + + checkAnswer( + df.selectExpr("crc32(a)", "crc32(b)"), + Row(2743272264L, 2180413220L)) + } + + test("string length function") { + checkAnswer( + nullStrings.select(strlen($"s"), strlen("s")), + nullStrings.collect().toSeq.map { r => + val v = r.getString(1) + val l = if (v == null) null else v.length + Row(l, l) + }) + + checkAnswer( + nullStrings.selectExpr("length(s)"), + nullStrings.collect().toSeq.map { r => + val v = r.getString(1) + val l = if (v == null) null else v.length + Row(l) + }) + } + + test("Levenshtein distance") { + val df = Seq(("kitten", "sitting"), ("frog", "fog")).toDF("l", "r") + checkAnswer(df.select(levenshtein("l", "r")), Seq(Row(3), Row(1))) + checkAnswer(df.selectExpr("levenshtein(l, r)"), Seq(Row(3), Row(1))) + } + + test("string ascii function") { + val df = Seq(("abc", "")).toDF("a", "b") + checkAnswer( + df.select(ascii($"a"), ascii("b")), + Row(97, 0)) + + checkAnswer( + df.selectExpr("ascii(a)", "ascii(b)"), + Row(97, 0)) + } + + test("string base64/unbase64 function") { + val bytes = Array[Byte](1, 2, 3, 4) + val df = Seq((bytes, "AQIDBA==")).toDF("a", "b") + checkAnswer( + df.select(base64("a"), base64($"a"), unbase64("b"), unbase64($"b")), + Row("AQIDBA==", "AQIDBA==", bytes, bytes)) + + checkAnswer( + df.selectExpr("base64(a)", "unbase64(b)"), + Row("AQIDBA==", bytes)) + } + + test("string encode/decode function") { + val bytes = Array[Byte](-27, -92, -89, -27, -115, -125, -28, -72, -106, -25, -107, -116) + // scalastyle:off + // non ascii characters are not allowed in the code, so we disable the scalastyle here. + val df = Seq(("大千世界", "utf-8", bytes)).toDF("a", "b", "c") + checkAnswer( + df.select( + encode($"a", "utf-8"), + encode("a", "utf-8"), + decode($"c", "utf-8"), + decode("c", "utf-8")), + Row(bytes, bytes, "大千世界", "大千世界")) + + checkAnswer( + df.selectExpr("encode(a, 'utf-8')", "decode(c, 'utf-8')"), + Row(bytes, "大千世界")) + // scalastyle:on + } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameImplicitsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameImplicitsSuite.scala index 2d2367d6e729..fbb30706a494 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameImplicitsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameImplicitsSuite.scala @@ -17,15 +17,14 @@ package org.apache.spark.sql -import org.apache.spark.sql.test.TestSQLContext.{sparkContext => sc} -import org.apache.spark.sql.test.TestSQLContext.implicits._ - - class DataFrameImplicitsSuite extends QueryTest { + private lazy val ctx = org.apache.spark.sql.test.TestSQLContext + import ctx.implicits._ + test("RDD of tuples") { checkAnswer( - sc.parallelize(1 to 10).map(i => (i, i.toString)).toDF("intCol", "strCol"), + ctx.sparkContext.parallelize(1 to 10).map(i => (i, i.toString)).toDF("intCol", "strCol"), (1 to 10).map(i => Row(i, i.toString))) } @@ -37,19 +36,19 @@ class DataFrameImplicitsSuite extends QueryTest { test("RDD[Int]") { checkAnswer( - sc.parallelize(1 to 10).toDF("intCol"), + ctx.sparkContext.parallelize(1 to 10).toDF("intCol"), (1 to 10).map(i => Row(i))) } test("RDD[Long]") { checkAnswer( - sc.parallelize(1L to 10L).toDF("longCol"), + ctx.sparkContext.parallelize(1L to 10L).toDF("longCol"), (1L to 10L).map(i => Row(i))) } test("RDD[String]") { checkAnswer( - sc.parallelize(1 to 10).map(_.toString).toDF("stringCol"), + ctx.sparkContext.parallelize(1 to 10).map(_.toString).toDF("stringCol"), (1 to 10).map(i => Row(i.toString))) } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameJoinSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameJoinSuite.scala index 787f3f175fea..e1c6c706242d 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameJoinSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameJoinSuite.scala @@ -18,13 +18,14 @@ package org.apache.spark.sql import org.apache.spark.sql.TestData._ +import org.apache.spark.sql.execution.joins.BroadcastHashJoin import org.apache.spark.sql.functions._ -import org.apache.spark.sql.test.TestSQLContext._ -import org.apache.spark.sql.test.TestSQLContext.implicits._ - class DataFrameJoinSuite extends QueryTest { + private lazy val ctx = org.apache.spark.sql.test.TestSQLContext + import ctx.implicits._ + test("join - join using") { val df = Seq(1, 2, 3).map(i => (i, i.toString)).toDF("int", "str") val df2 = Seq(1, 2, 3).map(i => (i, (i + 1).toString)).toDF("int", "str") @@ -34,6 +35,15 @@ class DataFrameJoinSuite extends QueryTest { Row(1, "1", "2") :: Row(2, "2", "3") :: Row(3, "3", "4") :: Nil) } + test("join - join using multiple columns") { + val df = Seq(1, 2, 3).map(i => (i, i + 1, i.toString)).toDF("int", "int2", "str") + val df2 = Seq(1, 2, 3).map(i => (i, i + 1, (i + 1).toString)).toDF("int", "int2", "str") + + checkAnswer( + df.join(df2, Seq("int", "int2")), + Row(1, 2, "1", "2") :: Row(2, 3, "2", "3") :: Row(3, 4, "3", "4") :: Nil) + } + test("join - join using self join") { val df = Seq(1, 2, 3).map(i => (i, i.toString)).toDF("int", "str") @@ -49,7 +59,8 @@ class DataFrameJoinSuite extends QueryTest { checkAnswer( df1.join(df2, $"df1.key" === $"df2.key"), - sql("SELECT a.key, b.key FROM testData a JOIN testData b ON a.key = b.key").collect().toSeq) + ctx.sql("SELECT a.key, b.key FROM testData a JOIN testData b ON a.key = b.key") + .collect().toSeq) } test("join - using aliases after self join") { @@ -83,4 +94,20 @@ class DataFrameJoinSuite extends QueryTest { left.join(right, left("key") === right("key")), Row(1, 1, 1, 1) :: Row(2, 1, 2, 2) :: Nil) } + + test("broadcast join hint") { + val df1 = Seq((1, "1"), (2, "2")).toDF("key", "value") + val df2 = Seq((1, "1"), (2, "2")).toDF("key", "value") + + // equijoin - should be converted into broadcast join + val plan1 = df1.join(broadcast(df2), "key").queryExecution.executedPlan + assert(plan1.collect { case p: BroadcastHashJoin => p }.size === 1) + + // no join key -- should not be a broadcast join + val plan2 = df1.join(broadcast(df2)).queryExecution.executedPlan + assert(plan2.collect { case p: BroadcastHashJoin => p }.size === 0) + + // planner should not crash without a join + broadcast(df1).queryExecution.executedPlan + } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameNaFunctionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameNaFunctionsSuite.scala index 41b4f02e6a29..495701d4f616 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameNaFunctionsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameNaFunctionsSuite.scala @@ -19,11 +19,12 @@ package org.apache.spark.sql import scala.collection.JavaConversions._ -import org.apache.spark.sql.test.TestSQLContext.implicits._ - class DataFrameNaFunctionsSuite extends QueryTest { + private lazy val ctx = org.apache.spark.sql.test.TestSQLContext + import ctx.implicits._ + def createDF(): DataFrame = { Seq[(String, java.lang.Integer, java.lang.Double)]( ("Bob", 16, 176.5), diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameStatSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameStatSuite.scala index 46b1845a9180..7ba4ba73e0cc 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameStatSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameStatSuite.scala @@ -17,16 +17,18 @@ package org.apache.spark.sql -import org.scalatest.FunSuite +import java.util.Random + import org.scalatest.Matchers._ -import org.apache.spark.sql.test.TestSQLContext -import org.apache.spark.sql.test.TestSQLContext.implicits._ +import org.apache.spark.SparkFunSuite + +class DataFrameStatSuite extends SparkFunSuite { + + private val sqlCtx = org.apache.spark.sql.test.TestSQLContext + import sqlCtx.implicits._ -class DataFrameStatSuite extends FunSuite { - - val sqlCtx = TestSQLContext - def toLetter(i: Int): String = (i + 97).toChar.toString + private def toLetter(i: Int): String = (i + 97).toChar.toString test("pearson correlation") { val df = Seq.tabulate(10)(i => (i, 2 * i, i * -1.0)).toDF("a", "b", "c") @@ -65,22 +67,52 @@ class DataFrameStatSuite extends FunSuite { } test("crosstab") { - val df = Seq((0, 0), (2, 1), (1, 0), (2, 0), (0, 0), (2, 0)).toDF("a", "b") + val rng = new Random() + val data = Seq.tabulate(25)(i => (rng.nextInt(5), rng.nextInt(10))) + val df = data.toDF("a", "b") val crosstab = df.stat.crosstab("a", "b") val columnNames = crosstab.schema.fieldNames assert(columnNames(0) === "a_b") - assert(columnNames(1) === "0") - assert(columnNames(2) === "1") - val rows: Array[Row] = crosstab.collect().sortBy(_.getString(0)) - assert(rows(0).get(0).toString === "0") - assert(rows(0).getLong(1) === 2L) - assert(rows(0).get(2) === null) - assert(rows(1).get(0).toString === "1") - assert(rows(1).getLong(1) === 1L) - assert(rows(1).get(2) === null) - assert(rows(2).get(0).toString === "2") - assert(rows(2).getLong(1) === 2L) - assert(rows(2).getLong(2) === 1L) + // reduce by key + val expected = data.map(t => (t, 1)).groupBy(_._1).mapValues(_.length) + val rows = crosstab.collect() + rows.foreach { row => + val i = row.getString(0).toInt + for (col <- 1 until columnNames.length) { + val j = columnNames(col).toInt + assert(row.getLong(col) === expected.getOrElse((i, j), 0).toLong) + } + } + } + + test("special crosstab elements (., '', null, ``)") { + val data = Seq( + ("a", Double.NaN, "ho"), + (null, 2.0, "ho"), + ("a.b", Double.NegativeInfinity, ""), + ("b", Double.PositiveInfinity, "`ha`"), + ("a", 1.0, null) + ) + val df = data.toDF("1", "2", "3") + val ct1 = df.stat.crosstab("1", "2") + // column fields should be 1 + distinct elements of second column + assert(ct1.schema.fields.length === 6) + assert(ct1.collect().length === 4) + val ct2 = df.stat.crosstab("1", "3") + assert(ct2.schema.fields.length === 5) + assert(ct2.schema.fieldNames.contains("ha")) + assert(ct2.collect().length === 4) + val ct3 = df.stat.crosstab("3", "2") + assert(ct3.schema.fields.length === 6) + assert(ct3.schema.fieldNames.contains("NaN")) + assert(ct3.schema.fieldNames.contains("Infinity")) + assert(ct3.schema.fieldNames.contains("-Infinity")) + assert(ct3.collect().length === 4) + val ct4 = df.stat.crosstab("3", "1") + assert(ct4.schema.fields.length === 5) + assert(ct4.schema.fieldNames.contains("null")) + assert(ct4.schema.fieldNames.contains("a.b")) + assert(ct4.collect().length === 4) } test("Frequent Items") { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala index 1d5f6b3aad6f..afb1cf5f8d1c 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala @@ -21,17 +21,19 @@ import scala.language.postfixOps import org.apache.spark.sql.functions._ import org.apache.spark.sql.types._ -import org.apache.spark.sql.test.{ExamplePointUDT, ExamplePoint, TestSQLContext} -import org.apache.spark.sql.test.TestSQLContext.implicits._ +import org.apache.spark.sql.test.{ExamplePointUDT, ExamplePoint} class DataFrameSuite extends QueryTest { import org.apache.spark.sql.TestData._ + lazy val ctx = org.apache.spark.sql.test.TestSQLContext + import ctx.implicits._ + test("analysis error should be eagerly reported") { - val oldSetting = TestSQLContext.conf.dataFrameEagerAnalysis + val oldSetting = ctx.conf.dataFrameEagerAnalysis // Eager analysis. - TestSQLContext.setConf(SQLConf.DATAFRAME_EAGER_ANALYSIS, "true") + ctx.setConf(SQLConf.DATAFRAME_EAGER_ANALYSIS, true) intercept[Exception] { testData.select('nonExistentName) } intercept[Exception] { @@ -45,11 +47,11 @@ class DataFrameSuite extends QueryTest { } // No more eager analysis once the flag is turned off - TestSQLContext.setConf(SQLConf.DATAFRAME_EAGER_ANALYSIS, "false") + ctx.setConf(SQLConf.DATAFRAME_EAGER_ANALYSIS, false) testData.select('nonExistentName) // Set the flag back to original value before this test. - TestSQLContext.setConf(SQLConf.DATAFRAME_EAGER_ANALYSIS, oldSetting.toString) + ctx.setConf(SQLConf.DATAFRAME_EAGER_ANALYSIS, oldSetting) } test("dataframe toString") { @@ -59,7 +61,7 @@ class DataFrameSuite extends QueryTest { } test("rename nested groupby") { - val df = Seq((1,(1,1))).toDF() + val df = Seq((1, (1, 1))).toDF() checkAnswer( df.groupBy("_1").agg(sum("_2._1")).toDF("key", "total"), @@ -67,12 +69,12 @@ class DataFrameSuite extends QueryTest { } test("invalid plan toString, debug mode") { - val oldSetting = TestSQLContext.conf.dataFrameEagerAnalysis - TestSQLContext.setConf(SQLConf.DATAFRAME_EAGER_ANALYSIS, "true") + val oldSetting = ctx.conf.dataFrameEagerAnalysis + ctx.setConf(SQLConf.DATAFRAME_EAGER_ANALYSIS, true) // Turn on debug mode so we can see invalid query plans. import org.apache.spark.sql.execution.debug._ - TestSQLContext.debug() + ctx.debug() val badPlan = testData.select('badColumn) @@ -81,7 +83,7 @@ class DataFrameSuite extends QueryTest { badPlan.toString) // Set the flag back to original value before this test. - TestSQLContext.setConf(SQLConf.DATAFRAME_EAGER_ANALYSIS, oldSetting.toString) + ctx.setConf(SQLConf.DATAFRAME_EAGER_ANALYSIS, oldSetting) } test("access complex data") { @@ -97,8 +99,8 @@ class DataFrameSuite extends QueryTest { } test("empty data frame") { - assert(TestSQLContext.emptyDataFrame.columns.toSeq === Seq.empty[String]) - assert(TestSQLContext.emptyDataFrame.count() === 0) + assert(ctx.emptyDataFrame.columns.toSeq === Seq.empty[String]) + assert(ctx.emptyDataFrame.count() === 0) } test("head and take") { @@ -132,6 +134,14 @@ class DataFrameSuite extends QueryTest { ) } + test("explode alias and star") { + val df = Seq((Array("a"), 1)).toDF("a", "b") + + checkAnswer( + df.select(explode($"a").as("a"), $"*"), + Row("a", Seq("a"), 1) :: Nil) + } + test("selectExpr") { checkAnswer( testData.selectExpr("abs(key)", "value"), @@ -150,6 +160,12 @@ class DataFrameSuite extends QueryTest { testData.collect().filter(_.getInt(0) > 90).toSeq) } + test("filterExpr using where") { + checkAnswer( + testData.where("key > 50"), + testData.collect().filter(_.getInt(0) > 50).toSeq) + } + test("repartition") { checkAnswer( testData.select('key).repartition(10).select('key), @@ -211,23 +227,23 @@ class DataFrameSuite extends QueryTest { test("global sorting") { checkAnswer( testData2.orderBy('a.asc, 'b.asc), - Seq(Row(1,1), Row(1,2), Row(2,1), Row(2,2), Row(3,1), Row(3,2))) + Seq(Row(1, 1), Row(1, 2), Row(2, 1), Row(2, 2), Row(3, 1), Row(3, 2))) checkAnswer( testData2.orderBy(asc("a"), desc("b")), - Seq(Row(1,2), Row(1,1), Row(2,2), Row(2,1), Row(3,2), Row(3,1))) + Seq(Row(1, 2), Row(1, 1), Row(2, 2), Row(2, 1), Row(3, 2), Row(3, 1))) checkAnswer( testData2.orderBy('a.asc, 'b.desc), - Seq(Row(1,2), Row(1,1), Row(2,2), Row(2,1), Row(3,2), Row(3,1))) + Seq(Row(1, 2), Row(1, 1), Row(2, 2), Row(2, 1), Row(3, 2), Row(3, 1))) checkAnswer( testData2.orderBy('a.desc, 'b.desc), - Seq(Row(3,2), Row(3,1), Row(2,2), Row(2,1), Row(1,2), Row(1,1))) + Seq(Row(3, 2), Row(3, 1), Row(2, 2), Row(2, 1), Row(1, 2), Row(1, 1))) checkAnswer( testData2.orderBy('a.desc, 'b.asc), - Seq(Row(3,1), Row(3,2), Row(2,1), Row(2,2), Row(1,1), Row(1,2))) + Seq(Row(3, 1), Row(3, 2), Row(2, 1), Row(2, 2), Row(1, 1), Row(1, 2))) checkAnswer( arrayData.toDF().orderBy('data.getItem(0).asc), @@ -291,7 +307,7 @@ class DataFrameSuite extends QueryTest { ) } - test("call udf in SQLContext") { + test("deprecated callUdf in SQLContext") { val df = Seq(("id1", 1), ("id2", 4), ("id3", 5)).toDF("id", "value") val sqlctx = df.sqlContext sqlctx.udf.register("simpleUdf", (v: Int) => v * v) @@ -300,6 +316,15 @@ class DataFrameSuite extends QueryTest { Row("id1", 1) :: Row("id2", 16) :: Row("id3", 25) :: Nil) } + test("callUDF in SQLContext") { + val df = Seq(("id1", 1), ("id2", 4), ("id3", 5)).toDF("id", "value") + val sqlctx = df.sqlContext + sqlctx.udf.register("simpleUDF", (v: Int) => v * v) + checkAnswer( + df.select($"id", callUDF("simpleUDF", $"value")), + Row("id1", 1) :: Row("id2", 16) :: Row("id3", 25) :: Nil) + } + test("withColumn") { val df = testData.toDF().withColumn("newCol", col("key") + 1) checkAnswer( @@ -311,7 +336,7 @@ class DataFrameSuite extends QueryTest { } test("replace column using withColumn") { - val df2 = TestSQLContext.sparkContext.parallelize(Array(1, 2, 3)).toDF("x") + val df2 = ctx.sparkContext.parallelize(Array(1, 2, 3)).toDF("x") val df3 = df2.withColumn("x", df2("x") + 1) checkAnswer( df3.select("x"), @@ -331,7 +356,52 @@ class DataFrameSuite extends QueryTest { checkAnswer( df, testData.collect().toSeq) - assert(df.schema.map(_.name) === Seq("key","value")) + assert(df.schema.map(_.name) === Seq("key", "value")) + } + + test("drop column using drop with column reference") { + val col = testData("key") + val df = testData.drop(col) + checkAnswer( + df, + testData.collect().map(x => Row(x.getString(1))).toSeq) + assert(df.schema.map(_.name) === Seq("value")) + } + + test("drop unknown column (no-op) with column reference") { + val col = Column("random") + val df = testData.drop(col) + checkAnswer( + df, + testData.collect().toSeq) + assert(df.schema.map(_.name) === Seq("key", "value")) + } + + test("drop unknown column with same name (no-op) with column reference") { + val col = Column("key") + val df = testData.drop(col) + checkAnswer( + df, + testData.collect().toSeq) + assert(df.schema.map(_.name) === Seq("key", "value")) + } + + test("drop column after join with duplicate columns using column reference") { + val newSalary = salary.withColumnRenamed("personId", "id") + val col = newSalary("id") + // this join will result in duplicate "id" columns + val joinedDf = person.join(newSalary, + person("id") === newSalary("id"), "inner") + // remove only the "id" column that was associated with newSalary + val df = joinedDf.drop(col) + checkAnswer( + df, + joinedDf.collect().map { + case Row(id: Int, name: String, age: Int, idToDrop: Int, salary: Double) => + Row(id, name, age, salary) + }.toSeq) + assert(df.schema.map(_.name) === Seq("id", "name", "age", "salary")) + assert(df("id") == person("id")) } test("withColumnRenamed") { @@ -347,7 +417,7 @@ class DataFrameSuite extends QueryTest { test("randomSplit") { val n = 600 - val data = TestSQLContext.sparkContext.parallelize(1 to n, 2).toDF("id") + val data = ctx.sparkContext.parallelize(1 to n, 2).toDF("id") for (seed <- 1 to 5) { val splits = data.randomSplit(Array[Double](1, 2, 3), seed) assert(splits.length == 3, "wrong number of splits") @@ -364,30 +434,35 @@ class DataFrameSuite extends QueryTest { test("describe") { val describeTestData = Seq( - ("Bob", 16, 176), + ("Bob", 16, 176), ("Alice", 32, 164), ("David", 60, 192), - ("Amy", 24, 180)).toDF("name", "age", "height") + ("Amy", 24, 180)).toDF("name", "age", "height") val describeResult = Seq( - Row("count", 4, 4), - Row("mean", 33.0, 178.0), - Row("stddev", 16.583123951777, 10.0), - Row("min", 16, 164), - Row("max", 60, 192)) + Row("count", "4", "4"), + Row("mean", "33.0", "178.0"), + Row("stddev", "16.583123951777", "10.0"), + Row("min", "16", "164"), + Row("max", "60", "192")) val emptyDescribeResult = Seq( - Row("count", 0, 0), - Row("mean", null, null), - Row("stddev", null, null), - Row("min", null, null), - Row("max", null, null)) + Row("count", "0", "0"), + Row("mean", null, null), + Row("stddev", null, null), + Row("min", null, null), + Row("max", null, null)) def getSchemaAsSeq(df: DataFrame): Seq[String] = df.schema.map(_.name) val describeTwoCols = describeTestData.describe("age", "height") assert(getSchemaAsSeq(describeTwoCols) === Seq("summary", "age", "height")) checkAnswer(describeTwoCols, describeResult) + // All aggregate value should have been cast to string + describeTwoCols.collect().foreach { row => + assert(row.get(1).isInstanceOf[String], "expected string but found " + row.get(1).getClass) + assert(row.get(2).isInstanceOf[String], "expected string but found " + row.get(2).getClass) + } val describeAllCols = describeTestData.describe() assert(getSchemaAsSeq(describeAllCols) === Seq("summary", "age", "height")) @@ -417,12 +492,84 @@ class DataFrameSuite extends QueryTest { testData.select($"*").show(1000) } + test("showString: truncate = [true, false]") { + val longString = Array.fill(21)("1").mkString + val df = ctx.sparkContext.parallelize(Seq("1", longString)).toDF() + val expectedAnswerForFalse = """+---------------------+ + ||_1 | + |+---------------------+ + ||1 | + ||111111111111111111111| + |+---------------------+ + |""".stripMargin + assert(df.showString(10, false) === expectedAnswerForFalse) + val expectedAnswerForTrue = """+--------------------+ + || _1| + |+--------------------+ + || 1| + ||11111111111111111...| + |+--------------------+ + |""".stripMargin + assert(df.showString(10, true) === expectedAnswerForTrue) + } + + test("showString(negative)") { + val expectedAnswer = """+---+-----+ + ||key|value| + |+---+-----+ + |+---+-----+ + |only showing top 0 rows + |""".stripMargin + assert(testData.select($"*").showString(-1) === expectedAnswer) + } + + test("showString(0)") { + val expectedAnswer = """+---+-----+ + ||key|value| + |+---+-----+ + |+---+-----+ + |only showing top 0 rows + |""".stripMargin + assert(testData.select($"*").showString(0) === expectedAnswer) + } + + test("showString: array") { + val df = Seq( + (Array(1, 2, 3), Array(1, 2, 3)), + (Array(2, 3, 4), Array(2, 3, 4)) + ).toDF() + val expectedAnswer = """+---------+---------+ + || _1| _2| + |+---------+---------+ + ||[1, 2, 3]|[1, 2, 3]| + ||[2, 3, 4]|[2, 3, 4]| + |+---------+---------+ + |""".stripMargin + assert(df.showString(10) === expectedAnswer) + } + + test("showString: minimum column width") { + val df = Seq( + (1, 1), + (2, 2) + ).toDF() + val expectedAnswer = """+---+---+ + || _1| _2| + |+---+---+ + || 1| 1| + || 2| 2| + |+---+---+ + |""".stripMargin + assert(df.showString(10) === expectedAnswer) + } + test("SPARK-7319 showString") { val expectedAnswer = """+---+-----+ ||key|value| |+---+-----+ || 1| 1| |+---+-----+ + |only showing top 1 row |""".stripMargin assert(testData.select($"*").showString(1) === expectedAnswer) } @@ -437,19 +584,22 @@ class DataFrameSuite extends QueryTest { } test("createDataFrame(RDD[Row], StructType) should convert UDTs (SPARK-6672)") { - val rowRDD = TestSQLContext.sparkContext.parallelize(Seq(Row(new ExamplePoint(1.0, 2.0)))) + val rowRDD = ctx.sparkContext.parallelize(Seq(Row(new ExamplePoint(1.0, 2.0)))) val schema = StructType(Array(StructField("point", new ExamplePointUDT(), false))) - val df = TestSQLContext.createDataFrame(rowRDD, schema) + val df = ctx.createDataFrame(rowRDD, schema) df.rdd.collect() } test("SPARK-6899") { - val originalValue = TestSQLContext.conf.codegenEnabled - TestSQLContext.setConf(SQLConf.CODEGEN_ENABLED, "true") - checkAnswer( - decimalData.agg(avg('a)), - Row(new java.math.BigDecimal(2.0))) - TestSQLContext.setConf(SQLConf.CODEGEN_ENABLED, originalValue.toString) + val originalValue = ctx.conf.codegenEnabled + ctx.setConf(SQLConf.CODEGEN_ENABLED, true) + try{ + checkAnswer( + decimalData.agg(avg('a)), + Row(new java.math.BigDecimal(2.0))) + } finally { + ctx.setConf(SQLConf.CODEGEN_ENABLED, originalValue) + } } test("SPARK-7133: Implement struct, array, and map field accessor") { @@ -460,14 +610,14 @@ class DataFrameSuite extends QueryTest { } test("SPARK-7551: support backticks for DataFrame attribute resolution") { - val df = TestSQLContext.jsonRDD(TestSQLContext.sparkContext.makeRDD( + val df = ctx.read.json(ctx.sparkContext.makeRDD( """{"a.b": {"c": {"d..e": {"f": 1}}}}""" :: Nil)) checkAnswer( df.select(df("`a.b`.c.`d..e`.`f`")), Row(1) ) - val df2 = TestSQLContext.jsonRDD(TestSQLContext.sparkContext.makeRDD( + val df2 = ctx.read.json(ctx.sparkContext.makeRDD( """{"a b": {"c": {"d e": {"f": 1}}}}""" :: Nil)) checkAnswer( df2.select(df2("`a b`.c.d e.f")), @@ -487,7 +637,7 @@ class DataFrameSuite extends QueryTest { } test("SPARK-7324 dropDuplicates") { - val testData = TestSQLContext.sparkContext.parallelize( + val testData = ctx.sparkContext.parallelize( (2, 1, 2) :: (1, 1, 1) :: (1, 2, 1) :: (2, 1, 2) :: (2, 2, 2) :: (2, 2, 1) :: @@ -532,4 +682,59 @@ class DataFrameSuite extends QueryTest { val p = df.logicalPlan.asInstanceOf[Project].child.asInstanceOf[Project] assert(!p.child.isInstanceOf[Project]) } + + test("SPARK-7150 range api") { + // numSlice is greater than length + val res1 = ctx.range(0, 10, 1, 15).select("id") + assert(res1.count == 10) + assert(res1.agg(sum("id")).as("sumid").collect() === Seq(Row(45))) + + val res2 = ctx.range(3, 15, 3, 2).select("id") + assert(res2.count == 4) + assert(res2.agg(sum("id")).as("sumid").collect() === Seq(Row(30))) + + val res3 = ctx.range(1, -2).select("id") + assert(res3.count == 0) + + // start is positive, end is negative, step is negative + val res4 = ctx.range(1, -2, -2, 6).select("id") + assert(res4.count == 2) + assert(res4.agg(sum("id")).as("sumid").collect() === Seq(Row(0))) + + // start, end, step are negative + val res5 = ctx.range(-3, -8, -2, 1).select("id") + assert(res5.count == 3) + assert(res5.agg(sum("id")).as("sumid").collect() === Seq(Row(-15))) + + // start, end are negative, step is positive + val res6 = ctx.range(-8, -4, 2, 1).select("id") + assert(res6.count == 2) + assert(res6.agg(sum("id")).as("sumid").collect() === Seq(Row(-14))) + + val res7 = ctx.range(-10, -9, -20, 1).select("id") + assert(res7.count == 0) + + val res8 = ctx.range(Long.MinValue, Long.MaxValue, Long.MaxValue, 100).select("id") + assert(res8.count == 3) + assert(res8.agg(sum("id")).as("sumid").collect() === Seq(Row(-3))) + + val res9 = ctx.range(Long.MaxValue, Long.MinValue, Long.MinValue, 100).select("id") + assert(res9.count == 2) + assert(res9.agg(sum("id")).as("sumid").collect() === Seq(Row(Long.MaxValue - 1))) + + // only end provided as argument + val res10 = ctx.range(10).select("id") + assert(res10.count == 10) + assert(res10.agg(sum("id")).as("sumid").collect() === Seq(Row(45))) + + val res11 = ctx.range(-1).select("id") + assert(res11.count == 0) + } + + test("SPARK-8621: support empty string column name") { + val df = Seq(Tuple1(1)).toDF("").as("t") + // We should allow empty string as column name + df.col("") + df.col("t.``") + } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DatetimeExpressionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DatetimeExpressionsSuite.scala new file mode 100644 index 000000000000..44b915304533 --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/DatetimeExpressionsSuite.scala @@ -0,0 +1,48 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql + +import org.apache.spark.sql.catalyst.util.DateTimeUtils +import org.apache.spark.sql.functions._ + +class DatetimeExpressionsSuite extends QueryTest { + private lazy val ctx = org.apache.spark.sql.test.TestSQLContext + + import ctx.implicits._ + + lazy val df1 = Seq((1, 2), (3, 1)).toDF("a", "b") + + test("function current_date") { + val d0 = DateTimeUtils.millisToDays(System.currentTimeMillis()) + val d1 = DateTimeUtils.fromJavaDate(df1.select(current_date()).collect().head.getDate(0)) + val d2 = DateTimeUtils.fromJavaDate( + ctx.sql("""SELECT CURRENT_DATE()""").collect().head.getDate(0)) + val d3 = DateTimeUtils.millisToDays(System.currentTimeMillis()) + assert(d0 <= d1 && d1 <= d2 && d2 <= d3 && d3 - d0 <= 1) + } + + test("function current_timestamp") { + checkAnswer(df1.select(countDistinct(current_timestamp())), Row(1)) + // Execution in one query should return the same value + checkAnswer(ctx.sql("""SELECT CURRENT_TIMESTAMP() = CURRENT_TIMESTAMP()"""), + Row(true)) + assert(math.abs(ctx.sql("""SELECT CURRENT_TIMESTAMP()""").collect().head.getTimestamp( + 0).getTime - System.currentTimeMillis()) < 5000) + } + +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala index 037d392c1f92..8953889d1fae 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala @@ -20,33 +20,35 @@ package org.apache.spark.sql import org.scalatest.BeforeAndAfterEach import org.apache.spark.sql.TestData._ -import org.apache.spark.sql.functions._ import org.apache.spark.sql.catalyst.analysis.UnresolvedRelation import org.apache.spark.sql.execution.joins._ -import org.apache.spark.sql.test.TestSQLContext._ -import org.apache.spark.sql.test.TestSQLContext.implicits._ class JoinSuite extends QueryTest with BeforeAndAfterEach { // Ensures tables are loaded. TestData + lazy val ctx = org.apache.spark.sql.test.TestSQLContext + import ctx.implicits._ + import ctx.logicalPlanToSparkQuery + test("equi-join is hash-join") { val x = testData2.as("x") val y = testData2.as("y") val join = x.join(y, $"x.a" === $"y.a", "inner").queryExecution.optimizedPlan - val planned = planner.HashJoin(join) + val planned = ctx.planner.HashJoin(join) assert(planned.size === 1) } def assertJoin(sqlString: String, c: Class[_]): Any = { - val df = sql(sqlString) + val df = ctx.sql(sqlString) val physical = df.queryExecution.sparkPlan val operators = physical.collect { case j: ShuffledHashJoin => j - case j: HashOuterJoin => j + case j: ShuffledHashOuterJoin => j case j: LeftSemiJoinHash => j case j: BroadcastHashJoin => j + case j: BroadcastHashOuterJoin => j case j: LeftSemiJoinBNL => j case j: CartesianProduct => j case j: BroadcastNestedLoopJoin => j @@ -61,9 +63,9 @@ class JoinSuite extends QueryTest with BeforeAndAfterEach { } test("join operator selection") { - cacheManager.clearCache() + ctx.cacheManager.clearCache() - val SORTMERGEJOIN_ENABLED: Boolean = conf.sortMergeJoinEnabled + val SORTMERGEJOIN_ENABLED: Boolean = ctx.conf.sortMergeJoinEnabled Seq( ("SELECT * FROM testData LEFT SEMI JOIN testData2 ON key = a", classOf[LeftSemiJoinHash]), ("SELECT * FROM testData LEFT SEMI JOIN testData2", classOf[LeftSemiJoinBNL]), @@ -80,12 +82,13 @@ class JoinSuite extends QueryTest with BeforeAndAfterEach { ("SELECT * FROM testData JOIN testData2 ON key = a", classOf[ShuffledHashJoin]), ("SELECT * FROM testData JOIN testData2 ON key = a and key = 2", classOf[ShuffledHashJoin]), ("SELECT * FROM testData JOIN testData2 ON key = a where key = 2", classOf[ShuffledHashJoin]), - ("SELECT * FROM testData LEFT JOIN testData2 ON key = a", classOf[HashOuterJoin]), + ("SELECT * FROM testData LEFT JOIN testData2 ON key = a", classOf[ShuffledHashOuterJoin]), ("SELECT * FROM testData RIGHT JOIN testData2 ON key = a where key = 2", - classOf[HashOuterJoin]), + classOf[ShuffledHashOuterJoin]), ("SELECT * FROM testData right join testData2 ON key = a and key = 2", - classOf[HashOuterJoin]), - ("SELECT * FROM testData full outer join testData2 ON key = a", classOf[HashOuterJoin]), + classOf[ShuffledHashOuterJoin]), + ("SELECT * FROM testData full outer join testData2 ON key = a", + classOf[ShuffledHashOuterJoin]), ("SELECT * FROM testData left JOIN testData2 ON (key * a != key + a)", classOf[BroadcastNestedLoopJoin]), ("SELECT * FROM testData right JOIN testData2 ON (key * a != key + a)", @@ -94,22 +97,22 @@ class JoinSuite extends QueryTest with BeforeAndAfterEach { classOf[BroadcastNestedLoopJoin]) ).foreach { case (query, joinClass) => assertJoin(query, joinClass) } try { - conf.setConf("spark.sql.planner.sortMergeJoin", "true") + ctx.conf.setConf(SQLConf.SORTMERGE_JOIN, true) Seq( ("SELECT * FROM testData JOIN testData2 ON key = a", classOf[SortMergeJoin]), ("SELECT * FROM testData JOIN testData2 ON key = a and key = 2", classOf[SortMergeJoin]), ("SELECT * FROM testData JOIN testData2 ON key = a where key = 2", classOf[SortMergeJoin]) ).foreach { case (query, joinClass) => assertJoin(query, joinClass) } } finally { - conf.setConf("spark.sql.planner.sortMergeJoin", SORTMERGEJOIN_ENABLED.toString) + ctx.conf.setConf(SQLConf.SORTMERGE_JOIN, SORTMERGEJOIN_ENABLED) } } test("broadcasted hash join operator selection") { - cacheManager.clearCache() - sql("CACHE TABLE testData") + ctx.cacheManager.clearCache() + ctx.sql("CACHE TABLE testData") - val SORTMERGEJOIN_ENABLED: Boolean = conf.sortMergeJoinEnabled + val SORTMERGEJOIN_ENABLED: Boolean = ctx.conf.sortMergeJoinEnabled Seq( ("SELECT * FROM testData join testData2 ON key = a", classOf[BroadcastHashJoin]), ("SELECT * FROM testData join testData2 ON key = a and key = 2", classOf[BroadcastHashJoin]), @@ -117,7 +120,7 @@ class JoinSuite extends QueryTest with BeforeAndAfterEach { classOf[BroadcastHashJoin]) ).foreach { case (query, joinClass) => assertJoin(query, joinClass) } try { - conf.setConf("spark.sql.planner.sortMergeJoin", "true") + ctx.conf.setConf(SQLConf.SORTMERGE_JOIN, true) Seq( ("SELECT * FROM testData join testData2 ON key = a", classOf[BroadcastHashJoin]), ("SELECT * FROM testData join testData2 ON key = a and key = 2", @@ -126,17 +129,45 @@ class JoinSuite extends QueryTest with BeforeAndAfterEach { classOf[BroadcastHashJoin]) ).foreach { case (query, joinClass) => assertJoin(query, joinClass) } } finally { - conf.setConf("spark.sql.planner.sortMergeJoin", SORTMERGEJOIN_ENABLED.toString) + ctx.conf.setConf(SQLConf.SORTMERGE_JOIN, SORTMERGEJOIN_ENABLED) + } + + ctx.sql("UNCACHE TABLE testData") + } + + test("broadcasted hash outer join operator selection") { + ctx.cacheManager.clearCache() + ctx.sql("CACHE TABLE testData") + + val SORTMERGEJOIN_ENABLED: Boolean = ctx.conf.sortMergeJoinEnabled + Seq( + ("SELECT * FROM testData LEFT JOIN testData2 ON key = a", classOf[ShuffledHashOuterJoin]), + ("SELECT * FROM testData RIGHT JOIN testData2 ON key = a where key = 2", + classOf[BroadcastHashOuterJoin]), + ("SELECT * FROM testData right join testData2 ON key = a and key = 2", + classOf[BroadcastHashOuterJoin]) + ).foreach { case (query, joinClass) => assertJoin(query, joinClass) } + try { + ctx.conf.setConf(SQLConf.SORTMERGE_JOIN, true) + Seq( + ("SELECT * FROM testData LEFT JOIN testData2 ON key = a", classOf[ShuffledHashOuterJoin]), + ("SELECT * FROM testData RIGHT JOIN testData2 ON key = a where key = 2", + classOf[BroadcastHashOuterJoin]), + ("SELECT * FROM testData right join testData2 ON key = a and key = 2", + classOf[BroadcastHashOuterJoin]) + ).foreach { case (query, joinClass) => assertJoin(query, joinClass) } + } finally { + ctx.conf.setConf(SQLConf.SORTMERGE_JOIN, SORTMERGEJOIN_ENABLED) } - sql("UNCACHE TABLE testData") + ctx.sql("UNCACHE TABLE testData") } test("multiple-key equi-join is hash-join") { val x = testData2.as("x") val y = testData2.as("y") val join = x.join(y, ($"x.a" === $"y.a") && ($"x.b" === $"y.b")).queryExecution.optimizedPlan - val planned = planner.HashJoin(join) + val planned = ctx.planner.HashJoin(join) assert(planned.size === 1) } @@ -167,10 +198,10 @@ class JoinSuite extends QueryTest with BeforeAndAfterEach { val y = testData2.where($"a" === 1).as("y") checkAnswer( x.join(y).where($"x.a" === $"y.a"), - Row(1,1,1,1) :: - Row(1,1,1,2) :: - Row(1,2,1,1) :: - Row(1,2,1,2) :: Nil + Row(1, 1, 1, 1) :: + Row(1, 1, 1, 2) :: + Row(1, 2, 1, 1) :: + Row(1, 2, 1, 2) :: Nil ) } @@ -241,7 +272,7 @@ class JoinSuite extends QueryTest with BeforeAndAfterEach { // Make sure we are choosing left.outputPartitioning as the // outputPartitioning for the outer join operator. checkAnswer( - sql( + ctx.sql( """ |SELECT l.N, count(*) |FROM upperCaseData l LEFT OUTER JOIN allNulls r ON (l.N = r.a) @@ -255,7 +286,7 @@ class JoinSuite extends QueryTest with BeforeAndAfterEach { Row(6, 1) :: Nil) checkAnswer( - sql( + ctx.sql( """ |SELECT r.a, count(*) |FROM upperCaseData l LEFT OUTER JOIN allNulls r ON (l.N = r.a) @@ -301,7 +332,7 @@ class JoinSuite extends QueryTest with BeforeAndAfterEach { // Make sure we are choosing right.outputPartitioning as the // outputPartitioning for the outer join operator. checkAnswer( - sql( + ctx.sql( """ |SELECT l.a, count(*) |FROM allNulls l RIGHT OUTER JOIN upperCaseData r ON (l.a = r.N) @@ -310,7 +341,7 @@ class JoinSuite extends QueryTest with BeforeAndAfterEach { Row(null, 6)) checkAnswer( - sql( + ctx.sql( """ |SELECT r.N, count(*) |FROM allNulls l RIGHT OUTER JOIN upperCaseData r ON (l.a = r.N) @@ -362,7 +393,7 @@ class JoinSuite extends QueryTest with BeforeAndAfterEach { // Make sure we are UnknownPartitioning as the outputPartitioning for the outer join operator. checkAnswer( - sql( + ctx.sql( """ |SELECT l.a, count(*) |FROM allNulls l FULL OUTER JOIN upperCaseData r ON (l.a = r.N) @@ -371,7 +402,7 @@ class JoinSuite extends QueryTest with BeforeAndAfterEach { Row(null, 10)) checkAnswer( - sql( + ctx.sql( """ |SELECT r.N, count(*) |FROM allNulls l FULL OUTER JOIN upperCaseData r ON (l.a = r.N) @@ -386,7 +417,7 @@ class JoinSuite extends QueryTest with BeforeAndAfterEach { Row(null, 4) :: Nil) checkAnswer( - sql( + ctx.sql( """ |SELECT l.N, count(*) |FROM upperCaseData l FULL OUTER JOIN allNulls r ON (l.N = r.a) @@ -401,7 +432,7 @@ class JoinSuite extends QueryTest with BeforeAndAfterEach { Row(null, 4) :: Nil) checkAnswer( - sql( + ctx.sql( """ |SELECT r.a, count(*) |FROM upperCaseData l FULL OUTER JOIN allNulls r ON (l.N = r.a) @@ -411,11 +442,11 @@ class JoinSuite extends QueryTest with BeforeAndAfterEach { } test("broadcasted left semi join operator selection") { - cacheManager.clearCache() - sql("CACHE TABLE testData") - val tmp = conf.autoBroadcastJoinThreshold + ctx.cacheManager.clearCache() + ctx.sql("CACHE TABLE testData") + val tmp = ctx.conf.autoBroadcastJoinThreshold - sql(s"SET ${SQLConf.AUTO_BROADCASTJOIN_THRESHOLD}=1000000000") + ctx.sql(s"SET ${SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key}=1000000000") Seq( ("SELECT * FROM testData LEFT SEMI JOIN testData2 ON key = a", classOf[BroadcastLeftSemiJoinHash]) @@ -423,7 +454,7 @@ class JoinSuite extends QueryTest with BeforeAndAfterEach { case (query, joinClass) => assertJoin(query, joinClass) } - sql(s"SET ${SQLConf.AUTO_BROADCASTJOIN_THRESHOLD}=-1") + ctx.sql(s"SET ${SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key}=-1") Seq( ("SELECT * FROM testData LEFT SEMI JOIN testData2 ON key = a", classOf[LeftSemiJoinHash]) @@ -431,12 +462,12 @@ class JoinSuite extends QueryTest with BeforeAndAfterEach { case (query, joinClass) => assertJoin(query, joinClass) } - setConf(SQLConf.AUTO_BROADCASTJOIN_THRESHOLD, tmp.toString) - sql("UNCACHE TABLE testData") + ctx.setConf(SQLConf.AUTO_BROADCASTJOIN_THRESHOLD, tmp) + ctx.sql("UNCACHE TABLE testData") } test("left semi join") { - val df = sql("SELECT * FROM testData2 LEFT SEMI JOIN testData ON key = a") + val df = ctx.sql("SELECT * FROM testData2 LEFT SEMI JOIN testData ON key = a") checkAnswer(df, Row(1, 1) :: Row(1, 2) :: diff --git a/sql/core/src/test/scala/org/apache/spark/sql/ListTablesSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/ListTablesSuite.scala index f9f41eb358bd..2089660c52bf 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/ListTablesSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/ListTablesSuite.scala @@ -19,49 +19,47 @@ package org.apache.spark.sql import org.scalatest.BeforeAndAfter -import org.apache.spark.sql.test.TestSQLContext -import org.apache.spark.sql.test.TestSQLContext._ import org.apache.spark.sql.types.{BooleanType, StringType, StructField, StructType} class ListTablesSuite extends QueryTest with BeforeAndAfter { - import org.apache.spark.sql.test.TestSQLContext.implicits._ + private lazy val ctx = org.apache.spark.sql.test.TestSQLContext + import ctx.implicits._ - val df = - sparkContext.parallelize((1 to 10).map(i => (i,s"str$i"))).toDF("key", "value") + private lazy val df = (1 to 10).map(i => (i, s"str$i")).toDF("key", "value") before { df.registerTempTable("ListTablesSuiteTable") } after { - catalog.unregisterTable(Seq("ListTablesSuiteTable")) + ctx.catalog.unregisterTable(Seq("ListTablesSuiteTable")) } test("get all tables") { checkAnswer( - tables().filter("tableName = 'ListTablesSuiteTable'"), + ctx.tables().filter("tableName = 'ListTablesSuiteTable'"), Row("ListTablesSuiteTable", true)) checkAnswer( - sql("SHOW tables").filter("tableName = 'ListTablesSuiteTable'"), + ctx.sql("SHOW tables").filter("tableName = 'ListTablesSuiteTable'"), Row("ListTablesSuiteTable", true)) - catalog.unregisterTable(Seq("ListTablesSuiteTable")) - assert(tables().filter("tableName = 'ListTablesSuiteTable'").count() === 0) + ctx.catalog.unregisterTable(Seq("ListTablesSuiteTable")) + assert(ctx.tables().filter("tableName = 'ListTablesSuiteTable'").count() === 0) } test("getting all Tables with a database name has no impact on returned table names") { checkAnswer( - tables("DB").filter("tableName = 'ListTablesSuiteTable'"), + ctx.tables("DB").filter("tableName = 'ListTablesSuiteTable'"), Row("ListTablesSuiteTable", true)) checkAnswer( - sql("show TABLES in DB").filter("tableName = 'ListTablesSuiteTable'"), + ctx.sql("show TABLES in DB").filter("tableName = 'ListTablesSuiteTable'"), Row("ListTablesSuiteTable", true)) - catalog.unregisterTable(Seq("ListTablesSuiteTable")) - assert(tables().filter("tableName = 'ListTablesSuiteTable'").count() === 0) + ctx.catalog.unregisterTable(Seq("ListTablesSuiteTable")) + assert(ctx.tables().filter("tableName = 'ListTablesSuiteTable'").count() === 0) } test("query the returned DataFrame of tables") { @@ -69,19 +67,20 @@ class ListTablesSuite extends QueryTest with BeforeAndAfter { StructField("tableName", StringType, false) :: StructField("isTemporary", BooleanType, false) :: Nil) - Seq(tables(), sql("SHOW TABLes")).foreach { + Seq(ctx.tables(), ctx.sql("SHOW TABLes")).foreach { case tableDF => assert(expectedSchema === tableDF.schema) tableDF.registerTempTable("tables") checkAnswer( - sql("SELECT isTemporary, tableName from tables WHERE tableName = 'ListTablesSuiteTable'"), + ctx.sql( + "SELECT isTemporary, tableName from tables WHERE tableName = 'ListTablesSuiteTable'"), Row(true, "ListTablesSuiteTable") ) checkAnswer( - tables().filter("tableName = 'tables'").select("tableName", "isTemporary"), + ctx.tables().filter("tableName = 'tables'").select("tableName", "isTemporary"), Row("tables", true)) - dropTempTable("tables") + ctx.dropTempTable("tables") } } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/MathExpressionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/MathExpressionsSuite.scala index c4281c4b55c0..24bef21b999e 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/MathExpressionsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/MathExpressionsSuite.scala @@ -17,36 +17,29 @@ package org.apache.spark.sql -import java.lang.{Double => JavaDouble} - import org.apache.spark.sql.functions._ -import org.apache.spark.sql.test.TestSQLContext -import org.apache.spark.sql.test.TestSQLContext.implicits._ - -private[this] object MathExpressionsTestData { - - case class DoubleData(a: JavaDouble, b: JavaDouble) - val doubleData = TestSQLContext.sparkContext.parallelize( - (1 to 10).map(i => DoubleData(i * 0.2 - 1, i * -0.2 + 1))).toDF() - - val nnDoubleData = TestSQLContext.sparkContext.parallelize( - (1 to 10).map(i => DoubleData(i * 0.1, i * -0.1))).toDF() - - case class NullDoubles(a: JavaDouble) - val nullDoubles = - TestSQLContext.sparkContext.parallelize( - NullDoubles(1.0) :: - NullDoubles(2.0) :: - NullDoubles(3.0) :: - NullDoubles(null) :: Nil - ).toDF() +import org.apache.spark.sql.functions.{log => logarithm} + +private object MathExpressionsTestData { + case class DoubleData(a: java.lang.Double, b: java.lang.Double) + case class NullDoubles(a: java.lang.Double) } class MathExpressionsSuite extends QueryTest { import MathExpressionsTestData._ - def testOneToOneMathFunction[@specialized(Int, Long, Float, Double) T]( + private lazy val ctx = org.apache.spark.sql.test.TestSQLContext + import ctx.implicits._ + + private lazy val doubleData = (1 to 10).map(i => DoubleData(i * 0.2 - 1, i * -0.2 + 1)).toDF() + + private lazy val nnDoubleData = (1 to 10).map(i => DoubleData(i * 0.1, i * -0.1)).toDF() + + private lazy val nullDoubles = + Seq(NullDoubles(1.0), NullDoubles(2.0), NullDoubles(3.0), NullDoubles(null)).toDF() + + private def testOneToOneMathFunction[@specialized(Int, Long, Float, Double) T]( c: Column => Column, f: T => T): Unit = { checkAnswer( @@ -65,7 +58,8 @@ class MathExpressionsSuite extends QueryTest { ) } - def testOneToOneNonNegativeMathFunction(c: Column => Column, f: Double => Double): Unit = { + private def testOneToOneNonNegativeMathFunction(c: Column => Column, f: Double => Double): Unit = + { checkAnswer( nnDoubleData.select(c('a)), (1 to 10).map(n => Row(f(n * 0.1))) @@ -89,7 +83,7 @@ class MathExpressionsSuite extends QueryTest { ) } - def testTwoToOneMathFunction( + private def testTwoToOneMathFunction( c: (Column, Column) => Column, d: (Column, Double) => Column, f: (Double, Double) => Double): Unit = { @@ -157,26 +151,49 @@ class MathExpressionsSuite extends QueryTest { testOneToOneMathFunction(tanh, math.tanh) } - test("toDeg") { + test("toDegrees") { testOneToOneMathFunction(toDegrees, math.toDegrees) + checkAnswer( + ctx.sql("SELECT degrees(0), degrees(1), degrees(1.5)"), + Seq((1, 2)).toDF().select(toDegrees(lit(0)), toDegrees(lit(1)), toDegrees(lit(1.5))) + ) } - test("toRad") { + test("toRadians") { testOneToOneMathFunction(toRadians, math.toRadians) + checkAnswer( + ctx.sql("SELECT radians(0), radians(1), radians(1.5)"), + Seq((1, 2)).toDF().select(toRadians(lit(0)), toRadians(lit(1)), toRadians(lit(1.5))) + ) } test("cbrt") { testOneToOneMathFunction(cbrt, math.cbrt) } - test("ceil") { + test("ceil and ceiling") { testOneToOneMathFunction(ceil, math.ceil) + checkAnswer( + ctx.sql("SELECT ceiling(0), ceiling(1), ceiling(1.5)"), + Row(0.0, 1.0, 2.0)) } test("floor") { testOneToOneMathFunction(floor, math.floor) } + test("factorial") { + val df = (0 to 5).map(i => (i, i)).toDF("a", "b") + checkAnswer( + df.select(factorial('a)), + Seq(Row(1), Row(1), Row(2), Row(6), Row(24), Row(120)) + ) + checkAnswer( + df.selectExpr("factorial(a)"), + Seq(Row(1), Row(1), Row(2), Row(6), Row(24), Row(120)) + ) + } + test("rint") { testOneToOneMathFunction(rint, math.rint) } @@ -189,12 +206,44 @@ class MathExpressionsSuite extends QueryTest { testOneToOneMathFunction(expm1, math.expm1) } - test("signum") { + test("signum / sign") { testOneToOneMathFunction[Double](signum, math.signum) + + checkAnswer( + ctx.sql("SELECT sign(10), signum(-11)"), + Row(1, -1)) } - test("pow") { + test("pow / power") { testTwoToOneMathFunction(pow, pow, math.pow) + + checkAnswer( + ctx.sql("SELECT pow(1, 2), power(2, 1)"), + Seq((1, 2)).toDF().select(pow(lit(1), lit(2)), pow(lit(2), lit(1))) + ) + } + + test("hex") { + val data = Seq((28, -28, 100800200404L, "hello")).toDF("a", "b", "c", "d") + checkAnswer(data.select(hex('a)), Seq(Row("1C"))) + checkAnswer(data.select(hex('b)), Seq(Row("FFFFFFFFFFFFFFE4"))) + checkAnswer(data.select(hex('c)), Seq(Row("177828FED4"))) + checkAnswer(data.select(hex('d)), Seq(Row("68656C6C6F"))) + checkAnswer(data.selectExpr("hex(a)"), Seq(Row("1C"))) + checkAnswer(data.selectExpr("hex(b)"), Seq(Row("FFFFFFFFFFFFFFE4"))) + checkAnswer(data.selectExpr("hex(c)"), Seq(Row("177828FED4"))) + checkAnswer(data.selectExpr("hex(d)"), Seq(Row("68656C6C6F"))) + checkAnswer(data.selectExpr("hex(cast(d as binary))"), Seq(Row("68656C6C6F"))) + } + + test("unhex") { + val data = Seq(("1C", "737472696E67")).toDF("a", "b") + checkAnswer(data.select(unhex('a)), Row(Array[Byte](28.toByte))) + checkAnswer(data.select(unhex('b)), Row("string".getBytes)) + checkAnswer(data.selectExpr("unhex(a)"), Row(Array[Byte](28.toByte))) + checkAnswer(data.selectExpr("unhex(b)"), Row("string".getBytes)) + checkAnswer(data.selectExpr("""unhex("##")"""), Row(null)) + checkAnswer(data.selectExpr("""unhex("G123")"""), Row(null)) } test("hypot") { @@ -205,8 +254,12 @@ class MathExpressionsSuite extends QueryTest { testTwoToOneMathFunction(atan2, atan2, math.atan2) } - test("log") { - testOneToOneNonNegativeMathFunction(log, math.log) + test("log / ln") { + testOneToOneNonNegativeMathFunction(org.apache.spark.sql.functions.log, math.log) + checkAnswer( + ctx.sql("SELECT ln(0), ln(1), ln(1.5)"), + Seq((1, 2)).toDF().select(logarithm(lit(0)), logarithm(lit(1)), logarithm(lit(1.5))) + ) } test("log10") { @@ -217,4 +270,111 @@ class MathExpressionsSuite extends QueryTest { testOneToOneNonNegativeMathFunction(log1p, math.log1p) } + test("shift left") { + val df = Seq[(Long, Integer, Short, Byte, Integer, Integer)]((21, 21, 21, 21, 21, null)) + .toDF("a", "b", "c", "d", "e", "f") + + checkAnswer( + df.select( + shiftLeft('a, 1), shiftLeft('b, 1), shiftLeft('c, 1), shiftLeft('d, 1), + shiftLeft('f, 1)), + Row(42.toLong, 42, 42.toShort, 42.toByte, null)) + + checkAnswer( + df.selectExpr( + "shiftLeft(a, 1)", "shiftLeft(b, 1)", "shiftLeft(b, 1)", "shiftLeft(d, 1)", + "shiftLeft(f, 1)"), + Row(42.toLong, 42, 42.toShort, 42.toByte, null)) + } + + test("shift right") { + val df = Seq[(Long, Integer, Short, Byte, Integer, Integer)]((42, 42, 42, 42, 42, null)) + .toDF("a", "b", "c", "d", "e", "f") + + checkAnswer( + df.select( + shiftRight('a, 1), shiftRight('b, 1), shiftRight('c, 1), shiftRight('d, 1), + shiftRight('f, 1)), + Row(21.toLong, 21, 21.toShort, 21.toByte, null)) + + checkAnswer( + df.selectExpr( + "shiftRight(a, 1)", "shiftRight(b, 1)", "shiftRight(c, 1)", "shiftRight(d, 1)", + "shiftRight(f, 1)"), + Row(21.toLong, 21, 21.toShort, 21.toByte, null)) + } + + test("shift right unsigned") { + val df = Seq[(Long, Integer, Short, Byte, Integer, Integer)]((-42, 42, 42, 42, 42, null)) + .toDF("a", "b", "c", "d", "e", "f") + + checkAnswer( + df.select( + shiftRightUnsigned('a, 1), shiftRightUnsigned('b, 1), shiftRightUnsigned('c, 1), + shiftRightUnsigned('d, 1), shiftRightUnsigned('f, 1)), + Row(9223372036854775787L, 21, 21.toShort, 21.toByte, null)) + + checkAnswer( + df.selectExpr( + "shiftRightUnsigned(a, 1)", "shiftRightUnsigned(b, 1)", "shiftRightUnsigned(c, 1)", + "shiftRightUnsigned(d, 1)", "shiftRightUnsigned(f, 1)"), + Row(9223372036854775787L, 21, 21.toShort, 21.toByte, null)) + } + + test("binary log") { + val df = Seq[(Integer, Integer)]((123, null)).toDF("a", "b") + checkAnswer( + df.select(org.apache.spark.sql.functions.log("a"), + org.apache.spark.sql.functions.log(2.0, "a"), + org.apache.spark.sql.functions.log("b")), + Row(math.log(123), math.log(123) / math.log(2), null)) + + checkAnswer( + df.selectExpr("log(a)", "log(2.0, a)", "log(b)"), + Row(math.log(123), math.log(123) / math.log(2), null)) + } + + test("abs") { + val input = + Seq[(java.lang.Double, java.lang.Double)]((null, null), (0.0, 0.0), (1.5, 1.5), (-2.5, 2.5)) + checkAnswer( + input.toDF("key", "value").select(abs($"key").alias("a")).sort("a"), + input.map(pair => Row(pair._2))) + + checkAnswer( + input.toDF("key", "value").selectExpr("abs(key) a").sort("a"), + input.map(pair => Row(pair._2))) + } + + test("log2") { + val df = Seq((1, 2)).toDF("a", "b") + checkAnswer( + df.select(log2("b") + log2("a")), + Row(1)) + + checkAnswer(ctx.sql("SELECT LOG2(8), LOG2(null)"), Row(3, null)) + } + + test("sqrt") { + val df = Seq((1, 4)).toDF("a", "b") + checkAnswer( + df.select(sqrt("a"), sqrt("b")), + Row(1.0, 2.0)) + + checkAnswer(ctx.sql("SELECT SQRT(4.0), SQRT(null)"), Row(2.0, null)) + checkAnswer(df.selectExpr("sqrt(a)", "sqrt(b)", "sqrt(null)"), Row(1.0, 2.0, null)) + } + + test("negative") { + checkAnswer( + ctx.sql("SELECT negative(1), negative(0), negative(-1)"), + Row(-1, 0, 1)) + } + + test("positive") { + val df = Seq((1, -1, "abc")).toDF("a", "b", "c") + checkAnswer(df.selectExpr("positive(a)"), Row(1)) + checkAnswer(df.selectExpr("positive(b)"), Row(-1)) + checkAnswer(df.selectExpr("positive(c)"), Row("abc")) + } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/QueryTest.scala b/sql/core/src/test/scala/org/apache/spark/sql/QueryTest.scala index bbf9ab113ca4..98ba3c99283a 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/QueryTest.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/QueryTest.scala @@ -67,6 +67,10 @@ class QueryTest extends PlanTest { checkAnswer(df, Seq(expectedAnswer)) } + protected def checkAnswer(df: DataFrame, expectedAnswer: DataFrame): Unit = { + checkAnswer(df, expectedAnswer.collect()) + } + def sqlTest(sqlString: String, expectedAnswer: Seq[Row])(implicit sqlContext: SQLContext) { test(sqlString) { checkAnswer(sqlContext.sql(sqlString), expectedAnswer) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/RowSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/RowSuite.scala index fb3ba4bc1b90..d84b57af9c88 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/RowSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/RowSuite.scala @@ -17,15 +17,16 @@ package org.apache.spark.sql +import org.apache.spark.SparkFunSuite import org.apache.spark.sql.execution.SparkSqlSerializer -import org.scalatest.FunSuite import org.apache.spark.sql.catalyst.expressions.{GenericMutableRow, SpecificMutableRow} -import org.apache.spark.sql.test.TestSQLContext -import org.apache.spark.sql.test.TestSQLContext.implicits._ import org.apache.spark.sql.types._ -class RowSuite extends FunSuite { +class RowSuite extends SparkFunSuite { + + private lazy val ctx = org.apache.spark.sql.test.TestSQLContext + import ctx.implicits._ test("create row") { val expected = new GenericMutableRow(4) @@ -56,7 +57,7 @@ class RowSuite extends FunSuite { test("serialize w/ kryo") { val row = Seq((1, Seq(1), Map(1 -> 1), BigDecimal(1))).toDF().first() - val serializer = new SparkSqlSerializer(TestSQLContext.sparkContext.getConf) + val serializer = new SparkSqlSerializer(ctx.sparkContext.getConf) val instance = serializer.newInstance() val ser = instance.serialize(row) val de = instance.deserialize(ser).asInstanceOf[Row] diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SQLConfEntrySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SQLConfEntrySuite.scala new file mode 100644 index 000000000000..2e33777f14ad --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/SQLConfEntrySuite.scala @@ -0,0 +1,150 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql + +import org.apache.spark.SparkFunSuite +import org.apache.spark.sql.SQLConf._ + +class SQLConfEntrySuite extends SparkFunSuite { + + val conf = new SQLConf + + test("intConf") { + val key = "spark.sql.SQLConfEntrySuite.int" + val confEntry = SQLConfEntry.intConf(key) + assert(conf.getConf(confEntry, 5) === 5) + + conf.setConf(confEntry, 10) + assert(conf.getConf(confEntry, 5) === 10) + + conf.setConfString(key, "20") + assert(conf.getConfString(key, "5") === "20") + assert(conf.getConfString(key) === "20") + assert(conf.getConf(confEntry, 5) === 20) + + val e = intercept[IllegalArgumentException] { + conf.setConfString(key, "abc") + } + assert(e.getMessage === s"$key should be int, but was abc") + } + + test("longConf") { + val key = "spark.sql.SQLConfEntrySuite.long" + val confEntry = SQLConfEntry.longConf(key) + assert(conf.getConf(confEntry, 5L) === 5L) + + conf.setConf(confEntry, 10L) + assert(conf.getConf(confEntry, 5L) === 10L) + + conf.setConfString(key, "20") + assert(conf.getConfString(key, "5") === "20") + assert(conf.getConfString(key) === "20") + assert(conf.getConf(confEntry, 5L) === 20L) + + val e = intercept[IllegalArgumentException] { + conf.setConfString(key, "abc") + } + assert(e.getMessage === s"$key should be long, but was abc") + } + + test("booleanConf") { + val key = "spark.sql.SQLConfEntrySuite.boolean" + val confEntry = SQLConfEntry.booleanConf(key) + assert(conf.getConf(confEntry, false) === false) + + conf.setConf(confEntry, true) + assert(conf.getConf(confEntry, false) === true) + + conf.setConfString(key, "true") + assert(conf.getConfString(key, "false") === "true") + assert(conf.getConfString(key) === "true") + assert(conf.getConf(confEntry, false) === true) + + val e = intercept[IllegalArgumentException] { + conf.setConfString(key, "abc") + } + assert(e.getMessage === s"$key should be boolean, but was abc") + } + + test("doubleConf") { + val key = "spark.sql.SQLConfEntrySuite.double" + val confEntry = SQLConfEntry.doubleConf(key) + assert(conf.getConf(confEntry, 5.0) === 5.0) + + conf.setConf(confEntry, 10.0) + assert(conf.getConf(confEntry, 5.0) === 10.0) + + conf.setConfString(key, "20.0") + assert(conf.getConfString(key, "5.0") === "20.0") + assert(conf.getConfString(key) === "20.0") + assert(conf.getConf(confEntry, 5.0) === 20.0) + + val e = intercept[IllegalArgumentException] { + conf.setConfString(key, "abc") + } + assert(e.getMessage === s"$key should be double, but was abc") + } + + test("stringConf") { + val key = "spark.sql.SQLConfEntrySuite.string" + val confEntry = SQLConfEntry.stringConf(key) + assert(conf.getConf(confEntry, "abc") === "abc") + + conf.setConf(confEntry, "abcd") + assert(conf.getConf(confEntry, "abc") === "abcd") + + conf.setConfString(key, "abcde") + assert(conf.getConfString(key, "abc") === "abcde") + assert(conf.getConfString(key) === "abcde") + assert(conf.getConf(confEntry, "abc") === "abcde") + } + + test("enumConf") { + val key = "spark.sql.SQLConfEntrySuite.enum" + val confEntry = SQLConfEntry.enumConf(key, v => v, Set("a", "b", "c"), defaultValue = Some("a")) + assert(conf.getConf(confEntry) === "a") + + conf.setConf(confEntry, "b") + assert(conf.getConf(confEntry) === "b") + + conf.setConfString(key, "c") + assert(conf.getConfString(key, "a") === "c") + assert(conf.getConfString(key) === "c") + assert(conf.getConf(confEntry) === "c") + + val e = intercept[IllegalArgumentException] { + conf.setConfString(key, "d") + } + assert(e.getMessage === s"The value of $key should be one of a, b, c, but was d") + } + + test("stringSeqConf") { + val key = "spark.sql.SQLConfEntrySuite.stringSeq" + val confEntry = SQLConfEntry.stringSeqConf("spark.sql.SQLConfEntrySuite.stringSeq", + defaultValue = Some(Nil)) + assert(conf.getConf(confEntry, Seq("a", "b", "c")) === Seq("a", "b", "c")) + + conf.setConf(confEntry, Seq("a", "b", "c", "d")) + assert(conf.getConf(confEntry, Seq("a", "b", "c")) === Seq("a", "b", "c", "d")) + + conf.setConfString(key, "a,b,c,d,e") + assert(conf.getConfString(key, "a,b,c") === "a,b,c,d,e") + assert(conf.getConfString(key) === "a,b,c,d,e") + assert(conf.getConf(confEntry, Seq("a", "b", "c")) === Seq("a", "b", "c", "d", "e")) + } +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SQLConfSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SQLConfSuite.scala index bf73d0c7074a..75791e9d53c2 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/SQLConfSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/SQLConfSuite.scala @@ -17,68 +17,72 @@ package org.apache.spark.sql -import org.scalatest.FunSuiteLike -import org.apache.spark.sql.test._ +class SQLConfSuite extends QueryTest { -/* Implicits */ -import TestSQLContext._ + private lazy val ctx = org.apache.spark.sql.test.TestSQLContext -class SQLConfSuite extends QueryTest with FunSuiteLike { - - val testKey = "test.key.0" - val testVal = "test.val.0" + private val testKey = "test.key.0" + private val testVal = "test.val.0" test("propagate from spark conf") { // We create a new context here to avoid order dependence with other tests that might call // clear(). - val newContext = new SQLContext(TestSQLContext.sparkContext) - assert(newContext.getConf("spark.sql.testkey", "false") == "true") + val newContext = new SQLContext(ctx.sparkContext) + assert(newContext.getConf("spark.sql.testkey", "false") === "true") } test("programmatic ways of basic setting and getting") { - conf.clear() - assert(getAllConfs.size === 0) + ctx.conf.clear() + assert(ctx.getAllConfs.size === 0) - setConf(testKey, testVal) - assert(getConf(testKey) == testVal) - assert(getConf(testKey, testVal + "_") == testVal) - assert(getAllConfs.contains(testKey)) + ctx.setConf(testKey, testVal) + assert(ctx.getConf(testKey) === testVal) + assert(ctx.getConf(testKey, testVal + "_") === testVal) + assert(ctx.getAllConfs.contains(testKey)) // Tests SQLConf as accessed from a SQLContext is mutable after // the latter is initialized, unlike SparkConf inside a SparkContext. - assert(TestSQLContext.getConf(testKey) == testVal) - assert(TestSQLContext.getConf(testKey, testVal + "_") == testVal) - assert(TestSQLContext.getAllConfs.contains(testKey)) + assert(ctx.getConf(testKey) == testVal) + assert(ctx.getConf(testKey, testVal + "_") === testVal) + assert(ctx.getAllConfs.contains(testKey)) - conf.clear() + ctx.conf.clear() } test("parse SQL set commands") { - conf.clear() - sql(s"set $testKey=$testVal") - assert(getConf(testKey, testVal + "_") == testVal) - assert(TestSQLContext.getConf(testKey, testVal + "_") == testVal) + ctx.conf.clear() + ctx.sql(s"set $testKey=$testVal") + assert(ctx.getConf(testKey, testVal + "_") === testVal) + assert(ctx.getConf(testKey, testVal + "_") === testVal) - sql("set some.property=20") - assert(getConf("some.property", "0") == "20") - sql("set some.property = 40") - assert(getConf("some.property", "0") == "40") + ctx.sql("set some.property=20") + assert(ctx.getConf("some.property", "0") === "20") + ctx.sql("set some.property = 40") + assert(ctx.getConf("some.property", "0") === "40") val key = "spark.sql.key" val vs = "val0,val_1,val2.3,my_table" - sql(s"set $key=$vs") - assert(getConf(key, "0") == vs) + ctx.sql(s"set $key=$vs") + assert(ctx.getConf(key, "0") === vs) - sql(s"set $key=") - assert(getConf(key, "0") == "") + ctx.sql(s"set $key=") + assert(ctx.getConf(key, "0") === "") - conf.clear() + ctx.conf.clear() } test("deprecated property") { - conf.clear() - sql(s"set ${SQLConf.Deprecated.MAPRED_REDUCE_TASKS}=10") - assert(getConf(SQLConf.SHUFFLE_PARTITIONS) == "10") + ctx.conf.clear() + ctx.sql(s"set ${SQLConf.Deprecated.MAPRED_REDUCE_TASKS}=10") + assert(ctx.conf.numShufflePartitions === 10) + } + + test("invalid conf value") { + ctx.conf.clear() + val e = intercept[IllegalArgumentException] { + ctx.sql(s"set ${SQLConf.CASE_SENSITIVE.key}=10") + } + assert(e.getMessage === s"${SQLConf.CASE_SENSITIVE.key} should be boolean, but was 10") } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SQLContextSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SQLContextSuite.scala new file mode 100644 index 000000000000..c8d8796568a4 --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/SQLContextSuite.scala @@ -0,0 +1,48 @@ +/* +* Licensed to the Apache Software Foundation (ASF) under one or more +* contributor license agreements. See the NOTICE file distributed with +* this work for additional information regarding copyright ownership. +* The ASF licenses this file to You under the Apache License, Version 2.0 +* (the "License"); you may not use this file except in compliance with +* the License. You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*/ + +package org.apache.spark.sql + +import org.scalatest.BeforeAndAfterAll + +import org.apache.spark.SparkFunSuite + +class SQLContextSuite extends SparkFunSuite with BeforeAndAfterAll { + + private lazy val ctx = org.apache.spark.sql.test.TestSQLContext + + override def afterAll(): Unit = { + SQLContext.setLastInstantiatedContext(ctx) + } + + test("getOrCreate instantiates SQLContext") { + SQLContext.clearLastInstantiatedContext() + val sqlContext = SQLContext.getOrCreate(ctx.sparkContext) + assert(sqlContext != null, "SQLContext.getOrCreate returned null") + assert(SQLContext.getOrCreate(ctx.sparkContext).eq(sqlContext), + "SQLContext created by SQLContext.getOrCreate not returned by SQLContext.getOrCreate") + } + + test("getOrCreate gets last explicitly instantiated SQLContext") { + SQLContext.clearLastInstantiatedContext() + val sqlContext = new SQLContext(ctx.sparkContext) + assert(SQLContext.getOrCreate(ctx.sparkContext) != null, + "SQLContext.getOrCreate after explicitly created SQLContext returned null") + assert(SQLContext.getOrCreate(ctx.sparkContext).eq(sqlContext), + "SQLContext.getOrCreate after explicitly created SQLContext did not return the context") + } +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala index 479ad9fe621d..12ad019e8b47 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala @@ -19,28 +19,59 @@ package org.apache.spark.sql import org.scalatest.BeforeAndAfterAll +import java.sql.Timestamp + import org.apache.spark.sql.catalyst.DefaultParserDialect import org.apache.spark.sql.catalyst.errors.DialectException import org.apache.spark.sql.execution.GeneratedAggregate import org.apache.spark.sql.functions._ import org.apache.spark.sql.TestData._ -import org.apache.spark.sql.test.TestSQLContext -import org.apache.spark.sql.test.TestSQLContext.{udf => _, _} - +import org.apache.spark.sql.test.SQLTestUtils import org.apache.spark.sql.types._ /** A SQL Dialect for testing purpose, and it can not be nested type */ class MyDialect extends DefaultParserDialect -class SQLQuerySuite extends QueryTest with BeforeAndAfterAll { +class SQLQuerySuite extends QueryTest with BeforeAndAfterAll with SQLTestUtils { // Make sure the tables are loaded. TestData - import org.apache.spark.sql.test.TestSQLContext.implicits._ - val sqlCtx = TestSQLContext + val sqlContext = org.apache.spark.sql.test.TestSQLContext + import sqlContext.implicits._ + import sqlContext.sql + + test("having clause") { + Seq(("one", 1), ("two", 2), ("three", 3), ("one", 5)).toDF("k", "v").registerTempTable("hav") + checkAnswer( + sql("SELECT k, sum(v) FROM hav GROUP BY k HAVING sum(v) > 2"), + Row("one", 6) :: Row("three", 3) :: Nil) + } + + test("SPARK-8010: promote numeric to string") { + val df = Seq((1, 1)).toDF("key", "value") + df.registerTempTable("src") + val queryCaseWhen = sql("select case when true then 1.0 else '1' end from src ") + val queryCoalesce = sql("select coalesce(null, 1, '1') from src ") + + checkAnswer(queryCaseWhen, Row("1.0") :: Nil) + checkAnswer(queryCoalesce, Row("1") :: Nil) + } + + test("SPARK-6743: no columns from cache") { + Seq( + (83, 0, 38), + (26, 0, 79), + (43, 81, 24) + ).toDF("a", "b", "c").registerTempTable("cachedData") + + sqlContext.cacheTable("cachedData") + checkAnswer( + sql("SELECT t1.b FROM cachedData, cachedData t1 GROUP BY t1.b"), + Row(0) :: Row(81) :: Nil) + } test("self join with aliases") { - Seq(1,2,3).map(i => (i, i.toString)).toDF("int", "str").registerTempTable("df") + Seq(1, 2, 3).map(i => (i, i.toString)).toDF("int", "str").registerTempTable("df") checkAnswer( sql( @@ -63,7 +94,7 @@ class SQLQuerySuite extends QueryTest with BeforeAndAfterAll { } test("self join with alias in agg") { - Seq(1,2,3) + Seq(1, 2, 3) .map(i => (i, i.toString)) .toDF("int", "str") .groupBy("str") @@ -81,14 +112,14 @@ class SQLQuerySuite extends QueryTest with BeforeAndAfterAll { } test("SQL Dialect Switching to a new SQL parser") { - val newContext = new SQLContext(TestSQLContext.sparkContext) + val newContext = new SQLContext(sqlContext.sparkContext) newContext.setConf("spark.sql.dialect", classOf[MyDialect].getCanonicalName()) assert(newContext.getSQLDialect().getClass === classOf[MyDialect]) assert(newContext.sql("SELECT 1").collect() === Array(Row(1))) } test("SQL Dialect Switch to an invalid parser with alias") { - val newContext = new SQLContext(TestSQLContext.sparkContext) + val newContext = new SQLContext(sqlContext.sparkContext) newContext.sql("SET spark.sql.dialect=MyTestClass") intercept[DialectException] { newContext.sql("SELECT 1") @@ -100,12 +131,38 @@ class SQLQuerySuite extends QueryTest with BeforeAndAfterAll { test("SPARK-4625 support SORT BY in SimpleSQLParser & DSL") { checkAnswer( sql("SELECT a FROM testData2 SORT BY a"), - Seq(1, 1, 2 ,2 ,3 ,3).map(Row(_)) + Seq(1, 1, 2, 2, 3, 3).map(Row(_)) ) } + test("SPARK-7158 collect and take return different results") { + import java.util.UUID + + val df = Seq(Tuple1(1), Tuple1(2), Tuple1(3)).toDF("index") + // we except the id is materialized once + val idUDF = udf(() => UUID.randomUUID().toString) + + val dfWithId = df.withColumn("id", idUDF()) + // Make a new DataFrame (actually the same reference to the old one) + val cached = dfWithId.cache() + // Trigger the cache + val d0 = dfWithId.collect() + val d1 = cached.collect() + val d2 = cached.collect() + + // Since the ID is only materialized once, then all of the records + // should come from the cache, not by re-computing. Otherwise, the ID + // will be different + assert(d0.map(_(0)) === d2.map(_(0))) + assert(d0.map(_(1)) === d2.map(_(1))) + + assert(d1.map(_(0)) === d2.map(_(0))) + assert(d1.map(_(1)) === d2.map(_(1))) + } + test("grouping on nested fields") { - jsonRDD(sparkContext.parallelize("""{"nested": {"attribute": 1}, "value": 2}""" :: Nil)) + sqlContext.read.json(sqlContext.sparkContext.parallelize( + """{"nested": {"attribute": 1}, "value": 2}""" :: Nil)) .registerTempTable("rows") checkAnswer( @@ -122,7 +179,9 @@ class SQLQuerySuite extends QueryTest with BeforeAndAfterAll { } test("SPARK-6201 IN type conversion") { - jsonRDD(sparkContext.parallelize(Seq("{\"a\": \"1\"}}", "{\"a\": \"2\"}}", "{\"a\": \"3\"}}"))) + sqlContext.read.json( + sqlContext.sparkContext.parallelize( + Seq("{\"a\": \"1\"}}", "{\"a\": \"2\"}}", "{\"a\": \"3\"}}"))) .registerTempTable("d") checkAnswer( @@ -130,25 +189,13 @@ class SQLQuerySuite extends QueryTest with BeforeAndAfterAll { Seq(Row("1"), Row("2"))) } - test("SPARK-3176 Added Parser of SQL ABS()") { - checkAnswer( - sql("SELECT ABS(-1.3)"), - Row(1.3)) - checkAnswer( - sql("SELECT ABS(0.0)"), - Row(0.0)) - checkAnswer( - sql("SELECT ABS(2.5)"), - Row(2.5)) - } - test("aggregation with codegen") { - val originalValue = conf.codegenEnabled - setConf(SQLConf.CODEGEN_ENABLED, "true") + val originalValue = sqlContext.conf.codegenEnabled + sqlContext.setConf(SQLConf.CODEGEN_ENABLED, true) // Prepare a table that we can group some rows. - table("testData") - .unionAll(table("testData")) - .unionAll(table("testData")) + sqlContext.table("testData") + .unionAll(sqlContext.table("testData")) + .unionAll(sqlContext.table("testData")) .registerTempTable("testData3x") def testCodeGen(sqlText: String, expectedResults: Seq[Row]): Unit = { @@ -170,77 +217,79 @@ class SQLQuerySuite extends QueryTest with BeforeAndAfterAll { checkAnswer(df, expectedResults) } - // Just to group rows. - testCodeGen( - "SELECT key FROM testData3x GROUP BY key", - (1 to 100).map(Row(_))) - // COUNT - testCodeGen( - "SELECT key, count(value) FROM testData3x GROUP BY key", - (1 to 100).map(i => Row(i, 3))) - testCodeGen( - "SELECT count(key) FROM testData3x", - Row(300) :: Nil) - // COUNT DISTINCT ON int - testCodeGen( - "SELECT value, count(distinct key) FROM testData3x GROUP BY value", - (1 to 100).map(i => Row(i.toString, 1))) - testCodeGen( - "SELECT count(distinct key) FROM testData3x", - Row(100) :: Nil) - // SUM - testCodeGen( - "SELECT value, sum(key) FROM testData3x GROUP BY value", - (1 to 100).map(i => Row(i.toString, 3 * i))) - testCodeGen( - "SELECT sum(key), SUM(CAST(key as Double)) FROM testData3x", - Row(5050 * 3, 5050 * 3.0) :: Nil) - // AVERAGE - testCodeGen( - "SELECT value, avg(key) FROM testData3x GROUP BY value", - (1 to 100).map(i => Row(i.toString, i))) - testCodeGen( - "SELECT avg(key) FROM testData3x", - Row(50.5) :: Nil) - // MAX - testCodeGen( - "SELECT value, max(key) FROM testData3x GROUP BY value", - (1 to 100).map(i => Row(i.toString, i))) - testCodeGen( - "SELECT max(key) FROM testData3x", - Row(100) :: Nil) - // MIN - testCodeGen( - "SELECT value, min(key) FROM testData3x GROUP BY value", - (1 to 100).map(i => Row(i.toString, i))) - testCodeGen( - "SELECT min(key) FROM testData3x", - Row(1) :: Nil) - // Some combinations. - testCodeGen( - """ - |SELECT - | value, - | sum(key), - | max(key), - | min(key), - | avg(key), - | count(key), - | count(distinct key) - |FROM testData3x - |GROUP BY value - """.stripMargin, - (1 to 100).map(i => Row(i.toString, i*3, i, i, i, 3, 1))) - testCodeGen( - "SELECT max(key), min(key), avg(key), count(key), count(distinct key) FROM testData3x", - Row(100, 1, 50.5, 300, 100) :: Nil) - // Aggregate with Code generation handling all null values - testCodeGen( - "SELECT sum('a'), avg('a'), count(null) FROM testData", - Row(0, null, 0) :: Nil) - - dropTempTable("testData3x") - setConf(SQLConf.CODEGEN_ENABLED, originalValue.toString) + try { + // Just to group rows. + testCodeGen( + "SELECT key FROM testData3x GROUP BY key", + (1 to 100).map(Row(_))) + // COUNT + testCodeGen( + "SELECT key, count(value) FROM testData3x GROUP BY key", + (1 to 100).map(i => Row(i, 3))) + testCodeGen( + "SELECT count(key) FROM testData3x", + Row(300) :: Nil) + // COUNT DISTINCT ON int + testCodeGen( + "SELECT value, count(distinct key) FROM testData3x GROUP BY value", + (1 to 100).map(i => Row(i.toString, 1))) + testCodeGen( + "SELECT count(distinct key) FROM testData3x", + Row(100) :: Nil) + // SUM + testCodeGen( + "SELECT value, sum(key) FROM testData3x GROUP BY value", + (1 to 100).map(i => Row(i.toString, 3 * i))) + testCodeGen( + "SELECT sum(key), SUM(CAST(key as Double)) FROM testData3x", + Row(5050 * 3, 5050 * 3.0) :: Nil) + // AVERAGE + testCodeGen( + "SELECT value, avg(key) FROM testData3x GROUP BY value", + (1 to 100).map(i => Row(i.toString, i))) + testCodeGen( + "SELECT avg(key) FROM testData3x", + Row(50.5) :: Nil) + // MAX + testCodeGen( + "SELECT value, max(key) FROM testData3x GROUP BY value", + (1 to 100).map(i => Row(i.toString, i))) + testCodeGen( + "SELECT max(key) FROM testData3x", + Row(100) :: Nil) + // MIN + testCodeGen( + "SELECT value, min(key) FROM testData3x GROUP BY value", + (1 to 100).map(i => Row(i.toString, i))) + testCodeGen( + "SELECT min(key) FROM testData3x", + Row(1) :: Nil) + // Some combinations. + testCodeGen( + """ + |SELECT + | value, + | sum(key), + | max(key), + | min(key), + | avg(key), + | count(key), + | count(distinct key) + |FROM testData3x + |GROUP BY value + """.stripMargin, + (1 to 100).map(i => Row(i.toString, i*3, i, i, i, 3, 1))) + testCodeGen( + "SELECT max(key), min(key), avg(key), count(key), count(distinct key) FROM testData3x", + Row(100, 1, 50.5, 300, 100) :: Nil) + // Aggregate with Code generation handling all null values + testCodeGen( + "SELECT sum('a'), avg('a'), count(null) FROM testData", + Row(0, null, 0) :: Nil) + } finally { + sqlContext.dropTempTable("testData3x") + sqlContext.setConf(SQLConf.CODEGEN_ENABLED, originalValue) + } } test("Add Parser of SQL COALESCE()") { @@ -297,6 +346,8 @@ class SQLQuerySuite extends QueryTest with BeforeAndAfterAll { } test("SPARK-3173 Timestamp support in the parser") { + (0 to 3).map(i => Tuple1(new Timestamp(i))).toDF("time").registerTempTable("timestamps") + checkAnswer(sql( "SELECT time FROM timestamps WHERE time='1969-12-31 16:00:00.0'"), Row(java.sql.Timestamp.valueOf("1969-12-31 16:00:00"))) @@ -340,7 +391,7 @@ class SQLQuerySuite extends QueryTest with BeforeAndAfterAll { test("left semi greater than predicate") { checkAnswer( sql("SELECT * FROM testData2 x LEFT SEMI JOIN testData2 y ON x.a >= y.a + 2"), - Seq(Row(3,1), Row(3,2)) + Seq(Row(3, 1), Row(3, 2)) ) } @@ -357,16 +408,16 @@ class SQLQuerySuite extends QueryTest with BeforeAndAfterAll { test("agg") { checkAnswer( sql("SELECT a, SUM(b) FROM testData2 GROUP BY a"), - Seq(Row(1,3), Row(2,3), Row(3,3))) + Seq(Row(1, 3), Row(2, 3), Row(3, 3))) } test("literal in agg grouping expressions") { checkAnswer( sql("SELECT a, count(1) FROM testData2 GROUP BY a, 1"), - Seq(Row(1,2), Row(2,2), Row(3,2))) + Seq(Row(1, 2), Row(2, 2), Row(3, 2))) checkAnswer( sql("SELECT a, count(2) FROM testData2 GROUP BY a, 2"), - Seq(Row(1,2), Row(2,2), Row(3,2))) + Seq(Row(1, 2), Row(2, 2), Row(3, 2))) } test("aggregates with nulls") { @@ -391,19 +442,19 @@ class SQLQuerySuite extends QueryTest with BeforeAndAfterAll { def sortTest(): Unit = { checkAnswer( sql("SELECT * FROM testData2 ORDER BY a ASC, b ASC"), - Seq(Row(1,1), Row(1,2), Row(2,1), Row(2,2), Row(3,1), Row(3,2))) + Seq(Row(1, 1), Row(1, 2), Row(2, 1), Row(2, 2), Row(3, 1), Row(3, 2))) checkAnswer( sql("SELECT * FROM testData2 ORDER BY a ASC, b DESC"), - Seq(Row(1,2), Row(1,1), Row(2,2), Row(2,1), Row(3,2), Row(3,1))) + Seq(Row(1, 2), Row(1, 1), Row(2, 2), Row(2, 1), Row(3, 2), Row(3, 1))) checkAnswer( sql("SELECT * FROM testData2 ORDER BY a DESC, b DESC"), - Seq(Row(3,2), Row(3,1), Row(2,2), Row(2,1), Row(1,2), Row(1,1))) + Seq(Row(3, 2), Row(3, 1), Row(2, 2), Row(2, 1), Row(1, 2), Row(1, 1))) checkAnswer( sql("SELECT * FROM testData2 ORDER BY a DESC, b ASC"), - Seq(Row(3,1), Row(3,2), Row(2,1), Row(2,2), Row(1,1), Row(1,2))) + Seq(Row(3, 1), Row(3, 2), Row(2, 1), Row(2, 2), Row(1, 1), Row(1, 2))) checkAnswer( sql("SELECT b FROM binaryData ORDER BY a ASC"), @@ -431,37 +482,43 @@ class SQLQuerySuite extends QueryTest with BeforeAndAfterAll { } test("sorting") { - val before = conf.externalSortEnabled - setConf(SQLConf.EXTERNAL_SORT, "false") + val before = sqlContext.conf.externalSortEnabled + sqlContext.setConf(SQLConf.EXTERNAL_SORT, false) sortTest() - setConf(SQLConf.EXTERNAL_SORT, before.toString) + sqlContext.setConf(SQLConf.EXTERNAL_SORT, before) } test("external sorting") { - val before = conf.externalSortEnabled - setConf(SQLConf.EXTERNAL_SORT, "true") + val before = sqlContext.conf.externalSortEnabled + sqlContext.setConf(SQLConf.EXTERNAL_SORT, true) sortTest() - setConf(SQLConf.EXTERNAL_SORT, before.toString) + sqlContext.setConf(SQLConf.EXTERNAL_SORT, before) } test("SPARK-6927 sorting with codegen on") { - val externalbefore = conf.externalSortEnabled - val codegenbefore = conf.codegenEnabled - setConf(SQLConf.EXTERNAL_SORT, "false") - setConf(SQLConf.CODEGEN_ENABLED, "true") - sortTest() - setConf(SQLConf.EXTERNAL_SORT, externalbefore.toString) - setConf(SQLConf.CODEGEN_ENABLED, codegenbefore.toString) + val externalbefore = sqlContext.conf.externalSortEnabled + val codegenbefore = sqlContext.conf.codegenEnabled + sqlContext.setConf(SQLConf.EXTERNAL_SORT, false) + sqlContext.setConf(SQLConf.CODEGEN_ENABLED, true) + try{ + sortTest() + } finally { + sqlContext.setConf(SQLConf.EXTERNAL_SORT, externalbefore) + sqlContext.setConf(SQLConf.CODEGEN_ENABLED, codegenbefore) + } } test("SPARK-6927 external sorting with codegen on") { - val externalbefore = conf.externalSortEnabled - val codegenbefore = conf.codegenEnabled - setConf(SQLConf.CODEGEN_ENABLED, "true") - setConf(SQLConf.EXTERNAL_SORT, "true") - sortTest() - setConf(SQLConf.EXTERNAL_SORT, externalbefore.toString) - setConf(SQLConf.CODEGEN_ENABLED, codegenbefore.toString) + val externalbefore = sqlContext.conf.externalSortEnabled + val codegenbefore = sqlContext.conf.codegenEnabled + sqlContext.setConf(SQLConf.CODEGEN_ENABLED, true) + sqlContext.setConf(SQLConf.EXTERNAL_SORT, true) + try { + sortTest() + } finally { + sqlContext.setConf(SQLConf.EXTERNAL_SORT, externalbefore) + sqlContext.setConf(SQLConf.CODEGEN_ENABLED, codegenbefore) + } } test("limit") { @@ -494,7 +551,8 @@ class SQLQuerySuite extends QueryTest with BeforeAndAfterAll { test("Allow only a single WITH clause per query") { intercept[RuntimeException] { - sql("with q1 as (select * from testData) with q2 as (select * from q1) select * from q2") + sql( + "with q1 as (select * from testData) with q2 as (select * from q1) select * from q2") } } @@ -538,7 +596,7 @@ class SQLQuerySuite extends QueryTest with BeforeAndAfterAll { test("average overflow") { checkAnswer( sql("SELECT AVG(a),b FROM largeAndSmallInts group by b"), - Seq(Row(2147483645.0,1), Row(2.0,2))) + Seq(Row(2147483645.0, 1), Row(2.0, 2))) } test("count") { @@ -605,10 +663,10 @@ class SQLQuerySuite extends QueryTest with BeforeAndAfterAll { | (SELECT * FROM testData2 WHERE a = 1) x JOIN | (SELECT * FROM testData2 WHERE a = 1) y |WHERE x.a = y.a""".stripMargin), - Row(1,1,1,1) :: - Row(1,1,1,2) :: - Row(1,2,1,1) :: - Row(1,2,1,2) :: Nil) + Row(1, 1, 1, 1) :: + Row(1, 1, 1, 2) :: + Row(1, 2, 1, 1) :: + Row(1, 2, 1, 2) :: Nil) } test("inner join, no matches") { @@ -640,7 +698,7 @@ class SQLQuerySuite extends QueryTest with BeforeAndAfterAll { row => Seq.fill(16)(Row.merge(row, row))).collect().toSeq) } - ignore("cartesian product join") { + test("cartesian product join") { checkAnswer( testData3.join(testData3), Row(1, null, 1, null) :: @@ -841,7 +899,7 @@ class SQLQuerySuite extends QueryTest with BeforeAndAfterAll { } test("SET commands semantics using sql()") { - conf.clear() + sqlContext.conf.clear() val testKey = "test.key.0" val testVal = "test.val.0" val nonexistentKey = "nonexistent" @@ -853,37 +911,37 @@ class SQLQuerySuite extends QueryTest with BeforeAndAfterAll { sql(s"SET $testKey=$testVal") checkAnswer( sql("SET"), - Row(s"$testKey=$testVal") + Row(testKey, testVal) ) sql(s"SET ${testKey + testKey}=${testVal + testVal}") checkAnswer( sql("set"), Seq( - Row(s"$testKey=$testVal"), - Row(s"${testKey + testKey}=${testVal + testVal}")) + Row(testKey, testVal), + Row(testKey + testKey, testVal + testVal)) ) // "set key" checkAnswer( sql(s"SET $testKey"), - Row(s"$testKey=$testVal") + Row(testKey, testVal) ) checkAnswer( sql(s"SET $nonexistentKey"), - Row(s"$nonexistentKey=") + Row(nonexistentKey, "") ) - conf.clear() + sqlContext.conf.clear() } test("SET commands with illegal or inappropriate argument") { - conf.clear() + sqlContext.conf.clear() // Set negative mapred.reduce.tasks for automatically determing // the number of reducers is not supported intercept[IllegalArgumentException](sql(s"SET mapred.reduce.tasks=-1")) intercept[IllegalArgumentException](sql(s"SET mapred.reduce.tasks=-01")) intercept[IllegalArgumentException](sql(s"SET mapred.reduce.tasks=-2")) - conf.clear() + sqlContext.conf.clear() } test("apply schema") { @@ -901,7 +959,7 @@ class SQLQuerySuite extends QueryTest with BeforeAndAfterAll { Row(values(0).toInt, values(1), values(2).toBoolean, v4) } - val df1 = sqlCtx.createDataFrame(rowRDD1, schema1) + val df1 = sqlContext.createDataFrame(rowRDD1, schema1) df1.registerTempTable("applySchema1") checkAnswer( sql("SELECT * FROM applySchema1"), @@ -931,7 +989,7 @@ class SQLQuerySuite extends QueryTest with BeforeAndAfterAll { Row(Row(values(0).toInt, values(2).toBoolean), Map(values(1) -> v4)) } - val df2 = sqlCtx.createDataFrame(rowRDD2, schema2) + val df2 = sqlContext.createDataFrame(rowRDD2, schema2) df2.registerTempTable("applySchema2") checkAnswer( sql("SELECT * FROM applySchema2"), @@ -956,7 +1014,7 @@ class SQLQuerySuite extends QueryTest with BeforeAndAfterAll { Row(Row(values(0).toInt, values(2).toBoolean), scala.collection.mutable.Map(values(1) -> v4)) } - val df3 = sqlCtx.createDataFrame(rowRDD3, schema2) + val df3 = sqlContext.createDataFrame(rowRDD3, schema2) df3.registerTempTable("applySchema3") checkAnswer( @@ -1001,7 +1059,7 @@ class SQLQuerySuite extends QueryTest with BeforeAndAfterAll { .build() val schemaWithMeta = new StructType(Array( schema("id"), schema("name").copy(metadata = metadata), schema("age"))) - val personWithMeta = sqlCtx.createDataFrame(person.rdd, schemaWithMeta) + val personWithMeta = sqlContext.createDataFrame(person.rdd, schemaWithMeta) def validateMetadata(rdd: DataFrame): Unit = { assert(rdd.schema("name").metadata.getString(docKey) == docValue) } @@ -1016,7 +1074,7 @@ class SQLQuerySuite extends QueryTest with BeforeAndAfterAll { } test("SPARK-3371 Renaming a function expression with group by gives error") { - TestSQLContext.udf.register("len", (s: String) => s.length) + sqlContext.udf.register("len", (s: String) => s.length) checkAnswer( sql("SELECT len(value) as temp FROM testData WHERE key = 1 group by len(value)"), Row(1)) @@ -1197,9 +1255,9 @@ class SQLQuerySuite extends QueryTest with BeforeAndAfterAll { } test("SPARK-3483 Special chars in column names") { - val data = sparkContext.parallelize( + val data = sqlContext.sparkContext.parallelize( Seq("""{"key?number1": "value1", "key.number2": "value2"}""")) - jsonRDD(data).registerTempTable("records") + sqlContext.read.json(data).registerTempTable("records") sql("SELECT `key?number1`, `key.number2` FROM records") } @@ -1240,35 +1298,37 @@ class SQLQuerySuite extends QueryTest with BeforeAndAfterAll { } test("SPARK-4322 Grouping field with struct field as sub expression") { - jsonRDD(sparkContext.makeRDD("""{"a": {"b": [{"c": 1}]}}""" :: Nil)).registerTempTable("data") + sqlContext.read.json(sqlContext.sparkContext.makeRDD("""{"a": {"b": [{"c": 1}]}}""" :: Nil)) + .registerTempTable("data") checkAnswer(sql("SELECT a.b[0].c FROM data GROUP BY a.b[0].c"), Row(1)) - dropTempTable("data") + sqlContext.dropTempTable("data") - jsonRDD(sparkContext.makeRDD("""{"a": {"b": 1}}""" :: Nil)).registerTempTable("data") + sqlContext.read.json( + sqlContext.sparkContext.makeRDD("""{"a": {"b": 1}}""" :: Nil)).registerTempTable("data") checkAnswer(sql("SELECT a.b + 1 FROM data GROUP BY a.b + 1"), Row(2)) - dropTempTable("data") + sqlContext.dropTempTable("data") } test("SPARK-4432 Fix attribute reference resolution error when using ORDER BY") { checkAnswer( sql("SELECT a + b FROM testData2 ORDER BY a"), - Seq(2, 3, 3 ,4 ,4 ,5).map(Row(_)) + Seq(2, 3, 3, 4, 4, 5).map(Row(_)) ) } test("oder by asc by default when not specify ascending and descending") { checkAnswer( sql("SELECT a, b FROM testData2 ORDER BY a desc, b"), - Seq(Row(3, 1), Row(3, 2), Row(2, 1), Row(2,2), Row(1, 1), Row(1, 2)) + Seq(Row(3, 1), Row(3, 2), Row(2, 1), Row(2, 2), Row(1, 1), Row(1, 2)) ) } test("Supporting relational operator '<=>' in Spark SQL") { - val nullCheckData1 = TestData(1,"1") :: TestData(2,null) :: Nil - val rdd1 = sparkContext.parallelize((0 to 1).map(i => nullCheckData1(i))) + val nullCheckData1 = TestData(1, "1") :: TestData(2, null) :: Nil + val rdd1 = sqlContext.sparkContext.parallelize((0 to 1).map(i => nullCheckData1(i))) rdd1.toDF().registerTempTable("nulldata1") - val nullCheckData2 = TestData(1,"1") :: TestData(2,null) :: Nil - val rdd2 = sparkContext.parallelize((0 to 1).map(i => nullCheckData2(i))) + val nullCheckData2 = TestData(1, "1") :: TestData(2, null) :: Nil + val rdd2 = sqlContext.sparkContext.parallelize((0 to 1).map(i => nullCheckData2(i))) rdd2.toDF().registerTempTable("nulldata2") checkAnswer(sql("SELECT nulldata1.key FROM nulldata1 join " + "nulldata2 on nulldata1.value <=> nulldata2.value"), @@ -1276,23 +1336,24 @@ class SQLQuerySuite extends QueryTest with BeforeAndAfterAll { } test("Multi-column COUNT(DISTINCT ...)") { - val data = TestData(1,"val_1") :: TestData(2,"val_2") :: Nil - val rdd = sparkContext.parallelize((0 to 1).map(i => data(i))) + val data = TestData(1, "val_1") :: TestData(2, "val_2") :: Nil + val rdd = sqlContext.sparkContext.parallelize((0 to 1).map(i => data(i))) rdd.toDF().registerTempTable("distinctData") checkAnswer(sql("SELECT COUNT(DISTINCT key,value) FROM distinctData"), Row(2)) } test("SPARK-4699 case sensitivity SQL query") { - setConf(SQLConf.CASE_SENSITIVE, "false") + sqlContext.setConf(SQLConf.CASE_SENSITIVE, false) val data = TestData(1, "val_1") :: TestData(2, "val_2") :: Nil - val rdd = sparkContext.parallelize((0 to 1).map(i => data(i))) + val rdd = sqlContext.sparkContext.parallelize((0 to 1).map(i => data(i))) rdd.toDF().registerTempTable("testTable1") checkAnswer(sql("SELECT VALUE FROM TESTTABLE1 where KEY = 1"), Row("val_1")) - setConf(SQLConf.CASE_SENSITIVE, "true") + sqlContext.setConf(SQLConf.CASE_SENSITIVE, true) } test("SPARK-6145: ORDER BY test for nested fields") { - jsonRDD(sparkContext.makeRDD("""{"a": {"b": 1, "a": {"a": 1}}, "c": [{"d": 1}]}""" :: Nil)) + sqlContext.read.json(sqlContext.sparkContext.makeRDD( + """{"a": {"b": 1, "a": {"a": 1}}, "c": [{"d": 1}]}""" :: Nil)) .registerTempTable("nestedOrder") checkAnswer(sql("SELECT 1 FROM nestedOrder ORDER BY a.b"), Row(1)) @@ -1304,17 +1365,106 @@ class SQLQuerySuite extends QueryTest with BeforeAndAfterAll { } test("SPARK-6145: special cases") { - jsonRDD(sparkContext.makeRDD( - """{"a": {"b": [1]}, "b": [{"a": 1}], "c0": {"a": 1}}""" :: Nil)).registerTempTable("t") - checkAnswer(sql("SELECT a.b[0] FROM t ORDER BY c0.a"), Row(1)) - checkAnswer(sql("SELECT b[0].a FROM t ORDER BY c0.a"), Row(1)) + sqlContext.read.json(sqlContext.sparkContext.makeRDD( + """{"a": {"b": [1]}, "b": [{"a": 1}], "_c0": {"a": 1}}""" :: Nil)).registerTempTable("t") + checkAnswer(sql("SELECT a.b[0] FROM t ORDER BY _c0.a"), Row(1)) + checkAnswer(sql("SELECT b[0].a FROM t ORDER BY _c0.a"), Row(1)) } test("SPARK-6898: complete support for special chars in column names") { - jsonRDD(sparkContext.makeRDD( + sqlContext.read.json(sqlContext.sparkContext.makeRDD( """{"a": {"c.b": 1}, "b.$q": [{"a@!.q": 1}], "q.w": {"w.i&": [1]}}""" :: Nil)) .registerTempTable("t") checkAnswer(sql("SELECT a.`c.b`, `b.$q`[0].`a@!.q`, `q.w`.`w.i&`[0] FROM t"), Row(1, 1, 1)) } + + test("SPARK-6583 order by aggregated function") { + Seq("1" -> 3, "1" -> 4, "2" -> 7, "2" -> 8, "3" -> 5, "3" -> 6, "4" -> 1, "4" -> 2) + .toDF("a", "b").registerTempTable("orderByData") + + checkAnswer( + sql( + """ + |SELECT a + |FROM orderByData + |GROUP BY a + |ORDER BY sum(b) + """.stripMargin), + Row("4") :: Row("1") :: Row("3") :: Row("2") :: Nil) + + checkAnswer( + sql( + """ + |SELECT sum(b) + |FROM orderByData + |GROUP BY a + |ORDER BY sum(b) + """.stripMargin), + Row(3) :: Row(7) :: Row(11) :: Row(15) :: Nil) + + checkAnswer( + sql( + """ + |SELECT a, sum(b) + |FROM orderByData + |GROUP BY a + |ORDER BY sum(b) + """.stripMargin), + Row("4", 3) :: Row("1", 7) :: Row("3", 11) :: Row("2", 15) :: Nil) + + checkAnswer( + sql( + """ + |SELECT a, sum(b) + |FROM orderByData + |GROUP BY a + |ORDER BY sum(b) + 1 + """.stripMargin), + Row("4", 3) :: Row("1", 7) :: Row("3", 11) :: Row("2", 15) :: Nil) + } + + test("SPARK-7952: fix the equality check between boolean and numeric types") { + withTempTable("t") { + // numeric field i, boolean field j, result of i = j, result of i <=> j + Seq[(Integer, java.lang.Boolean, java.lang.Boolean, java.lang.Boolean)]( + (1, true, true, true), + (0, false, true, true), + (2, true, false, false), + (2, false, false, false), + (null, true, null, false), + (null, false, null, false), + (0, null, null, false), + (1, null, null, false), + (null, null, null, true) + ).toDF("i", "b", "r1", "r2").registerTempTable("t") + + checkAnswer(sql("select i = b from t"), sql("select r1 from t")) + checkAnswer(sql("select i <=> b from t"), sql("select r2 from t")) + } + } + + test("SPARK-7067: order by queries for complex ExtractValue chain") { + withTempTable("t") { + sqlContext.read.json(sqlContext.sparkContext.makeRDD( + """{"a": {"b": [{"c": 1}]}, "b": [{"d": 1}]}""" :: Nil)).registerTempTable("t") + checkAnswer(sql("SELECT a.b FROM t ORDER BY b[0].d"), Row(Seq(Row(1)))) + } + } + + test("SPARK-8782: ORDER BY NULL") { + withTempTable("t") { + Seq((1, 2), (1, 2)).toDF("a", "b").registerTempTable("t") + checkAnswer(sql("SELECT * FROM t ORDER BY NULL"), Seq(Row(1, 2), Row(1, 2))) + } + } + + test("SPARK-8837: use keyword in column name") { + withTempTable("t") { + val df = Seq(1 -> "a").toDF("count", "sort") + checkAnswer(df.filter("count > 0"), Row(1, "a")) + df.registerTempTable("t") + checkAnswer(sql("select count, sort from t"), Row(1, "a")) + } + } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/ScalaReflectionRelationSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/ScalaReflectionRelationSuite.scala index 3fa00fd9d0cc..ab6d3dd96d27 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/ScalaReflectionRelationSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/ScalaReflectionRelationSuite.scala @@ -19,10 +19,7 @@ package org.apache.spark.sql import java.sql.{Date, Timestamp} -import org.scalatest.FunSuite - -import org.apache.spark.sql.catalyst.expressions._ -import org.apache.spark.sql.test.TestSQLContext._ +import org.apache.spark.SparkFunSuite case class ReflectData( stringField: String, @@ -74,45 +71,44 @@ case class ComplexReflectData( mapFieldContainsNull: Map[Int, Option[Long]], dataField: Data) -class ScalaReflectionRelationSuite extends FunSuite { +class ScalaReflectionRelationSuite extends SparkFunSuite { - import org.apache.spark.sql.test.TestSQLContext.implicits._ + private lazy val ctx = org.apache.spark.sql.test.TestSQLContext + import ctx.implicits._ test("query case class RDD") { val data = ReflectData("a", 1, 1L, 1.toFloat, 1.toDouble, 1.toShort, 1.toByte, true, - new java.math.BigDecimal(1), new Date(12345), new Timestamp(12345), Seq(1,2,3)) - val rdd = sparkContext.parallelize(data :: Nil) - rdd.toDF().registerTempTable("reflectData") + new java.math.BigDecimal(1), Date.valueOf("1970-01-01"), new Timestamp(12345), Seq(1, 2, 3)) + Seq(data).toDF().registerTempTable("reflectData") - assert(sql("SELECT * FROM reflectData").collect().head === + assert(ctx.sql("SELECT * FROM reflectData").collect().head === Row("a", 1, 1L, 1.toFloat, 1.toDouble, 1.toShort, 1.toByte, true, new java.math.BigDecimal(1), Date.valueOf("1970-01-01"), - new Timestamp(12345), Seq(1,2,3))) + new Timestamp(12345), Seq(1, 2, 3))) } test("query case class RDD with nulls") { val data = NullReflectData(null, null, null, null, null, null, null) - val rdd = sparkContext.parallelize(data :: Nil) - rdd.toDF().registerTempTable("reflectNullData") + Seq(data).toDF().registerTempTable("reflectNullData") - assert(sql("SELECT * FROM reflectNullData").collect().head === Row.fromSeq(Seq.fill(7)(null))) + assert(ctx.sql("SELECT * FROM reflectNullData").collect().head === + Row.fromSeq(Seq.fill(7)(null))) } test("query case class RDD with Nones") { val data = OptionalReflectData(None, None, None, None, None, None, None) - val rdd = sparkContext.parallelize(data :: Nil) - rdd.toDF().registerTempTable("reflectOptionalData") + Seq(data).toDF().registerTempTable("reflectOptionalData") - assert(sql("SELECT * FROM reflectOptionalData").collect().head === + assert(ctx.sql("SELECT * FROM reflectOptionalData").collect().head === Row.fromSeq(Seq.fill(7)(null))) } // Equality is broken for Arrays, so we test that separately. test("query binary data") { - val rdd = sparkContext.parallelize(ReflectBinary(Array[Byte](1)) :: Nil) - rdd.toDF().registerTempTable("reflectBinary") + Seq(ReflectBinary(Array[Byte](1))).toDF().registerTempTable("reflectBinary") - val result = sql("SELECT data FROM reflectBinary").collect().head(0).asInstanceOf[Array[Byte]] + val result = ctx.sql("SELECT data FROM reflectBinary") + .collect().head(0).asInstanceOf[Array[Byte]] assert(result.toSeq === Seq[Byte](1)) } @@ -128,20 +124,19 @@ class ScalaReflectionRelationSuite extends FunSuite { Map(10 -> 100L, 20 -> 200L), Map(10 -> Some(100L), 20 -> Some(200L), 30 -> None), Nested(None, "abc"))) - val rdd = sparkContext.parallelize(data :: Nil) - rdd.toDF().registerTempTable("reflectComplexData") - assert(sql("SELECT * FROM reflectComplexData").collect().head === - new GenericRow(Array[Any]( + Seq(data).toDF().registerTempTable("reflectComplexData") + assert(ctx.sql("SELECT * FROM reflectComplexData").collect().head === + Row( Seq(1, 2, 3), Seq(1, 2, null), Map(1 -> 10L, 2 -> 20L), Map(1 -> 10L, 2 -> 20L, 3 -> null), - new GenericRow(Array[Any]( + Row( Seq(10, 20, 30), Seq(10, 20, null), Map(10 -> 100L, 20 -> 200L), Map(10 -> 100L, 20 -> 200L, 30 -> null), - new GenericRow(Array[Any](null, "abc"))))))) + Row(null, "abc")))) } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SerializationSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SerializationSuite.scala index 6f6d3c9c243d..e55c9e460b79 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/SerializationSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/SerializationSuite.scala @@ -17,16 +17,15 @@ package org.apache.spark.sql -import org.scalatest.FunSuite - -import org.apache.spark.SparkConf +import org.apache.spark.{SparkConf, SparkFunSuite} import org.apache.spark.serializer.JavaSerializer -import org.apache.spark.sql.test.TestSQLContext -class SerializationSuite extends FunSuite { +class SerializationSuite extends SparkFunSuite { + + private lazy val ctx = org.apache.spark.sql.test.TestSQLContext test("[SPARK-5235] SQLContext should be serializable") { - val sqlContext = new SQLContext(TestSQLContext.sparkContext) + val sqlContext = new SQLContext(ctx.sparkContext) new JavaSerializer(new SparkConf()).newInstance().serialize(sqlContext) } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/TestData.scala b/sql/core/src/test/scala/org/apache/spark/sql/TestData.scala index 8fbc2d23d47e..207d7a352c7b 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/TestData.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/TestData.scala @@ -19,7 +19,6 @@ package org.apache.spark.sql import java.sql.Timestamp -import org.apache.spark.sql.catalyst.plans.logical import org.apache.spark.sql.test.TestSQLContext.implicits._ import org.apache.spark.sql.test._ @@ -109,8 +108,8 @@ object TestData { case class ArrayData(data: Seq[Int], nestedData: Seq[Seq[Int]]) val arrayData = TestSQLContext.sparkContext.parallelize( - ArrayData(Seq(1,2,3), Seq(Seq(1,2,3))) :: - ArrayData(Seq(2,3,4), Seq(Seq(2,3,4))) :: Nil) + ArrayData(Seq(1, 2, 3), Seq(Seq(1, 2, 3))) :: + ArrayData(Seq(2, 3, 4), Seq(Seq(2, 3, 4))) :: Nil) arrayData.toDF().registerTempTable("arrayData") case class MapData(data: scala.collection.Map[Int, String]) @@ -174,12 +173,6 @@ object TestData { "3, C3, true, null" :: "4, D4, true, 2147483644" :: Nil) - case class TimestampField(time: Timestamp) - val timestamps = TestSQLContext.sparkContext.parallelize((0 to 3).map { i => - TimestampField(new Timestamp(i)) - }) - timestamps.toDF().registerTempTable("timestamps") - case class IntField(i: Int) // An RDD with 4 elements and 8 partitions val withEmptyParts = TestSQLContext.sparkContext.parallelize((1 to 4).map(IntField), 8) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/UDFSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/UDFSuite.scala index d615542ab50a..c1516b450cbd 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/UDFSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/UDFSuite.scala @@ -17,43 +17,159 @@ package org.apache.spark.sql -import org.apache.spark.sql.test._ - -/* Implicits */ -import TestSQLContext._ -import TestSQLContext.implicits._ case class FunctionResult(f1: String, f2: String) class UDFSuite extends QueryTest { + private lazy val ctx = org.apache.spark.sql.test.TestSQLContext + import ctx.implicits._ + + test("built-in fixed arity expressions") { + val df = ctx.emptyDataFrame + df.selectExpr("rand()", "randn()", "rand(5)", "randn(50)") + } + + test("built-in vararg expressions") { + val df = Seq((1, 2)).toDF("a", "b") + df.selectExpr("array(a, b)") + df.selectExpr("struct(a, b)") + } + + test("built-in expressions with multiple constructors") { + val df = Seq(("abcd", 2)).toDF("a", "b") + df.selectExpr("substr(a, 2)", "substr(a, 2, 3)").collect() + } + + test("count") { + val df = Seq(("abcd", 2)).toDF("a", "b") + df.selectExpr("count(a)") + } + + test("count distinct") { + val df = Seq(("abcd", 2)).toDF("a", "b") + df.selectExpr("count(distinct a)") + } + + test("error reporting for incorrect number of arguments") { + val df = ctx.emptyDataFrame + val e = intercept[AnalysisException] { + df.selectExpr("substr('abcd', 2, 3, 4)") + } + assert(e.getMessage.contains("arguments")) + } + + test("error reporting for undefined functions") { + val df = ctx.emptyDataFrame + val e = intercept[AnalysisException] { + df.selectExpr("a_function_that_does_not_exist()") + } + assert(e.getMessage.contains("undefined function")) + } + test("Simple UDF") { - udf.register("strLenScala", (_: String).length) - assert(sql("SELECT strLenScala('test')").head().getInt(0) === 4) + ctx.udf.register("strLenScala", (_: String).length) + assert(ctx.sql("SELECT strLenScala('test')").head().getInt(0) === 4) } test("ZeroArgument UDF") { - udf.register("random0", () => { Math.random()}) - assert(sql("SELECT random0()").head().getDouble(0) >= 0.0) + ctx.udf.register("random0", () => { Math.random()}) + assert(ctx.sql("SELECT random0()").head().getDouble(0) >= 0.0) } test("TwoArgument UDF") { - udf.register("strLenScala", (_: String).length + (_:Int)) - assert(sql("SELECT strLenScala('test', 1)").head().getInt(0) === 5) + ctx.udf.register("strLenScala", (_: String).length + (_: Int)) + assert(ctx.sql("SELECT strLenScala('test', 1)").head().getInt(0) === 5) + } + + test("UDF in a WHERE") { + ctx.udf.register("oneArgFilter", (n: Int) => { n > 80 }) + + val df = ctx.sparkContext.parallelize( + (1 to 100).map(i => TestData(i, i.toString))).toDF() + df.registerTempTable("integerData") + + val result = + ctx.sql("SELECT * FROM integerData WHERE oneArgFilter(key)") + assert(result.count() === 20) + } + + test("UDF in a HAVING") { + ctx.udf.register("havingFilter", (n: Long) => { n > 5 }) + + val df = Seq(("red", 1), ("red", 2), ("blue", 10), + ("green", 100), ("green", 200)).toDF("g", "v") + df.registerTempTable("groupData") + + val result = + ctx.sql( + """ + | SELECT g, SUM(v) as s + | FROM groupData + | GROUP BY g + | HAVING havingFilter(s) + """.stripMargin) + + assert(result.count() === 2) + } + + test("UDF in a GROUP BY") { + ctx.udf.register("groupFunction", (n: Int) => { n > 10 }) + + val df = Seq(("red", 1), ("red", 2), ("blue", 10), + ("green", 100), ("green", 200)).toDF("g", "v") + df.registerTempTable("groupData") + + val result = + ctx.sql( + """ + | SELECT SUM(v) + | FROM groupData + | GROUP BY groupFunction(v) + """.stripMargin) + assert(result.count() === 2) + } + + test("UDFs everywhere") { + ctx.udf.register("groupFunction", (n: Int) => { n > 10 }) + ctx.udf.register("havingFilter", (n: Long) => { n > 2000 }) + ctx.udf.register("whereFilter", (n: Int) => { n < 150 }) + ctx.udf.register("timesHundred", (n: Long) => { n * 100 }) + + val df = Seq(("red", 1), ("red", 2), ("blue", 10), + ("green", 100), ("green", 200)).toDF("g", "v") + df.registerTempTable("groupData") + + val result = + ctx.sql( + """ + | SELECT timesHundred(SUM(v)) as v100 + | FROM groupData + | WHERE whereFilter(v) + | GROUP BY groupFunction(v) + | HAVING havingFilter(v100) + """.stripMargin) + assert(result.count() === 1) } test("struct UDF") { - udf.register("returnStruct", (f1: String, f2: String) => FunctionResult(f1, f2)) + ctx.udf.register("returnStruct", (f1: String, f2: String) => FunctionResult(f1, f2)) val result = - sql("SELECT returnStruct('test', 'test2') as ret") + ctx.sql("SELECT returnStruct('test', 'test2') as ret") .select($"ret.f1").head().getString(0) assert(result === "test") } test("udf that is transformed") { - udf.register("makeStruct", (x: Int, y: Int) => (x, y)) + ctx.udf.register("makeStruct", (x: Int, y: Int) => (x, y)) // 1 + 1 is constant folded causing a transformation. - assert(sql("SELECT makeStruct(1 + 1, 2)").first().getAs[Row](0) === Row(2, 2)) + assert(ctx.sql("SELECT makeStruct(1 + 1, 2)").first().getAs[Row](0) === Row(2, 2)) + } + + test("type coercion for udf inputs") { + ctx.udf.register("intExpected", (x: Int) => x) + // pass a decimal to intExpected. + assert(ctx.sql("SELECT intExpected(1.0)").head().getInt(0) === 1) } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/UserDefinedTypeSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/UserDefinedTypeSuite.scala index 2672e20deadc..45c9f06941c1 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/UserDefinedTypeSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/UserDefinedTypeSuite.scala @@ -17,10 +17,6 @@ package org.apache.spark.sql -import java.io.File - -import org.apache.spark.util.Utils - import scala.beans.{BeanInfo, BeanProperty} import com.clearspring.analytics.stream.cardinality.HyperLogLog @@ -28,12 +24,11 @@ import com.clearspring.analytics.stream.cardinality.HyperLogLog import org.apache.spark.rdd.RDD import org.apache.spark.sql.catalyst.expressions.{OpenHashSetUDT, HyperLogLogUDT} import org.apache.spark.sql.functions._ -import org.apache.spark.sql.test.TestSQLContext -import org.apache.spark.sql.test.TestSQLContext.{sparkContext, sql} -import org.apache.spark.sql.test.TestSQLContext.implicits._ import org.apache.spark.sql.types._ +import org.apache.spark.util.Utils import org.apache.spark.util.collection.OpenHashSet + @SQLUserDefinedType(udt = classOf[MyDenseVectorUDT]) private[sql] class MyDenseVector(val data: Array[Double]) extends Serializable { override def equals(other: Any): Boolean = other match { @@ -72,11 +67,13 @@ private[sql] class MyDenseVectorUDT extends UserDefinedType[MyDenseVector] { } class UserDefinedTypeSuite extends QueryTest { - val points = Seq( - MyLabeledPoint(1.0, new MyDenseVector(Array(0.1, 1.0))), - MyLabeledPoint(0.0, new MyDenseVector(Array(0.2, 2.0)))) - val pointsRDD = sparkContext.parallelize(points).toDF() + private lazy val ctx = org.apache.spark.sql.test.TestSQLContext + import ctx.implicits._ + + private lazy val pointsRDD = Seq( + MyLabeledPoint(1.0, new MyDenseVector(Array(0.1, 1.0))), + MyLabeledPoint(0.0, new MyDenseVector(Array(0.2, 2.0)))).toDF() test("register user type: MyDenseVector for MyLabeledPoint") { val labels: RDD[Double] = pointsRDD.select('label).rdd.map { case Row(v: Double) => v } @@ -94,10 +91,10 @@ class UserDefinedTypeSuite extends QueryTest { } test("UDTs and UDFs") { - TestSQLContext.udf.register("testType", (d: MyDenseVector) => d.isInstanceOf[MyDenseVector]) + ctx.udf.register("testType", (d: MyDenseVector) => d.isInstanceOf[MyDenseVector]) pointsRDD.registerTempTable("points") checkAnswer( - sql("SELECT testType(features) from points"), + ctx.sql("SELECT testType(features) from points"), Seq(Row(true), Row(true))) } @@ -105,13 +102,13 @@ class UserDefinedTypeSuite extends QueryTest { test("UDTs with Parquet") { val tempDir = Utils.createTempDir() tempDir.delete() - pointsRDD.saveAsParquetFile(tempDir.getCanonicalPath) + pointsRDD.write.parquet(tempDir.getCanonicalPath) } test("Repartition UDTs with Parquet") { val tempDir = Utils.createTempDir() tempDir.delete() - pointsRDD.repartition(1).saveAsParquetFile(tempDir.getCanonicalPath) + pointsRDD.repartition(1).write.parquet(tempDir.getCanonicalPath) } // Tests to make sure that all operators correctly convert types on the way out. diff --git a/sql/core/src/test/scala/org/apache/spark/sql/columnar/ColumnStatsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/columnar/ColumnStatsSuite.scala index 7cefcf44061c..9bd7b221e93f 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/columnar/ColumnStatsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/columnar/ColumnStatsSuite.scala @@ -17,27 +17,30 @@ package org.apache.spark.sql.columnar -import org.scalatest.FunSuite - -import org.apache.spark.sql.catalyst.expressions.Row +import org.apache.spark.SparkFunSuite +import org.apache.spark.sql.catalyst.expressions.InternalRow import org.apache.spark.sql.types._ -class ColumnStatsSuite extends FunSuite { - testColumnStats(classOf[ByteColumnStats], BYTE, Row(Byte.MaxValue, Byte.MinValue, 0)) - testColumnStats(classOf[ShortColumnStats], SHORT, Row(Short.MaxValue, Short.MinValue, 0)) - testColumnStats(classOf[IntColumnStats], INT, Row(Int.MaxValue, Int.MinValue, 0)) - testColumnStats(classOf[LongColumnStats], LONG, Row(Long.MaxValue, Long.MinValue, 0)) - testColumnStats(classOf[FloatColumnStats], FLOAT, Row(Float.MaxValue, Float.MinValue, 0)) - testColumnStats(classOf[DoubleColumnStats], DOUBLE, Row(Double.MaxValue, Double.MinValue, 0)) - testColumnStats(classOf[FixedDecimalColumnStats], FIXED_DECIMAL(15, 10), Row(null, null, 0)) - testColumnStats(classOf[StringColumnStats], STRING, Row(null, null, 0)) - testColumnStats(classOf[DateColumnStats], DATE, Row(Int.MaxValue, Int.MinValue, 0)) - testColumnStats(classOf[TimestampColumnStats], TIMESTAMP, Row(null, null, 0)) +class ColumnStatsSuite extends SparkFunSuite { + testColumnStats(classOf[BooleanColumnStats], BOOLEAN, InternalRow(true, false, 0)) + testColumnStats(classOf[ByteColumnStats], BYTE, InternalRow(Byte.MaxValue, Byte.MinValue, 0)) + testColumnStats(classOf[ShortColumnStats], SHORT, InternalRow(Short.MaxValue, Short.MinValue, 0)) + testColumnStats(classOf[IntColumnStats], INT, InternalRow(Int.MaxValue, Int.MinValue, 0)) + testColumnStats(classOf[DateColumnStats], DATE, InternalRow(Int.MaxValue, Int.MinValue, 0)) + testColumnStats(classOf[LongColumnStats], LONG, InternalRow(Long.MaxValue, Long.MinValue, 0)) + testColumnStats(classOf[TimestampColumnStats], TIMESTAMP, + InternalRow(Long.MaxValue, Long.MinValue, 0)) + testColumnStats(classOf[FloatColumnStats], FLOAT, InternalRow(Float.MaxValue, Float.MinValue, 0)) + testColumnStats(classOf[DoubleColumnStats], DOUBLE, + InternalRow(Double.MaxValue, Double.MinValue, 0)) + testColumnStats(classOf[StringColumnStats], STRING, InternalRow(null, null, 0)) + testColumnStats(classOf[FixedDecimalColumnStats], + FIXED_DECIMAL(15, 10), InternalRow(null, null, 0)) def testColumnStats[T <: AtomicType, U <: ColumnStats]( columnStatsClass: Class[U], columnType: NativeColumnType[T], - initialStatistics: Row): Unit = { + initialStatistics: InternalRow): Unit = { val columnStatsName = columnStatsClass.getSimpleName diff --git a/sql/core/src/test/scala/org/apache/spark/sql/columnar/ColumnTypeSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/columnar/ColumnTypeSuite.scala index 1e105e259dce..4d46a657056e 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/columnar/ColumnTypeSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/columnar/ColumnTypeSuite.scala @@ -18,27 +18,27 @@ package org.apache.spark.sql.columnar import java.nio.ByteBuffer -import java.sql.Timestamp -import com.esotericsoftware.kryo.{Serializer, Kryo} import com.esotericsoftware.kryo.io.{Input, Output} -import org.apache.spark.serializer.KryoRegistrator -import org.scalatest.FunSuite +import com.esotericsoftware.kryo.{Kryo, Serializer} -import org.apache.spark.{SparkConf, Logging} +import org.apache.spark.{Logging, SparkConf, SparkFunSuite} +import org.apache.spark.serializer.KryoRegistrator import org.apache.spark.sql.catalyst.expressions.GenericMutableRow import org.apache.spark.sql.columnar.ColumnarTestUtils._ import org.apache.spark.sql.execution.SparkSqlSerializer import org.apache.spark.sql.types._ +import org.apache.spark.unsafe.types.UTF8String -class ColumnTypeSuite extends FunSuite with Logging { + +class ColumnTypeSuite extends SparkFunSuite with Logging { val DEFAULT_BUFFER_SIZE = 512 test("defaultSize") { val checks = Map( - INT -> 4, SHORT -> 2, LONG -> 8, BYTE -> 1, DOUBLE -> 8, FLOAT -> 4, - FIXED_DECIMAL(15, 10) -> 8, BOOLEAN -> 1, STRING -> 8, DATE -> 4, TIMESTAMP -> 12, - BINARY -> 16, GENERIC -> 16) + BOOLEAN -> 1, BYTE -> 1, SHORT -> 2, INT -> 4, DATE -> 4, + LONG -> 8, TIMESTAMP -> 8, FLOAT -> 4, DOUBLE -> 8, + STRING -> 8, BINARY -> 16, FIXED_DECIMAL(15, 10) -> 8, GENERIC -> 16) checks.foreach { case (columnType, expectedSize) => assertResult(expectedSize, s"Wrong defaultSize for $columnType") { @@ -60,27 +60,24 @@ class ColumnTypeSuite extends FunSuite with Logging { } } - checkActualSize(INT, Int.MaxValue, 4) + checkActualSize(BOOLEAN, true, 1) + checkActualSize(BYTE, Byte.MaxValue, 1) checkActualSize(SHORT, Short.MaxValue, 2) + checkActualSize(INT, Int.MaxValue, 4) + checkActualSize(DATE, Int.MaxValue, 4) checkActualSize(LONG, Long.MaxValue, 8) - checkActualSize(BYTE, Byte.MaxValue, 1) - checkActualSize(DOUBLE, Double.MaxValue, 8) + checkActualSize(TIMESTAMP, Long.MaxValue, 8) checkActualSize(FLOAT, Float.MaxValue, 4) + checkActualSize(DOUBLE, Double.MaxValue, 8) + checkActualSize(STRING, UTF8String.fromString("hello"), 4 + "hello".getBytes("utf-8").length) + checkActualSize(BINARY, Array.fill[Byte](4)(0.toByte), 4 + 4) checkActualSize(FIXED_DECIMAL(15, 10), Decimal(0, 15, 10), 8) - checkActualSize(BOOLEAN, true, 1) - checkActualSize(STRING, UTF8String("hello"), 4 + "hello".getBytes("utf-8").length) - checkActualSize(DATE, 0, 4) - checkActualSize(TIMESTAMP, new Timestamp(0L), 12) - - val binary = Array.fill[Byte](4)(0: Byte) - checkActualSize(BINARY, binary, 4 + 4) val generic = Map(1 -> "a") checkActualSize(GENERIC, SparkSqlSerializer.serialize(generic), 4 + 8) } - testNativeColumnType[BooleanType.type]( - BOOLEAN, + testNativeColumnType(BOOLEAN)( (buffer: ByteBuffer, v: Boolean) => { buffer.put((if (v) 1 else 0).toByte) }, @@ -88,18 +85,23 @@ class ColumnTypeSuite extends FunSuite with Logging { buffer.get() == 1 }) - testNativeColumnType[IntegerType.type](INT, _.putInt(_), _.getInt) + testNativeColumnType(BYTE)(_.put(_), _.get) + + testNativeColumnType(SHORT)(_.putShort(_), _.getShort) + + testNativeColumnType(INT)(_.putInt(_), _.getInt) + + testNativeColumnType(DATE)(_.putInt(_), _.getInt) - testNativeColumnType[ShortType.type](SHORT, _.putShort(_), _.getShort) + testNativeColumnType(LONG)(_.putLong(_), _.getLong) - testNativeColumnType[LongType.type](LONG, _.putLong(_), _.getLong) + testNativeColumnType(TIMESTAMP)(_.putLong(_), _.getLong) - testNativeColumnType[ByteType.type](BYTE, _.put(_), _.get) + testNativeColumnType(FLOAT)(_.putFloat(_), _.getFloat) - testNativeColumnType[DoubleType.type](DOUBLE, _.putDouble(_), _.getDouble) + testNativeColumnType(DOUBLE)(_.putDouble(_), _.getDouble) - testNativeColumnType[DecimalType]( - FIXED_DECIMAL(15, 10), + testNativeColumnType(FIXED_DECIMAL(15, 10))( (buffer: ByteBuffer, decimal: Decimal) => { buffer.putLong(decimal.toUnscaledLong) }, @@ -107,10 +109,8 @@ class ColumnTypeSuite extends FunSuite with Logging { Decimal(buffer.getLong(), 15, 10) }) - testNativeColumnType[FloatType.type](FLOAT, _.putFloat(_), _.getFloat) - testNativeColumnType[StringType.type]( - STRING, + testNativeColumnType(STRING)( (buffer: ByteBuffer, string: UTF8String) => { val bytes = string.getBytes buffer.putInt(bytes.length) @@ -120,7 +120,7 @@ class ColumnTypeSuite extends FunSuite with Logging { val length = buffer.getInt() val bytes = new Array[Byte](length) buffer.get(bytes) - UTF8String(bytes) + UTF8String.fromBytes(bytes) }) testColumnType[BinaryType.type, Array[Byte]]( @@ -167,7 +167,7 @@ class ColumnTypeSuite extends FunSuite with Logging { val serializer = new SparkSqlSerializer(conf).newInstance() val buffer = ByteBuffer.allocate(512) - val obj = CustomClass(Int.MaxValue,Long.MaxValue) + val obj = CustomClass(Int.MaxValue, Long.MaxValue) val serializedObj = serializer.serialize(obj).array() GENERIC.append(serializer.serialize(obj).array(), buffer) @@ -197,8 +197,8 @@ class ColumnTypeSuite extends FunSuite with Logging { } def testNativeColumnType[T <: AtomicType]( - columnType: NativeColumnType[T], - putter: (ByteBuffer, T#InternalType) => Unit, + columnType: NativeColumnType[T]) + (putter: (ByteBuffer, T#InternalType) => Unit, getter: (ByteBuffer) => T#InternalType): Unit = { testColumnType[T, T#InternalType](columnType, putter, getter) @@ -278,7 +278,7 @@ private[columnar] object CustomerSerializer extends Serializer[CustomClass] { override def read(kryo: Kryo, input: Input, aClass: Class[CustomClass]): CustomClass = { val a = input.readInt() val b = input.readLong() - CustomClass(a,b) + CustomClass(a, b) } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/columnar/ColumnarTestUtils.scala b/sql/core/src/test/scala/org/apache/spark/sql/columnar/ColumnarTestUtils.scala index 75d993e563e0..d9861339739c 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/columnar/ColumnarTestUtils.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/columnar/ColumnarTestUtils.scala @@ -17,14 +17,12 @@ package org.apache.spark.sql.columnar -import java.sql.Timestamp - import scala.collection.immutable.HashSet import scala.util.Random - -import org.apache.spark.sql.Row +import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions.GenericMutableRow -import org.apache.spark.sql.types.{UTF8String, DataType, Decimal, AtomicType} +import org.apache.spark.sql.types.{DataType, Decimal, AtomicType} +import org.apache.spark.unsafe.types.UTF8String object ColumnarTestUtils { def makeNullRow(length: Int): GenericMutableRow = { @@ -41,21 +39,18 @@ object ColumnarTestUtils { } (columnType match { + case BOOLEAN => Random.nextBoolean() case BYTE => (Random.nextInt(Byte.MaxValue * 2) - Byte.MaxValue).toByte case SHORT => (Random.nextInt(Short.MaxValue * 2) - Short.MaxValue).toShort case INT => Random.nextInt() + case DATE => Random.nextInt() case LONG => Random.nextLong() + case TIMESTAMP => Random.nextLong() case FLOAT => Random.nextFloat() case DOUBLE => Random.nextDouble() - case FIXED_DECIMAL(precision, scale) => Decimal(Random.nextLong() % 100, precision, scale) - case STRING => UTF8String(Random.nextString(Random.nextInt(32))) - case BOOLEAN => Random.nextBoolean() + case STRING => UTF8String.fromString(Random.nextString(Random.nextInt(32))) case BINARY => randomBytes(Random.nextInt(32)) - case DATE => Random.nextInt() - case TIMESTAMP => - val timestamp = new Timestamp(Random.nextLong()) - timestamp.setNanos(Random.nextInt(999999999)) - timestamp + case FIXED_DECIMAL(precision, scale) => Decimal(Random.nextLong() % 100, precision, scale) case _ => // Using a random one-element map instead of an arbitrary object Map(Random.nextInt() -> Random.nextString(Random.nextInt(32))) @@ -81,9 +76,9 @@ object ColumnarTestUtils { def makeRandomRow( head: ColumnType[_ <: DataType, _], - tail: ColumnType[_ <: DataType, _]*): Row = makeRandomRow(Seq(head) ++ tail) + tail: ColumnType[_ <: DataType, _]*): InternalRow = makeRandomRow(Seq(head) ++ tail) - def makeRandomRow(columnTypes: Seq[ColumnType[_ <: DataType, _]]): Row = { + def makeRandomRow(columnTypes: Seq[ColumnType[_ <: DataType, _]]): InternalRow = { val row = new GenericMutableRow(columnTypes.length) makeRandomValues(columnTypes).zipWithIndex.foreach { case (value, index) => row(index) = value diff --git a/sql/core/src/test/scala/org/apache/spark/sql/columnar/InMemoryColumnarQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/columnar/InMemoryColumnarQuerySuite.scala index 56591d9dba29..01bc23277fa8 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/columnar/InMemoryColumnarQuerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/columnar/InMemoryColumnarQuerySuite.scala @@ -20,19 +20,20 @@ package org.apache.spark.sql.columnar import java.sql.{Date, Timestamp} import org.apache.spark.sql.TestData._ -import org.apache.spark.sql.catalyst.expressions.Row -import org.apache.spark.sql.test.TestSQLContext._ -import org.apache.spark.sql.test.TestSQLContext.implicits._ import org.apache.spark.sql.types._ -import org.apache.spark.sql.{QueryTest, TestData} +import org.apache.spark.sql.{QueryTest, Row, TestData} import org.apache.spark.storage.StorageLevel.MEMORY_ONLY class InMemoryColumnarQuerySuite extends QueryTest { // Make sure the tables are loaded. TestData + private lazy val ctx = org.apache.spark.sql.test.TestSQLContext + import ctx.implicits._ + import ctx.{logicalPlanToSparkQuery, sql} + test("simple columnar query") { - val plan = executePlan(testData.logicalPlan).executedPlan + val plan = ctx.executePlan(testData.logicalPlan).executedPlan val scan = InMemoryRelation(useCompression = true, 5, MEMORY_ONLY, plan, None) checkAnswer(scan, testData.collect().toSeq) @@ -40,16 +41,16 @@ class InMemoryColumnarQuerySuite extends QueryTest { test("default size avoids broadcast") { // TODO: Improve this test when we have better statistics - sparkContext.parallelize(1 to 10).map(i => TestData(i, i.toString)) + ctx.sparkContext.parallelize(1 to 10).map(i => TestData(i, i.toString)) .toDF().registerTempTable("sizeTst") - cacheTable("sizeTst") + ctx.cacheTable("sizeTst") assert( - table("sizeTst").queryExecution.analyzed.statistics.sizeInBytes > - conf.autoBroadcastJoinThreshold) + ctx.table("sizeTst").queryExecution.analyzed.statistics.sizeInBytes > + ctx.conf.autoBroadcastJoinThreshold) } test("projection") { - val plan = executePlan(testData.select('value, 'key).logicalPlan).executedPlan + val plan = ctx.executePlan(testData.select('value, 'key).logicalPlan).executedPlan val scan = InMemoryRelation(useCompression = true, 5, MEMORY_ONLY, plan, None) checkAnswer(scan, testData.collect().map { @@ -58,7 +59,7 @@ class InMemoryColumnarQuerySuite extends QueryTest { } test("SPARK-1436 regression: in-memory columns must be able to be accessed multiple times") { - val plan = executePlan(testData.logicalPlan).executedPlan + val plan = ctx.executePlan(testData.logicalPlan).executedPlan val scan = InMemoryRelation(useCompression = true, 5, MEMORY_ONLY, plan, None) checkAnswer(scan, testData.collect().toSeq) @@ -70,7 +71,7 @@ class InMemoryColumnarQuerySuite extends QueryTest { sql("SELECT * FROM repeatedData"), repeatedData.collect().toSeq.map(Row.fromTuple)) - cacheTable("repeatedData") + ctx.cacheTable("repeatedData") checkAnswer( sql("SELECT * FROM repeatedData"), @@ -82,7 +83,7 @@ class InMemoryColumnarQuerySuite extends QueryTest { sql("SELECT * FROM nullableRepeatedData"), nullableRepeatedData.collect().toSeq.map(Row.fromTuple)) - cacheTable("nullableRepeatedData") + ctx.cacheTable("nullableRepeatedData") checkAnswer( sql("SELECT * FROM nullableRepeatedData"), @@ -90,15 +91,18 @@ class InMemoryColumnarQuerySuite extends QueryTest { } test("SPARK-2729 regression: timestamp data type") { + val timestamps = (0 to 3).map(i => Tuple1(new Timestamp(i))).toDF("time") + timestamps.registerTempTable("timestamps") + checkAnswer( sql("SELECT time FROM timestamps"), - timestamps.collect().toSeq.map(Row.fromTuple)) + timestamps.collect().toSeq) - cacheTable("timestamps") + ctx.cacheTable("timestamps") checkAnswer( sql("SELECT time FROM timestamps"), - timestamps.collect().toSeq.map(Row.fromTuple)) + timestamps.collect().toSeq) } test("SPARK-3320 regression: batched column buffer building should work with empty partitions") { @@ -106,7 +110,7 @@ class InMemoryColumnarQuerySuite extends QueryTest { sql("SELECT * FROM withEmptyParts"), withEmptyParts.collect().toSeq.map(Row.fromTuple)) - cacheTable("withEmptyParts") + ctx.cacheTable("withEmptyParts") checkAnswer( sql("SELECT * FROM withEmptyParts"), @@ -155,7 +159,7 @@ class InMemoryColumnarQuerySuite extends QueryTest { // Create a RDD for the schema val rdd = - sparkContext.parallelize((1 to 100), 10).map { i => + ctx.sparkContext.parallelize((1 to 100), 10).map { i => Row( s"str${i}: test cache.", s"binary${i}: test cache.".getBytes("UTF-8"), @@ -173,20 +177,20 @@ class InMemoryColumnarQuerySuite extends QueryTest { new Timestamp(i), (1 to i).toSeq, (0 to i).map(j => s"map_key_$j" -> (Long.MaxValue - j)).toMap, - Row((i - 0.25).toFloat, (1 to i).toSeq)) + Row((i - 0.25).toFloat, Seq(true, false, null))) } - createDataFrame(rdd, schema).registerTempTable("InMemoryCache_different_data_types") + ctx.createDataFrame(rdd, schema).registerTempTable("InMemoryCache_different_data_types") // Cache the table. sql("cache table InMemoryCache_different_data_types") // Make sure the table is indeed cached. - val tableScan = table("InMemoryCache_different_data_types").queryExecution.executedPlan + val tableScan = ctx.table("InMemoryCache_different_data_types").queryExecution.executedPlan assert( - isCached("InMemoryCache_different_data_types"), + ctx.isCached("InMemoryCache_different_data_types"), "InMemoryCache_different_data_types should be cached.") // Issue a query and check the results. checkAnswer( sql(s"SELECT DISTINCT ${allColumns} FROM InMemoryCache_different_data_types"), - table("InMemoryCache_different_data_types").collect()) - dropTempTable("InMemoryCache_different_data_types") + ctx.table("InMemoryCache_different_data_types").collect()) + ctx.dropTempTable("InMemoryCache_different_data_types") } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/columnar/NullableColumnAccessorSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/columnar/NullableColumnAccessorSuite.scala index a0702144f942..9eaa76984608 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/columnar/NullableColumnAccessorSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/columnar/NullableColumnAccessorSuite.scala @@ -19,8 +19,7 @@ package org.apache.spark.sql.columnar import java.nio.ByteBuffer -import org.scalatest.FunSuite - +import org.apache.spark.SparkFunSuite import org.apache.spark.sql.catalyst.expressions.GenericMutableRow import org.apache.spark.sql.types.DataType @@ -39,13 +38,13 @@ object TestNullableColumnAccessor { } } -class NullableColumnAccessorSuite extends FunSuite { +class NullableColumnAccessorSuite extends SparkFunSuite { import ColumnarTestUtils._ Seq( - INT, LONG, SHORT, BOOLEAN, BYTE, STRING, DOUBLE, FLOAT, FIXED_DECIMAL(15, 10), BINARY, GENERIC, - DATE, TIMESTAMP - ).foreach { + BOOLEAN, BYTE, SHORT, INT, DATE, LONG, TIMESTAMP, FLOAT, DOUBLE, + STRING, BINARY, FIXED_DECIMAL(15, 10), GENERIC) + .foreach { testNullableColumnAccessor(_) } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/columnar/NullableColumnBuilderSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/columnar/NullableColumnBuilderSuite.scala index 3a5605d2335d..17e9ae464bcc 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/columnar/NullableColumnBuilderSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/columnar/NullableColumnBuilderSuite.scala @@ -17,8 +17,7 @@ package org.apache.spark.sql.columnar -import org.scalatest.FunSuite - +import org.apache.spark.SparkFunSuite import org.apache.spark.sql.execution.SparkSqlSerializer import org.apache.spark.sql.types._ @@ -35,13 +34,13 @@ object TestNullableColumnBuilder { } } -class NullableColumnBuilderSuite extends FunSuite { +class NullableColumnBuilderSuite extends SparkFunSuite { import ColumnarTestUtils._ Seq( - INT, LONG, SHORT, BOOLEAN, BYTE, STRING, DOUBLE, FLOAT, FIXED_DECIMAL(15, 10), BINARY, GENERIC, - DATE, TIMESTAMP - ).foreach { + BOOLEAN, BYTE, SHORT, INT, DATE, LONG, TIMESTAMP, FLOAT, DOUBLE, + STRING, BINARY, FIXED_DECIMAL(15, 10), GENERIC) + .foreach { testNullableColumnBuilder(_) } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/columnar/PartitionBatchPruningSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/columnar/PartitionBatchPruningSuite.scala index 2a0b701cad7f..2c0879927a12 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/columnar/PartitionBatchPruningSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/columnar/PartitionBatchPruningSuite.scala @@ -17,43 +17,46 @@ package org.apache.spark.sql.columnar -import org.scalatest.{BeforeAndAfter, BeforeAndAfterAll, FunSuite} +import org.scalatest.{BeforeAndAfter, BeforeAndAfterAll} +import org.apache.spark.SparkFunSuite import org.apache.spark.sql._ -import org.apache.spark.sql.test.TestSQLContext._ -import org.apache.spark.sql.test.TestSQLContext.implicits._ -class PartitionBatchPruningSuite extends FunSuite with BeforeAndAfterAll with BeforeAndAfter { - val originalColumnBatchSize = conf.columnBatchSize - val originalInMemoryPartitionPruning = conf.inMemoryPartitionPruning +class PartitionBatchPruningSuite extends SparkFunSuite with BeforeAndAfterAll with BeforeAndAfter { + + private lazy val ctx = org.apache.spark.sql.test.TestSQLContext + import ctx.implicits._ + + private lazy val originalColumnBatchSize = ctx.conf.columnBatchSize + private lazy val originalInMemoryPartitionPruning = ctx.conf.inMemoryPartitionPruning override protected def beforeAll(): Unit = { // Make a table with 5 partitions, 2 batches per partition, 10 elements per batch - setConf(SQLConf.COLUMN_BATCH_SIZE, "10") + ctx.setConf(SQLConf.COLUMN_BATCH_SIZE, 10) - val pruningData = sparkContext.makeRDD((1 to 100).map { key => + val pruningData = ctx.sparkContext.makeRDD((1 to 100).map { key => val string = if (((key - 1) / 10) % 2 == 0) null else key.toString TestData(key, string) }, 5).toDF() pruningData.registerTempTable("pruningData") // Enable in-memory partition pruning - setConf(SQLConf.IN_MEMORY_PARTITION_PRUNING, "true") + ctx.setConf(SQLConf.IN_MEMORY_PARTITION_PRUNING, true) // Enable in-memory table scan accumulators - setConf("spark.sql.inMemoryTableScanStatistics.enable", "true") + ctx.setConf("spark.sql.inMemoryTableScanStatistics.enable", "true") } override protected def afterAll(): Unit = { - setConf(SQLConf.COLUMN_BATCH_SIZE, originalColumnBatchSize.toString) - setConf(SQLConf.IN_MEMORY_PARTITION_PRUNING, originalInMemoryPartitionPruning.toString) + ctx.setConf(SQLConf.COLUMN_BATCH_SIZE, originalColumnBatchSize) + ctx.setConf(SQLConf.IN_MEMORY_PARTITION_PRUNING, originalInMemoryPartitionPruning) } before { - cacheTable("pruningData") + ctx.cacheTable("pruningData") } after { - uncacheTable("pruningData") + ctx.uncacheTable("pruningData") } // Comparisons @@ -107,7 +110,7 @@ class PartitionBatchPruningSuite extends FunSuite with BeforeAndAfterAll with Be expectedQueryResult: => Seq[Int]): Unit = { test(query) { - val df = sql(query) + val df = ctx.sql(query) val queryExecution = df.queryExecution assertResult(expectedQueryResult.toArray, s"Wrong query result: $queryExecution") { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/columnar/compression/BooleanBitSetSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/columnar/compression/BooleanBitSetSuite.scala index 8b518f094174..f606e2133bed 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/columnar/compression/BooleanBitSetSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/columnar/compression/BooleanBitSetSuite.scala @@ -17,14 +17,13 @@ package org.apache.spark.sql.columnar.compression -import org.scalatest.FunSuite - -import org.apache.spark.sql.Row +import org.apache.spark.SparkFunSuite +import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions.GenericMutableRow -import org.apache.spark.sql.columnar.{NoopColumnStats, BOOLEAN} import org.apache.spark.sql.columnar.ColumnarTestUtils._ +import org.apache.spark.sql.columnar.{BOOLEAN, NoopColumnStats} -class BooleanBitSetSuite extends FunSuite { +class BooleanBitSetSuite extends SparkFunSuite { import BooleanBitSet._ def skeleton(count: Int) { @@ -33,7 +32,7 @@ class BooleanBitSetSuite extends FunSuite { // ------------- val builder = TestCompressibleColumnBuilder(new NoopColumnStats, BOOLEAN, BooleanBitSet) - val rows = Seq.fill[Row](count)(makeRandomRow(BOOLEAN)) + val rows = Seq.fill[InternalRow](count)(makeRandomRow(BOOLEAN)) val values = rows.map(_(0)) rows.foreach(builder.appendFrom(_, 0)) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/columnar/compression/DictionaryEncodingSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/columnar/compression/DictionaryEncodingSuite.scala index 64b70552eb04..acfab6586c0d 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/columnar/compression/DictionaryEncodingSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/columnar/compression/DictionaryEncodingSuite.scala @@ -19,16 +19,15 @@ package org.apache.spark.sql.columnar.compression import java.nio.ByteBuffer -import org.scalatest.FunSuite - +import org.apache.spark.SparkFunSuite import org.apache.spark.sql.catalyst.expressions.GenericMutableRow import org.apache.spark.sql.columnar._ import org.apache.spark.sql.columnar.ColumnarTestUtils._ import org.apache.spark.sql.types.AtomicType -class DictionaryEncodingSuite extends FunSuite { - testDictionaryEncoding(new IntColumnStats, INT) - testDictionaryEncoding(new LongColumnStats, LONG) +class DictionaryEncodingSuite extends SparkFunSuite { + testDictionaryEncoding(new IntColumnStats, INT) + testDictionaryEncoding(new LongColumnStats, LONG) testDictionaryEncoding(new StringColumnStats, STRING) def testDictionaryEncoding[T <: AtomicType]( diff --git a/sql/core/src/test/scala/org/apache/spark/sql/columnar/compression/IntegralDeltaSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/columnar/compression/IntegralDeltaSuite.scala index bfd99f143bed..2111e9fbe62c 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/columnar/compression/IntegralDeltaSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/columnar/compression/IntegralDeltaSuite.scala @@ -17,15 +17,14 @@ package org.apache.spark.sql.columnar.compression -import org.scalatest.FunSuite - +import org.apache.spark.SparkFunSuite import org.apache.spark.sql.catalyst.expressions.GenericMutableRow import org.apache.spark.sql.columnar._ import org.apache.spark.sql.columnar.ColumnarTestUtils._ import org.apache.spark.sql.types.IntegralType -class IntegralDeltaSuite extends FunSuite { - testIntegralDelta(new IntColumnStats, INT, IntDelta) +class IntegralDeltaSuite extends SparkFunSuite { + testIntegralDelta(new IntColumnStats, INT, IntDelta) testIntegralDelta(new LongColumnStats, LONG, LongDelta) def testIntegralDelta[I <: IntegralType]( @@ -116,7 +115,7 @@ class IntegralDeltaSuite extends FunSuite { test(s"$scheme: simple case") { val input = columnType match { - case INT => Seq(2: Int, 1: Int, 2: Int, 130: Int) + case INT => Seq(2: Int, 1: Int, 2: Int, 130: Int) case LONG => Seq(2: Long, 1: Long, 2: Long, 130: Long) } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/columnar/compression/RunLengthEncodingSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/columnar/compression/RunLengthEncodingSuite.scala index fde7a4595be0..67ec08f594a4 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/columnar/compression/RunLengthEncodingSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/columnar/compression/RunLengthEncodingSuite.scala @@ -17,20 +17,19 @@ package org.apache.spark.sql.columnar.compression -import org.scalatest.FunSuite - +import org.apache.spark.SparkFunSuite import org.apache.spark.sql.catalyst.expressions.GenericMutableRow import org.apache.spark.sql.columnar._ import org.apache.spark.sql.columnar.ColumnarTestUtils._ import org.apache.spark.sql.types.AtomicType -class RunLengthEncodingSuite extends FunSuite { +class RunLengthEncodingSuite extends SparkFunSuite { testRunLengthEncoding(new NoopColumnStats, BOOLEAN) - testRunLengthEncoding(new ByteColumnStats, BYTE) - testRunLengthEncoding(new ShortColumnStats, SHORT) - testRunLengthEncoding(new IntColumnStats, INT) - testRunLengthEncoding(new LongColumnStats, LONG) - testRunLengthEncoding(new StringColumnStats, STRING) + testRunLengthEncoding(new ByteColumnStats, BYTE) + testRunLengthEncoding(new ShortColumnStats, SHORT) + testRunLengthEncoding(new IntColumnStats, INT) + testRunLengthEncoding(new LongColumnStats, LONG) + testRunLengthEncoding(new StringColumnStats, STRING) def testRunLengthEncoding[T <: AtomicType]( columnStats: ColumnStats, diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala index 523be56df65b..3dd24130af81 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala @@ -17,21 +17,19 @@ package org.apache.spark.sql.execution -import org.scalatest.FunSuite - -import org.apache.spark.sql.{SQLConf, execution} -import org.apache.spark.sql.functions._ +import org.apache.spark.SparkFunSuite import org.apache.spark.sql.TestData._ -import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.plans._ import org.apache.spark.sql.execution.joins.{BroadcastHashJoin, ShuffledHashJoin} +import org.apache.spark.sql.functions._ import org.apache.spark.sql.test.TestSQLContext._ import org.apache.spark.sql.test.TestSQLContext.implicits._ import org.apache.spark.sql.test.TestSQLContext.planner._ import org.apache.spark.sql.types._ +import org.apache.spark.sql.{Row, SQLConf, execution} -class PlannerSuite extends FunSuite { +class PlannerSuite extends SparkFunSuite { test("unions are collapsed") { val query = testData.unionAll(testData).unionAll(testData).logicalPlan val planned = BasicOperators(query).head @@ -65,7 +63,7 @@ class PlannerSuite extends FunSuite { test("sizeInBytes estimation of limit operator for broadcast hash join optimization") { def checkPlan(fieldTypes: Seq[DataType], newThreshold: Int): Unit = { - setConf(SQLConf.AUTO_BROADCASTJOIN_THRESHOLD, newThreshold.toString) + setConf(SQLConf.AUTO_BROADCASTJOIN_THRESHOLD, newThreshold) val fields = fieldTypes.zipWithIndex.map { case (dataType, index) => StructField(s"c${index}", dataType, true) } :+ StructField("key", IntegerType, true) @@ -121,12 +119,12 @@ class PlannerSuite extends FunSuite { checkPlan(complexTypes, newThreshold = 901617) - setConf(SQLConf.AUTO_BROADCASTJOIN_THRESHOLD, origThreshold.toString) + setConf(SQLConf.AUTO_BROADCASTJOIN_THRESHOLD, origThreshold) } test("InMemoryRelation statistics propagation") { val origThreshold = conf.autoBroadcastJoinThreshold - setConf(SQLConf.AUTO_BROADCASTJOIN_THRESHOLD, 81920.toString) + setConf(SQLConf.AUTO_BROADCASTJOIN_THRESHOLD, 81920) testData.limit(3).registerTempTable("tiny") sql("CACHE TABLE tiny") @@ -141,6 +139,12 @@ class PlannerSuite extends FunSuite { assert(broadcastHashJoins.size === 1, "Should use broadcast hash join") assert(shuffledHashJoins.isEmpty, "Should not use shuffled hash join") - setConf(SQLConf.AUTO_BROADCASTJOIN_THRESHOLD, origThreshold.toString) + setConf(SQLConf.AUTO_BROADCASTJOIN_THRESHOLD, origThreshold) + } + + test("efficient limit -> project -> sort") { + val query = testData.sort('key).select('value).limit(2).logicalPlan + val planned = planner.TakeOrderedAndProject(query) + assert(planned.head.isInstanceOf[execution.TakeOrderedAndProject]) } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/SortSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/SortSuite.scala new file mode 100644 index 000000000000..a1e3ca11b1ad --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/SortSuite.scala @@ -0,0 +1,44 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution + +import org.apache.spark.sql.catalyst.dsl.expressions._ + +class SortSuite extends SparkPlanTest { + + // This test was originally added as an example of how to use [[SparkPlanTest]]; + // it's not designed to be a comprehensive test of ExternalSort. + test("basic sorting using ExternalSort") { + + val input = Seq( + ("Hello", 4, 2.0), + ("Hello", 1, 1.0), + ("World", 8, 3.0) + ) + + checkAnswer( + input.toDF("a", "b", "c"), + ExternalSort('a.asc :: 'b.asc :: Nil, global = false, _: SparkPlan), + input.sorted) + + checkAnswer( + input.toDF("a", "b", "c"), + ExternalSort('b.asc :: 'a.asc :: Nil, global = false, _: SparkPlan), + input.sortBy(t => (t._2, t._1))) + } +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/SparkPlanTest.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/SparkPlanTest.scala new file mode 100644 index 000000000000..108b1122f7bf --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/SparkPlanTest.scala @@ -0,0 +1,226 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution + +import scala.language.implicitConversions +import scala.reflect.runtime.universe.TypeTag +import scala.util.control.NonFatal + +import org.apache.spark.SparkFunSuite + +import org.apache.spark.sql.catalyst.analysis.UnresolvedAttribute +import org.apache.spark.sql.catalyst.expressions.BoundReference +import org.apache.spark.sql.catalyst.util._ + +import org.apache.spark.sql.test.TestSQLContext +import org.apache.spark.sql.{DataFrameHolder, Row, DataFrame} + +/** + * Base class for writing tests for individual physical operators. For an example of how this + * class's test helper methods can be used, see [[SortSuite]]. + */ +class SparkPlanTest extends SparkFunSuite { + + /** + * Creates a DataFrame from a local Seq of Product. + */ + implicit def localSeqToDataFrameHolder[A <: Product : TypeTag](data: Seq[A]): DataFrameHolder = { + TestSQLContext.implicits.localSeqToDataFrameHolder(data) + } + + /** + * Runs the plan and makes sure the answer matches the expected result. + * @param input the input data to be used. + * @param planFunction a function which accepts the input SparkPlan and uses it to instantiate + * the physical operator that's being tested. + * @param expectedAnswer the expected result in a [[Seq]] of [[Row]]s. + */ + protected def checkAnswer( + input: DataFrame, + planFunction: SparkPlan => SparkPlan, + expectedAnswer: Seq[Row]): Unit = { + checkAnswer(input :: Nil, (plans: Seq[SparkPlan]) => planFunction(plans.head), expectedAnswer) + } + + /** + * Runs the plan and makes sure the answer matches the expected result. + * @param left the left input data to be used. + * @param right the right input data to be used. + * @param planFunction a function which accepts the input SparkPlan and uses it to instantiate + * the physical operator that's being tested. + * @param expectedAnswer the expected result in a [[Seq]] of [[Row]]s. + */ + protected def checkAnswer( + left: DataFrame, + right: DataFrame, + planFunction: (SparkPlan, SparkPlan) => SparkPlan, + expectedAnswer: Seq[Row]): Unit = { + checkAnswer(left :: right :: Nil, + (plans: Seq[SparkPlan]) => planFunction(plans(0), plans(1)), expectedAnswer) + } + + /** + * Runs the plan and makes sure the answer matches the expected result. + * @param input the input data to be used. + * @param planFunction a function which accepts the input SparkPlan and uses it to instantiate + * the physical operator that's being tested. + * @param expectedAnswer the expected result in a [[Seq]] of [[Row]]s. + */ + protected def checkAnswer( + input: Seq[DataFrame], + planFunction: Seq[SparkPlan] => SparkPlan, + expectedAnswer: Seq[Row]): Unit = { + SparkPlanTest.checkAnswer(input, planFunction, expectedAnswer) match { + case Some(errorMessage) => fail(errorMessage) + case None => + } + } + + /** + * Runs the plan and makes sure the answer matches the expected result. + * @param input the input data to be used. + * @param planFunction a function which accepts the input SparkPlan and uses it to instantiate + * the physical operator that's being tested. + * @param expectedAnswer the expected result in a [[Seq]] of [[Product]]s. + */ + protected def checkAnswer[A <: Product : TypeTag]( + input: DataFrame, + planFunction: SparkPlan => SparkPlan, + expectedAnswer: Seq[A]): Unit = { + val expectedRows = expectedAnswer.map(Row.fromTuple) + checkAnswer(input, planFunction, expectedRows) + } + + /** + * Runs the plan and makes sure the answer matches the expected result. + * @param left the left input data to be used. + * @param right the right input data to be used. + * @param planFunction a function which accepts the input SparkPlan and uses it to instantiate + * the physical operator that's being tested. + * @param expectedAnswer the expected result in a [[Seq]] of [[Product]]s. + */ + protected def checkAnswer[A <: Product : TypeTag]( + left: DataFrame, + right: DataFrame, + planFunction: (SparkPlan, SparkPlan) => SparkPlan, + expectedAnswer: Seq[A]): Unit = { + val expectedRows = expectedAnswer.map(Row.fromTuple) + checkAnswer(left, right, planFunction, expectedRows) + } + + /** + * Runs the plan and makes sure the answer matches the expected result. + * @param input the input data to be used. + * @param planFunction a function which accepts the input SparkPlan and uses it to instantiate + * the physical operator that's being tested. + * @param expectedAnswer the expected result in a [[Seq]] of [[Product]]s. + */ + protected def checkAnswer[A <: Product : TypeTag]( + input: Seq[DataFrame], + planFunction: Seq[SparkPlan] => SparkPlan, + expectedAnswer: Seq[A]): Unit = { + val expectedRows = expectedAnswer.map(Row.fromTuple) + checkAnswer(input, planFunction, expectedRows) + } + +} + +/** + * Helper methods for writing tests of individual physical operators. + */ +object SparkPlanTest { + + /** + * Runs the plan and makes sure the answer matches the expected result. + * @param input the input data to be used. + * @param planFunction a function which accepts the input SparkPlan and uses it to instantiate + * the physical operator that's being tested. + * @param expectedAnswer the expected result in a [[Seq]] of [[Row]]s. + */ + def checkAnswer( + input: Seq[DataFrame], + planFunction: Seq[SparkPlan] => SparkPlan, + expectedAnswer: Seq[Row]): Option[String] = { + + val outputPlan = planFunction(input.map(_.queryExecution.sparkPlan)) + + // A very simple resolver to make writing tests easier. In contrast to the real resolver + // this is always case sensitive and does not try to handle scoping or complex type resolution. + val resolvedPlan = TestSQLContext.prepareForExecution.execute( + outputPlan transform { + case plan: SparkPlan => + val inputMap = plan.children.flatMap(_.output).map(a => (a.name, a)).toMap + plan.transformExpressions { + case UnresolvedAttribute(Seq(u)) => + inputMap.getOrElse(u, + sys.error(s"Invalid Test: Cannot resolve $u given input $inputMap")) + } + } + ) + + def prepareAnswer(answer: Seq[Row]): Seq[Row] = { + // Converts data to types that we can do equality comparison using Scala collections. + // For BigDecimal type, the Scala type has a better definition of equality test (similar to + // Java's java.math.BigDecimal.compareTo). + // For binary arrays, we convert it to Seq to avoid of calling java.util.Arrays.equals for + // equality test. + // This function is copied from Catalyst's QueryTest + val converted: Seq[Row] = answer.map { s => + Row.fromSeq(s.toSeq.map { + case d: java.math.BigDecimal => BigDecimal(d) + case b: Array[Byte] => b.toSeq + case o => o + }) + } + converted.sortBy(_.toString()) + } + + val sparkAnswer: Seq[Row] = try { + resolvedPlan.executeCollect().toSeq + } catch { + case NonFatal(e) => + val errorMessage = + s""" + | Exception thrown while executing Spark plan: + | $outputPlan + | == Exception == + | $e + | ${org.apache.spark.sql.catalyst.util.stackTraceToString(e)} + """.stripMargin + return Some(errorMessage) + } + + if (prepareAnswer(expectedAnswer) != prepareAnswer(sparkAnswer)) { + val errorMessage = + s""" + | Results do not match for Spark plan: + | $outputPlan + | == Results == + | ${sideBySide( + s"== Correct Answer - ${expectedAnswer.size} ==" +: + prepareAnswer(expectedAnswer).map(_.toString()), + s"== Spark Answer - ${sparkAnswer.size} ==" +: + prepareAnswer(sparkAnswer).map(_.toString())).mkString("\n")} + """.stripMargin + return Some(errorMessage) + } + + None + } +} + diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/SparkSqlSerializer2Suite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/SparkSqlSerializer2Suite.scala index 15337c404543..8631e247c6c0 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/SparkSqlSerializer2Suite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/SparkSqlSerializer2Suite.scala @@ -19,17 +19,17 @@ package org.apache.spark.sql.execution import java.sql.{Timestamp, Date} -import org.scalatest.{FunSuite, BeforeAndAfterAll} +import org.apache.spark.sql.test.TestSQLContext +import org.scalatest.BeforeAndAfterAll import org.apache.spark.rdd.ShuffledRDD import org.apache.spark.serializer.Serializer -import org.apache.spark.ShuffleDependency +import org.apache.spark.{ShuffleDependency, SparkFunSuite} import org.apache.spark.sql.types._ import org.apache.spark.sql.Row -import org.apache.spark.sql.test.TestSQLContext._ import org.apache.spark.sql.{MyDenseVectorUDT, QueryTest} -class SparkSqlSerializer2DataTypeSuite extends FunSuite { +class SparkSqlSerializer2DataTypeSuite extends SparkFunSuite { // Make sure that we will not use serializer2 for unsupported data types. def checkSupported(dataType: DataType, isSupported: Boolean): Unit = { val testName = @@ -74,11 +74,13 @@ abstract class SparkSqlSerializer2Suite extends QueryTest with BeforeAndAfterAll var numShufflePartitions: Int = _ var useSerializer2: Boolean = _ + protected lazy val ctx = TestSQLContext + override def beforeAll(): Unit = { - numShufflePartitions = conf.numShufflePartitions - useSerializer2 = conf.useSqlSerializer2 + numShufflePartitions = ctx.conf.numShufflePartitions + useSerializer2 = ctx.conf.useSqlSerializer2 - sql("set spark.sql.useSerializer2=true") + ctx.sql("set spark.sql.useSerializer2=true") val supportedTypes = Seq(StringType, BinaryType, NullType, BooleanType, @@ -94,7 +96,7 @@ abstract class SparkSqlSerializer2Suite extends QueryTest with BeforeAndAfterAll // Create a RDD with all data types supported by SparkSqlSerializer2. val rdd = - sparkContext.parallelize((1 to 1000), 10).map { i => + ctx.sparkContext.parallelize((1 to 1000), 10).map { i => Row( s"str${i}: test serializer2.", s"binary${i}: test serializer2.".getBytes("UTF-8"), @@ -112,15 +114,15 @@ abstract class SparkSqlSerializer2Suite extends QueryTest with BeforeAndAfterAll new Timestamp(i)) } - createDataFrame(rdd, schema).registerTempTable("shuffle") + ctx.createDataFrame(rdd, schema).registerTempTable("shuffle") super.beforeAll() } override def afterAll(): Unit = { - dropTempTable("shuffle") - sql(s"set spark.sql.shuffle.partitions=$numShufflePartitions") - sql(s"set spark.sql.useSerializer2=$useSerializer2") + ctx.dropTempTable("shuffle") + ctx.sql(s"set spark.sql.shuffle.partitions=$numShufflePartitions") + ctx.sql(s"set spark.sql.useSerializer2=$useSerializer2") super.afterAll() } @@ -141,16 +143,16 @@ abstract class SparkSqlSerializer2Suite extends QueryTest with BeforeAndAfterAll } test("key schema and value schema are not nulls") { - val df = sql(s"SELECT DISTINCT ${allColumns} FROM shuffle") + val df = ctx.sql(s"SELECT DISTINCT ${allColumns} FROM shuffle") checkSerializer(df.queryExecution.executedPlan, serializerClass) checkAnswer( df, - table("shuffle").collect()) + ctx.table("shuffle").collect()) } test("key schema is null") { val aggregations = allColumns.split(",").map(c => s"COUNT($c)").mkString(",") - val df = sql(s"SELECT $aggregations FROM shuffle") + val df = ctx.sql(s"SELECT $aggregations FROM shuffle") checkSerializer(df.queryExecution.executedPlan, serializerClass) checkAnswer( df, @@ -158,15 +160,14 @@ abstract class SparkSqlSerializer2Suite extends QueryTest with BeforeAndAfterAll } test("value schema is null") { - val df = sql(s"SELECT col0 FROM shuffle ORDER BY col0") + val df = ctx.sql(s"SELECT col0 FROM shuffle ORDER BY col0") checkSerializer(df.queryExecution.executedPlan, serializerClass) - assert( - df.map(r => r.getString(0)).collect().toSeq === - table("shuffle").select("col0").map(r => r.getString(0)).collect().sorted.toSeq) + assert(df.map(r => r.getString(0)).collect().toSeq === + ctx.table("shuffle").select("col0").map(r => r.getString(0)).collect().sorted.toSeq) } test("no map output field") { - val df = sql(s"SELECT 1 + 1 FROM shuffle") + val df = ctx.sql(s"SELECT 1 + 1 FROM shuffle") checkSerializer(df.queryExecution.executedPlan, classOf[SparkSqlSerializer]) } } @@ -177,8 +178,8 @@ class SparkSqlSerializer2SortShuffleSuite extends SparkSqlSerializer2Suite { super.beforeAll() // Sort merge will not be triggered. val bypassMergeThreshold = - sparkContext.conf.getInt("spark.shuffle.sort.bypassMergeThreshold", 200) - sql(s"set spark.sql.shuffle.partitions=${bypassMergeThreshold-1}") + ctx.sparkContext.conf.getInt("spark.shuffle.sort.bypassMergeThreshold", 200) + ctx.sql(s"set spark.sql.shuffle.partitions=${bypassMergeThreshold-1}") } } @@ -189,7 +190,7 @@ class SparkSqlSerializer2SortMergeShuffleSuite extends SparkSqlSerializer2Suite super.beforeAll() // To trigger the sort merge. val bypassMergeThreshold = - sparkContext.conf.getInt("spark.shuffle.sort.bypassMergeThreshold", 200) - sql(s"set spark.sql.shuffle.partitions=${bypassMergeThreshold + 1}") + ctx.sparkContext.conf.getInt("spark.shuffle.sort.bypassMergeThreshold", 200) + ctx.sql(s"set spark.sql.shuffle.partitions=${bypassMergeThreshold + 1}") } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/debug/DebuggingSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/debug/DebuggingSuite.scala index 358d8cf06e46..8ec3985e0036 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/debug/DebuggingSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/debug/DebuggingSuite.scala @@ -17,12 +17,11 @@ package org.apache.spark.sql.execution.debug -import org.scalatest.FunSuite - +import org.apache.spark.SparkFunSuite import org.apache.spark.sql.TestData._ import org.apache.spark.sql.test.TestSQLContext._ -class DebuggingSuite extends FunSuite { +class DebuggingSuite extends SparkFunSuite { test("DataFrame.debug()") { testData.debug() } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/HashedRelationSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/HashedRelationSuite.scala index 2aad01ded1ac..71db6a215985 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/HashedRelationSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/HashedRelationSuite.scala @@ -17,47 +17,46 @@ package org.apache.spark.sql.execution.joins -import org.scalatest.FunSuite - -import org.apache.spark.sql.catalyst.expressions.{Projection, Row} +import org.apache.spark.SparkFunSuite +import org.apache.spark.sql.catalyst.expressions.{Projection, InternalRow} import org.apache.spark.util.collection.CompactBuffer -class HashedRelationSuite extends FunSuite { +class HashedRelationSuite extends SparkFunSuite { // Key is simply the record itself private val keyProjection = new Projection { - override def apply(row: Row): Row = row + override def apply(row: InternalRow): InternalRow = row } test("GeneralHashedRelation") { - val data = Array(Row(0), Row(1), Row(2), Row(2)) + val data = Array(InternalRow(0), InternalRow(1), InternalRow(2), InternalRow(2)) val hashed = HashedRelation(data.iterator, keyProjection) assert(hashed.isInstanceOf[GeneralHashedRelation]) - assert(hashed.get(data(0)) == CompactBuffer[Row](data(0))) - assert(hashed.get(data(1)) == CompactBuffer[Row](data(1))) - assert(hashed.get(Row(10)) === null) + assert(hashed.get(data(0)) == CompactBuffer[InternalRow](data(0))) + assert(hashed.get(data(1)) == CompactBuffer[InternalRow](data(1))) + assert(hashed.get(InternalRow(10)) === null) - val data2 = CompactBuffer[Row](data(2)) + val data2 = CompactBuffer[InternalRow](data(2)) data2 += data(2) assert(hashed.get(data(2)) == data2) } test("UniqueKeyHashedRelation") { - val data = Array(Row(0), Row(1), Row(2)) + val data = Array(InternalRow(0), InternalRow(1), InternalRow(2)) val hashed = HashedRelation(data.iterator, keyProjection) assert(hashed.isInstanceOf[UniqueKeyHashedRelation]) - assert(hashed.get(data(0)) == CompactBuffer[Row](data(0))) - assert(hashed.get(data(1)) == CompactBuffer[Row](data(1))) - assert(hashed.get(data(2)) == CompactBuffer[Row](data(2))) - assert(hashed.get(Row(10)) === null) + assert(hashed.get(data(0)) == CompactBuffer[InternalRow](data(0))) + assert(hashed.get(data(1)) == CompactBuffer[InternalRow](data(1))) + assert(hashed.get(data(2)) == CompactBuffer[InternalRow](data(2))) + assert(hashed.get(InternalRow(10)) === null) val uniqHashed = hashed.asInstanceOf[UniqueKeyHashedRelation] assert(uniqHashed.getValue(data(0)) == data(0)) assert(uniqHashed.getValue(data(1)) == data(1)) assert(uniqHashed.getValue(data(2)) == data(2)) - assert(uniqHashed.getValue(Row(10)) == null) + assert(uniqHashed.getValue(InternalRow(10)) == null) } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/OuterJoinSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/OuterJoinSuite.scala new file mode 100644 index 000000000000..5707d2fb300a --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/OuterJoinSuite.scala @@ -0,0 +1,88 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.joins + +import org.apache.spark.sql.catalyst.dsl.expressions._ +import org.apache.spark.sql.catalyst.expressions.{Expression, LessThan} +import org.apache.spark.sql.catalyst.plans.{FullOuter, LeftOuter, RightOuter} +import org.apache.spark.sql.execution.{SparkPlan, SparkPlanTest} + +class OuterJoinSuite extends SparkPlanTest { + + val left = Seq( + (1, 2.0), + (2, 1.0), + (3, 3.0) + ).toDF("a", "b") + + val right = Seq( + (2, 3.0), + (3, 2.0), + (4, 1.0) + ).toDF("c", "d") + + val leftKeys: List[Expression] = 'a :: Nil + val rightKeys: List[Expression] = 'c :: Nil + val condition = Some(LessThan('b, 'd)) + + test("shuffled hash outer join") { + checkAnswer(left, right, (left: SparkPlan, right: SparkPlan) => + ShuffledHashOuterJoin(leftKeys, rightKeys, LeftOuter, condition, left, right), + Seq( + (1, 2.0, null, null), + (2, 1.0, 2, 3.0), + (3, 3.0, null, null) + )) + + checkAnswer(left, right, (left: SparkPlan, right: SparkPlan) => + ShuffledHashOuterJoin(leftKeys, rightKeys, RightOuter, condition, left, right), + Seq( + (2, 1.0, 2, 3.0), + (null, null, 3, 2.0), + (null, null, 4, 1.0) + )) + + checkAnswer(left, right, (left: SparkPlan, right: SparkPlan) => + ShuffledHashOuterJoin(leftKeys, rightKeys, FullOuter, condition, left, right), + Seq( + (1, 2.0, null, null), + (2, 1.0, 2, 3.0), + (3, 3.0, null, null), + (null, null, 3, 2.0), + (null, null, 4, 1.0) + )) + } + + test("broadcast hash outer join") { + checkAnswer(left, right, (left: SparkPlan, right: SparkPlan) => + BroadcastHashOuterJoin(leftKeys, rightKeys, LeftOuter, condition, left, right), + Seq( + (1, 2.0, null, null), + (2, 1.0, 2, 3.0), + (3, 3.0, null, null) + )) + + checkAnswer(left, right, (left: SparkPlan, right: SparkPlan) => + BroadcastHashOuterJoin(leftKeys, rightKeys, RightOuter, condition, left, right), + Seq( + (2, 1.0, 2, 3.0), + (null, null, 3, 2.0), + (null, null, 4, 1.0) + )) + } +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCSuite.scala index 2abfe7f167f7..69ab1c292d22 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCSuite.scala @@ -21,20 +21,30 @@ import java.math.BigDecimal import java.sql.DriverManager import java.util.{Calendar, GregorianCalendar, Properties} -import org.apache.spark.sql.test._ -import org.apache.spark.sql.types._ import org.h2.jdbc.JdbcSQLException -import org.scalatest.{FunSuite, BeforeAndAfter} -import TestSQLContext._ -import TestSQLContext.implicits._ +import org.scalatest.BeforeAndAfter + +import org.apache.spark.SparkFunSuite +import org.apache.spark.sql.types._ -class JDBCSuite extends FunSuite with BeforeAndAfter { +class JDBCSuite extends SparkFunSuite with BeforeAndAfter { val url = "jdbc:h2:mem:testdb0" val urlWithUserAndPass = "jdbc:h2:mem:testdb0;user=testUser;password=testPass" var conn: java.sql.Connection = null val testBytes = Array[Byte](99.toByte, 134.toByte, 135.toByte, 200.toByte, 205.toByte) + val testH2Dialect = new JdbcDialect { + override def canHandle(url: String) : Boolean = url.startsWith("jdbc:h2") + override def getCatalystType( + sqlType: Int, typeName: String, size: Int, md: MetadataBuilder): Option[DataType] = + Some(StringType) + } + + private lazy val ctx = org.apache.spark.sql.test.TestSQLContext + import ctx.implicits._ + import ctx.sql + before { Class.forName("org.h2.Driver") // Extra properties that will be specified for our database. We need these to test @@ -61,6 +71,14 @@ class JDBCSuite extends FunSuite with BeforeAndAfter { |OPTIONS (url '$url', dbtable 'TEST.PEOPLE', user 'testUser', password 'testPass') """.stripMargin.replaceAll("\n", " ")) + sql( + s""" + |CREATE TEMPORARY TABLE fetchtwo + |USING org.apache.spark.sql.jdbc + |OPTIONS (url '$url', dbtable 'TEST.PEOPLE', user 'testUser', password 'testPass', + | fetchSize '2') + """.stripMargin.replaceAll("\n", " ")) + sql( s""" |CREATE TEMPORARY TABLE parts @@ -178,6 +196,14 @@ class JDBCSuite extends FunSuite with BeforeAndAfter { assert(names(2).equals("mary")) } + test("SELECT first field when fetchSize is two") { + val names = sql("SELECT NAME FROM fetchtwo").collect().map(x => x.getString(0)).sortWith(_ < _) + assert(names.size === 3) + assert(names(0).equals("fred")) + assert(names(1).equals("joe 'foo' \"bar\"")) + assert(names(2).equals("mary")) + } + test("SELECT second field") { val ids = sql("SELECT THEID FROM foobar").collect().map(x => x.getInt(0)).sortWith(_ < _) assert(ids.size === 3) @@ -186,6 +212,14 @@ class JDBCSuite extends FunSuite with BeforeAndAfter { assert(ids(2) === 3) } + test("SELECT second field when fetchSize is two") { + val ids = sql("SELECT THEID FROM fetchtwo").collect().map(x => x.getInt(0)).sortWith(_ < _) + assert(ids.size === 3) + assert(ids(0) === 1) + assert(ids(1) === 2) + assert(ids(2) === 3) + } + test("SELECT * partitioned") { assert(sql("SELECT * FROM parts").collect().size == 3) } @@ -221,22 +255,32 @@ class JDBCSuite extends FunSuite with BeforeAndAfter { } test("Basic API") { - assert(TestSQLContext.jdbc(urlWithUserAndPass, "TEST.PEOPLE").collect().size === 3) + assert(ctx.read.jdbc( + urlWithUserAndPass, "TEST.PEOPLE", new Properties).collect().length === 3) + } + + test("Basic API with FetchSize") { + val properties = new Properties + properties.setProperty("fetchSize", "2") + assert(ctx.read.jdbc( + urlWithUserAndPass, "TEST.PEOPLE", properties).collect().length === 3) } test("Partitioning via JDBCPartitioningInfo API") { - assert(TestSQLContext.jdbc(urlWithUserAndPass, "TEST.PEOPLE", "THEID", 0, 4, 3) - .collect.size === 3) + assert( + ctx.read.jdbc(urlWithUserAndPass, "TEST.PEOPLE", "THEID", 0, 4, 3, new Properties) + .collect().length === 3) } test("Partitioning via list-of-where-clauses API") { val parts = Array[String]("THEID < 2", "THEID >= 2") - assert(TestSQLContext.jdbc(urlWithUserAndPass, "TEST.PEOPLE", parts).collect().size === 3) + assert(ctx.read.jdbc(urlWithUserAndPass, "TEST.PEOPLE", parts, new Properties) + .collect().length === 3) } test("H2 integral types") { val rows = sql("SELECT * FROM inttypes WHERE A IS NOT NULL").collect() - assert(rows.size === 1) + assert(rows.length === 1) assert(rows(0).getInt(0) === 1) assert(rows(0).getBoolean(1) === false) assert(rows(0).getInt(2) === 3) @@ -246,7 +290,7 @@ class JDBCSuite extends FunSuite with BeforeAndAfter { test("H2 null entries") { val rows = sql("SELECT * FROM inttypes WHERE A IS NULL").collect() - assert(rows.size === 1) + assert(rows.length === 1) assert(rows(0).isNullAt(0)) assert(rows(0).isNullAt(1)) assert(rows(0).isNullAt(2)) @@ -282,28 +326,31 @@ class JDBCSuite extends FunSuite with BeforeAndAfter { assert(cal.get(Calendar.HOUR) === 11) assert(cal.get(Calendar.MINUTE) === 22) assert(cal.get(Calendar.SECOND) === 33) - assert(rows(0).getAs[java.sql.Timestamp](2).getNanos === 543543543) + assert(rows(0).getAs[java.sql.Timestamp](2).getNanos === 543543500) } test("test DATE types") { - val rows = TestSQLContext.jdbc(urlWithUserAndPass, "TEST.TIMETYPES").collect() - val cachedRows = TestSQLContext.jdbc(urlWithUserAndPass, "TEST.TIMETYPES").cache().collect() + val rows = ctx.read.jdbc( + urlWithUserAndPass, "TEST.TIMETYPES", new Properties).collect() + val cachedRows = ctx.read.jdbc(urlWithUserAndPass, "TEST.TIMETYPES", new Properties) + .cache().collect() assert(rows(0).getAs[java.sql.Date](1) === java.sql.Date.valueOf("1996-01-01")) assert(rows(1).getAs[java.sql.Date](1) === null) assert(cachedRows(0).getAs[java.sql.Date](1) === java.sql.Date.valueOf("1996-01-01")) } test("test DATE types in cache") { - val rows = TestSQLContext.jdbc(urlWithUserAndPass, "TEST.TIMETYPES").collect() - TestSQLContext - .jdbc(urlWithUserAndPass, "TEST.TIMETYPES").cache().registerTempTable("mycached_date") + val rows = ctx.read.jdbc(urlWithUserAndPass, "TEST.TIMETYPES", new Properties).collect() + ctx.read.jdbc(urlWithUserAndPass, "TEST.TIMETYPES", new Properties) + .cache().registerTempTable("mycached_date") val cachedRows = sql("select * from mycached_date").collect() assert(rows(0).getAs[java.sql.Date](1) === java.sql.Date.valueOf("1996-01-01")) assert(cachedRows(0).getAs[java.sql.Date](1) === java.sql.Date.valueOf("1996-01-01")) } test("test types for null value") { - val rows = TestSQLContext.jdbc(urlWithUserAndPass, "TEST.NULLTYPES").collect() + val rows = ctx.read.jdbc( + urlWithUserAndPass, "TEST.NULLTYPES", new Properties).collect() assert((0 to 14).forall(i => rows(0).isNullAt(i))) } @@ -346,4 +393,55 @@ class JDBCSuite extends FunSuite with BeforeAndAfter { """.stripMargin.replaceAll("\n", " ")) } } + + test("Remap types via JdbcDialects") { + JdbcDialects.registerDialect(testH2Dialect) + val df = ctx.read.jdbc(urlWithUserAndPass, "TEST.PEOPLE", new Properties) + assert(df.schema.filter(_.dataType != org.apache.spark.sql.types.StringType).isEmpty) + val rows = df.collect() + assert(rows(0).get(0).isInstanceOf[String]) + assert(rows(0).get(1).isInstanceOf[String]) + JdbcDialects.unregisterDialect(testH2Dialect) + } + + test("Default jdbc dialect registration") { + assert(JdbcDialects.get("jdbc:mysql://127.0.0.1/db") == MySQLDialect) + assert(JdbcDialects.get("jdbc:postgresql://127.0.0.1/db") == PostgresDialect) + assert(JdbcDialects.get("test.invalid") == NoopDialect) + } + + test("quote column names by jdbc dialect") { + val MySQL = JdbcDialects.get("jdbc:mysql://127.0.0.1/db") + val Postgres = JdbcDialects.get("jdbc:postgresql://127.0.0.1/db") + + val columns = Seq("abc", "key") + val MySQLColumns = columns.map(MySQL.quoteIdentifier(_)) + val PostgresColumns = columns.map(Postgres.quoteIdentifier(_)) + assert(MySQLColumns === Seq("`abc`", "`key`")) + assert(PostgresColumns === Seq(""""abc"""", """"key"""")) + } + + test("Dialect unregister") { + JdbcDialects.registerDialect(testH2Dialect) + JdbcDialects.unregisterDialect(testH2Dialect) + assert(JdbcDialects.get(urlWithUserAndPass) == NoopDialect) + } + + test("Aggregated dialects") { + val agg = new AggregatedDialect(List(new JdbcDialect { + override def canHandle(url: String) : Boolean = url.startsWith("jdbc:h2:") + override def getCatalystType( + sqlType: Int, typeName: String, size: Int, md: MetadataBuilder): Option[DataType] = + if (sqlType % 2 == 0) { + Some(LongType) + } else { + None + } + }, testH2Dialect)) + assert(agg.canHandle("jdbc:h2:xxx")) + assert(!agg.canHandle("jdbc:h2")) + assert(agg.getCatalystType(0, "", 1, null) === Some(LongType)) + assert(agg.getCatalystType(1, "", 1, null) === Some(StringType)) + } + } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCWriteSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCWriteSuite.scala index 0800eded443d..d949ef42267e 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCWriteSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCWriteSuite.scala @@ -20,13 +20,13 @@ package org.apache.spark.sql.jdbc import java.sql.DriverManager import java.util.Properties -import org.scalatest.{BeforeAndAfter, FunSuite} +import org.scalatest.BeforeAndAfter -import org.apache.spark.sql.Row -import org.apache.spark.sql.test._ +import org.apache.spark.SparkFunSuite +import org.apache.spark.sql.{SaveMode, Row} import org.apache.spark.sql.types._ -class JDBCWriteSuite extends FunSuite with BeforeAndAfter { +class JDBCWriteSuite extends SparkFunSuite with BeforeAndAfter { val url = "jdbc:h2:mem:testdb2" var conn: java.sql.Connection = null val url1 = "jdbc:h2:mem:testdb3" @@ -35,12 +35,16 @@ class JDBCWriteSuite extends FunSuite with BeforeAndAfter { properties.setProperty("user", "testUser") properties.setProperty("password", "testPass") properties.setProperty("rowId", "false") - + + private lazy val ctx = org.apache.spark.sql.test.TestSQLContext + import ctx.implicits._ + import ctx.sql + before { Class.forName("org.h2.Driver") conn = DriverManager.getConnection(url) conn.prepareStatement("create schema test").executeUpdate() - + conn1 = DriverManager.getConnection(url1, properties) conn1.prepareStatement("create schema test").executeUpdate() conn1.prepareStatement("drop table if exists test.people").executeUpdate() @@ -52,20 +56,20 @@ class JDBCWriteSuite extends FunSuite with BeforeAndAfter { conn1.prepareStatement( "create table test.people1 (name TEXT(32) NOT NULL, theid INTEGER NOT NULL)").executeUpdate() conn1.commit() - - TestSQLContext.sql( + + ctx.sql( s""" |CREATE TEMPORARY TABLE PEOPLE |USING org.apache.spark.sql.jdbc |OPTIONS (url '$url1', dbtable 'TEST.PEOPLE', user 'testUser', password 'testPass') """.stripMargin.replaceAll("\n", " ")) - - TestSQLContext.sql( + + ctx.sql( s""" |CREATE TEMPORARY TABLE PEOPLE1 |USING org.apache.spark.sql.jdbc |OPTIONS (url '$url1', dbtable 'TEST.PEOPLE1', user 'testUser', password 'testPass') - """.stripMargin.replaceAll("\n", " ")) + """.stripMargin.replaceAll("\n", " ")) } after { @@ -73,81 +77,81 @@ class JDBCWriteSuite extends FunSuite with BeforeAndAfter { conn1.close() } - val sc = TestSQLContext.sparkContext + private lazy val sc = ctx.sparkContext - val arr2x2 = Array[Row](Row.apply("dave", 42), Row.apply("mary", 222)) - val arr1x2 = Array[Row](Row.apply("fred", 3)) - val schema2 = StructType( + private lazy val arr2x2 = Array[Row](Row.apply("dave", 42), Row.apply("mary", 222)) + private lazy val arr1x2 = Array[Row](Row.apply("fred", 3)) + private lazy val schema2 = StructType( StructField("name", StringType) :: StructField("id", IntegerType) :: Nil) - val arr2x3 = Array[Row](Row.apply("dave", 42, 1), Row.apply("mary", 222, 2)) - val schema3 = StructType( + private lazy val arr2x3 = Array[Row](Row.apply("dave", 42, 1), Row.apply("mary", 222, 2)) + private lazy val schema3 = StructType( StructField("name", StringType) :: StructField("id", IntegerType) :: StructField("seq", IntegerType) :: Nil) test("Basic CREATE") { - val df = TestSQLContext.createDataFrame(sc.parallelize(arr2x2), schema2) + val df = ctx.createDataFrame(sc.parallelize(arr2x2), schema2) - df.createJDBCTable(url, "TEST.BASICCREATETEST", false) - assert(2 == TestSQLContext.jdbc(url, "TEST.BASICCREATETEST").count) - assert(2 == TestSQLContext.jdbc(url, "TEST.BASICCREATETEST").collect()(0).length) + df.write.jdbc(url, "TEST.BASICCREATETEST", new Properties) + assert(2 === ctx.read.jdbc(url, "TEST.BASICCREATETEST", new Properties).count) + assert(2 === ctx.read.jdbc(url, "TEST.BASICCREATETEST", new Properties).collect()(0).length) } test("CREATE with overwrite") { - val df = TestSQLContext.createDataFrame(sc.parallelize(arr2x3), schema3) - val df2 = TestSQLContext.createDataFrame(sc.parallelize(arr1x2), schema2) + val df = ctx.createDataFrame(sc.parallelize(arr2x3), schema3) + val df2 = ctx.createDataFrame(sc.parallelize(arr1x2), schema2) - df.createJDBCTable(url1, "TEST.DROPTEST", false, properties) - assert(2 == TestSQLContext.jdbc(url1, "TEST.DROPTEST", properties).count) - assert(3 == TestSQLContext.jdbc(url1, "TEST.DROPTEST", properties).collect()(0).length) + df.write.jdbc(url1, "TEST.DROPTEST", properties) + assert(2 === ctx.read.jdbc(url1, "TEST.DROPTEST", properties).count) + assert(3 === ctx.read.jdbc(url1, "TEST.DROPTEST", properties).collect()(0).length) - df2.createJDBCTable(url1, "TEST.DROPTEST", true, properties) - assert(1 == TestSQLContext.jdbc(url1, "TEST.DROPTEST", properties).count) - assert(2 == TestSQLContext.jdbc(url1, "TEST.DROPTEST", properties).collect()(0).length) + df2.write.mode(SaveMode.Overwrite).jdbc(url1, "TEST.DROPTEST", properties) + assert(1 === ctx.read.jdbc(url1, "TEST.DROPTEST", properties).count) + assert(2 === ctx.read.jdbc(url1, "TEST.DROPTEST", properties).collect()(0).length) } test("CREATE then INSERT to append") { - val df = TestSQLContext.createDataFrame(sc.parallelize(arr2x2), schema2) - val df2 = TestSQLContext.createDataFrame(sc.parallelize(arr1x2), schema2) + val df = ctx.createDataFrame(sc.parallelize(arr2x2), schema2) + val df2 = ctx.createDataFrame(sc.parallelize(arr1x2), schema2) - df.createJDBCTable(url, "TEST.APPENDTEST", false) - df2.insertIntoJDBC(url, "TEST.APPENDTEST", false) - assert(3 == TestSQLContext.jdbc(url, "TEST.APPENDTEST").count) - assert(2 == TestSQLContext.jdbc(url, "TEST.APPENDTEST").collect()(0).length) + df.write.jdbc(url, "TEST.APPENDTEST", new Properties) + df2.write.mode(SaveMode.Append).jdbc(url, "TEST.APPENDTEST", new Properties) + assert(3 === ctx.read.jdbc(url, "TEST.APPENDTEST", new Properties).count) + assert(2 === ctx.read.jdbc(url, "TEST.APPENDTEST", new Properties).collect()(0).length) } test("CREATE then INSERT to truncate") { - val df = TestSQLContext.createDataFrame(sc.parallelize(arr2x2), schema2) - val df2 = TestSQLContext.createDataFrame(sc.parallelize(arr1x2), schema2) + val df = ctx.createDataFrame(sc.parallelize(arr2x2), schema2) + val df2 = ctx.createDataFrame(sc.parallelize(arr1x2), schema2) - df.createJDBCTable(url1, "TEST.TRUNCATETEST", false, properties) - df2.insertIntoJDBC(url1, "TEST.TRUNCATETEST", true, properties) - assert(1 == TestSQLContext.jdbc(url1, "TEST.TRUNCATETEST", properties).count) - assert(2 == TestSQLContext.jdbc(url1, "TEST.TRUNCATETEST", properties).collect()(0).length) + df.write.jdbc(url1, "TEST.TRUNCATETEST", properties) + df2.write.mode(SaveMode.Overwrite).jdbc(url1, "TEST.TRUNCATETEST", properties) + assert(1 === ctx.read.jdbc(url1, "TEST.TRUNCATETEST", properties).count) + assert(2 === ctx.read.jdbc(url1, "TEST.TRUNCATETEST", properties).collect()(0).length) } test("Incompatible INSERT to append") { - val df = TestSQLContext.createDataFrame(sc.parallelize(arr2x2), schema2) - val df2 = TestSQLContext.createDataFrame(sc.parallelize(arr2x3), schema3) + val df = ctx.createDataFrame(sc.parallelize(arr2x2), schema2) + val df2 = ctx.createDataFrame(sc.parallelize(arr2x3), schema3) - df.createJDBCTable(url, "TEST.INCOMPATIBLETEST", false) + df.write.jdbc(url, "TEST.INCOMPATIBLETEST", new Properties) intercept[org.apache.spark.SparkException] { - df2.insertIntoJDBC(url, "TEST.INCOMPATIBLETEST", true) + df2.write.mode(SaveMode.Append).jdbc(url, "TEST.INCOMPATIBLETEST", new Properties) } } - + test("INSERT to JDBC Datasource") { - TestSQLContext.sql("INSERT INTO TABLE PEOPLE1 SELECT * FROM PEOPLE") - assert(2 == TestSQLContext.jdbc(url1, "TEST.PEOPLE1", properties).count) - assert(2 == TestSQLContext.jdbc(url1, "TEST.PEOPLE1", properties).collect()(0).length) + ctx.sql("INSERT INTO TABLE PEOPLE1 SELECT * FROM PEOPLE") + assert(2 === ctx.read.jdbc(url1, "TEST.PEOPLE1", properties).count) + assert(2 === ctx.read.jdbc(url1, "TEST.PEOPLE1", properties).collect()(0).length) } - + test("INSERT to JDBC Datasource with overwrite") { - TestSQLContext.sql("INSERT INTO TABLE PEOPLE1 SELECT * FROM PEOPLE") - TestSQLContext.sql("INSERT OVERWRITE TABLE PEOPLE1 SELECT * FROM PEOPLE") - assert(2 == TestSQLContext.jdbc(url1, "TEST.PEOPLE1", properties).count) - assert(2 == TestSQLContext.jdbc(url1, "TEST.PEOPLE1", properties).collect()(0).length) - } + ctx.sql("INSERT INTO TABLE PEOPLE1 SELECT * FROM PEOPLE") + ctx.sql("INSERT OVERWRITE TABLE PEOPLE1 SELECT * FROM PEOPLE") + assert(2 === ctx.read.jdbc(url1, "TEST.PEOPLE1", properties).count) + assert(2 === ctx.read.jdbc(url1, "TEST.PEOPLE1", properties).collect()(0).length) + } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/json/JsonSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/json/JsonSuite.scala index b06e3385980f..8204a584179b 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/json/JsonSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/json/JsonSuite.scala @@ -23,22 +23,19 @@ import java.sql.{Date, Timestamp} import com.fasterxml.jackson.core.JsonFactory import org.scalactic.Tolerance._ +import org.apache.spark.sql.{QueryTest, Row, SQLConf} import org.apache.spark.sql.TestData._ -import org.apache.spark.sql.catalyst.util.DateUtils -import org.apache.spark.sql.functions._ +import org.apache.spark.sql.catalyst.util.DateTimeUtils import org.apache.spark.sql.json.InferSchema.compatibleType import org.apache.spark.sql.sources.LogicalRelation -import org.apache.spark.sql.test.TestSQLContext -import org.apache.spark.sql.test.TestSQLContext._ -import org.apache.spark.sql.test.TestSQLContext.implicits._ import org.apache.spark.sql.types._ -import org.apache.spark.sql.{QueryTest, Row, SQLConf} import org.apache.spark.util.Utils -class JsonSuite extends QueryTest { - import org.apache.spark.sql.json.TestJsonData._ +class JsonSuite extends QueryTest with TestJsonData { - TestJsonData + protected lazy val ctx = org.apache.spark.sql.test.TestSQLContext + import ctx.sql + import ctx.implicits._ test("Type promotion") { def checkTypePromotion(expected: Any, actual: Any) { @@ -79,22 +76,28 @@ class JsonSuite extends QueryTest { checkTypePromotion( Decimal(doubleNumber), enforceCorrectType(doubleNumber, DecimalType.Unlimited)) - checkTypePromotion(new Timestamp(intNumber), enforceCorrectType(intNumber, TimestampType)) - checkTypePromotion(new Timestamp(intNumber.toLong), + checkTypePromotion(DateTimeUtils.fromJavaTimestamp(new Timestamp(intNumber)), + enforceCorrectType(intNumber, TimestampType)) + checkTypePromotion(DateTimeUtils.fromJavaTimestamp(new Timestamp(intNumber.toLong)), enforceCorrectType(intNumber.toLong, TimestampType)) val strTime = "2014-09-30 12:34:56" - checkTypePromotion(Timestamp.valueOf(strTime), enforceCorrectType(strTime, TimestampType)) + checkTypePromotion(DateTimeUtils.fromJavaTimestamp(Timestamp.valueOf(strTime)), + enforceCorrectType(strTime, TimestampType)) val strDate = "2014-10-15" checkTypePromotion( - DateUtils.fromJavaDate(Date.valueOf(strDate)), enforceCorrectType(strDate, DateType)) + DateTimeUtils.fromJavaDate(Date.valueOf(strDate)), enforceCorrectType(strDate, DateType)) val ISO8601Time1 = "1970-01-01T01:00:01.0Z" - checkTypePromotion(new Timestamp(3601000), enforceCorrectType(ISO8601Time1, TimestampType)) - checkTypePromotion(DateUtils.millisToDays(3601000), enforceCorrectType(ISO8601Time1, DateType)) + checkTypePromotion(DateTimeUtils.fromJavaTimestamp(new Timestamp(3601000)), + enforceCorrectType(ISO8601Time1, TimestampType)) + checkTypePromotion(DateTimeUtils.millisToDays(3601000), + enforceCorrectType(ISO8601Time1, DateType)) val ISO8601Time2 = "1970-01-01T02:00:01-01:00" - checkTypePromotion(new Timestamp(10801000), enforceCorrectType(ISO8601Time2, TimestampType)) - checkTypePromotion(DateUtils.millisToDays(10801000), enforceCorrectType(ISO8601Time2, DateType)) + checkTypePromotion(DateTimeUtils.fromJavaTimestamp(new Timestamp(10801000)), + enforceCorrectType(ISO8601Time2, TimestampType)) + checkTypePromotion(DateTimeUtils.millisToDays(10801000), + enforceCorrectType(ISO8601Time2, DateType)) } test("Get compatible type") { @@ -215,7 +218,7 @@ class JsonSuite extends QueryTest { } test("Complex field and type inferring with null in sampling") { - val jsonDF = jsonRDD(jsonNullStruct) + val jsonDF = ctx.read.json(jsonNullStruct) val expectedSchema = StructType( StructField("headers", StructType( StructField("Charset", StringType, true) :: @@ -234,7 +237,7 @@ class JsonSuite extends QueryTest { } test("Primitive field and type inferring") { - val jsonDF = jsonRDD(primitiveFieldAndType) + val jsonDF = ctx.read.json(primitiveFieldAndType) val expectedSchema = StructType( StructField("bigInteger", DecimalType.Unlimited, true) :: @@ -262,7 +265,7 @@ class JsonSuite extends QueryTest { } test("Complex field and type inferring") { - val jsonDF = jsonRDD(complexFieldAndType1) + val jsonDF = ctx.read.json(complexFieldAndType1) val expectedSchema = StructType( StructField("arrayOfArray1", ArrayType(ArrayType(StringType, true), true), true) :: @@ -361,7 +364,7 @@ class JsonSuite extends QueryTest { } test("GetField operation on complex data type") { - val jsonDF = jsonRDD(complexFieldAndType1) + val jsonDF = ctx.read.json(complexFieldAndType1) jsonDF.registerTempTable("jsonTable") checkAnswer( @@ -377,7 +380,7 @@ class JsonSuite extends QueryTest { } test("Type conflict in primitive field values") { - val jsonDF = jsonRDD(primitiveFieldValueTypeConflict) + val jsonDF = ctx.read.json(primitiveFieldValueTypeConflict) val expectedSchema = StructType( StructField("num_bool", StringType, true) :: @@ -451,7 +454,7 @@ class JsonSuite extends QueryTest { } ignore("Type conflict in primitive field values (Ignored)") { - val jsonDF = jsonRDD(primitiveFieldValueTypeConflict) + val jsonDF = ctx.read.json(primitiveFieldValueTypeConflict) jsonDF.registerTempTable("jsonTable") // Right now, the analyzer does not promote strings in a boolean expression. @@ -504,7 +507,7 @@ class JsonSuite extends QueryTest { } test("Type conflict in complex field values") { - val jsonDF = jsonRDD(complexFieldValueTypeConflict) + val jsonDF = ctx.read.json(complexFieldValueTypeConflict) val expectedSchema = StructType( StructField("array", ArrayType(LongType, true), true) :: @@ -523,12 +526,12 @@ class JsonSuite extends QueryTest { Row(Seq(), "11", "[1,2,3]", Row(null), "[]") :: Row(null, """{"field":false}""", null, null, "{}") :: Row(Seq(4, 5, 6), null, "str", Row(null), "[7,8,9]") :: - Row(Seq(7), "{}","""["str1","str2",33]""", Row("str"), """{"field":true}""") :: Nil + Row(Seq(7), "{}", """["str1","str2",33]""", Row("str"), """{"field":true}""") :: Nil ) } test("Type conflict in array elements") { - val jsonDF = jsonRDD(arrayElementTypeConflict) + val jsonDF = ctx.read.json(arrayElementTypeConflict) val expectedSchema = StructType( StructField("array1", ArrayType(StringType, true), true) :: @@ -556,7 +559,7 @@ class JsonSuite extends QueryTest { } test("Handling missing fields") { - val jsonDF = jsonRDD(missingFields) + val jsonDF = ctx.read.json(missingFields) val expectedSchema = StructType( StructField("a", BooleanType, true) :: @@ -575,8 +578,9 @@ class JsonSuite extends QueryTest { val dir = Utils.createTempDir() dir.delete() val path = dir.getCanonicalPath - sparkContext.parallelize(1 to 100).map(i => s"""{"a": 1, "b": "str$i"}""").saveAsTextFile(path) - val jsonDF = jsonFile(path, 0.49) + ctx.sparkContext.parallelize(1 to 100) + .map(i => s"""{"a": 1, "b": "str$i"}""").saveAsTextFile(path) + val jsonDF = ctx.read.option("samplingRatio", "0.49").json(path) val analyzed = jsonDF.queryExecution.analyzed assert( @@ -591,7 +595,7 @@ class JsonSuite extends QueryTest { val schema = StructType(StructField("a", LongType, true) :: Nil) val logicalRelation = - jsonFile(path, schema).queryExecution.analyzed.asInstanceOf[LogicalRelation] + ctx.read.schema(schema).json(path).queryExecution.analyzed.asInstanceOf[LogicalRelation] val relationWithSchema = logicalRelation.relation.asInstanceOf[JSONRelation] assert(relationWithSchema.path === Some(path)) assert(relationWithSchema.schema === schema) @@ -603,7 +607,7 @@ class JsonSuite extends QueryTest { dir.delete() val path = dir.getCanonicalPath primitiveFieldAndType.map(record => record.replaceAll("\n", " ")).saveAsTextFile(path) - val jsonDF = jsonFile(path) + val jsonDF = ctx.read.json(path) val expectedSchema = StructType( StructField("bigInteger", DecimalType.Unlimited, true) :: @@ -672,7 +676,7 @@ class JsonSuite extends QueryTest { StructField("null", StringType, true) :: StructField("string", StringType, true) :: Nil) - val jsonDF1 = jsonFile(path, schema) + val jsonDF1 = ctx.read.schema(schema).json(path) assert(schema === jsonDF1.schema) @@ -689,7 +693,7 @@ class JsonSuite extends QueryTest { "this is a simple string.") ) - val jsonDF2 = jsonRDD(primitiveFieldAndType, schema) + val jsonDF2 = ctx.read.schema(schema).json(primitiveFieldAndType) assert(schema === jsonDF2.schema) @@ -710,7 +714,7 @@ class JsonSuite extends QueryTest { test("Applying schemas with MapType") { val schemaWithSimpleMap = StructType( StructField("map", MapType(StringType, IntegerType, true), false) :: Nil) - val jsonWithSimpleMap = jsonRDD(mapType1, schemaWithSimpleMap) + val jsonWithSimpleMap = ctx.read.schema(schemaWithSimpleMap).json(mapType1) jsonWithSimpleMap.registerTempTable("jsonWithSimpleMap") @@ -738,7 +742,7 @@ class JsonSuite extends QueryTest { val schemaWithComplexMap = StructType( StructField("map", MapType(StringType, innerStruct, true), false) :: Nil) - val jsonWithComplexMap = jsonRDD(mapType2, schemaWithComplexMap) + val jsonWithComplexMap = ctx.read.schema(schemaWithComplexMap).json(mapType2) jsonWithComplexMap.registerTempTable("jsonWithComplexMap") @@ -764,7 +768,7 @@ class JsonSuite extends QueryTest { } test("SPARK-2096 Correctly parse dot notations") { - val jsonDF = jsonRDD(complexFieldAndType2) + val jsonDF = ctx.read.json(complexFieldAndType2) jsonDF.registerTempTable("jsonTable") checkAnswer( @@ -782,7 +786,7 @@ class JsonSuite extends QueryTest { } test("SPARK-3390 Complex arrays") { - val jsonDF = jsonRDD(complexFieldAndType2) + val jsonDF = ctx.read.json(complexFieldAndType2) jsonDF.registerTempTable("jsonTable") checkAnswer( @@ -805,7 +809,7 @@ class JsonSuite extends QueryTest { } test("SPARK-3308 Read top level JSON arrays") { - val jsonDF = jsonRDD(jsonArray) + val jsonDF = ctx.read.json(jsonArray) jsonDF.registerTempTable("jsonTable") checkAnswer( @@ -823,10 +827,10 @@ class JsonSuite extends QueryTest { test("Corrupt records") { // Test if we can query corrupt records. - val oldColumnNameOfCorruptRecord = TestSQLContext.conf.columnNameOfCorruptRecord - TestSQLContext.setConf(SQLConf.COLUMN_NAME_OF_CORRUPT_RECORD, "_unparsed") + val oldColumnNameOfCorruptRecord = ctx.conf.columnNameOfCorruptRecord + ctx.setConf(SQLConf.COLUMN_NAME_OF_CORRUPT_RECORD, "_unparsed") - val jsonDF = jsonRDD(corruptRecords) + val jsonDF = ctx.read.json(corruptRecords) jsonDF.registerTempTable("jsonTable") val schema = StructType( @@ -876,11 +880,11 @@ class JsonSuite extends QueryTest { Row("]") :: Nil ) - TestSQLContext.setConf(SQLConf.COLUMN_NAME_OF_CORRUPT_RECORD, oldColumnNameOfCorruptRecord) + ctx.setConf(SQLConf.COLUMN_NAME_OF_CORRUPT_RECORD, oldColumnNameOfCorruptRecord) } test("SPARK-4068: nulls in arrays") { - val jsonDF = jsonRDD(nullsInArrays) + val jsonDF = ctx.read.json(nullsInArrays) jsonDF.registerTempTable("jsonTable") val schema = StructType( @@ -926,7 +930,7 @@ class JsonSuite extends QueryTest { Row(values(0).toInt, values(1), values(2).toBoolean, r.split(",").toList, v5) } - val df1 = createDataFrame(rowRDD1, schema1) + val df1 = ctx.createDataFrame(rowRDD1, schema1) df1.registerTempTable("applySchema1") val df2 = df1.toDF val result = df2.toJSON.collect() @@ -949,7 +953,7 @@ class JsonSuite extends QueryTest { Row(Row(values(0).toInt, values(2).toBoolean), Map(values(1) -> v4)) } - val df3 = createDataFrame(rowRDD2, schema2) + val df3 = ctx.createDataFrame(rowRDD2, schema2) df3.registerTempTable("applySchema2") val df4 = df3.toDF val result2 = df4.toJSON.collect() @@ -957,8 +961,8 @@ class JsonSuite extends QueryTest { assert(result2(1) === "{\"f1\":{\"f11\":2,\"f12\":false},\"f2\":{\"B2\":null}}") assert(result2(3) === "{\"f1\":{\"f11\":4,\"f12\":true},\"f2\":{\"D4\":2147483644}}") - val jsonDF = jsonRDD(primitiveFieldAndType) - val primTable = jsonRDD(jsonDF.toJSON) + val jsonDF = ctx.read.json(primitiveFieldAndType) + val primTable = ctx.read.json(jsonDF.toJSON) primTable.registerTempTable("primativeTable") checkAnswer( sql("select * from primativeTable"), @@ -970,8 +974,8 @@ class JsonSuite extends QueryTest { "this is a simple string.") ) - val complexJsonDF = jsonRDD(complexFieldAndType1) - val compTable = jsonRDD(complexJsonDF.toJSON) + val complexJsonDF = ctx.read.json(complexFieldAndType1) + val compTable = ctx.read.json(complexJsonDF.toJSON) compTable.registerTempTable("complexTable") // Access elements of a primitive array. checkAnswer( @@ -1074,4 +1078,35 @@ class JsonSuite extends QueryTest { assert(StructType(Seq()) === emptySchema) } + test("SPARK-7565 MapType in JsonRDD") { + val useStreaming = ctx.conf.useJacksonStreamingAPI + val oldColumnNameOfCorruptRecord = ctx.conf.columnNameOfCorruptRecord + ctx.setConf(SQLConf.COLUMN_NAME_OF_CORRUPT_RECORD, "_unparsed") + + val schemaWithSimpleMap = StructType( + StructField("map", MapType(StringType, IntegerType, true), false) :: Nil) + try{ + for (useStreaming <- List(true, false)) { + ctx.setConf(SQLConf.USE_JACKSON_STREAMING_API, useStreaming) + val temp = Utils.createTempDir().getPath + + val df = ctx.read.schema(schemaWithSimpleMap).json(mapType1) + df.write.mode("overwrite").parquet(temp) + // order of MapType is not defined + assert(ctx.read.parquet(temp).count() == 5) + + val df2 = ctx.read.json(corruptRecords) + df2.write.mode("overwrite").parquet(temp) + checkAnswer(ctx.read.parquet(temp), df2.collect()) + } + } finally { + ctx.setConf(SQLConf.USE_JACKSON_STREAMING_API, useStreaming) + ctx.setConf(SQLConf.COLUMN_NAME_OF_CORRUPT_RECORD, oldColumnNameOfCorruptRecord) + } + } + + test("SPARK-8093 Erase empty structs") { + val emptySchema = InferSchema(emptyRecords, 1.0, "") + assert(StructType(Seq()) === emptySchema) + } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/json/TestJsonData.scala b/sql/core/src/test/scala/org/apache/spark/sql/json/TestJsonData.scala index 47a97a49daab..eb62066ac643 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/json/TestJsonData.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/json/TestJsonData.scala @@ -17,12 +17,15 @@ package org.apache.spark.sql.json -import org.apache.spark.sql.test.TestSQLContext +import org.apache.spark.rdd.RDD +import org.apache.spark.sql.SQLContext -object TestJsonData { +trait TestJsonData { - val primitiveFieldAndType = - TestSQLContext.sparkContext.parallelize( + protected def ctx: SQLContext + + def primitiveFieldAndType: RDD[String] = + ctx.sparkContext.parallelize( """{"string":"this is a simple string.", "integer":10, "long":21474836470, @@ -32,8 +35,8 @@ object TestJsonData { "null":null }""" :: Nil) - val primitiveFieldValueTypeConflict = - TestSQLContext.sparkContext.parallelize( + def primitiveFieldValueTypeConflict: RDD[String] = + ctx.sparkContext.parallelize( """{"num_num_1":11, "num_num_2":null, "num_num_3": 1.1, "num_bool":true, "num_str":13.1, "str_bool":"str1"}""" :: """{"num_num_1":null, "num_num_2":21474836470.9, "num_num_3": null, @@ -43,15 +46,15 @@ object TestJsonData { """{"num_num_1":21474836570, "num_num_2":1.1, "num_num_3": 21474836470, "num_bool":null, "num_str":92233720368547758070, "str_bool":null}""" :: Nil) - val jsonNullStruct = - TestSQLContext.sparkContext.parallelize( + def jsonNullStruct: RDD[String] = + ctx.sparkContext.parallelize( """{"nullstr":"","ip":"27.31.100.29","headers":{"Host":"1.abc.com","Charset":"UTF-8"}}""" :: """{"nullstr":"","ip":"27.31.100.29","headers":{}}""" :: """{"nullstr":"","ip":"27.31.100.29","headers":""}""" :: """{"nullstr":null,"ip":"27.31.100.29","headers":null}""" :: Nil) - val complexFieldValueTypeConflict = - TestSQLContext.sparkContext.parallelize( + def complexFieldValueTypeConflict: RDD[String] = + ctx.sparkContext.parallelize( """{"num_struct":11, "str_array":[1, 2, 3], "array":[], "struct_array":[], "struct": {}}""" :: """{"num_struct":{"field":false}, "str_array":null, @@ -61,23 +64,23 @@ object TestJsonData { """{"num_struct":{}, "str_array":["str1", "str2", 33], "array":[7], "struct_array":{"field": true}, "struct": {"field": "str"}}""" :: Nil) - val arrayElementTypeConflict = - TestSQLContext.sparkContext.parallelize( + def arrayElementTypeConflict: RDD[String] = + ctx.sparkContext.parallelize( """{"array1": [1, 1.1, true, null, [], {}, [2,3,4], {"field":"str"}], "array2": [{"field":214748364700}, {"field":1}]}""" :: """{"array3": [{"field":"str"}, {"field":1}]}""" :: """{"array3": [1, 2, 3]}""" :: Nil) - val missingFields = - TestSQLContext.sparkContext.parallelize( + def missingFields: RDD[String] = + ctx.sparkContext.parallelize( """{"a":true}""" :: """{"b":21474836470}""" :: """{"c":[33, 44]}""" :: """{"d":{"field":true}}""" :: """{"e":"str"}""" :: Nil) - val complexFieldAndType1 = - TestSQLContext.sparkContext.parallelize( + def complexFieldAndType1: RDD[String] = + ctx.sparkContext.parallelize( """{"struct":{"field1": true, "field2": 92233720368547758070}, "structWithArrayFields":{"field1":[4, 5, 6], "field2":["str1", "str2"]}, "arrayOfString":["str1", "str2"], @@ -92,8 +95,8 @@ object TestJsonData { "arrayOfArray2":[[1, 2, 3], [1.1, 2.1, 3.1]] }""" :: Nil) - val complexFieldAndType2 = - TestSQLContext.sparkContext.parallelize( + def complexFieldAndType2: RDD[String] = + ctx.sparkContext.parallelize( """{"arrayOfStruct":[{"field1": true, "field2": "str1"}, {"field1": false}, {"field3": null}], "complexArrayOfStruct": [ { @@ -146,16 +149,16 @@ object TestJsonData { ]] }""" :: Nil) - val mapType1 = - TestSQLContext.sparkContext.parallelize( + def mapType1: RDD[String] = + ctx.sparkContext.parallelize( """{"map": {"a": 1}}""" :: """{"map": {"b": 2}}""" :: """{"map": {"c": 3}}""" :: """{"map": {"c": 1, "d": 4}}""" :: """{"map": {"e": null}}""" :: Nil) - val mapType2 = - TestSQLContext.sparkContext.parallelize( + def mapType2: RDD[String] = + ctx.sparkContext.parallelize( """{"map": {"a": {"field1": [1, 2, 3, null]}}}""" :: """{"map": {"b": {"field2": 2}}}""" :: """{"map": {"c": {"field1": [], "field2": 4}}}""" :: @@ -163,22 +166,22 @@ object TestJsonData { """{"map": {"e": null}}""" :: """{"map": {"f": {"field1": null}}}""" :: Nil) - val nullsInArrays = - TestSQLContext.sparkContext.parallelize( + def nullsInArrays: RDD[String] = + ctx.sparkContext.parallelize( """{"field1":[[null], [[["Test"]]]]}""" :: """{"field2":[null, [{"Test":1}]]}""" :: """{"field3":[[null], [{"Test":"2"}]]}""" :: """{"field4":[[null, [1,2,3]]]}""" :: Nil) - val jsonArray = - TestSQLContext.sparkContext.parallelize( + def jsonArray: RDD[String] = + ctx.sparkContext.parallelize( """[{"a":"str_a_1"}]""" :: """[{"a":"str_a_2"}, {"b":"str_b_3"}]""" :: """{"b":"str_b_4", "a":"str_a_4", "c":"str_c_4"}""" :: """[]""" :: Nil) - val corruptRecords = - TestSQLContext.sparkContext.parallelize( + def corruptRecords: RDD[String] = + ctx.sparkContext.parallelize( """{""" :: """""" :: """{"a":1, b:2}""" :: @@ -186,6 +189,14 @@ object TestJsonData { """{"b":"str_b_4", "a":"str_a_4", "c":"str_c_4"}""" :: """]""" :: Nil) - val empty = - TestSQLContext.sparkContext.parallelize(Seq[String]()) + def emptyRecords: RDD[String] = + ctx.sparkContext.parallelize( + """{""" :: + """""" :: + """{"a": {}}""" :: + """{"a": {"b": {}}}""" :: + """{"b": [{"c": {}}]}""" :: + """]""" :: Nil) + + def empty: RDD[String] = ctx.sparkContext.parallelize(Seq[String]()) } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetFilterSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetFilterSuite.scala index 5ad439584716..a2763c78b645 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetFilterSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetFilterSuite.scala @@ -18,16 +18,15 @@ package org.apache.spark.sql.parquet import org.scalatest.BeforeAndAfterAll -import parquet.filter2.predicate.Operators._ -import parquet.filter2.predicate.{FilterPredicate, Operators} +import org.apache.parquet.filter2.predicate.Operators._ +import org.apache.parquet.filter2.predicate.{FilterPredicate, Operators} import org.apache.spark.sql.catalyst.dsl.expressions._ import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.planning.PhysicalOperation import org.apache.spark.sql.sources.LogicalRelation -import org.apache.spark.sql.test.TestSQLContext import org.apache.spark.sql.types._ -import org.apache.spark.sql.{Column, DataFrame, QueryTest, SQLConf} +import org.apache.spark.sql.{Column, DataFrame, QueryTest, Row, SQLConf} /** * A test suite that tests Parquet filter2 API based filter pushdown optimization. @@ -42,7 +41,7 @@ import org.apache.spark.sql.{Column, DataFrame, QueryTest, SQLConf} * data type is nullable. */ class ParquetFilterSuiteBase extends QueryTest with ParquetTest { - val sqlContext = TestSQLContext + lazy val sqlContext = org.apache.spark.sql.test.TestSQLContext private def checkFilterPredicate( df: DataFrame, @@ -52,7 +51,7 @@ class ParquetFilterSuiteBase extends QueryTest with ParquetTest { expected: Seq[Row]): Unit = { val output = predicate.collect { case a: Attribute => a }.distinct - withSQLConf(SQLConf.PARQUET_FILTER_PUSHDOWN_ENABLED -> "true") { + withSQLConf(SQLConf.PARQUET_FILTER_PUSHDOWN_ENABLED.key -> "true") { val query = df .select(output.map(e => Column(e)): _*) .where(Column(predicate)) @@ -312,28 +311,28 @@ class ParquetFilterSuiteBase extends QueryTest with ParquetTest { } class ParquetDataSourceOnFilterSuite extends ParquetFilterSuiteBase with BeforeAndAfterAll { - val originalConf = sqlContext.conf.parquetUseDataSourceApi + lazy val originalConf = sqlContext.conf.parquetUseDataSourceApi override protected def beforeAll(): Unit = { - sqlContext.conf.setConf(SQLConf.PARQUET_USE_DATA_SOURCE_API, "true") + sqlContext.conf.setConf(SQLConf.PARQUET_USE_DATA_SOURCE_API, true) } override protected def afterAll(): Unit = { - sqlContext.setConf(SQLConf.PARQUET_USE_DATA_SOURCE_API, originalConf.toString) + sqlContext.setConf(SQLConf.PARQUET_USE_DATA_SOURCE_API, originalConf) } test("SPARK-6554: don't push down predicates which reference partition columns") { import sqlContext.implicits._ - withSQLConf(SQLConf.PARQUET_FILTER_PUSHDOWN_ENABLED -> "true") { + withSQLConf(SQLConf.PARQUET_FILTER_PUSHDOWN_ENABLED.key -> "true") { withTempPath { dir => val path = s"${dir.getCanonicalPath}/part=1" - (1 to 3).map(i => (i, i.toString)).toDF("a", "b").saveAsParquetFile(path) + (1 to 3).map(i => (i, i.toString)).toDF("a", "b").write.parquet(path) // If the "part = 1" filter gets pushed down, this query will throw an exception since // "part" is not a valid column in the actual Parquet file checkAnswer( - sqlContext.parquetFile(path).filter("part = 1"), + sqlContext.read.parquet(path).filter("part = 1"), (1 to 3).map(i => Row(i, i.toString, 1))) } } @@ -341,23 +340,23 @@ class ParquetDataSourceOnFilterSuite extends ParquetFilterSuiteBase with BeforeA } class ParquetDataSourceOffFilterSuite extends ParquetFilterSuiteBase with BeforeAndAfterAll { - val originalConf = sqlContext.conf.parquetUseDataSourceApi + lazy val originalConf = sqlContext.conf.parquetUseDataSourceApi override protected def beforeAll(): Unit = { - sqlContext.conf.setConf(SQLConf.PARQUET_USE_DATA_SOURCE_API, "false") + sqlContext.conf.setConf(SQLConf.PARQUET_USE_DATA_SOURCE_API, false) } override protected def afterAll(): Unit = { - sqlContext.setConf(SQLConf.PARQUET_USE_DATA_SOURCE_API, originalConf.toString) + sqlContext.setConf(SQLConf.PARQUET_USE_DATA_SOURCE_API, originalConf) } test("SPARK-6742: don't push down predicates which reference partition columns") { import sqlContext.implicits._ - withSQLConf(SQLConf.PARQUET_FILTER_PUSHDOWN_ENABLED -> "true") { + withSQLConf(SQLConf.PARQUET_FILTER_PUSHDOWN_ENABLED.key -> "true") { withTempPath { dir => val path = s"${dir.getCanonicalPath}/part=1" - (1 to 3).map(i => (i, i.toString)).toDF("a", "b").saveAsParquetFile(path) + (1 to 3).map(i => (i, i.toString)).toDF("a", "b").write.parquet(path) // If the "part = 1" filter gets pushed down, this query will throw an exception since // "part" is not a valid column in the actual Parquet file diff --git a/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetIOSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetIOSuite.scala index 008443df216a..7b16eba00d6f 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetIOSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetIOSuite.scala @@ -23,24 +23,22 @@ import scala.reflect.runtime.universe.TypeTag import org.apache.hadoop.conf.Configuration import org.apache.hadoop.fs.{FileSystem, Path} +import org.apache.hadoop.mapreduce.{JobContext, TaskAttemptContext} +import org.apache.parquet.example.data.simple.SimpleGroup +import org.apache.parquet.example.data.{Group, GroupWriter} +import org.apache.parquet.hadoop.api.WriteSupport +import org.apache.parquet.hadoop.api.WriteSupport.WriteContext +import org.apache.parquet.hadoop.metadata.{CompressionCodecName, FileMetaData, ParquetMetadata} +import org.apache.parquet.hadoop.{Footer, ParquetFileWriter, ParquetOutputCommitter, ParquetWriter} +import org.apache.parquet.io.api.RecordConsumer +import org.apache.parquet.schema.{MessageType, MessageTypeParser} import org.scalatest.BeforeAndAfterAll -import parquet.example.data.simple.SimpleGroup -import parquet.example.data.{Group, GroupWriter} -import parquet.hadoop.api.WriteSupport -import parquet.hadoop.api.WriteSupport.WriteContext -import parquet.hadoop.metadata.{ParquetMetadata, FileMetaData, CompressionCodecName} -import parquet.hadoop.{Footer, ParquetFileWriter, ParquetWriter} -import parquet.io.api.RecordConsumer -import parquet.schema.{MessageType, MessageTypeParser} +import org.apache.spark.SparkException +import org.apache.spark.sql._ import org.apache.spark.sql.catalyst.ScalaReflection -import org.apache.spark.sql.catalyst.expressions.Row -import org.apache.spark.sql.catalyst.util.DateUtils -import org.apache.spark.sql.test.TestSQLContext -import org.apache.spark.sql.test.TestSQLContext._ -import org.apache.spark.sql.test.TestSQLContext.implicits._ +import org.apache.spark.sql.catalyst.util.DateTimeUtils import org.apache.spark.sql.types._ -import org.apache.spark.sql.{DataFrame, QueryTest, SQLConf, SaveMode} // Write support class for nested groups: ParquetWriter initializes GroupWriteSupport // with an empty configuration (it is after all not intended to be used in this way?) @@ -66,9 +64,8 @@ private[parquet] class TestGroupWriteSupport(schema: MessageType) extends WriteS * A test suite that tests basic Parquet I/O. */ class ParquetIOSuiteBase extends QueryTest with ParquetTest { - val sqlContext = TestSQLContext - - import sqlContext.implicits.localSeqToDataFrameHolder + lazy val sqlContext = org.apache.spark.sql.test.TestSQLContext + import sqlContext.implicits._ /** * Writes `data` to a Parquet file, reads it back and check file contents. @@ -97,14 +94,13 @@ class ParquetIOSuiteBase extends QueryTest with ParquetTest { val data = (1 to 4).map(i => Tuple1(i.toString)) // Property spark.sql.parquet.binaryAsString shouldn't affect Parquet files written by Spark SQL // as we store Spark SQL schema in the extra metadata. - withSQLConf(SQLConf.PARQUET_BINARY_AS_STRING -> "false")(checkParquetFile(data)) - withSQLConf(SQLConf.PARQUET_BINARY_AS_STRING -> "true")(checkParquetFile(data)) + withSQLConf(SQLConf.PARQUET_BINARY_AS_STRING.key -> "false")(checkParquetFile(data)) + withSQLConf(SQLConf.PARQUET_BINARY_AS_STRING.key -> "true")(checkParquetFile(data)) } test("fixed-length decimals") { - def makeDecimalRDD(decimal: DecimalType): DataFrame = - sparkContext + sqlContext.sparkContext .parallelize(0 to 1000) .map(i => Tuple1(i / 100.0)) .toDF() @@ -114,40 +110,40 @@ class ParquetIOSuiteBase extends QueryTest with ParquetTest { for ((precision, scale) <- Seq((5, 2), (1, 0), (1, 1), (18, 10), (18, 17))) { withTempPath { dir => val data = makeDecimalRDD(DecimalType(precision, scale)) - data.saveAsParquetFile(dir.getCanonicalPath) - checkAnswer(parquetFile(dir.getCanonicalPath), data.collect().toSeq) + data.write.parquet(dir.getCanonicalPath) + checkAnswer(sqlContext.read.parquet(dir.getCanonicalPath), data.collect().toSeq) } } // Decimals with precision above 18 are not yet supported intercept[Throwable] { withTempPath { dir => - makeDecimalRDD(DecimalType(19, 10)).saveAsParquetFile(dir.getCanonicalPath) - parquetFile(dir.getCanonicalPath).collect() + makeDecimalRDD(DecimalType(19, 10)).write.parquet(dir.getCanonicalPath) + sqlContext.read.parquet(dir.getCanonicalPath).collect() } } // Unlimited-length decimals are not yet supported intercept[Throwable] { withTempPath { dir => - makeDecimalRDD(DecimalType.Unlimited).saveAsParquetFile(dir.getCanonicalPath) - parquetFile(dir.getCanonicalPath).collect() + makeDecimalRDD(DecimalType.Unlimited).write.parquet(dir.getCanonicalPath) + sqlContext.read.parquet(dir.getCanonicalPath).collect() } } } test("date type") { def makeDateRDD(): DataFrame = - sparkContext + sqlContext.sparkContext .parallelize(0 to 1000) - .map(i => Tuple1(DateUtils.toJavaDate(i))) + .map(i => Tuple1(DateTimeUtils.toJavaDate(i))) .toDF() .select($"_1") withTempPath { dir => val data = makeDateRDD() - data.saveAsParquetFile(dir.getCanonicalPath) - checkAnswer(parquetFile(dir.getCanonicalPath), data.collect().toSeq) + data.write.parquet(dir.getCanonicalPath) + checkAnswer(sqlContext.read.parquet(dir.getCanonicalPath), data.collect().toSeq) } } @@ -161,6 +157,11 @@ class ParquetIOSuiteBase extends QueryTest with ParquetTest { checkParquetFile(data) } + test("array and double") { + val data = (1 to 4).map(i => (i.toDouble, Seq(i.toDouble, (i + 1).toDouble))) + checkParquetFile(data) + } + test("struct") { val data = (1 to 4).map(i => Tuple1((i, s"val_$i"))) withParquetDataFrame(data) { df => @@ -200,7 +201,7 @@ class ParquetIOSuiteBase extends QueryTest with ParquetTest { withParquetDataFrame(allNulls :: Nil) { df => val rows = df.collect() - assert(rows.size === 1) + assert(rows.length === 1) assert(rows.head === Row(Seq.fill(5)(null): _*)) } } @@ -213,7 +214,7 @@ class ParquetIOSuiteBase extends QueryTest with ParquetTest { withParquetDataFrame(allNones :: Nil) { df => val rows = df.collect() - assert(rows.size === 1) + assert(rows.length === 1) assert(rows.head === Row(Seq.fill(3)(null): _*)) } } @@ -234,9 +235,9 @@ class ParquetIOSuiteBase extends QueryTest with ParquetTest { val data = (0 until 10).map(i => (i, i.toString)) def checkCompressionCodec(codec: CompressionCodecName): Unit = { - withSQLConf(SQLConf.PARQUET_COMPRESSION -> codec.name()) { + withSQLConf(SQLConf.PARQUET_COMPRESSION.key -> codec.name()) { withParquetFile(data) { path => - assertResult(conf.parquetCompressionCodec.toUpperCase) { + assertResult(sqlContext.conf.parquetCompressionCodec.toUpperCase) { compressionCodecFor(path) } } @@ -244,7 +245,7 @@ class ParquetIOSuiteBase extends QueryTest with ParquetTest { } // Checks default compression codec - checkCompressionCodec(CompressionCodecName.fromConf(conf.parquetCompressionCodec)) + checkCompressionCodec(CompressionCodecName.fromConf(sqlContext.conf.parquetCompressionCodec)) checkCompressionCodec(CompressionCodecName.UNCOMPRESSED) checkCompressionCodec(CompressionCodecName.GZIP) @@ -283,7 +284,7 @@ class ParquetIOSuiteBase extends QueryTest with ParquetTest { withTempDir { dir => val path = new Path(dir.toURI.toString, "part-r-0.parquet") makeRawParquetFile(path) - checkAnswer(parquetFile(path.toString), (0 until 10).map { i => + checkAnswer(sqlContext.read.parquet(path.toString), (0 until 10).map { i => Row(i % 2 == 0, i, i.toLong, i.toFloat, i.toDouble) }) } @@ -311,8 +312,8 @@ class ParquetIOSuiteBase extends QueryTest with ParquetTest { test("save - overwrite") { withParquetFile((1 to 10).map(i => (i, i.toString))) { file => val newData = (11 to 20).map(i => (i, i.toString)) - newData.toDF().save("org.apache.spark.sql.parquet", SaveMode.Overwrite, Map("path" -> file)) - checkAnswer(parquetFile(file), newData.map(Row.fromTuple)) + newData.toDF().write.format("parquet").mode(SaveMode.Overwrite).save(file) + checkAnswer(sqlContext.read.parquet(file), newData.map(Row.fromTuple)) } } @@ -320,8 +321,8 @@ class ParquetIOSuiteBase extends QueryTest with ParquetTest { val data = (1 to 10).map(i => (i, i.toString)) withParquetFile(data) { file => val newData = (11 to 20).map(i => (i, i.toString)) - newData.toDF().save("org.apache.spark.sql.parquet", SaveMode.Ignore, Map("path" -> file)) - checkAnswer(parquetFile(file), data.map(Row.fromTuple)) + newData.toDF().write.format("parquet").mode(SaveMode.Ignore).save(file) + checkAnswer(sqlContext.read.parquet(file), data.map(Row.fromTuple)) } } @@ -330,8 +331,7 @@ class ParquetIOSuiteBase extends QueryTest with ParquetTest { withParquetFile(data) { file => val newData = (11 to 20).map(i => (i, i.toString)) val errorMessage = intercept[Throwable] { - newData.toDF().save( - "org.apache.spark.sql.parquet", SaveMode.ErrorIfExists, Map("path" -> file)) + newData.toDF().write.format("parquet").mode(SaveMode.ErrorIfExists).save(file) }.getMessage assert(errorMessage.contains("already exists")) } @@ -341,8 +341,8 @@ class ParquetIOSuiteBase extends QueryTest with ParquetTest { val data = (1 to 10).map(i => (i, i.toString)) withParquetFile(data) { file => val newData = (11 to 20).map(i => (i, i.toString)) - newData.toDF().save("org.apache.spark.sql.parquet", SaveMode.Append, Map("path" -> file)) - checkAnswer(parquetFile(file), (data ++ newData).map(Row.fromTuple)) + newData.toDF().write.format("parquet").mode(SaveMode.Append).save(file) + checkAnswer(sqlContext.read.parquet(file), (data ++ newData).map(Row.fromTuple)) } } @@ -370,11 +370,11 @@ class ParquetIOSuiteBase extends QueryTest with ParquetTest { val path = new Path(location.getCanonicalPath) ParquetFileWriter.writeMetadataFile( - sparkContext.hadoopConfiguration, + sqlContext.sparkContext.hadoopConfiguration, path, new Footer(path, new ParquetMetadata(fileMetadata, Nil)) :: Nil) - assertResult(parquetFile(path.toString).schema) { + assertResult(sqlContext.read.parquet(path.toString).schema) { StructType( StructField("a", BooleanType, nullable = false) :: StructField("b", IntegerType, nullable = false) :: @@ -384,6 +384,8 @@ class ParquetIOSuiteBase extends QueryTest with ParquetTest { } test("SPARK-6352 DirectParquetOutputCommitter") { + val clonedConf = new Configuration(configuration) + // Write to a parquet file and let it fail. // _temporary should be missing if direct output committer works. try { @@ -392,52 +394,84 @@ class ParquetIOSuiteBase extends QueryTest with ParquetTest { sqlContext.udf.register("div0", (x: Int) => x / 0) withTempPath { dir => intercept[org.apache.spark.SparkException] { - sqlContext.sql("select div0(1)").saveAsParquetFile(dir.getCanonicalPath) + sqlContext.sql("select div0(1)").write.parquet(dir.getCanonicalPath) } val path = new Path(dir.getCanonicalPath, "_temporary") val fs = path.getFileSystem(configuration) assert(!fs.exists(path)) } + } finally { + // Hadoop 1 doesn't have `Configuration.unset` + configuration.clear() + clonedConf.foreach(entry => configuration.set(entry.getKey, entry.getValue)) } - finally { - configuration.set("spark.sql.parquet.output.committer.class", - "parquet.hadoop.ParquetOutputCommitter") + } + + test("SPARK-8121: spark.sql.parquet.output.committer.class shouldn't be overriden") { + withTempPath { dir => + val clonedConf = new Configuration(configuration) + + configuration.set( + SQLConf.OUTPUT_COMMITTER_CLASS.key, classOf[ParquetOutputCommitter].getCanonicalName) + + configuration.set( + "spark.sql.parquet.output.committer.class", + classOf[BogusParquetOutputCommitter].getCanonicalName) + + try { + val message = intercept[SparkException] { + sqlContext.range(0, 1).write.parquet(dir.getCanonicalPath) + }.getCause.getMessage + assert(message === "Intentional exception for testing purposes") + } finally { + // Hadoop 1 doesn't have `Configuration.unset` + configuration.clear() + clonedConf.foreach(entry => configuration.set(entry.getKey, entry.getValue)) + } } } } +class BogusParquetOutputCommitter(outputPath: Path, context: TaskAttemptContext) + extends ParquetOutputCommitter(outputPath, context) { + + override def commitJob(jobContext: JobContext): Unit = { + sys.error("Intentional exception for testing purposes") + } +} + class ParquetDataSourceOnIOSuite extends ParquetIOSuiteBase with BeforeAndAfterAll { - val originalConf = sqlContext.conf.parquetUseDataSourceApi + private lazy val originalConf = sqlContext.conf.parquetUseDataSourceApi override protected def beforeAll(): Unit = { - sqlContext.conf.setConf(SQLConf.PARQUET_USE_DATA_SOURCE_API, "true") + sqlContext.conf.setConf(SQLConf.PARQUET_USE_DATA_SOURCE_API, true) } override protected def afterAll(): Unit = { - sqlContext.setConf(SQLConf.PARQUET_USE_DATA_SOURCE_API, originalConf.toString) + sqlContext.setConf(SQLConf.PARQUET_USE_DATA_SOURCE_API.key, originalConf.toString) } test("SPARK-6330 regression test") { // In 1.3.0, save to fs other than file: without configuring core-site.xml would get: // IllegalArgumentException: Wrong FS: hdfs://..., expected: file:/// intercept[Throwable] { - sqlContext.parquetFile("file:///nonexistent") + sqlContext.read.parquet("file:///nonexistent") } val errorMessage = intercept[Throwable] { - sqlContext.parquetFile("hdfs://nonexistent") + sqlContext.read.parquet("hdfs://nonexistent") }.toString assert(errorMessage.contains("UnknownHostException")) } } class ParquetDataSourceOffIOSuite extends ParquetIOSuiteBase with BeforeAndAfterAll { - val originalConf = sqlContext.conf.parquetUseDataSourceApi + private lazy val originalConf = sqlContext.conf.parquetUseDataSourceApi override protected def beforeAll(): Unit = { - sqlContext.conf.setConf(SQLConf.PARQUET_USE_DATA_SOURCE_API, "false") + sqlContext.conf.setConf(SQLConf.PARQUET_USE_DATA_SOURCE_API, false) } override protected def afterAll(): Unit = { - sqlContext.setConf(SQLConf.PARQUET_USE_DATA_SOURCE_API, originalConf.toString) + sqlContext.setConf(SQLConf.PARQUET_USE_DATA_SOURCE_API, originalConf) } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetPartitionDiscoverySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetPartitionDiscoverySuite.scala index 138e19766dc8..d0ebb11b063f 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetPartitionDiscoverySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetPartitionDiscoverySuite.scala @@ -14,18 +14,25 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.apache.spark.sql.parquet +import java.io.File +import java.math.BigInteger +import java.sql.Timestamp + import scala.collection.mutable.ArrayBuffer +import com.google.common.io.Files import org.apache.hadoop.fs.Path +import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions.Literal import org.apache.spark.sql.sources.PartitioningUtils._ -import org.apache.spark.sql.sources.{Partition, PartitionSpec} -import org.apache.spark.sql.test.TestSQLContext +import org.apache.spark.sql.sources.{LogicalRelation, Partition, PartitionSpec} import org.apache.spark.sql.types._ -import org.apache.spark.sql.{QueryTest, Row, SQLContext} +import org.apache.spark.sql._ +import org.apache.spark.unsafe.types.UTF8String // The data where the partitioning key exists only in the directory structure. case class ParquetData(intField: Int, stringField: String) @@ -34,64 +41,65 @@ case class ParquetData(intField: Int, stringField: String) case class ParquetDataWithKey(intField: Int, pi: Int, stringField: String, ps: String) class ParquetPartitionDiscoverySuite extends QueryTest with ParquetTest { - override val sqlContext: SQLContext = TestSQLContext - import sqlContext._ + override lazy val sqlContext: SQLContext = org.apache.spark.sql.test.TestSQLContext import sqlContext.implicits._ + import sqlContext.sql val defaultPartitionName = "__HIVE_DEFAULT_PARTITION__" test("column type inference") { def check(raw: String, literal: Literal): Unit = { - assert(inferPartitionColumnValue(raw, defaultPartitionName) === literal) + assert(inferPartitionColumnValue(raw, defaultPartitionName, true) === literal) } check("10", Literal.create(10, IntegerType)) check("1000000000000000", Literal.create(1000000000000000L, LongType)) - check("1.5", Literal.create(1.5, FloatType)) + check("1.5", Literal.create(1.5, DoubleType)) check("hello", Literal.create("hello", StringType)) check(defaultPartitionName, Literal.create(null, NullType)) } test("parse partition") { - def check(path: String, expected: PartitionValues): Unit = { - assert(expected === parsePartition(new Path(path), defaultPartitionName)) + def check(path: String, expected: Option[PartitionValues]): Unit = { + assert(expected === parsePartition(new Path(path), defaultPartitionName, true)) } def checkThrows[T <: Throwable: Manifest](path: String, expected: String): Unit = { val message = intercept[T] { - parsePartition(new Path(path), defaultPartitionName) + parsePartition(new Path(path), defaultPartitionName, true).get }.getMessage assert(message.contains(expected)) } - check( - "file:///", - PartitionValues( - ArrayBuffer.empty[String], - ArrayBuffer.empty[Literal])) - - check( - "file://path/a=10", + check("file://path/a=10", Some { PartitionValues( ArrayBuffer("a"), - ArrayBuffer(Literal.create(10, IntegerType)))) + ArrayBuffer(Literal.create(10, IntegerType))) + }) - check( - "file://path/a=10/b=hello/c=1.5", + check("file://path/a=10/b=hello/c=1.5", Some { PartitionValues( ArrayBuffer("a", "b", "c"), ArrayBuffer( Literal.create(10, IntegerType), Literal.create("hello", StringType), - Literal.create(1.5, FloatType)))) + Literal.create(1.5, DoubleType))) + }) - check( - "file://path/a=10/b_hello/c=1.5", + check("file://path/a=10/b_hello/c=1.5", Some { PartitionValues( ArrayBuffer("c"), - ArrayBuffer(Literal.create(1.5, FloatType)))) + ArrayBuffer(Literal.create(1.5, DoubleType))) + }) + + check("file:///", None) + check("file:///path/_temporary", None) + check("file:///path/_temporary/c=1.5", None) + check("file:///path/_temporary/path", None) + check("file://path/a=10/_temporary/c=1.5", None) + check("file://path/a=10/c=1.5/_temporary", None) checkThrows[AssertionError]("file://path/=10", "Empty partition column name") checkThrows[AssertionError]("file://path/a=", "Empty partition column value") @@ -99,7 +107,7 @@ class ParquetPartitionDiscoverySuite extends QueryTest with ParquetTest { test("parse partitions") { def check(paths: Seq[String], spec: PartitionSpec): Unit = { - assert(parsePartitions(paths.map(new Path(_)), defaultPartitionName) === spec) + assert(parsePartitions(paths.map(new Path(_)), defaultPartitionName, true) === spec) } check(Seq( @@ -108,18 +116,42 @@ class ParquetPartitionDiscoverySuite extends QueryTest with ParquetTest { StructType(Seq( StructField("a", IntegerType), StructField("b", StringType))), - Seq(Partition(Row(10, "hello"), "hdfs://host:9000/path/a=10/b=hello")))) + Seq(Partition(InternalRow(10, UTF8String.fromString("hello")), + "hdfs://host:9000/path/a=10/b=hello")))) check(Seq( "hdfs://host:9000/path/a=10/b=20", "hdfs://host:9000/path/a=10.5/b=hello"), PartitionSpec( StructType(Seq( - StructField("a", FloatType), + StructField("a", DoubleType), + StructField("b", StringType))), + Seq( + Partition(InternalRow(10, UTF8String.fromString("20")), + "hdfs://host:9000/path/a=10/b=20"), + Partition(InternalRow(10.5, UTF8String.fromString("hello")), + "hdfs://host:9000/path/a=10.5/b=hello")))) + + check(Seq( + "hdfs://host:9000/path/_temporary", + "hdfs://host:9000/path/a=10/b=20", + "hdfs://host:9000/path/a=10.5/b=hello", + "hdfs://host:9000/path/a=10.5/_temporary", + "hdfs://host:9000/path/a=10.5/_TeMpOrArY", + "hdfs://host:9000/path/a=10.5/b=hello/_temporary", + "hdfs://host:9000/path/a=10.5/b=hello/_TEMPORARY", + "hdfs://host:9000/path/_temporary/path", + "hdfs://host:9000/path/a=11/_temporary/path", + "hdfs://host:9000/path/a=10.5/b=world/_temporary/path"), + PartitionSpec( + StructType(Seq( + StructField("a", DoubleType), StructField("b", StringType))), Seq( - Partition(Row(10, "20"), "hdfs://host:9000/path/a=10/b=20"), - Partition(Row(10.5, "hello"), "hdfs://host:9000/path/a=10.5/b=hello")))) + Partition(InternalRow(10, UTF8String.fromString("20")), + "hdfs://host:9000/path/a=10/b=20"), + Partition(InternalRow(10.5, UTF8String.fromString("hello")), + "hdfs://host:9000/path/a=10.5/b=hello")))) check(Seq( s"hdfs://host:9000/path/a=10/b=20", @@ -129,19 +161,107 @@ class ParquetPartitionDiscoverySuite extends QueryTest with ParquetTest { StructField("a", IntegerType), StructField("b", StringType))), Seq( - Partition(Row(10, "20"), s"hdfs://host:9000/path/a=10/b=20"), - Partition(Row(null, "hello"), s"hdfs://host:9000/path/a=$defaultPartitionName/b=hello")))) + Partition(InternalRow(10, UTF8String.fromString("20")), + s"hdfs://host:9000/path/a=10/b=20"), + Partition(InternalRow(null, UTF8String.fromString("hello")), + s"hdfs://host:9000/path/a=$defaultPartitionName/b=hello")))) check(Seq( s"hdfs://host:9000/path/a=10/b=$defaultPartitionName", s"hdfs://host:9000/path/a=10.5/b=$defaultPartitionName"), PartitionSpec( StructType(Seq( - StructField("a", FloatType), + StructField("a", DoubleType), StructField("b", StringType))), Seq( - Partition(Row(10, null), s"hdfs://host:9000/path/a=10/b=$defaultPartitionName"), - Partition(Row(10.5, null), s"hdfs://host:9000/path/a=10.5/b=$defaultPartitionName")))) + Partition(InternalRow(10, null), s"hdfs://host:9000/path/a=10/b=$defaultPartitionName"), + Partition(InternalRow(10.5, null), + s"hdfs://host:9000/path/a=10.5/b=$defaultPartitionName")))) + + check(Seq( + s"hdfs://host:9000/path1", + s"hdfs://host:9000/path2"), + PartitionSpec.emptySpec) + } + + test("parse partitions with type inference disabled") { + def check(paths: Seq[String], spec: PartitionSpec): Unit = { + assert(parsePartitions(paths.map(new Path(_)), defaultPartitionName, false) === spec) + } + + check(Seq( + "hdfs://host:9000/path/a=10/b=hello"), + PartitionSpec( + StructType(Seq( + StructField("a", StringType), + StructField("b", StringType))), + Seq(Partition(InternalRow(UTF8String.fromString("10"), UTF8String.fromString("hello")), + "hdfs://host:9000/path/a=10/b=hello")))) + + check(Seq( + "hdfs://host:9000/path/a=10/b=20", + "hdfs://host:9000/path/a=10.5/b=hello"), + PartitionSpec( + StructType(Seq( + StructField("a", StringType), + StructField("b", StringType))), + Seq( + Partition(InternalRow(UTF8String.fromString("10"), UTF8String.fromString("20")), + "hdfs://host:9000/path/a=10/b=20"), + Partition(InternalRow(UTF8String.fromString("10.5"), UTF8String.fromString("hello")), + "hdfs://host:9000/path/a=10.5/b=hello")))) + + check(Seq( + "hdfs://host:9000/path/_temporary", + "hdfs://host:9000/path/a=10/b=20", + "hdfs://host:9000/path/a=10.5/b=hello", + "hdfs://host:9000/path/a=10.5/_temporary", + "hdfs://host:9000/path/a=10.5/_TeMpOrArY", + "hdfs://host:9000/path/a=10.5/b=hello/_temporary", + "hdfs://host:9000/path/a=10.5/b=hello/_TEMPORARY", + "hdfs://host:9000/path/_temporary/path", + "hdfs://host:9000/path/a=11/_temporary/path", + "hdfs://host:9000/path/a=10.5/b=world/_temporary/path"), + PartitionSpec( + StructType(Seq( + StructField("a", StringType), + StructField("b", StringType))), + Seq( + Partition(InternalRow(UTF8String.fromString("10"), UTF8String.fromString("20")), + "hdfs://host:9000/path/a=10/b=20"), + Partition(InternalRow(UTF8String.fromString("10.5"), UTF8String.fromString("hello")), + "hdfs://host:9000/path/a=10.5/b=hello")))) + + check(Seq( + s"hdfs://host:9000/path/a=10/b=20", + s"hdfs://host:9000/path/a=$defaultPartitionName/b=hello"), + PartitionSpec( + StructType(Seq( + StructField("a", StringType), + StructField("b", StringType))), + Seq( + Partition(InternalRow(UTF8String.fromString("10"), UTF8String.fromString("20")), + s"hdfs://host:9000/path/a=10/b=20"), + Partition(InternalRow(null, UTF8String.fromString("hello")), + s"hdfs://host:9000/path/a=$defaultPartitionName/b=hello")))) + + check(Seq( + s"hdfs://host:9000/path/a=10/b=$defaultPartitionName", + s"hdfs://host:9000/path/a=10.5/b=$defaultPartitionName"), + PartitionSpec( + StructType(Seq( + StructField("a", StringType), + StructField("b", StringType))), + Seq( + Partition(InternalRow(UTF8String.fromString("10"), null), + s"hdfs://host:9000/path/a=10/b=$defaultPartitionName"), + Partition(InternalRow(UTF8String.fromString("10.5"), null), + s"hdfs://host:9000/path/a=10.5/b=$defaultPartitionName")))) + + check(Seq( + s"hdfs://host:9000/path1", + s"hdfs://host:9000/path2"), + PartitionSpec.emptySpec) } test("read partitioned table - normal case") { @@ -150,12 +270,17 @@ class ParquetPartitionDiscoverySuite extends QueryTest with ParquetTest { pi <- Seq(1, 2) ps <- Seq("foo", "bar") } { + val dir = makePartitionDir(base, defaultPartitionName, "pi" -> pi, "ps" -> ps) makeParquetFile( (1 to 10).map(i => ParquetData(i, i.toString)), - makePartitionDir(base, defaultPartitionName, "pi" -> pi, "ps" -> ps)) + dir) + // Introduce _temporary dir to test the robustness of the schema discovery process. + new File(dir.toString, "_temporary").mkdir() } + // Introduce _temporary dir to the base dir the robustness of the schema discovery process. + new File(base.getCanonicalPath, "_temporary").mkdir() - parquetFile(base.getCanonicalPath).registerTempTable("t") + sqlContext.read.parquet(base.getCanonicalPath).registerTempTable("t") withTempTable("t") { checkAnswer( @@ -202,7 +327,7 @@ class ParquetPartitionDiscoverySuite extends QueryTest with ParquetTest { makePartitionDir(base, defaultPartitionName, "pi" -> pi, "ps" -> ps)) } - parquetFile(base.getCanonicalPath).registerTempTable("t") + sqlContext.read.parquet(base.getCanonicalPath).registerTempTable("t") withTempTable("t") { checkAnswer( @@ -250,10 +375,7 @@ class ParquetPartitionDiscoverySuite extends QueryTest with ParquetTest { makePartitionDir(base, defaultPartitionName, "pi" -> pi, "ps" -> ps)) } - val parquetRelation = load( - "org.apache.spark.sql.parquet", - Map("path" -> base.getCanonicalPath)) - + val parquetRelation = sqlContext.read.format("parquet").load(base.getCanonicalPath) parquetRelation.registerTempTable("t") withTempTable("t") { @@ -293,10 +415,7 @@ class ParquetPartitionDiscoverySuite extends QueryTest with ParquetTest { makePartitionDir(base, defaultPartitionName, "pi" -> pi, "ps" -> ps)) } - val parquetRelation = load( - "org.apache.spark.sql.parquet", - Map("path" -> base.getCanonicalPath)) - + val parquetRelation = sqlContext.read.format("parquet").load(base.getCanonicalPath) parquetRelation.registerTempTable("t") withTempTable("t") { @@ -328,7 +447,7 @@ class ParquetPartitionDiscoverySuite extends QueryTest with ParquetTest { (1 to 10).map(i => (i, i.toString)).toDF("intField", "stringField"), makePartitionDir(base, defaultPartitionName, "pi" -> 2)) - load(base.getCanonicalPath, "org.apache.spark.sql.parquet").registerTempTable("t") + sqlContext.read.format("parquet").load(base.getCanonicalPath).registerTempTable("t") withTempTable("t") { checkAnswer( @@ -337,4 +456,131 @@ class ParquetPartitionDiscoverySuite extends QueryTest with ParquetTest { } } } + + test("SPARK-7749 Non-partitioned table should have empty partition spec") { + withTempPath { dir => + (1 to 10).map(i => (i, i.toString)).toDF("a", "b").write.parquet(dir.getCanonicalPath) + val queryExecution = sqlContext.read.parquet(dir.getCanonicalPath).queryExecution + queryExecution.analyzed.collectFirst { + case LogicalRelation(relation: ParquetRelation2) => + assert(relation.partitionSpec === PartitionSpec.emptySpec) + }.getOrElse { + fail(s"Expecting a ParquetRelation2, but got:\n$queryExecution") + } + } + } + + test("SPARK-7847: Dynamic partition directory path escaping and unescaping") { + withTempPath { dir => + val df = Seq("/", "[]", "?").zipWithIndex.map(_.swap).toDF("i", "s") + df.write.format("parquet").partitionBy("s").save(dir.getCanonicalPath) + checkAnswer(sqlContext.read.parquet(dir.getCanonicalPath), df.collect()) + } + } + + test("Various partition value types") { + val row = + Row( + 100.toByte, + 40000.toShort, + Int.MaxValue, + Long.MaxValue, + 1.5.toFloat, + 4.5, + new java.math.BigDecimal(new BigInteger("212500"), 5), + new java.math.BigDecimal(2.125), + java.sql.Date.valueOf("2015-05-23"), + new Timestamp(0), + "This is a string, /[]?=:", + "This is not a partition column") + + // BooleanType is not supported yet + val partitionColumnTypes = + Seq( + ByteType, + ShortType, + IntegerType, + LongType, + FloatType, + DoubleType, + DecimalType(10, 5), + DecimalType.Unlimited, + DateType, + TimestampType, + StringType) + + val partitionColumns = partitionColumnTypes.zipWithIndex.map { + case (t, index) => StructField(s"p_$index", t) + } + + val schema = StructType(partitionColumns :+ StructField(s"i", StringType)) + val df = sqlContext.createDataFrame(sqlContext.sparkContext.parallelize(row :: Nil), schema) + + withTempPath { dir => + df.write.format("parquet").partitionBy(partitionColumns.map(_.name): _*).save(dir.toString) + val fields = schema.map(f => Column(f.name).cast(f.dataType)) + checkAnswer(sqlContext.read.load(dir.toString).select(fields: _*), row) + } + } + + test("SPARK-8037: Ignores files whose name starts with dot") { + withTempPath { dir => + val df = (1 to 3).map(i => (i, i, i, i)).toDF("a", "b", "c", "d") + + df.write + .format("parquet") + .partitionBy("b", "c", "d") + .save(dir.getCanonicalPath) + + Files.touch(new File(s"${dir.getCanonicalPath}/b=1", ".DS_Store")) + Files.createParentDirs(new File(s"${dir.getCanonicalPath}/b=1/c=1/.foo/bar")) + + checkAnswer(sqlContext.read.format("parquet").load(dir.getCanonicalPath), df) + } + } + + test("listConflictingPartitionColumns") { + def makeExpectedMessage(colNameLists: Seq[String], paths: Seq[String]): String = { + val conflictingColNameLists = colNameLists.zipWithIndex.map { case (list, index) => + s"\tPartition column name list #$index: $list" + }.mkString("\n", "\n", "\n") + + // scalastyle:off + s"""Conflicting partition column names detected: + |$conflictingColNameLists + |For partitioned table directories, data files should only live in leaf directories. + |And directories at the same level should have the same partition column name. + |Please check the following directories for unexpected files or inconsistent partition column names: + |${paths.map("\t" + _).mkString("\n", "\n", "")} + """.stripMargin.trim + // scalastyle:on + } + + assert( + listConflictingPartitionColumns( + Seq( + (new Path("file:/tmp/foo/a=1"), PartitionValues(Seq("a"), Seq(Literal(1)))), + (new Path("file:/tmp/foo/b=1"), PartitionValues(Seq("b"), Seq(Literal(1)))))).trim === + makeExpectedMessage(Seq("a", "b"), Seq("file:/tmp/foo/a=1", "file:/tmp/foo/b=1"))) + + assert( + listConflictingPartitionColumns( + Seq( + (new Path("file:/tmp/foo/a=1/_temporary"), PartitionValues(Seq("a"), Seq(Literal(1)))), + (new Path("file:/tmp/foo/a=1"), PartitionValues(Seq("a"), Seq(Literal(1)))))).trim === + makeExpectedMessage( + Seq("a"), + Seq("file:/tmp/foo/a=1/_temporary", "file:/tmp/foo/a=1"))) + + assert( + listConflictingPartitionColumns( + Seq( + (new Path("file:/tmp/foo/a=1"), + PartitionValues(Seq("a"), Seq(Literal(1)))), + (new Path("file:/tmp/foo/a=1/b=foo"), + PartitionValues(Seq("a", "b"), Seq(Literal(1), Literal("foo")))))).trim === + makeExpectedMessage( + Seq("a", "a, b"), + Seq("file:/tmp/foo/a=1", "file:/tmp/foo/a=1/b=foo"))) + } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetQuerySuite.scala index b98ba09ccfc2..a0a81c4309c0 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetQuerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetQuerySuite.scala @@ -17,18 +17,18 @@ package org.apache.spark.sql.parquet +import org.apache.hadoop.fs.Path import org.scalatest.BeforeAndAfterAll -import org.apache.spark.sql.{SQLConf, QueryTest} -import org.apache.spark.sql.catalyst.expressions.Row -import org.apache.spark.sql.test.TestSQLContext -import org.apache.spark.sql.test.TestSQLContext._ +import org.apache.spark.sql.types._ +import org.apache.spark.sql.{QueryTest, Row, SQLConf} /** * A test suite that tests various Parquet queries. */ class ParquetQuerySuiteBase extends QueryTest with ParquetTest { - val sqlContext = TestSQLContext + lazy val sqlContext = org.apache.spark.sql.test.TestSQLContext + import sqlContext.sql test("simple select queries") { withParquetTable((0 until 10).map(i => (i, i.toString)), "t") { @@ -39,22 +39,22 @@ class ParquetQuerySuiteBase extends QueryTest with ParquetTest { test("appending") { val data = (0 until 10).map(i => (i, i.toString)) - createDataFrame(data).toDF("c1", "c2").registerTempTable("tmp") + sqlContext.createDataFrame(data).toDF("c1", "c2").registerTempTable("tmp") withParquetTable(data, "t") { sql("INSERT INTO TABLE t SELECT * FROM tmp") - checkAnswer(table("t"), (data ++ data).map(Row.fromTuple)) + checkAnswer(sqlContext.table("t"), (data ++ data).map(Row.fromTuple)) } - catalog.unregisterTable(Seq("tmp")) + sqlContext.catalog.unregisterTable(Seq("tmp")) } test("overwriting") { val data = (0 until 10).map(i => (i, i.toString)) - createDataFrame(data).toDF("c1", "c2").registerTempTable("tmp") + sqlContext.createDataFrame(data).toDF("c1", "c2").registerTempTable("tmp") withParquetTable(data, "t") { sql("INSERT OVERWRITE TABLE t SELECT * FROM tmp") - checkAnswer(table("t"), data.map(Row.fromTuple)) + checkAnswer(sqlContext.table("t"), data.map(Row.fromTuple)) } - catalog.unregisterTable(Seq("tmp")) + sqlContext.catalog.unregisterTable(Seq("tmp")) } test("self-join") { @@ -111,28 +111,59 @@ class ParquetQuerySuiteBase extends QueryTest with ParquetTest { List(Row("same", "run_5", 100))) } } + + test("SPARK-6917 DecimalType should work with non-native types") { + val data = (1 to 10).map(i => Row(Decimal(i, 18, 0), new java.sql.Timestamp(i))) + val schema = StructType(List(StructField("d", DecimalType(18, 0), false), + StructField("time", TimestampType, false)).toArray) + withTempPath { file => + val df = sqlContext.createDataFrame(sqlContext.sparkContext.parallelize(data), schema) + df.write.parquet(file.getCanonicalPath) + val df2 = sqlContext.read.parquet(file.getCanonicalPath) + checkAnswer(df2, df.collect().toSeq) + } + } + + test("Enabling/disabling schema merging") { + def testSchemaMerging(expectedColumnNumber: Int): Unit = { + withTempDir { dir => + val basePath = dir.getCanonicalPath + sqlContext.range(0, 10).toDF("a").write.parquet(new Path(basePath, "foo=1").toString) + sqlContext.range(0, 10).toDF("b").write.parquet(new Path(basePath, "foo=2").toString) + assert(sqlContext.read.parquet(basePath).columns.length === expectedColumnNumber) + } + } + + withSQLConf(SQLConf.PARQUET_SCHEMA_MERGING_ENABLED.key -> "true") { + testSchemaMerging(3) + } + + withSQLConf(SQLConf.PARQUET_SCHEMA_MERGING_ENABLED.key -> "false") { + testSchemaMerging(2) + } + } } class ParquetDataSourceOnQuerySuite extends ParquetQuerySuiteBase with BeforeAndAfterAll { - val originalConf = sqlContext.conf.parquetUseDataSourceApi + private lazy val originalConf = sqlContext.conf.parquetUseDataSourceApi override protected def beforeAll(): Unit = { - sqlContext.conf.setConf(SQLConf.PARQUET_USE_DATA_SOURCE_API, "true") + sqlContext.conf.setConf(SQLConf.PARQUET_USE_DATA_SOURCE_API, true) } override protected def afterAll(): Unit = { - sqlContext.setConf(SQLConf.PARQUET_USE_DATA_SOURCE_API, originalConf.toString) + sqlContext.setConf(SQLConf.PARQUET_USE_DATA_SOURCE_API, originalConf) } } class ParquetDataSourceOffQuerySuite extends ParquetQuerySuiteBase with BeforeAndAfterAll { - val originalConf = sqlContext.conf.parquetUseDataSourceApi + private lazy val originalConf = sqlContext.conf.parquetUseDataSourceApi override protected def beforeAll(): Unit = { - sqlContext.conf.setConf(SQLConf.PARQUET_USE_DATA_SOURCE_API, "false") + sqlContext.conf.setConf(SQLConf.PARQUET_USE_DATA_SOURCE_API, false) } override protected def afterAll(): Unit = { - sqlContext.setConf(SQLConf.PARQUET_USE_DATA_SOURCE_API, originalConf.toString) + sqlContext.setConf(SQLConf.PARQUET_USE_DATA_SOURCE_API, originalConf) } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetSchemaSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetSchemaSuite.scala index c964b6d98455..35d3c33f99a0 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetSchemaSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetSchemaSuite.scala @@ -20,31 +20,113 @@ package org.apache.spark.sql.parquet import scala.reflect.ClassTag import scala.reflect.runtime.universe.TypeTag -import org.scalatest.FunSuite -import parquet.schema.MessageTypeParser +import org.apache.parquet.schema.MessageTypeParser +import org.apache.spark.SparkFunSuite import org.apache.spark.sql.catalyst.ScalaReflection import org.apache.spark.sql.test.TestSQLContext import org.apache.spark.sql.types._ -class ParquetSchemaSuite extends FunSuite with ParquetTest { +abstract class ParquetSchemaTest extends SparkFunSuite with ParquetTest { val sqlContext = TestSQLContext /** * Checks whether the reflected Parquet message type for product type `T` conforms `messageType`. */ - private def testSchema[T <: Product: ClassTag: TypeTag]( - testName: String, messageType: String, isThriftDerived: Boolean = false): Unit = { - test(testName) { - val actual = ParquetTypesConverter.convertFromAttributes( - ScalaReflection.attributesFor[T], isThriftDerived) - val expected = MessageTypeParser.parseMessageType(messageType) + protected def testSchemaInference[T <: Product: ClassTag: TypeTag]( + testName: String, + messageType: String, + binaryAsString: Boolean = true, + int96AsTimestamp: Boolean = true, + followParquetFormatSpec: Boolean = false, + isThriftDerived: Boolean = false): Unit = { + testSchema( + testName, + StructType.fromAttributes(ScalaReflection.attributesFor[T]), + messageType, + binaryAsString, + int96AsTimestamp, + followParquetFormatSpec, + isThriftDerived) + } + + protected def testParquetToCatalyst( + testName: String, + sqlSchema: StructType, + parquetSchema: String, + binaryAsString: Boolean = true, + int96AsTimestamp: Boolean = true, + followParquetFormatSpec: Boolean = false, + isThriftDerived: Boolean = false): Unit = { + val converter = new CatalystSchemaConverter( + assumeBinaryIsString = binaryAsString, + assumeInt96IsTimestamp = int96AsTimestamp, + followParquetFormatSpec = followParquetFormatSpec) + + test(s"sql <= parquet: $testName") { + val actual = converter.convert(MessageTypeParser.parseMessageType(parquetSchema)) + val expected = sqlSchema + assert( + actual === expected, + s"""Schema mismatch. + |Expected schema: ${expected.json} + |Actual schema: ${actual.json} + """.stripMargin) + } + } + + protected def testCatalystToParquet( + testName: String, + sqlSchema: StructType, + parquetSchema: String, + binaryAsString: Boolean = true, + int96AsTimestamp: Boolean = true, + followParquetFormatSpec: Boolean = false, + isThriftDerived: Boolean = false): Unit = { + val converter = new CatalystSchemaConverter( + assumeBinaryIsString = binaryAsString, + assumeInt96IsTimestamp = int96AsTimestamp, + followParquetFormatSpec = followParquetFormatSpec) + + test(s"sql => parquet: $testName") { + val actual = converter.convert(sqlSchema) + val expected = MessageTypeParser.parseMessageType(parquetSchema) actual.checkContains(expected) expected.checkContains(actual) } } - testSchema[(Boolean, Int, Long, Float, Double, Array[Byte])]( + protected def testSchema( + testName: String, + sqlSchema: StructType, + parquetSchema: String, + binaryAsString: Boolean = true, + int96AsTimestamp: Boolean = true, + followParquetFormatSpec: Boolean = false, + isThriftDerived: Boolean = false): Unit = { + + testCatalystToParquet( + testName, + sqlSchema, + parquetSchema, + binaryAsString, + int96AsTimestamp, + followParquetFormatSpec, + isThriftDerived) + + testParquetToCatalyst( + testName, + sqlSchema, + parquetSchema, + binaryAsString, + int96AsTimestamp, + followParquetFormatSpec, + isThriftDerived) + } +} + +class ParquetSchemaInferenceSuite extends ParquetSchemaTest { + testSchemaInference[(Boolean, Int, Long, Float, Double, Array[Byte])]( "basic types", """ |message root { @@ -55,9 +137,10 @@ class ParquetSchemaSuite extends FunSuite with ParquetTest { | required double _5; | optional binary _6; |} - """.stripMargin) + """.stripMargin, + binaryAsString = false) - testSchema[(Byte, Short, Int, Long, java.sql.Date)]( + testSchemaInference[(Byte, Short, Int, Long, java.sql.Date)]( "logical integral types", """ |message root { @@ -69,27 +152,87 @@ class ParquetSchemaSuite extends FunSuite with ParquetTest { |} """.stripMargin) - // Currently String is the only supported logical binary type. - testSchema[Tuple1[String]]( - "binary logical types", + testSchemaInference[Tuple1[String]]( + "string", """ |message root { | optional binary _1 (UTF8); |} + """.stripMargin, + binaryAsString = true) + + testSchemaInference[Tuple1[String]]( + "binary enum as string", + """ + |message root { + | optional binary _1 (ENUM); + |} """.stripMargin) - testSchema[Tuple1[Seq[Int]]]( - "array", + testSchemaInference[Tuple1[Seq[Int]]]( + "non-nullable array - non-standard", """ |message root { | optional group _1 (LIST) { - | repeated int32 array; + | repeated int32 element; | } |} """.stripMargin) - testSchema[Tuple1[Map[Int, String]]]( - "map", + testSchemaInference[Tuple1[Seq[Int]]]( + "non-nullable array - standard", + """ + |message root { + | optional group _1 (LIST) { + | repeated group list { + | required int32 element; + | } + | } + |} + """.stripMargin, + followParquetFormatSpec = true) + + testSchemaInference[Tuple1[Seq[Integer]]]( + "nullable array - non-standard", + """ + |message root { + | optional group _1 (LIST) { + | repeated group bag { + | optional int32 element; + | } + | } + |} + """.stripMargin) + + testSchemaInference[Tuple1[Seq[Integer]]]( + "nullable array - standard", + """ + |message root { + | optional group _1 (LIST) { + | repeated group list { + | optional int32 element; + | } + | } + |} + """.stripMargin, + followParquetFormatSpec = true) + + testSchemaInference[Tuple1[Map[Int, String]]]( + "map - standard", + """ + |message root { + | optional group _1 (MAP) { + | repeated group key_value { + | required int32 key; + | optional binary value (UTF8); + | } + | } + |} + """.stripMargin, + followParquetFormatSpec = true) + + testSchemaInference[Tuple1[Map[Int, String]]]( + "map - non-standard", """ |message root { | optional group _1 (MAP) { @@ -101,7 +244,7 @@ class ParquetSchemaSuite extends FunSuite with ParquetTest { |} """.stripMargin) - testSchema[Tuple1[Pair[Int, String]]]( + testSchemaInference[Tuple1[Pair[Int, String]]]( "struct", """ |message root { @@ -110,20 +253,21 @@ class ParquetSchemaSuite extends FunSuite with ParquetTest { | optional binary _2 (UTF8); | } |} - """.stripMargin) + """.stripMargin, + followParquetFormatSpec = true) - testSchema[Tuple1[Map[Int, (String, Seq[(Int, Double)])]]]( - "deeply nested type", + testSchemaInference[Tuple1[Map[Int, (String, Seq[(Int, Double)])]]]( + "deeply nested type - non-standard", """ |message root { - | optional group _1 (MAP) { - | repeated group map (MAP_KEY_VALUE) { + | optional group _1 (MAP_KEY_VALUE) { + | repeated group map { | required int32 key; | optional group value { | optional binary _1 (UTF8); | optional group _2 (LIST) { | repeated group bag { - | optional group array { + | optional group element { | required int32 _1; | required double _2; | } @@ -135,43 +279,76 @@ class ParquetSchemaSuite extends FunSuite with ParquetTest { |} """.stripMargin) - testSchema[(Option[Int], Map[Int, Option[Double]])]( - "optional types", + testSchemaInference[Tuple1[Map[Int, (String, Seq[(Int, Double)])]]]( + "deeply nested type - standard", """ |message root { - | optional int32 _1; - | optional group _2 (MAP) { - | repeated group map (MAP_KEY_VALUE) { + | optional group _1 (MAP) { + | repeated group key_value { | required int32 key; - | optional double value; + | optional group value { + | optional binary _1 (UTF8); + | optional group _2 (LIST) { + | repeated group list { + | optional group element { + | required int32 _1; + | required double _2; + | } + | } + | } + | } | } | } |} - """.stripMargin) + """.stripMargin, + followParquetFormatSpec = true) - // Test for SPARK-4520 -- ensure that thrift generated parquet schema is generated - // as expected from attributes - testSchema[(Array[Byte], Array[Byte], Array[Byte], Seq[Int], Map[Array[Byte], Seq[Int]])]( - "thrift generated parquet schema", + testSchemaInference[(Option[Int], Map[Int, Option[Double]])]( + "optional types", """ |message root { - | optional binary _1 (UTF8); - | optional binary _2 (UTF8); - | optional binary _3 (UTF8); - | optional group _4 (LIST) { - | repeated int32 _4_tuple; - | } - | optional group _5 (MAP) { - | repeated group map (MAP_KEY_VALUE) { - | required binary key (UTF8); - | optional group value (LIST) { - | repeated int32 value_tuple; - | } + | optional int32 _1; + | optional group _2 (MAP) { + | repeated group key_value { + | required int32 key; + | optional double value; | } | } |} - """.stripMargin, isThriftDerived = true) + """.stripMargin, + followParquetFormatSpec = true) + + // Parquet files generated by parquet-thrift are already handled by the schema converter, but + // let's leave this test here until both read path and write path are all updated. + ignore("thrift generated parquet schema") { + // Test for SPARK-4520 -- ensure that thrift generated parquet schema is generated + // as expected from attributes + testSchemaInference[( + Array[Byte], Array[Byte], Array[Byte], Seq[Int], Map[Array[Byte], Seq[Int]])]( + "thrift generated parquet schema", + """ + |message root { + | optional binary _1 (UTF8); + | optional binary _2 (UTF8); + | optional binary _3 (UTF8); + | optional group _4 (LIST) { + | repeated int32 _4_tuple; + | } + | optional group _5 (MAP) { + | repeated group map (MAP_KEY_VALUE) { + | required binary key (UTF8); + | optional group value (LIST) { + | repeated int32 value_tuple; + | } + | } + | } + |} + """.stripMargin, + isThriftDerived = true) + } +} +class ParquetSchemaSuite extends ParquetSchemaTest { test("DataType string parser compatibility") { // This is the generated string from previous versions of the Spark SQL, using the following: // val schema = StructType(List( @@ -181,10 +358,7 @@ class ParquetSchemaSuite extends FunSuite with ParquetTest { "StructType(List(StructField(c1,IntegerType,false), StructField(c2,BinaryType,true)))" // scalastyle:off - val jsonString = - """ - |{"type":"struct","fields":[{"name":"c1","type":"integer","nullable":false,"metadata":{}},{"name":"c2","type":"binary","nullable":true,"metadata":{}}]} - """.stripMargin + val jsonString = """{"type":"struct","fields":[{"name":"c1","type":"integer","nullable":false,"metadata":{}},{"name":"c2","type":"binary","nullable":true,"metadata":{}}]}""" // scalastyle:on val fromCaseClassString = ParquetTypesConverter.convertFromString(caseClassString) @@ -278,4 +452,465 @@ class ParquetSchemaSuite extends FunSuite with ParquetTest { StructField("secondField", StringType, nullable = true)))) }.getMessage.contains("detected conflicting schemas")) } + + // ======================================================= + // Tests for converting Parquet LIST to Catalyst ArrayType + // ======================================================= + + testParquetToCatalyst( + "Backwards-compatibility: LIST with nullable element type - 1 - standard", + StructType(Seq( + StructField( + "f1", + ArrayType(IntegerType, containsNull = true), + nullable = true))), + """message root { + | optional group f1 (LIST) { + | repeated group list { + | optional int32 element; + | } + | } + |} + """.stripMargin) + + testParquetToCatalyst( + "Backwards-compatibility: LIST with nullable element type - 2", + StructType(Seq( + StructField( + "f1", + ArrayType(IntegerType, containsNull = true), + nullable = true))), + """message root { + | optional group f1 (LIST) { + | repeated group element { + | optional int32 num; + | } + | } + |} + """.stripMargin) + + testParquetToCatalyst( + "Backwards-compatibility: LIST with non-nullable element type - 1 - standard", + StructType(Seq( + StructField("f1", ArrayType(IntegerType, containsNull = false), nullable = true))), + """message root { + | optional group f1 (LIST) { + | repeated group list { + | required int32 element; + | } + | } + |} + """.stripMargin) + + testParquetToCatalyst( + "Backwards-compatibility: LIST with non-nullable element type - 2", + StructType(Seq( + StructField("f1", ArrayType(IntegerType, containsNull = false), nullable = true))), + """message root { + | optional group f1 (LIST) { + | repeated group element { + | required int32 num; + | } + | } + |} + """.stripMargin) + + testParquetToCatalyst( + "Backwards-compatibility: LIST with non-nullable element type - 3", + StructType(Seq( + StructField("f1", ArrayType(IntegerType, containsNull = false), nullable = true))), + """message root { + | optional group f1 (LIST) { + | repeated int32 element; + | } + |} + """.stripMargin) + + testParquetToCatalyst( + "Backwards-compatibility: LIST with non-nullable element type - 4", + StructType(Seq( + StructField( + "f1", + ArrayType( + StructType(Seq( + StructField("str", StringType, nullable = false), + StructField("num", IntegerType, nullable = false))), + containsNull = false), + nullable = true))), + """message root { + | optional group f1 (LIST) { + | repeated group element { + | required binary str (UTF8); + | required int32 num; + | } + | } + |} + """.stripMargin) + + testParquetToCatalyst( + "Backwards-compatibility: LIST with non-nullable element type - 5 - parquet-avro style", + StructType(Seq( + StructField( + "f1", + ArrayType( + StructType(Seq( + StructField("str", StringType, nullable = false))), + containsNull = false), + nullable = true))), + """message root { + | optional group f1 (LIST) { + | repeated group array { + | required binary str (UTF8); + | } + | } + |} + """.stripMargin) + + testParquetToCatalyst( + "Backwards-compatibility: LIST with non-nullable element type - 6 - parquet-thrift style", + StructType(Seq( + StructField( + "f1", + ArrayType( + StructType(Seq( + StructField("str", StringType, nullable = false))), + containsNull = false), + nullable = true))), + """message root { + | optional group f1 (LIST) { + | repeated group f1_tuple { + | required binary str (UTF8); + | } + | } + |} + """.stripMargin) + + // ======================================================= + // Tests for converting Catalyst ArrayType to Parquet LIST + // ======================================================= + + testCatalystToParquet( + "Backwards-compatibility: LIST with nullable element type - 1 - standard", + StructType(Seq( + StructField( + "f1", + ArrayType(IntegerType, containsNull = true), + nullable = true))), + """message root { + | optional group f1 (LIST) { + | repeated group list { + | optional int32 element; + | } + | } + |} + """.stripMargin, + followParquetFormatSpec = true) + + testCatalystToParquet( + "Backwards-compatibility: LIST with nullable element type - 2 - prior to 1.4.x", + StructType(Seq( + StructField( + "f1", + ArrayType(IntegerType, containsNull = true), + nullable = true))), + """message root { + | optional group f1 (LIST) { + | repeated group bag { + | optional int32 element; + | } + | } + |} + """.stripMargin) + + testCatalystToParquet( + "Backwards-compatibility: LIST with non-nullable element type - 1 - standard", + StructType(Seq( + StructField( + "f1", + ArrayType(IntegerType, containsNull = false), + nullable = true))), + """message root { + | optional group f1 (LIST) { + | repeated group list { + | required int32 element; + | } + | } + |} + """.stripMargin, + followParquetFormatSpec = true) + + testCatalystToParquet( + "Backwards-compatibility: LIST with non-nullable element type - 2 - prior to 1.4.x", + StructType(Seq( + StructField( + "f1", + ArrayType(IntegerType, containsNull = false), + nullable = true))), + """message root { + | optional group f1 (LIST) { + | repeated int32 element; + | } + |} + """.stripMargin) + + // ==================================================== + // Tests for converting Parquet Map to Catalyst MapType + // ==================================================== + + testParquetToCatalyst( + "Backwards-compatibility: MAP with non-nullable value type - 1 - standard", + StructType(Seq( + StructField( + "f1", + MapType(IntegerType, StringType, valueContainsNull = false), + nullable = true))), + """message root { + | optional group f1 (MAP) { + | repeated group key_value { + | required int32 key; + | required binary value (UTF8); + | } + | } + |} + """.stripMargin) + + testParquetToCatalyst( + "Backwards-compatibility: MAP with non-nullable value type - 2", + StructType(Seq( + StructField( + "f1", + MapType(IntegerType, StringType, valueContainsNull = false), + nullable = true))), + """message root { + | optional group f1 (MAP_KEY_VALUE) { + | repeated group map { + | required int32 num; + | required binary str (UTF8); + | } + | } + |} + """.stripMargin) + + testParquetToCatalyst( + "Backwards-compatibility: MAP with non-nullable value type - 3 - prior to 1.4.x", + StructType(Seq( + StructField( + "f1", + MapType(IntegerType, StringType, valueContainsNull = false), + nullable = true))), + """message root { + | optional group f1 (MAP) { + | repeated group map (MAP_KEY_VALUE) { + | required int32 key; + | required binary value (UTF8); + | } + | } + |} + """.stripMargin) + + testParquetToCatalyst( + "Backwards-compatibility: MAP with nullable value type - 1 - standard", + StructType(Seq( + StructField( + "f1", + MapType(IntegerType, StringType, valueContainsNull = true), + nullable = true))), + """message root { + | optional group f1 (MAP) { + | repeated group key_value { + | required int32 key; + | optional binary value (UTF8); + | } + | } + |} + """.stripMargin) + + testParquetToCatalyst( + "Backwards-compatibility: MAP with nullable value type - 2", + StructType(Seq( + StructField( + "f1", + MapType(IntegerType, StringType, valueContainsNull = true), + nullable = true))), + """message root { + | optional group f1 (MAP_KEY_VALUE) { + | repeated group map { + | required int32 num; + | optional binary str (UTF8); + | } + | } + |} + """.stripMargin) + + testParquetToCatalyst( + "Backwards-compatibility: MAP with nullable value type - 3 - parquet-avro style", + StructType(Seq( + StructField( + "f1", + MapType(IntegerType, StringType, valueContainsNull = true), + nullable = true))), + """message root { + | optional group f1 (MAP) { + | repeated group map (MAP_KEY_VALUE) { + | required int32 key; + | optional binary value (UTF8); + | } + | } + |} + """.stripMargin) + + // ==================================================== + // Tests for converting Catalyst MapType to Parquet Map + // ==================================================== + + testCatalystToParquet( + "Backwards-compatibility: MAP with non-nullable value type - 1 - standard", + StructType(Seq( + StructField( + "f1", + MapType(IntegerType, StringType, valueContainsNull = false), + nullable = true))), + """message root { + | optional group f1 (MAP) { + | repeated group key_value { + | required int32 key; + | required binary value (UTF8); + | } + | } + |} + """.stripMargin, + followParquetFormatSpec = true) + + testCatalystToParquet( + "Backwards-compatibility: MAP with non-nullable value type - 2 - prior to 1.4.x", + StructType(Seq( + StructField( + "f1", + MapType(IntegerType, StringType, valueContainsNull = false), + nullable = true))), + """message root { + | optional group f1 (MAP) { + | repeated group map (MAP_KEY_VALUE) { + | required int32 key; + | required binary value (UTF8); + | } + | } + |} + """.stripMargin) + + testCatalystToParquet( + "Backwards-compatibility: MAP with nullable value type - 1 - standard", + StructType(Seq( + StructField( + "f1", + MapType(IntegerType, StringType, valueContainsNull = true), + nullable = true))), + """message root { + | optional group f1 (MAP) { + | repeated group key_value { + | required int32 key; + | optional binary value (UTF8); + | } + | } + |} + """.stripMargin, + followParquetFormatSpec = true) + + testCatalystToParquet( + "Backwards-compatibility: MAP with nullable value type - 3 - prior to 1.4.x", + StructType(Seq( + StructField( + "f1", + MapType(IntegerType, StringType, valueContainsNull = true), + nullable = true))), + """message root { + | optional group f1 (MAP) { + | repeated group map (MAP_KEY_VALUE) { + | required int32 key; + | optional binary value (UTF8); + | } + | } + |} + """.stripMargin) + + // ================================= + // Tests for conversion for decimals + // ================================= + + testSchema( + "DECIMAL(1, 0) - standard", + StructType(Seq(StructField("f1", DecimalType(1, 0)))), + """message root { + | optional int32 f1 (DECIMAL(1, 0)); + |} + """.stripMargin, + followParquetFormatSpec = true) + + testSchema( + "DECIMAL(8, 3) - standard", + StructType(Seq(StructField("f1", DecimalType(8, 3)))), + """message root { + | optional int32 f1 (DECIMAL(8, 3)); + |} + """.stripMargin, + followParquetFormatSpec = true) + + testSchema( + "DECIMAL(9, 3) - standard", + StructType(Seq(StructField("f1", DecimalType(9, 3)))), + """message root { + | optional int32 f1 (DECIMAL(9, 3)); + |} + """.stripMargin, + followParquetFormatSpec = true) + + testSchema( + "DECIMAL(18, 3) - standard", + StructType(Seq(StructField("f1", DecimalType(18, 3)))), + """message root { + | optional int64 f1 (DECIMAL(18, 3)); + |} + """.stripMargin, + followParquetFormatSpec = true) + + testSchema( + "DECIMAL(19, 3) - standard", + StructType(Seq(StructField("f1", DecimalType(19, 3)))), + """message root { + | optional fixed_len_byte_array(9) f1 (DECIMAL(19, 3)); + |} + """.stripMargin, + followParquetFormatSpec = true) + + testSchema( + "DECIMAL(1, 0) - prior to 1.4.x", + StructType(Seq(StructField("f1", DecimalType(1, 0)))), + """message root { + | optional fixed_len_byte_array(1) f1 (DECIMAL(1, 0)); + |} + """.stripMargin) + + testSchema( + "DECIMAL(8, 3) - prior to 1.4.x", + StructType(Seq(StructField("f1", DecimalType(8, 3)))), + """message root { + | optional fixed_len_byte_array(4) f1 (DECIMAL(8, 3)); + |} + """.stripMargin) + + testSchema( + "DECIMAL(9, 3) - prior to 1.4.x", + StructType(Seq(StructField("f1", DecimalType(9, 3)))), + """message root { + | optional fixed_len_byte_array(5) f1 (DECIMAL(9, 3)); + |} + """.stripMargin) + + testSchema( + "DECIMAL(18, 3) - prior to 1.4.x", + StructType(Seq(StructField("f1", DecimalType(18, 3)))), + """message root { + | optional fixed_len_byte_array(8) f1 (DECIMAL(18, 3)); + |} + """.stripMargin) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTest.scala b/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetTest.scala similarity index 57% rename from sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTest.scala rename to sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetTest.scala index 9d17516e0ef7..eb15a1609f1d 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTest.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetTest.scala @@ -21,10 +21,9 @@ import java.io.File import scala.reflect.ClassTag import scala.reflect.runtime.universe.TypeTag -import scala.util.Try -import org.apache.spark.sql.{DataFrame, SQLContext, SaveMode} -import org.apache.spark.util.Utils +import org.apache.spark.sql.test.SQLTestUtils +import org.apache.spark.sql.{DataFrame, SaveMode} /** * A helper trait that provides convenient facilities for Parquet testing. @@ -33,54 +32,7 @@ import org.apache.spark.util.Utils * convenient to use tuples rather than special case classes when writing test cases/suites. * Especially, `Tuple1.apply` can be used to easily wrap a single type/value. */ -private[sql] trait ParquetTest { - val sqlContext: SQLContext - - import sqlContext.implicits.{localSeqToDataFrameHolder, rddToDataFrameHolder} - import sqlContext.{conf, sparkContext} - - protected def configuration = sparkContext.hadoopConfiguration - - /** - * Sets all SQL configurations specified in `pairs`, calls `f`, and then restore all SQL - * configurations. - * - * @todo Probably this method should be moved to a more general place - */ - protected def withSQLConf(pairs: (String, String)*)(f: => Unit): Unit = { - val (keys, values) = pairs.unzip - val currentValues = keys.map(key => Try(conf.getConf(key)).toOption) - (keys, values).zipped.foreach(conf.setConf) - try f finally { - keys.zip(currentValues).foreach { - case (key, Some(value)) => conf.setConf(key, value) - case (key, None) => conf.unsetConf(key) - } - } - } - - /** - * Generates a temporary path without creating the actual file/directory, then pass it to `f`. If - * a file/directory is created there by `f`, it will be delete after `f` returns. - * - * @todo Probably this method should be moved to a more general place - */ - protected def withTempPath(f: File => Unit): Unit = { - val path = Utils.createTempDir() - path.delete() - try f(path) finally Utils.deleteRecursively(path) - } - - /** - * Creates a temporary directory, which is then passed to `f` and will be deleted after `f` - * returns. - * - * @todo Probably this method should be moved to a more general place - */ - protected def withTempDir(f: File => Unit): Unit = { - val dir = Utils.createTempDir().getCanonicalFile - try f(dir) finally Utils.deleteRecursively(dir) - } +private[sql] trait ParquetTest extends SQLTestUtils { /** * Writes `data` to a Parquet file, which is then passed to `f` and will be deleted after `f` @@ -90,7 +42,7 @@ private[sql] trait ParquetTest { (data: Seq[T]) (f: String => Unit): Unit = { withTempPath { file => - sparkContext.parallelize(data).toDF().saveAsParquetFile(file.getCanonicalPath) + sqlContext.createDataFrame(data).write.parquet(file.getCanonicalPath) f(file.getCanonicalPath) } } @@ -102,14 +54,7 @@ private[sql] trait ParquetTest { protected def withParquetDataFrame[T <: Product: ClassTag: TypeTag] (data: Seq[T]) (f: DataFrame => Unit): Unit = { - withParquetFile(data)(path => f(sqlContext.parquetFile(path))) - } - - /** - * Drops temporary table `tableName` after calling `f`. - */ - protected def withTempTable(tableName: String)(f: => Unit): Unit = { - try f finally sqlContext.dropTempTable(tableName) + withParquetFile(data)(path => f(sqlContext.read.parquet(path))) } /** @@ -128,12 +73,12 @@ private[sql] trait ParquetTest { protected def makeParquetFile[T <: Product: ClassTag: TypeTag]( data: Seq[T], path: File): Unit = { - data.toDF().save(path.getCanonicalPath, "org.apache.spark.sql.parquet", SaveMode.Overwrite) + sqlContext.createDataFrame(data).write.mode(SaveMode.Overwrite).parquet(path.getCanonicalPath) } protected def makeParquetFile[T <: Product: ClassTag: TypeTag]( df: DataFrame, path: File): Unit = { - df.save(path.getCanonicalPath, "org.apache.spark.sql.parquet", SaveMode.Overwrite) + df.write.mode(SaveMode.Overwrite).parquet(path.getCanonicalPath) } protected def makePartitionDir( diff --git a/sql/core/src/test/scala/org/apache/spark/sql/sources/CreateTableAsSelectSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/sources/CreateTableAsSelectSuite.scala index 4e54b2eb8df7..a71088430bfd 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/sources/CreateTableAsSelectSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/sources/CreateTableAsSelectSuite.scala @@ -26,18 +26,20 @@ import org.apache.spark.util.Utils class CreateTableAsSelectSuite extends DataSourceTest with BeforeAndAfterAll { - import caseInsensitiveContext._ + import caseInsensitiveContext.sql + + private lazy val sparkContext = caseInsensitiveContext.sparkContext var path: File = null override def beforeAll(): Unit = { path = Utils.createTempDir() val rdd = sparkContext.parallelize((1 to 10).map(i => s"""{"a":$i, "b":"str${i}"}""")) - jsonRDD(rdd).registerTempTable("jt") + caseInsensitiveContext.read.json(rdd).registerTempTable("jt") } override def afterAll(): Unit = { - dropTempTable("jt") + caseInsensitiveContext.dropTempTable("jt") } after { @@ -59,7 +61,7 @@ class CreateTableAsSelectSuite extends DataSourceTest with BeforeAndAfterAll { sql("SELECT a, b FROM jsonTable"), sql("SELECT a, b FROM jt").collect()) - dropTempTable("jsonTable") + caseInsensitiveContext.dropTempTable("jsonTable") } test("CREATE TEMPORARY TABLE AS SELECT based on the file without write permission") { @@ -129,7 +131,7 @@ class CreateTableAsSelectSuite extends DataSourceTest with BeforeAndAfterAll { sql("SELECT * FROM jsonTable"), sql("SELECT a * 4 FROM jt").collect()) - dropTempTable("jsonTable") + caseInsensitiveContext.dropTempTable("jsonTable") // Explicitly delete the data. if (path.exists()) Utils.deleteRecursively(path) @@ -147,7 +149,7 @@ class CreateTableAsSelectSuite extends DataSourceTest with BeforeAndAfterAll { sql("SELECT * FROM jsonTable"), sql("SELECT b FROM jt").collect()) - dropTempTable("jsonTable") + caseInsensitiveContext.dropTempTable("jsonTable") } test("CREATE TEMPORARY TABLE AS SELECT with IF NOT EXISTS is not allowed") { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/sources/DDLTestSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/sources/DDLTestSuite.scala index 6664e8d64c13..54e1efb6e36e 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/sources/DDLTestSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/sources/DDLTestSuite.scala @@ -19,7 +19,9 @@ package org.apache.spark.sql.sources import org.apache.spark.rdd.RDD import org.apache.spark.sql._ +import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.types._ +import org.apache.spark.unsafe.types.UTF8String class DDLScanSource extends RelationProvider { override def createRelation( @@ -43,7 +45,7 @@ case class SimpleDDLScan(from: Int, to: Int, table: String)(@transient val sqlCo StructField("bigintType", LongType, nullable = false), StructField("tinyintType", ByteType, nullable = false), StructField("decimalType", DecimalType.Unlimited, nullable = false), - StructField("fixedDecimalType", DecimalType(5,1), nullable = false), + StructField("fixedDecimalType", DecimalType(5, 1), nullable = false), StructField("binaryType", BinaryType, nullable = false), StructField("booleanType", BooleanType, nullable = false), StructField("smallIntType", ShortType, nullable = false), @@ -51,32 +53,33 @@ case class SimpleDDLScan(from: Int, to: Int, table: String)(@transient val sqlCo StructField("mapType", MapType(StringType, StringType)), StructField("arrayType", ArrayType(StringType)), StructField("structType", - StructType(StructField("f1",StringType) :: - (StructField("f2",IntegerType)) :: Nil + StructType(StructField("f1", StringType) :: StructField("f2", IntegerType) :: Nil ) ) )) + override def needConversion: Boolean = false override def buildScan(): RDD[Row] = { - sqlContext.sparkContext.parallelize(from to to).map(e => Row(s"people$e", e * 2)) + sqlContext.sparkContext.parallelize(from to to).map { e => + InternalRow(UTF8String.fromString(s"people$e"), e * 2): Row + } } } class DDLTestSuite extends DataSourceTest { - import caseInsensitiveContext._ before { - sql( - """ - |CREATE TEMPORARY TABLE ddlPeople - |USING org.apache.spark.sql.sources.DDLScanSource - |OPTIONS ( - | From '1', - | To '10', - | Table 'test1' - |) - """.stripMargin) + caseInsensitiveContext.sql( + """ + |CREATE TEMPORARY TABLE ddlPeople + |USING org.apache.spark.sql.sources.DDLScanSource + |OPTIONS ( + | From '1', + | To '10', + | Table 'test1' + |) + """.stripMargin) } sqlTest( @@ -99,4 +102,11 @@ class DDLTestSuite extends DataSourceTest { Row("arrayType", "array", ""), Row("structType", "struct", "") )) + + test("SPARK-7686 DescribeCommand should have correct physical plan output attributes") { + val attributes = caseInsensitiveContext.sql("describe ddlPeople") + .queryExecution.executedPlan.output + assert(attributes.map(_.name) === Seq("col_name", "data_type", "comment")) + assert(attributes.map(_.dataType).toSet === Set(StringType)) + } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/sources/DataSourceTest.scala b/sql/core/src/test/scala/org/apache/spark/sql/sources/DataSourceTest.scala index 24ed665c67d2..00cc7d5ea580 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/sources/DataSourceTest.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/sources/DataSourceTest.scala @@ -17,14 +17,18 @@ package org.apache.spark.sql.sources +import org.scalatest.BeforeAndAfter + import org.apache.spark.sql._ -import org.apache.spark.sql.catalyst.CatalystConf import org.apache.spark.sql.test.TestSQLContext -import org.scalatest.BeforeAndAfter + abstract class DataSourceTest extends QueryTest with BeforeAndAfter { // We want to test some edge cases. - implicit val caseInsensitiveContext = new SQLContext(TestSQLContext.sparkContext) + protected implicit lazy val caseInsensitiveContext = { + val ctx = new SQLContext(TestSQLContext.sparkContext) + ctx.setConf(SQLConf.CASE_SENSITIVE, false) + ctx + } - caseInsensitiveContext.setConf(SQLConf.CASE_SENSITIVE, "false") } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/sources/FilteredScanSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/sources/FilteredScanSuite.scala index cce747e7dbf6..81b3a0f0c5b3 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/sources/FilteredScanSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/sources/FilteredScanSuite.scala @@ -97,7 +97,7 @@ object FiltersPushed { class FilteredScanSuite extends DataSourceTest { - import caseInsensitiveContext._ + import caseInsensitiveContext.sql before { sql( @@ -154,7 +154,7 @@ class FilteredScanSuite extends DataSourceTest { sqlTest( "SELECT a, b FROM oneToTenFiltered WHERE a IN (1,3,5)", - Seq(1,3,5).map(i => Row(i, i * 2))) + Seq(1, 3, 5).map(i => Row(i, i * 2))) sqlTest( "SELECT a, b FROM oneToTenFiltered WHERE A = 1", diff --git a/sql/core/src/test/scala/org/apache/spark/sql/sources/InsertSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/sources/InsertSuite.scala index d1d427e1790b..0b7c46c482c8 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/sources/InsertSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/sources/InsertSuite.scala @@ -26,14 +26,16 @@ import org.apache.spark.util.Utils class InsertSuite extends DataSourceTest with BeforeAndAfterAll { - import caseInsensitiveContext._ + import caseInsensitiveContext.sql + + private lazy val sparkContext = caseInsensitiveContext.sparkContext var path: File = null override def beforeAll: Unit = { path = Utils.createTempDir() val rdd = sparkContext.parallelize((1 to 10).map(i => s"""{"a":$i, "b":"str${i}"}""")) - jsonRDD(rdd).registerTempTable("jt") + caseInsensitiveContext.read.json(rdd).registerTempTable("jt") sql( s""" |CREATE TEMPORARY TABLE jsonTable (a int, b string) @@ -45,8 +47,8 @@ class InsertSuite extends DataSourceTest with BeforeAndAfterAll { } override def afterAll: Unit = { - dropTempTable("jsonTable") - dropTempTable("jt") + caseInsensitiveContext.dropTempTable("jsonTable") + caseInsensitiveContext.dropTempTable("jt") Utils.deleteRecursively(path) } @@ -109,7 +111,7 @@ class InsertSuite extends DataSourceTest with BeforeAndAfterAll { // Writing the table to less part files. val rdd1 = sparkContext.parallelize((1 to 10).map(i => s"""{"a":$i, "b":"str${i}"}"""), 5) - jsonRDD(rdd1).registerTempTable("jt1") + caseInsensitiveContext.read.json(rdd1).registerTempTable("jt1") sql( s""" |INSERT OVERWRITE TABLE jsonTable SELECT a, b FROM jt1 @@ -121,7 +123,7 @@ class InsertSuite extends DataSourceTest with BeforeAndAfterAll { // Writing the table to more part files. val rdd2 = sparkContext.parallelize((1 to 10).map(i => s"""{"a":$i, "b":"str${i}"}"""), 10) - jsonRDD(rdd2).registerTempTable("jt2") + caseInsensitiveContext.read.json(rdd2).registerTempTable("jt2") sql( s""" |INSERT OVERWRITE TABLE jsonTable SELECT a, b FROM jt2 @@ -140,8 +142,8 @@ class InsertSuite extends DataSourceTest with BeforeAndAfterAll { (1 to 10).map(i => Row(i * 10, s"str$i")) ) - dropTempTable("jt1") - dropTempTable("jt2") + caseInsensitiveContext.dropTempTable("jt1") + caseInsensitiveContext.dropTempTable("jt2") } test("INSERT INTO not supported for JSONRelation for now") { @@ -154,13 +156,14 @@ class InsertSuite extends DataSourceTest with BeforeAndAfterAll { } test("save directly to the path of a JSON table") { - table("jt").selectExpr("a * 5 as a", "b").save(path.toString, "json", SaveMode.Overwrite) + caseInsensitiveContext.table("jt").selectExpr("a * 5 as a", "b") + .write.mode(SaveMode.Overwrite).json(path.toString) checkAnswer( sql("SELECT a, b FROM jsonTable"), (1 to 10).map(i => Row(i * 5, s"str$i")) ) - table("jt").save(path.toString, "json", SaveMode.Overwrite) + caseInsensitiveContext.table("jt").write.mode(SaveMode.Overwrite).json(path.toString) checkAnswer( sql("SELECT a, b FROM jsonTable"), (1 to 10).map(i => Row(i, s"str$i")) @@ -181,7 +184,7 @@ class InsertSuite extends DataSourceTest with BeforeAndAfterAll { test("Caching") { // Cached Query Execution - cacheTable("jsonTable") + caseInsensitiveContext.cacheTable("jsonTable") assertCached(sql("SELECT * FROM jsonTable")) checkAnswer( sql("SELECT * FROM jsonTable"), @@ -220,7 +223,7 @@ class InsertSuite extends DataSourceTest with BeforeAndAfterAll { sql("SELECT a * 2, b FROM jt").collect()) // Verify uncaching - uncacheTable("jsonTable") + caseInsensitiveContext.uncacheTable("jsonTable") assertCached(sql("SELECT * FROM jsonTable"), 0) } @@ -251,6 +254,6 @@ class InsertSuite extends DataSourceTest with BeforeAndAfterAll { "It is not allowed to insert into a table that is not an InsertableRelation." ) - dropTempTable("oneToTen") + caseInsensitiveContext.dropTempTable("oneToTen") } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/sources/PrunedScanSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/sources/PrunedScanSuite.scala index c2bc52e2120c..257526feab94 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/sources/PrunedScanSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/sources/PrunedScanSuite.scala @@ -52,10 +52,9 @@ case class SimplePrunedScan(from: Int, to: Int)(@transient val sqlContext: SQLCo } class PrunedScanSuite extends DataSourceTest { - import caseInsensitiveContext._ before { - sql( + caseInsensitiveContext.sql( """ |CREATE TEMPORARY TABLE oneToTenPruned |USING org.apache.spark.sql.sources.PrunedScanSource @@ -115,7 +114,7 @@ class PrunedScanSuite extends DataSourceTest { def testPruning(sqlString: String, expectedColumns: String*): Unit = { test(s"Columns output ${expectedColumns.mkString(",")}: $sqlString") { - val queryExecution = sql(sqlString).queryExecution + val queryExecution = caseInsensitiveContext.sql(sqlString).queryExecution val rawPlan = queryExecution.executedPlan.collect { case p: execution.PhysicalRDD => p } match { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/sources/ResolvedDataSourceSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/sources/ResolvedDataSourceSuite.scala index 8331a14c9295..296b0d6f74a0 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/sources/ResolvedDataSourceSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/sources/ResolvedDataSourceSuite.scala @@ -17,9 +17,9 @@ package org.apache.spark.sql.sources -import org.scalatest.FunSuite +import org.apache.spark.SparkFunSuite -class ResolvedDataSourceSuite extends FunSuite { +class ResolvedDataSourceSuite extends SparkFunSuite { test("builtin sources") { assert(ResolvedDataSource.lookupDataSource("jdbc") === diff --git a/sql/core/src/test/scala/org/apache/spark/sql/sources/SaveLoadSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/sources/SaveLoadSuite.scala index 6567d1acd764..b032515a9d28 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/sources/SaveLoadSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/sources/SaveLoadSuite.scala @@ -27,7 +27,9 @@ import org.apache.spark.util.Utils class SaveLoadSuite extends DataSourceTest with BeforeAndAfterAll { - import caseInsensitiveContext._ + import caseInsensitiveContext.sql + + private lazy val sparkContext = caseInsensitiveContext.sparkContext var originalDefaultSource: String = null @@ -36,62 +38,72 @@ class SaveLoadSuite extends DataSourceTest with BeforeAndAfterAll { var df: DataFrame = null override def beforeAll(): Unit = { - originalDefaultSource = conf.defaultDataSourceName + originalDefaultSource = caseInsensitiveContext.conf.defaultDataSourceName path = Utils.createTempDir() path.delete() val rdd = sparkContext.parallelize((1 to 10).map(i => s"""{"a":$i, "b":"str${i}"}""")) - df = jsonRDD(rdd) + df = caseInsensitiveContext.read.json(rdd) df.registerTempTable("jsonTable") } override def afterAll(): Unit = { - conf.setConf(SQLConf.DEFAULT_DATA_SOURCE_NAME, originalDefaultSource) + caseInsensitiveContext.conf.setConf(SQLConf.DEFAULT_DATA_SOURCE_NAME, originalDefaultSource) } after { - conf.setConf(SQLConf.DEFAULT_DATA_SOURCE_NAME, originalDefaultSource) + caseInsensitiveContext.conf.setConf(SQLConf.DEFAULT_DATA_SOURCE_NAME, originalDefaultSource) Utils.deleteRecursively(path) } def checkLoad(): Unit = { - conf.setConf(SQLConf.DEFAULT_DATA_SOURCE_NAME, "org.apache.spark.sql.json") - checkAnswer(load(path.toString), df.collect()) + caseInsensitiveContext.conf.setConf( + SQLConf.DEFAULT_DATA_SOURCE_NAME, "org.apache.spark.sql.json") + checkAnswer(caseInsensitiveContext.read.load(path.toString), df.collect()) // Test if we can pick up the data source name passed in load. - conf.setConf(SQLConf.DEFAULT_DATA_SOURCE_NAME, "not a source name") - checkAnswer(load(path.toString, "org.apache.spark.sql.json"), df.collect()) - checkAnswer(load("org.apache.spark.sql.json", Map("path" -> path.toString)), df.collect()) + caseInsensitiveContext.conf.setConf(SQLConf.DEFAULT_DATA_SOURCE_NAME, "not a source name") + checkAnswer(caseInsensitiveContext.read.format("json").load(path.toString), df.collect()) + checkAnswer(caseInsensitiveContext.read.format("json").load(path.toString), df.collect()) val schema = StructType(StructField("b", StringType, true) :: Nil) checkAnswer( - load("org.apache.spark.sql.json", schema, Map("path" -> path.toString)), + caseInsensitiveContext.read.format("json").schema(schema).load(path.toString), sql("SELECT b FROM jsonTable").collect()) } test("save with path and load") { - conf.setConf(SQLConf.DEFAULT_DATA_SOURCE_NAME, "org.apache.spark.sql.json") - df.save(path.toString) + caseInsensitiveContext.conf.setConf( + SQLConf.DEFAULT_DATA_SOURCE_NAME, "org.apache.spark.sql.json") + df.write.save(path.toString) + checkLoad() + } + + test("save with string mode and path, and load") { + caseInsensitiveContext.conf.setConf( + SQLConf.DEFAULT_DATA_SOURCE_NAME, "org.apache.spark.sql.json") + path.createNewFile() + df.write.mode("overwrite").save(path.toString) checkLoad() } test("save with path and datasource, and load") { - conf.setConf(SQLConf.DEFAULT_DATA_SOURCE_NAME, "not a source name") - df.save(path.toString, "org.apache.spark.sql.json") + caseInsensitiveContext.conf.setConf(SQLConf.DEFAULT_DATA_SOURCE_NAME, "not a source name") + df.write.json(path.toString) checkLoad() } test("save with data source and options, and load") { - conf.setConf(SQLConf.DEFAULT_DATA_SOURCE_NAME, "not a source name") - df.save("org.apache.spark.sql.json", SaveMode.ErrorIfExists, Map("path" -> path.toString)) + caseInsensitiveContext.conf.setConf(SQLConf.DEFAULT_DATA_SOURCE_NAME, "not a source name") + df.write.mode(SaveMode.ErrorIfExists).json(path.toString) checkLoad() } test("save and save again") { - df.save(path.toString, "org.apache.spark.sql.json") + df.write.json(path.toString) var message = intercept[RuntimeException] { - df.save(path.toString, "org.apache.spark.sql.json") + df.write.json(path.toString) }.getMessage assert( @@ -100,14 +112,14 @@ class SaveLoadSuite extends DataSourceTest with BeforeAndAfterAll { if (path.exists()) Utils.deleteRecursively(path) - df.save(path.toString, "org.apache.spark.sql.json") + df.write.json(path.toString) checkLoad() - df.save("org.apache.spark.sql.json", SaveMode.Overwrite, Map("path" -> path.toString)) + df.write.mode(SaveMode.Overwrite).json(path.toString) checkLoad() message = intercept[RuntimeException] { - df.save("org.apache.spark.sql.json", SaveMode.Append, Map("path" -> path.toString)) + df.write.mode(SaveMode.Append).json(path.toString) }.getMessage assert( diff --git a/sql/core/src/test/scala/org/apache/spark/sql/sources/TableScanSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/sources/TableScanSuite.scala index 77af04a49174..2c916f3322b6 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/sources/TableScanSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/sources/TableScanSuite.scala @@ -17,11 +17,14 @@ package org.apache.spark.sql.sources -import java.sql.{Timestamp, Date} +import java.sql.{Date, Timestamp} import org.apache.spark.rdd.RDD import org.apache.spark.sql._ +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.util.DateTimeUtils import org.apache.spark.sql.types._ +import org.apache.spark.unsafe.types.UTF8String class DefaultSource extends SimpleScanSource @@ -47,6 +50,10 @@ class AllDataTypesScanSource extends SchemaRelationProvider { sqlContext: SQLContext, parameters: Map[String, String], schema: StructType): BaseRelation = { + // Check that weird parameters are passed correctly. + parameters("option_with_underscores") + parameters("option.with.dots") + AllDataTypesScan(parameters("from").toInt, parameters("TO").toInt, schema)(sqlContext) } } @@ -60,10 +67,12 @@ case class AllDataTypesScan( override def schema: StructType = userSpecifiedSchema + override def needConversion: Boolean = false + override def buildScan(): RDD[Row] = { sqlContext.sparkContext.parallelize(from to to).map { i => - Row( - s"str_$i", + InternalRow( + UTF8String.fromString(s"str_$i"), s"str_$i".getBytes(), i % 2 == 0, i.toByte, @@ -72,25 +81,26 @@ case class AllDataTypesScan( i.toLong, i.toFloat, i.toDouble, - new java.math.BigDecimal(i), - new java.math.BigDecimal(i), - new Date(1970, 1, 1), - new Timestamp(20000 + i), - s"varchar_$i", + Decimal(new java.math.BigDecimal(i)), + Decimal(new java.math.BigDecimal(i)), + DateTimeUtils.fromJavaDate(new Date(1970, 1, 1)), + DateTimeUtils.fromJavaTimestamp(new Timestamp(20000 + i)), + UTF8String.fromString(s"varchar_$i"), Seq(i, i + 1), - Seq(Map(s"str_$i" -> Row(i.toLong))), - Map(i -> i.toString), - Map(Map(s"str_$i" -> i.toFloat) -> Row(i.toLong)), - Row(i, i.toString), - Row(Seq(s"str_$i", s"str_${i + 1}"), Row(Seq(new Date(1970, 1, i + 1))))) + Seq(Map(UTF8String.fromString(s"str_$i") -> InternalRow(i.toLong))), + Map(i -> UTF8String.fromString(i.toString)), + Map(Map(UTF8String.fromString(s"str_$i") -> i.toFloat) -> InternalRow(i.toLong)), + InternalRow(i, UTF8String.fromString(i.toString)), + InternalRow(Seq(UTF8String.fromString(s"str_$i"), UTF8String.fromString(s"str_${i + 1}")), + InternalRow(Seq(DateTimeUtils.fromJavaDate(new Date(1970, 1, i + 1)))))) } } } class TableScanSuite extends DataSourceTest { - import caseInsensitiveContext._ + import caseInsensitiveContext.sql - var tableWithSchemaExpected = (1 to 10).map { i => + private lazy val tableWithSchemaExpected = (1 to 10).map { i => Row( s"str_$i", s"str_$i", @@ -121,7 +131,9 @@ class TableScanSuite extends DataSourceTest { |USING org.apache.spark.sql.sources.SimpleScanSource |OPTIONS ( | From '1', - | To '10' + | To '10', + | option_with_underscores 'someval', + | option.with.dots 'someval' |) """.stripMargin) @@ -152,7 +164,9 @@ class TableScanSuite extends DataSourceTest { |USING org.apache.spark.sql.sources.AllDataTypesScanSource |OPTIONS ( | From '1', - | To '10' + | To '10', + | option_with_underscores 'someval', + | option.with.dots 'someval' |) """.stripMargin) } @@ -215,7 +229,7 @@ class TableScanSuite extends DataSourceTest { Nil ) - assert(expectedSchema == table("tableWithSchema").schema) + assert(expectedSchema == caseInsensitiveContext.table("tableWithSchema").schema) checkAnswer( sql( @@ -270,7 +284,7 @@ class TableScanSuite extends DataSourceTest { test("Caching") { // Cached Query Execution - cacheTable("oneToTen") + caseInsensitiveContext.cacheTable("oneToTen") assertCached(sql("SELECT * FROM oneToTen")) checkAnswer( sql("SELECT * FROM oneToTen"), @@ -297,7 +311,7 @@ class TableScanSuite extends DataSourceTest { (2 to 10).map(i => Row(i, i - 1)).toSeq) // Verify uncaching - uncacheTable("oneToTen") + caseInsensitiveContext.uncacheTable("oneToTen") assertCached(sql("SELECT * FROM oneToTen"), 0) } @@ -354,7 +368,9 @@ class TableScanSuite extends DataSourceTest { |USING org.apache.spark.sql.sources.AllDataTypesScanSource |OPTIONS ( | from '1', - | to '10' + | to '10', + | option_with_underscores 'someval', + | option.with.dots 'someval' |) """.stripMargin) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/test/SQLTestUtils.scala b/sql/core/src/test/scala/org/apache/spark/sql/test/SQLTestUtils.scala new file mode 100644 index 000000000000..fa01823e9417 --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/test/SQLTestUtils.scala @@ -0,0 +1,90 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.test + +import java.io.File + +import scala.util.Try + +import org.apache.spark.sql.SQLContext +import org.apache.spark.util.Utils + +trait SQLTestUtils { + def sqlContext: SQLContext + + protected def configuration = sqlContext.sparkContext.hadoopConfiguration + + /** + * Sets all SQL configurations specified in `pairs`, calls `f`, and then restore all SQL + * configurations. + * + * @todo Probably this method should be moved to a more general place + */ + protected def withSQLConf(pairs: (String, String)*)(f: => Unit): Unit = { + val (keys, values) = pairs.unzip + val currentValues = keys.map(key => Try(sqlContext.conf.getConfString(key)).toOption) + (keys, values).zipped.foreach(sqlContext.conf.setConfString) + try f finally { + keys.zip(currentValues).foreach { + case (key, Some(value)) => sqlContext.conf.setConfString(key, value) + case (key, None) => sqlContext.conf.unsetConf(key) + } + } + } + + /** + * Generates a temporary path without creating the actual file/directory, then pass it to `f`. If + * a file/directory is created there by `f`, it will be delete after `f` returns. + * + * @todo Probably this method should be moved to a more general place + */ + protected def withTempPath(f: File => Unit): Unit = { + val path = Utils.createTempDir() + path.delete() + try f(path) finally Utils.deleteRecursively(path) + } + + /** + * Creates a temporary directory, which is then passed to `f` and will be deleted after `f` + * returns. + * + * @todo Probably this method should be moved to a more general place + */ + protected def withTempDir(f: File => Unit): Unit = { + val dir = Utils.createTempDir().getCanonicalFile + try f(dir) finally Utils.deleteRecursively(dir) + } + + /** + * Drops temporary table `tableName` after calling `f`. + */ + protected def withTempTable(tableNames: String*)(f: => Unit): Unit = { + try f finally tableNames.foreach(sqlContext.dropTempTable) + } + + /** + * Drops table `tableName` after calling `f`. + */ + protected def withTable(tableNames: String*)(f: => Unit): Unit = { + try f finally { + tableNames.foreach { name => + sqlContext.sql(s"DROP TABLE IF EXISTS $name") + } + } + } +} diff --git a/sql/hive-thriftserver/pom.xml b/sql/hive-thriftserver/pom.xml index 437f697d25bf..73e6ccdb1eaf 100644 --- a/sql/hive-thriftserver/pom.xml +++ b/sql/hive-thriftserver/pom.xml @@ -22,7 +22,7 @@ org.apache.spark spark-parent_2.10 - 1.4.0-SNAPSHOT + 1.5.0-SNAPSHOT ../../pom.xml @@ -41,6 +41,13 @@ spark-hive_${scala.binary.version} ${project.version} + + org.apache.spark + spark-core_${scala.binary.version} + ${project.version} + test-jar + test + com.google.guava guava diff --git a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/HiveThriftServer2.scala b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/HiveThriftServer2.scala index 3458b04bfba0..700d994bb6a8 100644 --- a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/HiveThriftServer2.scala +++ b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/HiveThriftServer2.scala @@ -17,23 +17,24 @@ package org.apache.spark.sql.hive.thriftserver +import scala.collection.mutable +import scala.collection.mutable.ArrayBuffer + import org.apache.commons.logging.LogFactory import org.apache.hadoop.hive.conf.HiveConf import org.apache.hadoop.hive.conf.HiveConf.ConfVars import org.apache.hive.service.cli.thrift.{ThriftBinaryCLIService, ThriftHttpCLIService} import org.apache.hive.service.server.{HiveServer2, ServerOptionsProcessor} -import org.apache.spark.sql.SQLConf -import org.apache.spark.{SparkContext, SparkConf, Logging} import org.apache.spark.annotation.DeveloperApi +import org.apache.spark.scheduler.{SparkListener, SparkListenerApplicationEnd, SparkListenerJobStart} +import org.apache.spark.sql.SQLConf import org.apache.spark.sql.hive.HiveContext import org.apache.spark.sql.hive.thriftserver.ReflectionUtils._ -import org.apache.spark.scheduler.{SparkListenerJobStart, SparkListenerApplicationEnd, SparkListener} import org.apache.spark.sql.hive.thriftserver.ui.ThriftServerTab import org.apache.spark.util.Utils +import org.apache.spark.{Logging, SparkContext} -import scala.collection.mutable -import scala.collection.mutable.ArrayBuffer /** * The main entry point for the Spark SQL port of HiveServer2. Starts up a `SparkSQLContext` and a @@ -51,6 +52,7 @@ object HiveThriftServer2 extends Logging { @DeveloperApi def startWithContext(sqlContext: HiveContext): Unit = { val server = new HiveThriftServer2(sqlContext) + sqlContext.setConf("spark.sql.hive.version", HiveContext.hiveExecutionVersion) server.init(sqlContext.hiveconf) server.start() listener = new HiveThriftServer2Listener(server, sqlContext.conf) @@ -151,9 +153,9 @@ object HiveThriftServer2 extends Logging { val sessionList = new mutable.LinkedHashMap[String, SessionInfo] val executionList = new mutable.LinkedHashMap[String, ExecutionInfo] val retainedStatements = - conf.getConf(SQLConf.THRIFTSERVER_UI_STATEMENT_LIMIT, "200").toInt + conf.getConf(SQLConf.THRIFTSERVER_UI_STATEMENT_LIMIT) val retainedSessions = - conf.getConf(SQLConf.THRIFTSERVER_UI_SESSION_LIMIT, "200").toInt + conf.getConf(SQLConf.THRIFTSERVER_UI_SESSION_LIMIT) var totalRunning = 0 override def onJobStart(jobStart: SparkListenerJobStart): Unit = { diff --git a/sql/hive-thriftserver/v0.13.1/src/main/scala/org/apache/spark/sql/hive/thriftserver/Shim13.scala b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkExecuteStatementOperation.scala similarity index 51% rename from sql/hive-thriftserver/v0.13.1/src/main/scala/org/apache/spark/sql/hive/thriftserver/Shim13.scala rename to sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkExecuteStatementOperation.scala index b9d4f1c58c98..e8758887ff3a 100644 --- a/sql/hive-thriftserver/v0.13.1/src/main/scala/org/apache/spark/sql/hive/thriftserver/Shim13.scala +++ b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkExecuteStatementOperation.scala @@ -17,78 +17,55 @@ package org.apache.spark.sql.hive.thriftserver +import java.security.PrivilegedExceptionAction import java.sql.{Date, Timestamp} -import java.util.concurrent.Executors -import java.util.{ArrayList => JArrayList, List => JList, Map => JMap, UUID} - -import org.apache.commons.logging.Log -import org.apache.hadoop.hive.conf.HiveConf -import org.apache.hadoop.hive.conf.HiveConf.ConfVars -import org.apache.hive.service.cli.thrift.TProtocolVersion -import org.apache.spark.sql.hive.thriftserver.server.SparkSQLOperationManager +import java.util.concurrent.RejectedExecutionException +import java.util.{Map => JMap, UUID} import scala.collection.JavaConversions._ import scala.collection.mutable.{ArrayBuffer, Map => SMap} +import scala.util.control.NonFatal +import org.apache.hadoop.hive.conf.HiveConf import org.apache.hadoop.hive.metastore.api.FieldSchema -import org.apache.hadoop.security.UserGroupInformation import org.apache.hive.service.cli._ +import org.apache.hadoop.hive.ql.metadata.Hive +import org.apache.hadoop.hive.ql.metadata.HiveException +import org.apache.hadoop.hive.ql.session.SessionState +import org.apache.hadoop.hive.shims.ShimLoader +import org.apache.hadoop.security.UserGroupInformation import org.apache.hive.service.cli.operation.ExecuteStatementOperation -import org.apache.hive.service.cli.session.{SessionManager, HiveSession} +import org.apache.hive.service.cli.session.HiveSession -import org.apache.spark.{SparkContext, Logging} -import org.apache.spark.sql.{DataFrame, Row => SparkRow, SQLConf} +import org.apache.spark.Logging import org.apache.spark.sql.execution.SetCommand -import org.apache.spark.sql.hive.thriftserver.ReflectionUtils._ import org.apache.spark.sql.hive.{HiveContext, HiveMetastoreTypes} import org.apache.spark.sql.types._ +import org.apache.spark.sql.{DataFrame, Row => SparkRow, SQLConf} -/** - * A compatibility layer for interacting with Hive version 0.13.1. - */ -private[thriftserver] object HiveThriftServerShim { - val version = "0.13.1" - - def setServerUserName( - sparkServiceUGI: UserGroupInformation, - sparkCliService:SparkSQLCLIService) = { - setSuperField(sparkCliService, "serviceUGI", sparkServiceUGI) - } -} - -private[hive] class SparkSQLDriver(val _context: HiveContext = SparkSQLEnv.hiveContext) - extends AbstractSparkSQLDriver(_context) { - override def getResults(res: JList[_]): Boolean = { - if (hiveResponse == null) { - false - } else { - res.asInstanceOf[JArrayList[String]].addAll(hiveResponse) - hiveResponse = null - true - } - } -} private[hive] class SparkExecuteStatementOperation( parentSession: HiveSession, statement: String, confOverlay: JMap[String, String], - runInBackground: Boolean = true)( - hiveContext: HiveContext, - sessionToActivePool: SMap[SessionHandle, String]) - // NOTE: `runInBackground` is set to `false` intentionally to disable asynchronous execution - extends ExecuteStatementOperation(parentSession, statement, confOverlay, false) with Logging { + runInBackground: Boolean = true) + (hiveContext: HiveContext, sessionToActivePool: SMap[SessionHandle, String]) + extends ExecuteStatementOperation(parentSession, statement, confOverlay, runInBackground) + with Logging { private var result: DataFrame = _ private var iter: Iterator[SparkRow] = _ private var dataTypes: Array[DataType] = _ + private var statementId: String = _ def close(): Unit = { // RDDs will be cleaned automatically upon garbage collection. - logDebug("CLOSING") + hiveContext.sparkContext.clearJobGroup() + logDebug(s"CLOSING $statementId") + cleanup(OperationState.CLOSED) } - def addNonNullColumnValue(from: SparkRow, to: ArrayBuffer[Any], ordinal: Int) { + def addNonNullColumnValue(from: SparkRow, to: ArrayBuffer[Any], ordinal: Int) { dataTypes(ordinal) match { case StringType => to += from.getString(ordinal) @@ -149,10 +126,10 @@ private[hive] class SparkExecuteStatementOperation( } def getResultSetSchema: TableSchema = { - logInfo(s"Result Schema: ${result.queryExecution.analyzed.output}") - if (result.queryExecution.analyzed.output.size == 0) { + if (result == null || result.queryExecution.analyzed.output.size == 0) { new TableSchema(new FieldSchema("Result", "string", "") :: Nil) } else { + logInfo(s"Result Schema: ${result.queryExecution.analyzed.output}") val schema = result.queryExecution.analyzed.output.map { attr => new FieldSchema(attr.name, HiveMetastoreTypes.toMetastoreType(attr.dataType), "") } @@ -160,9 +137,73 @@ private[hive] class SparkExecuteStatementOperation( } } - def run(): Unit = { - val statementId = UUID.randomUUID().toString - logInfo(s"Running query '$statement'") + override def run(): Unit = { + setState(OperationState.PENDING) + setHasResultSet(true) // avoid no resultset for async run + + if (!runInBackground) { + runInternal() + } else { + val parentSessionState = SessionState.get() + val hiveConf = getConfigForOperation() + val sparkServiceUGI = ShimLoader.getHadoopShims.getUGIForConf(hiveConf) + val sessionHive = getCurrentHive() + val currentSqlSession = hiveContext.currentSession + + // Runnable impl to call runInternal asynchronously, + // from a different thread + val backgroundOperation = new Runnable() { + + override def run(): Unit = { + val doAsAction = new PrivilegedExceptionAction[Object]() { + override def run(): Object = { + + // User information is part of the metastore client member in Hive + hiveContext.setSession(currentSqlSession) + Hive.set(sessionHive) + SessionState.setCurrentSessionState(parentSessionState) + try { + runInternal() + } catch { + case e: HiveSQLException => + setOperationException(e) + log.error("Error running hive query: ", e) + } + return null + } + } + + try { + ShimLoader.getHadoopShims().doAs(sparkServiceUGI, doAsAction) + } catch { + case e: Exception => + setOperationException(new HiveSQLException(e)) + logError("Error running hive query as user : " + + sparkServiceUGI.getShortUserName(), e) + } + } + } + try { + // This submit blocks if no background threads are available to run this operation + val backgroundHandle = + getParentSession().getSessionManager().submitBackgroundOperation(backgroundOperation) + setBackgroundHandle(backgroundHandle) + } catch { + case rejected: RejectedExecutionException => + setState(OperationState.ERROR) + throw new HiveSQLException("The background threadpool cannot accept" + + " new task for execution, please retry the operation", rejected) + case NonFatal(e) => + logError(s"Error executing query in background", e) + setState(OperationState.ERROR) + throw e + } + } + } + + private def runInternal(): Unit = { + statementId = UUID.randomUUID().toString + logInfo(s"Running query '$statement' with $statementId") setState(OperationState.RUNNING) HiveThriftServer2.listener.onStatementStart( statementId, @@ -178,7 +219,7 @@ private[hive] class SparkExecuteStatementOperation( result = hiveContext.sql(statement) logDebug(result.queryExecution.toString()) result.queryExecution.logical match { - case SetCommand(Some((SQLConf.THRIFTSERVER_POOL, Some(value))), _) => + case SetCommand(Some((SQLConf.THRIFTSERVER_POOL.key, Some(value)))) => sessionToActivePool(parentSession.getSessionHandle) = value logInfo(s"Setting spark.scheduler.pool=$value for future statements in this session.") case _ => @@ -194,63 +235,82 @@ private[hive] class SparkExecuteStatementOperation( } } dataTypes = result.queryExecution.analyzed.output.map(_.dataType).toArray - setHasResultSet(true) } catch { + case e: HiveSQLException => + if (getStatus().getState() == OperationState.CANCELED) { + return + } else { + setState(OperationState.ERROR); + throw e + } // Actually do need to catch Throwable as some failures don't inherit from Exception and // HiveServer will silently swallow them. case e: Throwable => + val currentState = getStatus().getState() + logError(s"Error executing query, currentState $currentState, ", e) setState(OperationState.ERROR) HiveThriftServer2.listener.onStatementError( statementId, e.getMessage, e.getStackTraceString) - logError("Error executing query:", e) throw new HiveSQLException(e.toString) } setState(OperationState.FINISHED) HiveThriftServer2.listener.onStatementFinish(statementId) } -} - -private[hive] class SparkSQLSessionManager(hiveContext: HiveContext) - extends SessionManager - with ReflectedCompositeService { - - private lazy val sparkSqlOperationManager = new SparkSQLOperationManager(hiveContext) - - override def init(hiveConf: HiveConf) { - setSuperField(this, "hiveConf", hiveConf) - - val backgroundPoolSize = hiveConf.getIntVar(ConfVars.HIVE_SERVER2_ASYNC_EXEC_THREADS) - setSuperField(this, "backgroundOperationPool", Executors.newFixedThreadPool(backgroundPoolSize)) - getAncestorField[Log](this, 3, "LOG").info( - s"HiveServer2: Async execution pool size $backgroundPoolSize") - setSuperField(this, "operationManager", sparkSqlOperationManager) - addService(sparkSqlOperationManager) - - initCompositeService(hiveConf) + override def cancel(): Unit = { + logInfo(s"Cancel '$statement' with $statementId") + if (statementId != null) { + hiveContext.sparkContext.cancelJobGroup(statementId) + } + cleanup(OperationState.CANCELED) } - override def openSession( - protocol: TProtocolVersion, - username: String, - passwd: String, - sessionConf: java.util.Map[String, String], - withImpersonation: Boolean, - delegationToken: String): SessionHandle = { - hiveContext.openSession() - val sessionHandle = super.openSession( - protocol, username, passwd, sessionConf, withImpersonation, delegationToken) - val session = super.getSession(sessionHandle) - HiveThriftServer2.listener.onSessionCreated( - session.getIpAddress, sessionHandle.getSessionId.toString, session.getUsername) - sessionHandle + private def cleanup(state: OperationState) { + setState(state) + if (runInBackground) { + val backgroundHandle = getBackgroundHandle() + if (backgroundHandle != null) { + backgroundHandle.cancel(true) + } + } } - override def closeSession(sessionHandle: SessionHandle) { - HiveThriftServer2.listener.onSessionClosed(sessionHandle.getSessionId.toString) - super.closeSession(sessionHandle) - sparkSqlOperationManager.sessionToActivePool -= sessionHandle + /** + * If there are query specific settings to overlay, then create a copy of config + * There are two cases we need to clone the session config that's being passed to hive driver + * 1. Async query - + * If the client changes a config setting, that shouldn't reflect in the execution + * already underway + * 2. confOverlay - + * The query specific settings should only be applied to the query config and not session + * @return new configuration + * @throws HiveSQLException + */ + private def getConfigForOperation(): HiveConf = { + var sqlOperationConf = getParentSession().getHiveConf() + if (!getConfOverlay().isEmpty() || runInBackground) { + // clone the partent session config for this query + sqlOperationConf = new HiveConf(sqlOperationConf) + + // apply overlay query specific settings, if any + getConfOverlay().foreach { case (k, v) => + try { + sqlOperationConf.verifyAndSet(k, v) + } catch { + case e: IllegalArgumentException => + throw new HiveSQLException("Error applying statement specific settings", e) + } + } + } + return sqlOperationConf + } - hiveContext.detachSession() + private def getCurrentHive(): Hive = { + try { + return Hive.get() + } catch { + case e: HiveException => + throw new HiveSQLException("Failed to get current Hive object", e); + } } } diff --git a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkSQLCLIDriver.scala b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkSQLCLIDriver.scala index deb1008c468b..039cfa40d26b 100644 --- a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkSQLCLIDriver.scala +++ b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkSQLCLIDriver.scala @@ -32,18 +32,18 @@ import org.apache.hadoop.hive.common.{HiveInterruptCallback, HiveInterruptUtils} import org.apache.hadoop.hive.conf.HiveConf import org.apache.hadoop.hive.ql.Driver import org.apache.hadoop.hive.ql.exec.Utilities -import org.apache.hadoop.hive.ql.processors.{AddResourceProcessor, SetProcessor, CommandProcessor} +import org.apache.hadoop.hive.ql.processors.{AddResourceProcessor, SetProcessor, CommandProcessor, CommandProcessorFactory} import org.apache.hadoop.hive.ql.session.SessionState import org.apache.thrift.transport.TSocket import org.apache.spark.Logging -import org.apache.spark.sql.hive.{HiveContext, HiveShim} +import org.apache.spark.sql.hive.HiveContext import org.apache.spark.util.Utils private[hive] object SparkSQLCLIDriver { private var prompt = "spark-sql" private var continuedPrompt = "".padTo(prompt.length, ' ') - private var transport:TSocket = _ + private var transport: TSocket = _ installSignalHandler() @@ -267,7 +267,7 @@ private[hive] class SparkSQLCLIDriver extends CliDriver with Logging { } else { var ret = 0 val hconf = conf.asInstanceOf[HiveConf] - val proc: CommandProcessor = HiveShim.getCommandProcessor(Array(tokens(0)), hconf) + val proc: CommandProcessor = CommandProcessorFactory.get(Array(tokens(0)), hconf) if (proc != null) { if (proc.isInstanceOf[Driver] || proc.isInstanceOf[SetProcessor] || @@ -276,13 +276,13 @@ private[hive] class SparkSQLCLIDriver extends CliDriver with Logging { driver.init() val out = sessionState.out - val start:Long = System.currentTimeMillis() + val start: Long = System.currentTimeMillis() if (sessionState.getIsVerbose) { out.println(cmd) } val rc = driver.run(cmd) val end = System.currentTimeMillis() - val timeTaken:Double = (end - start) / 1000.0 + val timeTaken: Double = (end - start) / 1000.0 ret = rc.getResponseCode if (ret != 0) { @@ -310,7 +310,7 @@ private[hive] class SparkSQLCLIDriver extends CliDriver with Logging { res.clear() } } catch { - case e:IOException => + case e: IOException => console.printError( s"""Failed with exception ${e.getClass.getName}: ${e.getMessage} |${org.apache.hadoop.util.StringUtils.stringifyException(e)} diff --git a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkSQLCLIService.scala b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkSQLCLIService.scala index 499e077d7294..41f647d5f8c5 100644 --- a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkSQLCLIService.scala +++ b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkSQLCLIService.scala @@ -21,8 +21,6 @@ import java.io.IOException import java.util.{List => JList} import javax.security.auth.login.LoginException -import scala.collection.JavaConversions._ - import org.apache.commons.logging.Log import org.apache.hadoop.hive.conf.HiveConf import org.apache.hadoop.hive.shims.ShimLoader @@ -34,7 +32,8 @@ import org.apache.hive.service.{AbstractService, Service, ServiceException} import org.apache.spark.sql.hive.HiveContext import org.apache.spark.sql.hive.thriftserver.ReflectionUtils._ -import org.apache.spark.util.Utils + +import scala.collection.JavaConversions._ private[hive] class SparkSQLCLIService(hiveContext: HiveContext) extends CLIService @@ -52,7 +51,7 @@ private[hive] class SparkSQLCLIService(hiveContext: HiveContext) try { HiveAuthFactory.loginFromKeytab(hiveConf) sparkServiceUGI = ShimLoader.getHadoopShims.getUGIForConf(hiveConf) - HiveThriftServerShim.setServerUserName(sparkServiceUGI, this) + setSuperField(this, "serviceUGI", sparkServiceUGI) } catch { case e @ (_: IOException | _: LoginException) => throw new ServiceException("Unable to login to kerberos with given principal/keytab", e) diff --git a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/AbstractSparkSQLDriver.scala b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkSQLDriver.scala similarity index 86% rename from sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/AbstractSparkSQLDriver.scala rename to sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkSQLDriver.scala index 48ac9062af96..77272aecf283 100644 --- a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/AbstractSparkSQLDriver.scala +++ b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkSQLDriver.scala @@ -17,7 +17,7 @@ package org.apache.spark.sql.hive.thriftserver -import scala.collection.JavaConversions._ +import java.util.{ArrayList => JArrayList, List => JList} import org.apache.commons.lang3.exception.ExceptionUtils import org.apache.hadoop.hive.metastore.api.{FieldSchema, Schema} @@ -27,8 +27,12 @@ import org.apache.hadoop.hive.ql.processors.CommandProcessorResponse import org.apache.spark.Logging import org.apache.spark.sql.hive.{HiveContext, HiveMetastoreTypes} -private[hive] abstract class AbstractSparkSQLDriver( - val context: HiveContext = SparkSQLEnv.hiveContext) extends Driver with Logging { +import scala.collection.JavaConversions._ + +private[hive] class SparkSQLDriver( + val context: HiveContext = SparkSQLEnv.hiveContext) + extends Driver + with Logging { private[hive] var tableSchema: Schema = _ private[hive] var hiveResponse: Seq[String] = _ @@ -71,6 +75,16 @@ private[hive] abstract class AbstractSparkSQLDriver( 0 } + override def getResults(res: JList[_]): Boolean = { + if (hiveResponse == null) { + false + } else { + res.asInstanceOf[JArrayList[String]].addAll(hiveResponse) + hiveResponse = null + true + } + } + override def getSchema: Schema = tableSchema override def destroy() { diff --git a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkSQLEnv.scala b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkSQLEnv.scala index 7c0c505e2d61..1d41c4613182 100644 --- a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkSQLEnv.scala +++ b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkSQLEnv.scala @@ -22,7 +22,7 @@ import java.io.PrintStream import scala.collection.JavaConversions._ import org.apache.spark.scheduler.StatsReportListener -import org.apache.spark.sql.hive.{HiveShim, HiveContext} +import org.apache.spark.sql.hive.HiveContext import org.apache.spark.{Logging, SparkConf, SparkContext} import org.apache.spark.util.Utils @@ -38,9 +38,14 @@ private[hive] object SparkSQLEnv extends Logging { val sparkConf = new SparkConf(loadDefaults = true) val maybeSerializer = sparkConf.getOption("spark.serializer") val maybeKryoReferenceTracking = sparkConf.getOption("spark.kryo.referenceTracking") + // If user doesn't specify the appName, we want to get [SparkSQL::localHostName] instead of + // the default appName [SparkSQLCLIDriver] in cli or beeline. + val maybeAppName = sparkConf + .getOption("spark.app.name") + .filterNot(_ == classOf[SparkSQLCLIDriver].getName) sparkConf - .setAppName(s"SparkSQL::${Utils.localHostName()}") + .setAppName(maybeAppName.getOrElse(s"SparkSQL::${Utils.localHostName()}")) .set( "spark.serializer", maybeSerializer.getOrElse("org.apache.spark.serializer.KryoSerializer")) @@ -56,7 +61,7 @@ private[hive] object SparkSQLEnv extends Logging { hiveContext.metadataHive.setInfo(new PrintStream(System.err, true, "UTF-8")) hiveContext.metadataHive.setError(new PrintStream(System.err, true, "UTF-8")) - hiveContext.setConf("spark.sql.hive.version", HiveShim.version) + hiveContext.setConf("spark.sql.hive.version", HiveContext.hiveExecutionVersion) if (log.isDebugEnabled) { hiveContext.hiveconf.getAllProperties.toSeq.sorted.foreach { case (k, v) => diff --git a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkSQLSessionManager.scala b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkSQLSessionManager.scala new file mode 100644 index 000000000000..2d5ee6800228 --- /dev/null +++ b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkSQLSessionManager.scala @@ -0,0 +1,77 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.hive.thriftserver + +import java.util.concurrent.Executors + +import org.apache.commons.logging.Log +import org.apache.hadoop.hive.conf.HiveConf +import org.apache.hadoop.hive.conf.HiveConf.ConfVars +import org.apache.hive.service.cli.SessionHandle +import org.apache.hive.service.cli.session.SessionManager +import org.apache.hive.service.cli.thrift.TProtocolVersion + +import org.apache.spark.sql.hive.HiveContext +import org.apache.spark.sql.hive.thriftserver.ReflectionUtils._ +import org.apache.spark.sql.hive.thriftserver.server.SparkSQLOperationManager + + +private[hive] class SparkSQLSessionManager(hiveContext: HiveContext) + extends SessionManager + with ReflectedCompositeService { + + private lazy val sparkSqlOperationManager = new SparkSQLOperationManager(hiveContext) + + override def init(hiveConf: HiveConf) { + setSuperField(this, "hiveConf", hiveConf) + + val backgroundPoolSize = hiveConf.getIntVar(ConfVars.HIVE_SERVER2_ASYNC_EXEC_THREADS) + setSuperField(this, "backgroundOperationPool", Executors.newFixedThreadPool(backgroundPoolSize)) + getAncestorField[Log](this, 3, "LOG").info( + s"HiveServer2: Async execution pool size $backgroundPoolSize") + + setSuperField(this, "operationManager", sparkSqlOperationManager) + addService(sparkSqlOperationManager) + + initCompositeService(hiveConf) + } + + override def openSession( + protocol: TProtocolVersion, + username: String, + passwd: String, + sessionConf: java.util.Map[String, String], + withImpersonation: Boolean, + delegationToken: String): SessionHandle = { + hiveContext.openSession() + val sessionHandle = super.openSession( + protocol, username, passwd, sessionConf, withImpersonation, delegationToken) + val session = super.getSession(sessionHandle) + HiveThriftServer2.listener.onSessionCreated( + session.getIpAddress, sessionHandle.getSessionId.toString, session.getUsername) + sessionHandle + } + + override def closeSession(sessionHandle: SessionHandle) { + HiveThriftServer2.listener.onSessionClosed(sessionHandle.getSessionId.toString) + super.closeSession(sessionHandle) + sparkSqlOperationManager.sessionToActivePool -= sessionHandle + + hiveContext.detachSession() + } +} diff --git a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/server/SparkSQLOperationManager.scala b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/server/SparkSQLOperationManager.scala index 9c0bf02391e0..c8031ed0f343 100644 --- a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/server/SparkSQLOperationManager.scala +++ b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/server/SparkSQLOperationManager.scala @@ -44,9 +44,12 @@ private[thriftserver] class SparkSQLOperationManager(hiveContext: HiveContext) confOverlay: JMap[String, String], async: Boolean): ExecuteStatementOperation = synchronized { - val operation = new SparkExecuteStatementOperation(parentSession, statement, confOverlay)( - hiveContext, sessionToActivePool) + val runInBackground = async && hiveContext.hiveThriftServerAsync + val operation = new SparkExecuteStatementOperation(parentSession, statement, confOverlay, + runInBackground)(hiveContext, sessionToActivePool) handleToOperation.put(operation.getHandle, operation) + logDebug(s"Created Operation for $statement with session=$parentSession, " + + s"runInBackground=$runInBackground") operation } } diff --git a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/ui/ThriftServerPage.scala b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/ui/ThriftServerPage.scala index 6a2be4a58e5c..10c83d8b27a2 100644 --- a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/ui/ThriftServerPage.scala +++ b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/ui/ThriftServerPage.scala @@ -47,7 +47,7 @@ private[ui] class ThriftServerPage(parent: ThriftServerTab) extends WebUIPage("" ++ generateSessionStatsTable() ++ generateSQLStatsTable() - UIUtils.headerSparkPage("ThriftServer", content, parent, Some(5000)) + UIUtils.headerSparkPage("JDBC/ODBC Server", content, parent, Some(5000)) } /** Generate basic stats of the thrift server program */ @@ -77,7 +77,7 @@ private[ui] class ThriftServerPage(parent: ThriftServerTab) extends WebUIPage("" [{id}] } - val detail = if(info.state == ExecutionState.FAILED) info.detail else info.executePlan + val detail = if (info.state == ExecutionState.FAILED) info.detail else info.executePlan {info.userName} @@ -85,7 +85,7 @@ private[ui] class ThriftServerPage(parent: ThriftServerTab) extends WebUIPage("" {info.groupId} {formatDate(info.startTimestamp)} - {if(info.finishTimestamp > 0) formatDate(info.finishTimestamp)} + {if (info.finishTimestamp > 0) formatDate(info.finishTimestamp)} {formatDurationOption(Some(info.totalTime))} {info.statement} {info.state} @@ -143,14 +143,14 @@ private[ui] class ThriftServerPage(parent: ThriftServerTab) extends WebUIPage("" val headerRow = Seq("User", "IP", "Session ID", "Start Time", "Finish Time", "Duration", "Total Execute") def generateDataRow(session: SessionInfo): Seq[Node] = { - val sessionLink = "%s/ThriftServer/session?id=%s" + val sessionLink = "%s/sql/session?id=%s" .format(UIUtils.prependBaseUri(parent.basePath), session.sessionId) {session.userName} {session.ip} {session.sessionId} {formatDate(session.startTimestamp)} - {if(session.finishTimestamp > 0) formatDate(session.finishTimestamp)} + {if (session.finishTimestamp > 0) formatDate(session.finishTimestamp)} {formatDurationOption(Some(session.totalTime))} {session.totalExecution.toString} diff --git a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/ui/ThriftServerSessionPage.scala b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/ui/ThriftServerSessionPage.scala index 33ba038ecce7..3b01afa603ce 100644 --- a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/ui/ThriftServerSessionPage.scala +++ b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/ui/ThriftServerSessionPage.scala @@ -55,7 +55,7 @@ private[ui] class ThriftServerSessionPage(parent: ThriftServerTab) Total run {sessionStat._2.totalExecution} SQL ++ generateSQLStatsTable(sessionStat._2.sessionId) - UIUtils.headerSparkPage("ThriftServer", content, parent, Some(5000)) + UIUtils.headerSparkPage("JDBC/ODBC Session", content, parent, Some(5000)) } /** Generate basic stats of the streaming program */ @@ -87,7 +87,7 @@ private[ui] class ThriftServerSessionPage(parent: ThriftServerTab) [{id}] } - val detail = if(info.state == ExecutionState.FAILED) info.detail else info.executePlan + val detail = if (info.state == ExecutionState.FAILED) info.detail else info.executePlan {info.userName} diff --git a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/ui/ThriftServerTab.scala b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/ui/ThriftServerTab.scala index 343031f10c75..94fd8a6bb60b 100644 --- a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/ui/ThriftServerTab.scala +++ b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/ui/ThriftServerTab.scala @@ -27,7 +27,9 @@ import org.apache.spark.{SparkContext, Logging, SparkException} * This assumes the given SparkContext has enabled its SparkUI. */ private[thriftserver] class ThriftServerTab(sparkContext: SparkContext) - extends SparkUITab(getSparkUI(sparkContext), "ThriftServer") with Logging { + extends SparkUITab(getSparkUI(sparkContext), "sql") with Logging { + + override val name = "SQL" val parent = getSparkUI(sparkContext) val listener = HiveThriftServer2.listener diff --git a/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/CliSuite.scala b/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/CliSuite.scala index b070fa8eaa46..13b0c5951ddd 100644 --- a/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/CliSuite.scala +++ b/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/CliSuite.scala @@ -25,12 +25,16 @@ import scala.concurrent.{Await, Promise} import scala.sys.process.{Process, ProcessLogger} import org.apache.hadoop.hive.conf.HiveConf.ConfVars -import org.scalatest.{BeforeAndAfter, BeforeAndAfterAll, FunSuite} +import org.scalatest.BeforeAndAfter -import org.apache.spark.Logging +import org.apache.spark.{Logging, SparkFunSuite} import org.apache.spark.util.Utils -class CliSuite extends FunSuite with BeforeAndAfter with Logging { +/** + * A test suite for the `spark-sql` CLI tool. Note that all test cases share the same temporary + * Hive metastore and warehouse. + */ +class CliSuite extends SparkFunSuite with BeforeAndAfter with Logging { val warehousePath = Utils.createTempDir() val metastorePath = Utils.createTempDir() @@ -58,13 +62,13 @@ class CliSuite extends FunSuite with BeforeAndAfter with Logging { | --master local | --hiveconf ${ConfVars.METASTORECONNECTURLKEY}=$jdbcUrl | --hiveconf ${ConfVars.METASTOREWAREHOUSE}=$warehousePath - | --driver-class-path ${sys.props("java.class.path")} """.stripMargin.split("\\s+").toSeq ++ extraArgs } var next = 0 val foundAllExpectedAnswers = Promise.apply[Unit]() - val queryStream = new ByteArrayInputStream(queries.mkString("\n").getBytes) + // Explicitly adds ENTER for each statement to make sure they are actually entered into the CLI. + val queryStream = new ByteArrayInputStream(queries.map(_ + "\n").mkString.getBytes) val buffer = new ArrayBuffer[String]() val lock = new Object @@ -124,12 +128,12 @@ class CliSuite extends FunSuite with BeforeAndAfter with Logging { "SELECT COUNT(*) FROM hive_test;" -> "5", "DROP TABLE hive_test;" - -> "Time taken: " + -> "OK" ) } test("Single command with -e") { - runCliWithin(1.minute, Seq("-e", "SHOW DATABASES;"))("" -> "OK") + runCliWithin(2.minute, Seq("-e", "SHOW DATABASES;"))("" -> "OK") } test("Single command with --database") { @@ -151,4 +155,33 @@ class CliSuite extends FunSuite with BeforeAndAfter with Logging { -> "hive_test" ) } + + test("Commands using SerDe provided in --jars") { + val jarFile = + "../hive/src/test/resources/hive-hcatalog-core-0.13.1.jar" + .split("/") + .mkString(File.separator) + + val dataFilePath = + Thread.currentThread().getContextClassLoader.getResource("data/files/small_kv.txt") + + runCliWithin(3.minute, Seq("--jars", s"$jarFile"))( + """CREATE TABLE t1(key string, val string) + |ROW FORMAT SERDE 'org.apache.hive.hcatalog.data.JsonSerDe'; + """.stripMargin + -> "OK", + "CREATE TABLE sourceTable (key INT, val STRING);" + -> "OK", + s"LOAD DATA LOCAL INPATH '$dataFilePath' OVERWRITE INTO TABLE sourceTable;" + -> "OK", + "INSERT INTO TABLE t1 SELECT key, val FROM sourceTable;" + -> "Time taken:", + "SELECT count(key) FROM t1;" + -> "5", + "DROP TABLE t1;" + -> "OK", + "DROP TABLE sourceTable;" + -> "OK" + ) + } } diff --git a/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/HiveThriftServer2Suites.scala b/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/HiveThriftServer2Suites.scala index 1fadea97fd07..301aa5a6411e 100644 --- a/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/HiveThriftServer2Suites.scala +++ b/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/HiveThriftServer2Suites.scala @@ -19,14 +19,18 @@ package org.apache.spark.sql.hive.thriftserver import java.io.File import java.net.URL -import java.sql.{Date, DriverManager, Statement} +import java.nio.charset.StandardCharsets +import java.sql.{Date, DriverManager, SQLException, Statement} import scala.collection.mutable.ArrayBuffer import scala.concurrent.duration._ -import scala.concurrent.{Await, Promise} +import scala.concurrent.{Await, Promise, future} +import scala.concurrent.ExecutionContext.Implicits.global import scala.sys.process.{Process, ProcessLogger} import scala.util.{Random, Try} +import com.google.common.base.Charsets.UTF_8 +import com.google.common.io.Files import org.apache.hadoop.hive.conf.HiveConf.ConfVars import org.apache.hive.jdbc.HiveDriver import org.apache.hive.service.auth.PlainSaslHelper @@ -35,10 +39,10 @@ import org.apache.hive.service.cli.thrift.TCLIService.Client import org.apache.hive.service.cli.thrift.ThriftCLIServiceClient import org.apache.thrift.protocol.TBinaryProtocol import org.apache.thrift.transport.TSocket -import org.scalatest.{BeforeAndAfterAll, FunSuite} +import org.scalatest.BeforeAndAfterAll -import org.apache.spark.Logging -import org.apache.spark.sql.hive.HiveShim +import org.apache.spark.{Logging, SparkFunSuite} +import org.apache.spark.sql.hive.HiveContext import org.apache.spark.util.Utils object TestData { @@ -54,7 +58,7 @@ class HiveThriftBinaryServerSuite extends HiveThriftJdbcTest { override def mode: ServerMode.Value = ServerMode.binary private def withCLIServiceClient(f: ThriftCLIServiceClient => Unit): Unit = { - // Transport creation logics below mimics HiveConnection.createBinaryTransport + // Transport creation logic below mimics HiveConnection.createBinaryTransport val rawTransport = new TSocket("localhost", serverPort) val user = System.getProperty("user.name") val transport = PlainSaslHelper.getPlainTransport(user, "anonymous", rawTransport) @@ -109,7 +113,8 @@ class HiveThriftBinaryServerSuite extends HiveThriftJdbcTest { withJdbcStatement { statement => val resultSet = statement.executeQuery("SET spark.sql.hive.version") resultSet.next() - assert(resultSet.getString(1) === s"spark.sql.hive.version=${HiveShim.version}") + assert(resultSet.getString(1) === "spark.sql.hive.version") + assert(resultSet.getString(2) === HiveContext.hiveExecutionVersion) } } @@ -233,7 +238,7 @@ class HiveThriftBinaryServerSuite extends HiveThriftJdbcTest { // first session, we get the default value of the session status { statement => - val rs1 = statement.executeQuery(s"SET ${SQLConf.SHUFFLE_PARTITIONS}") + val rs1 = statement.executeQuery(s"SET ${SQLConf.SHUFFLE_PARTITIONS.key}") rs1.next() defaultV1 = rs1.getString(1) assert(defaultV1 != "200") @@ -251,19 +256,21 @@ class HiveThriftBinaryServerSuite extends HiveThriftJdbcTest { { statement => val queries = Seq( - s"SET ${SQLConf.SHUFFLE_PARTITIONS}=291", + s"SET ${SQLConf.SHUFFLE_PARTITIONS.key}=291", "SET hive.cli.print.header=true" ) queries.map(statement.execute) - val rs1 = statement.executeQuery(s"SET ${SQLConf.SHUFFLE_PARTITIONS}") + val rs1 = statement.executeQuery(s"SET ${SQLConf.SHUFFLE_PARTITIONS.key}") rs1.next() - assert("spark.sql.shuffle.partitions=291" === rs1.getString(1)) + assert("spark.sql.shuffle.partitions" === rs1.getString(1)) + assert("291" === rs1.getString(2)) rs1.close() val rs2 = statement.executeQuery("SET hive.cli.print.header") rs2.next() - assert("hive.cli.print.header=true" === rs2.getString(1)) + assert("hive.cli.print.header" === rs2.getString(1)) + assert("true" === rs2.getString(2)) rs2.close() }, @@ -271,7 +278,7 @@ class HiveThriftBinaryServerSuite extends HiveThriftJdbcTest { // default value { statement => - val rs1 = statement.executeQuery(s"SET ${SQLConf.SHUFFLE_PARTITIONS}") + val rs1 = statement.executeQuery(s"SET ${SQLConf.SHUFFLE_PARTITIONS.key}") rs1.next() assert(defaultV1 === rs1.getString(1)) rs1.close() @@ -335,6 +342,42 @@ class HiveThriftBinaryServerSuite extends HiveThriftJdbcTest { } ) } + + test("test jdbc cancel") { + withJdbcStatement { statement => + val queries = Seq( + "DROP TABLE IF EXISTS test_map", + "CREATE TABLE test_map(key INT, value STRING)", + s"LOAD DATA LOCAL INPATH '${TestData.smallKv}' OVERWRITE INTO TABLE test_map") + + queries.foreach(statement.execute) + + val largeJoin = "SELECT COUNT(*) FROM test_map " + + List.fill(10)("join test_map").mkString(" ") + val f = future { Thread.sleep(100); statement.cancel(); } + val e = intercept[SQLException] { + statement.executeQuery(largeJoin) + } + assert(e.getMessage contains "cancelled") + Await.result(f, 3.minute) + + // cancel is a noop + statement.executeQuery("SET spark.sql.hive.thriftServer.async=false") + val sf = future { Thread.sleep(100); statement.cancel(); } + val smallJoin = "SELECT COUNT(*) FROM test_map " + + List.fill(4)("join test_map").mkString(" ") + val rs1 = statement.executeQuery(smallJoin) + Await.result(sf, 3.minute) + rs1.next() + assert(rs1.getInt(1) === math.pow(5, 5)) + rs1.close() + + val rs2 = statement.executeQuery("SELECT COUNT(*) FROM test_map") + rs2.next() + assert(rs2.getInt(1) === 5) + rs2.close() + } + } } class HiveThriftHttpServerSuite extends HiveThriftJdbcTest { @@ -363,7 +406,8 @@ class HiveThriftHttpServerSuite extends HiveThriftJdbcTest { withJdbcStatement { statement => val resultSet = statement.executeQuery("SET spark.sql.hive.version") resultSet.next() - assert(resultSet.getString(1) === s"spark.sql.hive.version=${HiveShim.version}") + assert(resultSet.getString(1) === "spark.sql.hive.version") + assert(resultSet.getString(2) === HiveContext.hiveExecutionVersion) } } } @@ -391,10 +435,10 @@ abstract class HiveThriftJdbcTest extends HiveThriftServer2Test { val statements = connections.map(_.createStatement()) try { - statements.zip(fs).map { case (s, f) => f(s) } + statements.zip(fs).foreach { case (s, f) => f(s) } } finally { - statements.map(_.close()) - connections.map(_.close()) + statements.foreach(_.close()) + connections.foreach(_.close()) } } @@ -403,7 +447,7 @@ abstract class HiveThriftJdbcTest extends HiveThriftServer2Test { } } -abstract class HiveThriftServer2Test extends FunSuite with BeforeAndAfterAll with Logging { +abstract class HiveThriftServer2Test extends SparkFunSuite with BeforeAndAfterAll with Logging { def mode: ServerMode.Value private val CLASS_NAME = HiveThriftServer2.getClass.getCanonicalName.stripSuffix("$") @@ -433,15 +477,33 @@ abstract class HiveThriftServer2Test extends FunSuite with BeforeAndAfterAll wit ConfVars.HIVE_SERVER2_THRIFT_HTTP_PORT } + val driverClassPath = { + // Writes a temporary log4j.properties and prepend it to driver classpath, so that it + // overrides all other potential log4j configurations contained in other dependency jar files. + val tempLog4jConf = Utils.createTempDir().getCanonicalPath + + Files.write( + """log4j.rootCategory=INFO, console + |log4j.appender.console=org.apache.log4j.ConsoleAppender + |log4j.appender.console.target=System.err + |log4j.appender.console.layout=org.apache.log4j.PatternLayout + |log4j.appender.console.layout.ConversionPattern=%d{yy/MM/dd HH:mm:ss} %p %c{1}: %m%n + """.stripMargin, + new File(s"$tempLog4jConf/log4j.properties"), + UTF_8) + + tempLog4jConf + File.pathSeparator + sys.props("java.class.path") + } + s"""$startScript | --master local - | --hiveconf hive.root.logger=INFO,console | --hiveconf ${ConfVars.METASTORECONNECTURLKEY}=$metastoreJdbcUri | --hiveconf ${ConfVars.METASTOREWAREHOUSE}=$warehousePath | --hiveconf ${ConfVars.HIVE_SERVER2_THRIFT_BIND_HOST}=localhost | --hiveconf ${ConfVars.HIVE_SERVER2_TRANSPORT_MODE}=$mode | --hiveconf $portConf=$port - | --driver-class-path ${sys.props("java.class.path")} + | --driver-class-path $driverClassPath + | --driver-java-options -Dlog4j.debug | --conf spark.ui.enabled=false """.stripMargin.split("\\s+").toSeq } diff --git a/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/UISeleniumSuite.scala b/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/UISeleniumSuite.scala index 47541015a361..806240e6de45 100644 --- a/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/UISeleniumSuite.scala +++ b/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/UISeleniumSuite.scala @@ -17,20 +17,18 @@ package org.apache.spark.sql.hive.thriftserver - - import scala.util.Random +import org.apache.hadoop.hive.conf.HiveConf.ConfVars import org.openqa.selenium.WebDriver import org.openqa.selenium.htmlunit.HtmlUnitDriver -import org.scalatest.{Matchers, BeforeAndAfterAll} +import org.scalatest.{BeforeAndAfterAll, Matchers} import org.scalatest.concurrent.Eventually._ import org.scalatest.selenium.WebBrowser import org.scalatest.time.SpanSugar._ -import org.apache.hadoop.hive.conf.HiveConf.ConfVars import org.apache.spark.sql.hive.HiveContext - +import org.apache.spark.ui.SparkUICssErrorHandler class UISeleniumSuite extends HiveThriftJdbcTest @@ -43,7 +41,9 @@ class UISeleniumSuite override def mode: ServerMode.Value = ServerMode.binary override def beforeAll(): Unit = { - webDriver = new HtmlUnitDriver + webDriver = new HtmlUnitDriver { + getWebClient.setCssErrorHandler(new SparkUICssErrorHandler) + } super.beforeAll() } @@ -75,9 +75,9 @@ class UISeleniumSuite """.stripMargin.split("\\s+").toSeq } - test("thrift server ui test") { - withJdbcStatement(statement =>{ - val baseURL = s"http://localhost:${uiPort}" + ignore("thrift server ui test") { + withJdbcStatement { statement => + val baseURL = s"http://localhost:$uiPort" val queries = Seq( "CREATE TABLE test_map(key INT, value STRING)", @@ -86,20 +86,20 @@ class UISeleniumSuite queries.foreach(statement.execute) eventually(timeout(10 seconds), interval(50 milliseconds)) { - go to (baseURL) - find(cssSelector("""ul li a[href*="ThriftServer"]""")) should not be(None) + go to baseURL + find(cssSelector("""ul li a[href*="sql"]""")) should not be None } eventually(timeout(10 seconds), interval(50 milliseconds)) { - go to (baseURL + "/ThriftServer") - find(id("sessionstat")) should not be(None) - find(id("sqlstat")) should not be(None) + go to (baseURL + "/sql") + find(id("sessionstat")) should not be None + find(id("sqlstat")) should not be None // check whether statements exists queries.foreach { line => findAll(cssSelector("""ul table tbody tr td""")).map(_.text).toList should contain (line) } } - }) + } } } diff --git a/sql/hive-thriftserver/v0.12.0/src/main/scala/org/apache/spark/sql/hive/thriftserver/Shim12.scala b/sql/hive-thriftserver/v0.12.0/src/main/scala/org/apache/spark/sql/hive/thriftserver/Shim12.scala deleted file mode 100644 index b3a79ba1c7d6..000000000000 --- a/sql/hive-thriftserver/v0.12.0/src/main/scala/org/apache/spark/sql/hive/thriftserver/Shim12.scala +++ /dev/null @@ -1,278 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.sql.hive.thriftserver - -import java.sql.{Date, Timestamp} -import java.util.concurrent.Executors -import java.util.{ArrayList => JArrayList, Map => JMap, UUID} - -import org.apache.commons.logging.Log -import org.apache.hadoop.hive.conf.HiveConf -import org.apache.hadoop.hive.conf.HiveConf.ConfVars -import org.apache.hive.service.cli.thrift.TProtocolVersion -import org.apache.spark.sql.hive.thriftserver.server.SparkSQLOperationManager - -import scala.collection.JavaConversions._ -import scala.collection.mutable.{ArrayBuffer, Map => SMap} - -import org.apache.hadoop.hive.common.`type`.HiveDecimal -import org.apache.hadoop.hive.metastore.api.FieldSchema -import org.apache.hadoop.hive.shims.ShimLoader -import org.apache.hadoop.security.UserGroupInformation -import org.apache.hive.service.cli._ -import org.apache.hive.service.cli.operation.ExecuteStatementOperation -import org.apache.hive.service.cli.session.{SessionManager, HiveSession} - -import org.apache.spark.Logging -import org.apache.spark.sql.{DataFrame, SQLConf, Row => SparkRow} -import org.apache.spark.sql.execution.SetCommand -import org.apache.spark.sql.hive.thriftserver.ReflectionUtils._ -import org.apache.spark.sql.hive.{HiveContext, HiveMetastoreTypes} -import org.apache.spark.sql.types._ - -/** - * A compatibility layer for interacting with Hive version 0.12.0. - */ -private[thriftserver] object HiveThriftServerShim { - val version = "0.12.0" - - def setServerUserName(sparkServiceUGI: UserGroupInformation, sparkCliService:SparkSQLCLIService) = { - val serverUserName = ShimLoader.getHadoopShims.getShortUserName(sparkServiceUGI) - setSuperField(sparkCliService, "serverUserName", serverUserName) - } -} - -private[hive] class SparkSQLDriver(val _context: HiveContext = SparkSQLEnv.hiveContext) - extends AbstractSparkSQLDriver(_context) { - override def getResults(res: JArrayList[String]): Boolean = { - if (hiveResponse == null) { - false - } else { - res.addAll(hiveResponse) - hiveResponse = null - true - } - } -} - -private[hive] class SparkExecuteStatementOperation( - parentSession: HiveSession, - statement: String, - confOverlay: JMap[String, String])( - hiveContext: HiveContext, - sessionToActivePool: SMap[SessionHandle, String]) - extends ExecuteStatementOperation(parentSession, statement, confOverlay) with Logging { - - private var result: DataFrame = _ - private var iter: Iterator[SparkRow] = _ - private var dataTypes: Array[DataType] = _ - - def close(): Unit = { - // RDDs will be cleaned automatically upon garbage collection. - logDebug("CLOSING") - } - - def getNextRowSet(order: FetchOrientation, maxRowsL: Long): RowSet = { - if (!iter.hasNext) { - new RowSet() - } else { - // maxRowsL here typically maps to java.sql.Statement.getFetchSize, which is an int - val maxRows = maxRowsL.toInt - var curRow = 0 - var rowSet = new ArrayBuffer[Row](maxRows.min(1024)) - - while (curRow < maxRows && iter.hasNext) { - val sparkRow = iter.next() - val row = new Row() - var curCol = 0 - - while (curCol < sparkRow.length) { - if (sparkRow.isNullAt(curCol)) { - addNullColumnValue(sparkRow, row, curCol) - } else { - addNonNullColumnValue(sparkRow, row, curCol) - } - curCol += 1 - } - rowSet += row - curRow += 1 - } - new RowSet(rowSet, 0) - } - } - - def addNonNullColumnValue(from: SparkRow, to: Row, ordinal: Int) { - dataTypes(ordinal) match { - case StringType => - to.addString(from(ordinal).asInstanceOf[String]) - case IntegerType => - to.addColumnValue(ColumnValue.intValue(from.getInt(ordinal))) - case BooleanType => - to.addColumnValue(ColumnValue.booleanValue(from.getBoolean(ordinal))) - case DoubleType => - to.addColumnValue(ColumnValue.doubleValue(from.getDouble(ordinal))) - case FloatType => - to.addColumnValue(ColumnValue.floatValue(from.getFloat(ordinal))) - case DecimalType() => - val hiveDecimal = from.getDecimal(ordinal) - to.addColumnValue(ColumnValue.stringValue(new HiveDecimal(hiveDecimal))) - case LongType => - to.addColumnValue(ColumnValue.longValue(from.getLong(ordinal))) - case ByteType => - to.addColumnValue(ColumnValue.byteValue(from.getByte(ordinal))) - case ShortType => - to.addColumnValue(ColumnValue.shortValue(from.getShort(ordinal))) - case DateType => - to.addColumnValue(ColumnValue.dateValue(from(ordinal).asInstanceOf[Date])) - case TimestampType => - to.addColumnValue( - ColumnValue.timestampValue(from.get(ordinal).asInstanceOf[Timestamp])) - case BinaryType | _: ArrayType | _: StructType | _: MapType => - val hiveString = HiveContext.toHiveString((from.get(ordinal), dataTypes(ordinal))) - to.addColumnValue(ColumnValue.stringValue(hiveString)) - } - } - - def addNullColumnValue(from: SparkRow, to: Row, ordinal: Int) { - dataTypes(ordinal) match { - case StringType => - to.addString(null) - case IntegerType => - to.addColumnValue(ColumnValue.intValue(null)) - case BooleanType => - to.addColumnValue(ColumnValue.booleanValue(null)) - case DoubleType => - to.addColumnValue(ColumnValue.doubleValue(null)) - case FloatType => - to.addColumnValue(ColumnValue.floatValue(null)) - case DecimalType() => - to.addColumnValue(ColumnValue.stringValue(null: HiveDecimal)) - case LongType => - to.addColumnValue(ColumnValue.longValue(null)) - case ByteType => - to.addColumnValue(ColumnValue.byteValue(null)) - case ShortType => - to.addColumnValue(ColumnValue.shortValue(null)) - case DateType => - to.addColumnValue(ColumnValue.dateValue(null)) - case TimestampType => - to.addColumnValue(ColumnValue.timestampValue(null)) - case BinaryType | _: ArrayType | _: StructType | _: MapType => - to.addColumnValue(ColumnValue.stringValue(null: String)) - } - } - - def getResultSetSchema: TableSchema = { - logInfo(s"Result Schema: ${result.queryExecution.analyzed.output}") - if (result.queryExecution.analyzed.output.size == 0) { - new TableSchema(new FieldSchema("Result", "string", "") :: Nil) - } else { - val schema = result.queryExecution.analyzed.output.map { attr => - new FieldSchema(attr.name, HiveMetastoreTypes.toMetastoreType(attr.dataType), "") - } - new TableSchema(schema) - } - } - - def run(): Unit = { - val statementId = UUID.randomUUID().toString - logInfo(s"Running query '$statement'") - setState(OperationState.RUNNING) - HiveThriftServer2.listener.onStatementStart( - statementId, parentSession.getSessionHandle.getSessionId.toString, statement, statementId) - hiveContext.sparkContext.setJobGroup(statementId, statement) - sessionToActivePool.get(parentSession.getSessionHandle).foreach { pool => - hiveContext.sparkContext.setLocalProperty("spark.scheduler.pool", pool) - } - try { - result = hiveContext.sql(statement) - logDebug(result.queryExecution.toString()) - result.queryExecution.logical match { - case SetCommand(Some((SQLConf.THRIFTSERVER_POOL, Some(value))), _) => - sessionToActivePool(parentSession.getSessionHandle) = value - logInfo(s"Setting spark.scheduler.pool=$value for future statements in this session.") - case _ => - } - HiveThriftServer2.listener.onStatementParsed(statementId, result.queryExecution.toString()) - iter = { - val useIncrementalCollect = - hiveContext.getConf("spark.sql.thriftServer.incrementalCollect", "false").toBoolean - if (useIncrementalCollect) { - result.rdd.toLocalIterator - } else { - result.collect().iterator - } - } - dataTypes = result.queryExecution.analyzed.output.map(_.dataType).toArray - setHasResultSet(true) - } catch { - // Actually do need to catch Throwable as some failures don't inherit from Exception and - // HiveServer will silently swallow them. - case e: Throwable => - setState(OperationState.ERROR) - HiveThriftServer2.listener.onStatementError( - statementId, e.getMessage, e.getStackTraceString) - logError("Error executing query:",e) - throw new HiveSQLException(e.toString) - } - setState(OperationState.FINISHED) - HiveThriftServer2.listener.onStatementFinish(statementId) - } -} - -private[hive] class SparkSQLSessionManager(hiveContext: HiveContext) - extends SessionManager - with ReflectedCompositeService { - - private lazy val sparkSqlOperationManager = new SparkSQLOperationManager(hiveContext) - - override def init(hiveConf: HiveConf) { - setSuperField(this, "hiveConf", hiveConf) - - val backgroundPoolSize = hiveConf.getIntVar(ConfVars.HIVE_SERVER2_ASYNC_EXEC_THREADS) - setSuperField(this, "backgroundOperationPool", Executors.newFixedThreadPool(backgroundPoolSize)) - getAncestorField[Log](this, 3, "LOG").info( - s"HiveServer2: Async execution pool size $backgroundPoolSize") - - setSuperField(this, "operationManager", sparkSqlOperationManager) - addService(sparkSqlOperationManager) - - initCompositeService(hiveConf) - } - - override def openSession( - username: String, - passwd: String, - sessionConf: java.util.Map[String, String], - withImpersonation: Boolean, - delegationToken: String): SessionHandle = { - hiveContext.openSession() - val sessionHandle = super.openSession( - username, passwd, sessionConf, withImpersonation, delegationToken) - HiveThriftServer2.listener.onSessionCreated("UNKNOWN", sessionHandle.getSessionId.toString) - sessionHandle - } - - override def closeSession(sessionHandle: SessionHandle) { - HiveThriftServer2.listener.onSessionClosed(sessionHandle.getSessionId.toString) - super.closeSession(sessionHandle) - sparkSqlOperationManager.sessionToActivePool -= sessionHandle - - hiveContext.detachSession() - } -} diff --git a/sql/hive/compatibility/src/test/scala/org/apache/spark/sql/hive/execution/HiveCompatibilitySuite.scala b/sql/hive/compatibility/src/test/scala/org/apache/spark/sql/hive/execution/HiveCompatibilitySuite.scala index b6245a57074c..415a81644c58 100644 --- a/sql/hive/compatibility/src/test/scala/org/apache/spark/sql/hive/execution/HiveCompatibilitySuite.scala +++ b/sql/hive/compatibility/src/test/scala/org/apache/spark/sql/hive/execution/HiveCompatibilitySuite.scala @@ -23,7 +23,6 @@ import java.util.{Locale, TimeZone} import org.scalatest.BeforeAndAfter import org.apache.spark.sql.SQLConf -import org.apache.spark.sql.hive.HiveShim import org.apache.spark.sql.hive.test.TestHive /** @@ -48,17 +47,17 @@ class HiveCompatibilitySuite extends HiveQueryFileTest with BeforeAndAfter { // Add Locale setting Locale.setDefault(Locale.US) // Set a relatively small column batch size for testing purposes - TestHive.setConf(SQLConf.COLUMN_BATCH_SIZE, "5") + TestHive.setConf(SQLConf.COLUMN_BATCH_SIZE, 5) // Enable in-memory partition pruning for testing purposes - TestHive.setConf(SQLConf.IN_MEMORY_PARTITION_PRUNING, "true") + TestHive.setConf(SQLConf.IN_MEMORY_PARTITION_PRUNING, true) } override def afterAll() { TestHive.cacheTables = false TimeZone.setDefault(originalTimeZone) Locale.setDefault(originalLocale) - TestHive.setConf(SQLConf.COLUMN_BATCH_SIZE, originalColumnBatchSize.toString) - TestHive.setConf(SQLConf.IN_MEMORY_PARTITION_PRUNING, originalInMemoryPartitionPruning.toString) + TestHive.setConf(SQLConf.COLUMN_BATCH_SIZE, originalColumnBatchSize) + TestHive.setConf(SQLConf.IN_MEMORY_PARTITION_PRUNING, originalInMemoryPartitionPruning) } /** A list of tests deemed out of scope currently and thus completely disregarded. */ @@ -250,8 +249,15 @@ class HiveCompatibilitySuite extends HiveQueryFileTest with BeforeAndAfter { // The isolated classloader seemed to make some of our test reset mechanisms less robust. "combine1", // This test changes compression settings in a way that breaks all subsequent tests. - "load_dyn_part14.*" // These work alone but fail when run with other tests... - ) ++ HiveShim.compatibilityBlackList + "load_dyn_part14.*", // These work alone but fail when run with other tests... + + // the answer is sensitive for jdk version + "udf_java_method", + + // Spark SQL use Long for TimestampType, lose the precision under 100ns + "timestamp_1", + "timestamp_2" + ) /** * The set of tests that are believed to be working in catalyst. Tests not on whiteList or @@ -793,8 +799,6 @@ class HiveCompatibilitySuite extends HiveQueryFileTest with BeforeAndAfter { "stats_publisher_error_1", "subq2", "tablename_with_select", - "timestamp_1", - "timestamp_2", "timestamp_3", "timestamp_comparison", "timestamp_lazy", @@ -815,19 +819,19 @@ class HiveCompatibilitySuite extends HiveQueryFileTest with BeforeAndAfter { "udf2", "udf5", "udf6", - "udf7", + // "udf7", turn this on after we figure out null vs nan vs infinity "udf8", "udf9", "udf_10_trims", "udf_E", "udf_PI", "udf_abs", - "udf_acos", + // "udf_acos", turn this on after we figure out null vs nan vs infinity "udf_add", "udf_array", "udf_array_contains", "udf_ascii", - "udf_asin", + // "udf_asin", turn this on after we figure out null vs nan vs infinity "udf_atan", "udf_avg", "udf_bigint", @@ -877,7 +881,6 @@ class HiveCompatibilitySuite extends HiveQueryFileTest with BeforeAndAfter { "udf_int", "udf_isnotnull", "udf_isnull", - "udf_java_method", "udf_lcase", "udf_length", "udf_lessthan", @@ -916,7 +919,7 @@ class HiveCompatibilitySuite extends HiveQueryFileTest with BeforeAndAfter { "udf_repeat", "udf_rlike", "udf_round", - "udf_round_3", + // "udf_round_3", TODO: FIX THIS failed due to cast exception "udf_rpad", "udf_rtrim", "udf_second", @@ -946,6 +949,7 @@ class HiveCompatibilitySuite extends HiveQueryFileTest with BeforeAndAfter { "udf_trim", "udf_ucase", "udf_unix_timestamp", + "udf_unhex", "udf_upper", "udf_var_pop", "udf_var_samp", diff --git a/sql/hive/compatibility/src/test/scala/org/apache/spark/sql/hive/execution/HiveWindowFunctionQuerySuite.scala b/sql/hive/compatibility/src/test/scala/org/apache/spark/sql/hive/execution/HiveWindowFunctionQuerySuite.scala index 934452fe579a..31a49a368333 100644 --- a/sql/hive/compatibility/src/test/scala/org/apache/spark/sql/hive/execution/HiveWindowFunctionQuerySuite.scala +++ b/sql/hive/compatibility/src/test/scala/org/apache/spark/sql/hive/execution/HiveWindowFunctionQuerySuite.scala @@ -526,8 +526,14 @@ abstract class HiveWindowFunctionQueryBaseSuite extends HiveComparisonTest with | rows between 2 preceding and 2 following); """.stripMargin, reset = false) + // collect_set() output array in an arbitrary order, hence causes different result + // when running this test suite under Java 7 and 8. + // We change the original sql query a little bit for making the test suite passed + // under different JDK createQueryTest("windowing.q -- 20. testSTATs", """ + |select p_mfgr,p_name, p_size, sdev, sdev_pop, uniq_data, var, cor, covarp + |from ( |select p_mfgr,p_name, p_size, |stddev(p_retailprice) over w1 as sdev, |stddev_pop(p_retailprice) over w1 as sdev_pop, @@ -538,6 +544,8 @@ abstract class HiveWindowFunctionQueryBaseSuite extends HiveComparisonTest with |from part |window w1 as (distribute by p_mfgr sort by p_mfgr, p_name | rows between 2 preceding and 2 following) + |) t lateral view explode(uniq_size) d as uniq_data + |order by p_mfgr,p_name, p_size, sdev, sdev_pop, uniq_data, var, cor, covarp """.stripMargin, reset = false) createQueryTest("windowing.q -- 21. testDISTs", diff --git a/sql/hive/compatibility/src/test/scala/org/apache/spark/sql/hive/execution/SortMergeCompatibilitySuite.scala b/sql/hive/compatibility/src/test/scala/org/apache/spark/sql/hive/execution/SortMergeCompatibilitySuite.scala index 65d070bd3cbd..f458567e5d7e 100644 --- a/sql/hive/compatibility/src/test/scala/org/apache/spark/sql/hive/execution/SortMergeCompatibilitySuite.scala +++ b/sql/hive/compatibility/src/test/scala/org/apache/spark/sql/hive/execution/SortMergeCompatibilitySuite.scala @@ -26,11 +26,11 @@ import org.apache.spark.sql.hive.test.TestHive class SortMergeCompatibilitySuite extends HiveCompatibilitySuite { override def beforeAll() { super.beforeAll() - TestHive.setConf(SQLConf.SORTMERGE_JOIN, "true") + TestHive.setConf(SQLConf.SORTMERGE_JOIN, true) } override def afterAll() { - TestHive.setConf(SQLConf.SORTMERGE_JOIN, "false") + TestHive.setConf(SQLConf.SORTMERGE_JOIN, false) super.afterAll() } diff --git a/sql/hive/pom.xml b/sql/hive/pom.xml index e322340094e6..a17546d70624 100644 --- a/sql/hive/pom.xml +++ b/sql/hive/pom.xml @@ -22,7 +22,7 @@ org.apache.spark spark-parent_2.10 - 1.4.0-SNAPSHOT + 1.5.0-SNAPSHOT ../../pom.xml @@ -41,6 +41,13 @@ spark-core_${scala.binary.version} ${project.version} + + org.apache.spark + spark-core_${scala.binary.version} + ${project.version} + test-jar + test + org.apache.spark spark-sql_${scala.binary.version} @@ -136,16 +143,6 @@ - - hive-0.12.0 - - - com.twitter - parquet-hive-bundle - 1.5.0 - - - diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/ExtendedHiveQlParser.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/ExtendedHiveQlParser.scala index 3f20c6142e59..7f8449cdc282 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/ExtendedHiveQlParser.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/ExtendedHiveQlParser.scala @@ -29,10 +29,10 @@ import org.apache.spark.sql.hive.execution.{AddJar, AddFile, HiveNativeCommand} private[hive] class ExtendedHiveQlParser extends AbstractSparkSQLParser { // Keyword is a convention with AbstractSparkSQLParser, which will scan all of the `Keyword` // properties via reflection the class in runtime for constructing the SqlLexical object - protected val ADD = Keyword("ADD") - protected val DFS = Keyword("DFS") + protected val ADD = Keyword("ADD") + protected val DFS = Keyword("DFS") protected val FILE = Keyword("FILE") - protected val JAR = Keyword("JAR") + protected val JAR = Keyword("JAR") protected lazy val start: Parser[LogicalPlan] = dfs | addJar | addFile | hiveQl diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala index 9d98c36e947a..b91242af2d15 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala @@ -17,36 +17,35 @@ package org.apache.spark.sql.hive -import java.io.{BufferedReader, File, InputStreamReader, PrintStream} +import java.io.File +import java.net.{URL, URLClassLoader} import java.sql.Timestamp -import java.util.{ArrayList => JArrayList} - -import org.apache.hadoop.hive.ql.parse.VariableSubstitution -import org.apache.spark.sql.catalyst.ParserDialect import scala.collection.JavaConversions._ +import scala.collection.mutable.HashMap import scala.language.implicitConversions import org.apache.hadoop.fs.{FileSystem, Path} +import org.apache.hadoop.hive.common.StatsSetupConst +import org.apache.hadoop.hive.common.`type`.HiveDecimal import org.apache.hadoop.hive.conf.HiveConf -import org.apache.hadoop.hive.ql.Driver import org.apache.hadoop.hive.ql.metadata.Table import org.apache.hadoop.hive.ql.parse.VariableSubstitution -import org.apache.hadoop.hive.ql.processors._ import org.apache.hadoop.hive.ql.session.SessionState import org.apache.hadoop.hive.serde2.io.{DateWritable, TimestampWritable} -import org.apache.spark.{SparkConf, SparkContext} +import org.apache.spark.SparkContext import org.apache.spark.annotation.Experimental -import org.apache.spark.deploy.SparkHadoopUtil import org.apache.spark.sql._ -import org.apache.spark.sql.catalyst.analysis.{Analyzer, EliminateSubQueries, OverrideCatalog, OverrideFunctionRegistry} +import org.apache.spark.sql.SQLConf.SQLConfEntry +import org.apache.spark.sql.SQLConf.SQLConfEntry._ +import org.apache.spark.sql.catalyst.ParserDialect +import org.apache.spark.sql.catalyst.analysis._ import org.apache.spark.sql.catalyst.plans.logical._ -import org.apache.spark.sql.execution.{ExecutedCommand, ExtractPythonUdfs, QueryExecutionException, SetCommand} +import org.apache.spark.sql.execution.{ExecutedCommand, ExtractPythonUDFs, SetCommand} import org.apache.spark.sql.hive.client._ import org.apache.spark.sql.hive.execution.{DescribeHiveTableCommand, HiveNativeCommand} -import org.apache.spark.sql.sources.{DDLParser, DataSourceStrategy} -import org.apache.spark.sql.catalyst.CatalystConf +import org.apache.spark.sql.sources.DataSourceStrategy import org.apache.spark.sql.types._ import org.apache.spark.util.Utils @@ -71,13 +70,14 @@ class HiveContext(sc: SparkContext) extends SQLContext(sc) { import HiveContext._ + println("create HiveContext") + /** * When true, enables an experimental feature where metastore tables that use the parquet SerDe * are automatically converted to use the Spark SQL parquet table scan, instead of the Hive * SerDe. */ - protected[sql] def convertMetastoreParquet: Boolean = - getConf("spark.sql.hive.convertMetastoreParquet", "true") == "true" + protected[sql] def convertMetastoreParquet: Boolean = getConf(CONVERT_METASTORE_PARQUET) /** * When true, also tries to merge possibly different but compatible Parquet schemas in different @@ -86,7 +86,7 @@ class HiveContext(sc: SparkContext) extends SQLContext(sc) { * This configuration is only effective when "spark.sql.hive.convertMetastoreParquet" is true. */ protected[sql] def convertMetastoreParquetWithSchemaMerging: Boolean = - getConf("spark.sql.hive.convertMetastoreParquet.mergeSchema", "false") == "true" + getConf(CONVERT_METASTORE_PARQUET_WITH_SCHEMA_MERGING) /** * When true, a table created by a Hive CTAS statement (no USING clause) will be @@ -100,8 +100,7 @@ class HiveContext(sc: SparkContext) extends SQLContext(sc) { * - The CTAS statement specifies SequenceFile (STORED AS SEQUENCEFILE) as the file format * and no SerDe is specified (no ROW FORMAT SERDE clause). */ - protected[sql] def convertCTAS: Boolean = - getConf("spark.sql.hive.convertCTAS", "false").toBoolean + protected[sql] def convertCTAS: Boolean = getConf(CONVERT_CTAS) /** * The version of the hive client that will be used to communicate with the metastore. Note that @@ -119,8 +118,30 @@ class HiveContext(sc: SparkContext) extends SQLContext(sc) { * option is only valid when using the execution version of Hive. * - maven - download the correct version of hive on demand from maven. */ - protected[hive] def hiveMetastoreJars: String = - getConf(HIVE_METASTORE_JARS, "builtin") + protected[hive] def hiveMetastoreJars: String = getConf(HIVE_METASTORE_JARS) + + /** + * A comma separated list of class prefixes that should be loaded using the classloader that + * is shared between Spark SQL and a specific version of Hive. An example of classes that should + * be shared is JDBC drivers that are needed to talk to the metastore. Other classes that need + * to be shared are those that interact with classes that are already shared. For example, + * custom appenders that are used by log4j. + */ + protected[hive] def hiveMetastoreSharedPrefixes: Seq[String] = + getConf(HIVE_METASTORE_SHARED_PREFIXES).filterNot(_ == "") + + /** + * A comma separated list of class prefixes that should explicitly be reloaded for each version + * of Hive that Spark SQL is communicating with. For example, Hive UDFs that are declared in a + * prefix that typically would be shared (i.e. org.apache.spark.*) + */ + protected[hive] def hiveMetastoreBarrierPrefixes: Seq[String] = + getConf(HIVE_METASTORE_BARRIER_PREFIXES).filterNot(_ == "") + + /* + * hive thrift server use background spark sql thread pool to execute sql queries + */ + protected[hive] def hiveThriftServerAsync: Boolean = getConf(HIVE_THRIFT_SERVER_ASYNC) @transient protected[sql] lazy val substitutor = new VariableSubstitution() @@ -130,14 +151,15 @@ class HiveContext(sc: SparkContext) extends SQLContext(sc) { * Hive 13 as this is the version of Hive that is packaged with Spark SQL. This copy of the * client is used for execution related tasks like registering temporary functions or ensuring * that the ThreadLocal SessionState is correctly populated. This copy of Hive is *not* used - * for storing peristent metadata, and only point to a dummy metastore in a temporary directory. + * for storing persistent metadata, and only point to a dummy metastore in a temporary directory. */ @transient protected[hive] lazy val executionHive: ClientWrapper = { - logInfo(s"Initilizing execution hive, version $hiveExecutionVersion") + logInfo(s"Initializing execution hive, version $hiveExecutionVersion") new ClientWrapper( version = IsolatedClientLoader.hiveVersion(hiveExecutionVersion), - config = newTemporaryConfiguration()) + config = newTemporaryConfiguration(), + initClassLoader = Utils.getContextOrSparkClassLoader) } SessionState.setCurrentSessionState(executionHive.state) @@ -164,13 +186,22 @@ class HiveContext(sc: SparkContext) extends SQLContext(sc) { "Specify a vaild path to the correct hive jars using $HIVE_METASTORE_JARS " + s"or change $HIVE_METASTORE_VERSION to $hiveExecutionVersion.") } - val jars = getClass.getClassLoader match { - case urlClassLoader: java.net.URLClassLoader => urlClassLoader.getURLs - case other => - throw new IllegalArgumentException( - "Unable to locate hive jars to connect to metastore " + - s"using classloader ${other.getClass.getName}. " + - "Please set spark.sql.hive.metastore.jars") + + // We recursively find all jars in the class loader chain, + // starting from the given classLoader. + def allJars(classLoader: ClassLoader): Array[URL] = classLoader match { + case null => Array.empty[URL] + case urlClassLoader: URLClassLoader => + urlClassLoader.getURLs ++ allJars(urlClassLoader.getParent) + case other => allJars(other.getParent) + } + + val classLoader = Utils.getContextOrSparkClassLoader + val jars = allJars(classLoader) + if (jars.length == 0) { + throw new IllegalArgumentException( + "Unable to locate hive jars to connect to metastore. " + + "Please set spark.sql.hive.metastore.jars.") } logInfo( @@ -179,12 +210,14 @@ class HiveContext(sc: SparkContext) extends SQLContext(sc) { version = metaVersion, execJars = jars.toSeq, config = allConfig, - isolationOn = true) + isolationOn = true, + barrierPrefixes = hiveMetastoreBarrierPrefixes, + sharedPrefixes = hiveMetastoreSharedPrefixes) } else if (hiveMetastoreJars == "maven") { // TODO: Support for loading the jars from an already downloaded location. logInfo( s"Initializing HiveMetastoreConnection version $hiveMetastoreVersion using maven.") - IsolatedClientLoader.forVersion(hiveMetastoreVersion, allConfig ) + IsolatedClientLoader.forVersion(hiveMetastoreVersion, allConfig) } else { // Convert to files and expand any directories. val jars = @@ -210,7 +243,9 @@ class HiveContext(sc: SparkContext) extends SQLContext(sc) { version = metaVersion, execJars = jars.toSeq, config = allConfig, - isolationOn = true) + isolationOn = true, + barrierPrefixes = hiveMetastoreBarrierPrefixes, + sharedPrefixes = hiveMetastoreSharedPrefixes) } isolatedLoader.client } @@ -231,13 +266,11 @@ class HiveContext(sc: SparkContext) extends SQLContext(sc) { * @since 1.3.0 */ def refreshTable(tableName: String): Unit = { - // TODO: Database support... - catalog.refreshTable("default", tableName) + catalog.refreshTable(catalog.client.currentDatabase, tableName) } protected[hive] def invalidateTable(tableName: String): Unit = { - // TODO: Database support... - catalog.invalidateTable("default", tableName) + catalog.invalidateTable(catalog.client.currentDatabase, tableName) } /** @@ -293,7 +326,7 @@ class HiveContext(sc: SparkContext) extends SQLContext(sc) { val tableParameters = relation.hiveQlTable.getParameters val oldTotalSize = - Option(tableParameters.get(HiveShim.getStatsSetupConstTotalSize)) + Option(tableParameters.get(StatsSetupConst.TOTAL_SIZE)) .map(_.toLong) .getOrElse(0L) val newTotalSize = getFileSizeForTable(hiveconf, relation.hiveQlTable) @@ -304,7 +337,7 @@ class HiveContext(sc: SparkContext) extends SQLContext(sc) { catalog.client.alterTable( relation.table.copy( properties = relation.table.properties + - (HiveShim.getStatsSetupConstTotalSize -> newTotalSize.toString))) + (StatsSetupConst.TOTAL_SIZE -> newTotalSize.toString))) } case otherRelation => throw new UnsupportedOperationException( @@ -316,22 +349,29 @@ class HiveContext(sc: SparkContext) extends SQLContext(sc) { override def setConf(key: String, value: String): Unit = { super.setConf(key, value) - hiveconf.set(key, value) executionHive.runSqlHive(s"SET $key=$value") metadataHive.runSqlHive(s"SET $key=$value") + // If users put any Spark SQL setting in the spark conf (e.g. spark-defaults.conf), + // this setConf will be called in the constructor of the SQLContext. + // Also, calling hiveconf will create a default session containing a HiveConf, which + // will interfer with the creation of executionHive (which is a lazy val). So, + // we put hiveconf.set at the end of this method. + hiveconf.set(key, value) } - /* A catalyst metadata catalog that points to the Hive Metastore. */ + private[sql] override def setConf[T](entry: SQLConfEntry[T], value: T): Unit = { + setConf(entry.key, entry.stringConverter(value)) + } + + /* A catalyst metadata catalog that points to the Hive Metastore. */ @transient override protected[sql] lazy val catalog = new HiveMetastoreCatalog(metadataHive, this) with OverrideCatalog // Note that HiveUDFs will be overridden by functions registered in this context. @transient - override protected[sql] lazy val functionRegistry = - new HiveFunctionRegistry with OverrideFunctionRegistry { - def caseSensitive: Boolean = false - } + override protected[sql] lazy val functionRegistry: FunctionRegistry = + new OverrideFunctionRegistry(new HiveFunctionRegistry(FunctionRegistry.builtin)) /* An analyzer that uses the Hive metastore. */ @transient @@ -341,10 +381,14 @@ class HiveContext(sc: SparkContext) extends SQLContext(sc) { catalog.ParquetConversions :: catalog.CreateTables :: catalog.PreInsertionCasts :: - ExtractPythonUdfs :: + ExtractPythonUDFs :: ResolveHiveWindowFunction :: sources.PreInsertCastAndRename :: Nil + + override val extendedCheckRules = Seq( + sources.PreWriteCheck(catalog) + ) } override protected[sql] def createSession(): SQLSession = { @@ -357,8 +401,7 @@ class HiveContext(sc: SparkContext) extends SQLContext(sc) { protected[hive] class SQLSession extends super.SQLSession { protected[sql] override lazy val conf: SQLConf = new SQLConf { override def dialect: String = getConf(SQLConf.DIALECT, "hiveql") - override def caseSensitiveAnalysis: Boolean = - getConf(SQLConf.CASE_SENSITIVE, "false").toBoolean + override def caseSensitiveAnalysis: Boolean = getConf(SQLConf.CASE_SENSITIVE, false) } /** @@ -399,7 +442,7 @@ class HiveContext(sc: SparkContext) extends SQLContext(sc) { HiveCommandStrategy(self), HiveDDLStrategy, DDLStrategy, - TakeOrdered, + TakeOrderedAndProject, ParquetOperations, InMemoryScans, ParquetConversion, // Must be before HiveTableScans @@ -474,14 +517,68 @@ private[hive] object HiveContext { val hiveExecutionVersion: String = "0.13.1" val HIVE_METASTORE_VERSION: String = "spark.sql.hive.metastore.version" - val HIVE_METASTORE_JARS: String = "spark.sql.hive.metastore.jars" + val HIVE_METASTORE_JARS = stringConf("spark.sql.hive.metastore.jars", + defaultValue = Some("builtin"), + doc = "Location of the jars that should be used to instantiate the HiveMetastoreClient. This" + + " property can be one of three options: " + + "1. \"builtin\" Use Hive 0.13.1, which is bundled with the Spark assembly jar when " + + "-Phive is enabled. When this option is chosen, " + + "spark.sql.hive.metastore.version must be either 0.13.1 or not defined. " + + "2. \"maven\" Use Hive jars of specified version downloaded from Maven repositories." + + "3. A classpath in the standard format for both Hive and Hadoop.") + + val CONVERT_METASTORE_PARQUET = booleanConf("spark.sql.hive.convertMetastoreParquet", + defaultValue = Some(true), + doc = "When set to false, Spark SQL will use the Hive SerDe for parquet tables instead of " + + "the built in support.") + + val CONVERT_METASTORE_PARQUET_WITH_SCHEMA_MERGING = booleanConf( + "spark.sql.hive.convertMetastoreParquet.mergeSchema", + defaultValue = Some(false), + doc = "TODO") + + val CONVERT_CTAS = booleanConf("spark.sql.hive.convertCTAS", + defaultValue = Some(false), + doc = "TODO") + + val HIVE_METASTORE_SHARED_PREFIXES = stringSeqConf("spark.sql.hive.metastore.sharedPrefixes", + defaultValue = Some(jdbcPrefixes), + doc = "A comma separated list of class prefixes that should be loaded using the classloader " + + "that is shared between Spark SQL and a specific version of Hive. An example of classes " + + "that should be shared is JDBC drivers that are needed to talk to the metastore. Other " + + "classes that need to be shared are those that interact with classes that are already " + + "shared. For example, custom appenders that are used by log4j.") + + private def jdbcPrefixes = Seq( + "com.mysql.jdbc", "org.postgresql", "com.microsoft.sqlserver", "oracle.jdbc") + + val HIVE_METASTORE_BARRIER_PREFIXES = stringSeqConf("spark.sql.hive.metastore.barrierPrefixes", + defaultValue = Some(Seq()), + doc = "A comma separated list of class prefixes that should explicitly be reloaded for each " + + "version of Hive that Spark SQL is communicating with. For example, Hive UDFs that are " + + "declared in a prefix that typically would be shared (i.e. org.apache.spark.*).") + + val HIVE_THRIFT_SERVER_ASYNC = booleanConf("spark.sql.hive.thriftServer.async", + defaultValue = Some(true), + doc = "TODO") /** Constructs a configuration for hive, where the metastore is located in a temp directory. */ def newTemporaryConfiguration(): Map[String, String] = { val tempDir = Utils.createTempDir() val localMetastore = new File(tempDir, "metastore").getAbsolutePath - Map( - "javax.jdo.option.ConnectionURL" -> s"jdbc:derby:;databaseName=$localMetastore;create=true") + val propMap: HashMap[String, String] = HashMap() + // We have to mask all properties in hive-site.xml that relates to metastore data source + // as we used a local metastore here. + HiveConf.ConfVars.values().foreach { confvar => + if (confvar.varname.contains("datanucleus") || confvar.varname.contains("jdo")) { + propMap.put(confvar.varname, confvar.defaultVal) + } + } + propMap.put("javax.jdo.option.ConnectionURL", + s"jdbc:derby:;databaseName=$localMetastore;create=true") + propMap.put("datanucleus.rdbms.datastoreAdapterClassName", + "org.datanucleus.store.rdbms.adapter.DerbyAdapter") + propMap.toMap } protected val primitiveTypes = @@ -495,7 +592,7 @@ private[hive] object HiveContext { }.mkString("{", ",", "}") case (seq: Seq[_], ArrayType(typ, _)) => seq.map(v => (v, typ)).map(toHiveStructString).mkString("[", ",", "]") - case (map: Map[_,_], MapType(kType, vType, _)) => + case (map: Map[_, _], MapType(kType, vType, _)) => map.map { case (key, value) => toHiveStructString((key, kType)) + ":" + toHiveStructString((value, vType)) @@ -506,7 +603,7 @@ private[hive] object HiveContext { case (bin: Array[Byte], BinaryType) => new String(bin, "UTF-8") case (decimal: java.math.BigDecimal, DecimalType()) => // Hive strips trailing zeros so use its toString - HiveShim.createDecimal(decimal).toString + HiveDecimal.create(decimal).toString case (other, tpe) if primitiveTypes contains tpe => other.toString } diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveInspectors.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveInspectors.scala index 7c7666f6e4b7..a925e18ee145 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveInspectors.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveInspectors.scala @@ -18,15 +18,17 @@ package org.apache.spark.sql.hive import org.apache.hadoop.hive.common.`type`.{HiveDecimal, HiveVarchar} -import org.apache.hadoop.hive.serde2.objectinspector._ import org.apache.hadoop.hive.serde2.objectinspector.primitive._ +import org.apache.hadoop.hive.serde2.objectinspector.{StructField => HiveStructField, _} +import org.apache.hadoop.hive.serde2.typeinfo.{DecimalTypeInfo, TypeInfoFactory} import org.apache.hadoop.hive.serde2.{io => hiveIo} import org.apache.hadoop.{io => hadoopIo} - +import org.apache.spark.Logging import org.apache.spark.sql.catalyst.expressions._ -import org.apache.spark.sql.catalyst.util.DateUtils -import org.apache.spark.sql.types +import org.apache.spark.sql.catalyst.util.DateTimeUtils import org.apache.spark.sql.types._ +import org.apache.spark.sql.{AnalysisException, types} +import org.apache.spark.unsafe.types.UTF8String /* Implicit conversions */ import scala.collection.JavaConversions._ @@ -122,7 +124,7 @@ import scala.collection.JavaConversions._ * even a normal java object (POJO) * UnionObjectInspector: (tag: Int, object data) (TODO: not supported by SparkSQL yet) * - * 3) ConstantObjectInspector: + * 3) ConstantObjectInspector: * Constant object inspector can be either primitive type or Complex type, and it bundles a * constant value as its property, usually the value is created when the constant object inspector * constructed. @@ -133,7 +135,7 @@ import scala.collection.JavaConversions._ } }}} * Hive provides 3 built-in constant object inspectors: - * Primitive Object Inspectors: + * Primitive Object Inspectors: * WritableConstantStringObjectInspector * WritableConstantHiveVarcharObjectInspector * WritableConstantHiveDecimalObjectInspector @@ -147,9 +149,9 @@ import scala.collection.JavaConversions._ * WritableConstantByteObjectInspector * WritableConstantBinaryObjectInspector * WritableConstantDateObjectInspector - * Map Object Inspector: + * Map Object Inspector: * StandardConstantMapObjectInspector - * List Object Inspector: + * List Object Inspector: * StandardConstantListObjectInspector]] * Struct Object Inspector: Hive doesn't provide the built-in constant object inspector for Struct * Union Object Inspector: Hive doesn't provide the built-in constant object inspector for Union @@ -170,7 +172,7 @@ import scala.collection.JavaConversions._ * e.g. date_add(printf("%s-%s-%s", a,b,c), 3) * We don't need to unwrap the data for printf and wrap it again and passes in data_add */ -private[hive] trait HiveInspectors { +private[hive] trait HiveInspectors extends Logging { def javaClassToDataType(clz: Class[_]): DataType = clz match { // writable @@ -216,6 +218,14 @@ private[hive] trait HiveInspectors { // Hive seems to return this for struct types? case c: Class[_] if c == classOf[java.lang.Object] => NullType + + // java list type unsupported + case c: Class[_] if c == classOf[java.util.List[_]] => + throw new AnalysisException( + "List type in java is unsupported because " + + "JVM type erasure makes spark fail to catch a component type in List<>") + + case c => throw new AnalysisException(s"Unsupported java type $c") } /** @@ -241,18 +251,19 @@ private[hive] trait HiveInspectors { def unwrap(data: Any, oi: ObjectInspector): Any = oi match { case coi: ConstantObjectInspector if coi.getWritableConstantValue == null => null case poi: WritableConstantStringObjectInspector => - UTF8String(poi.getWritableConstantValue.toString) + UTF8String.fromString(poi.getWritableConstantValue.toString) case poi: WritableConstantHiveVarcharObjectInspector => - UTF8String(poi.getWritableConstantValue.getHiveVarchar.getValue) + UTF8String.fromString(poi.getWritableConstantValue.getHiveVarchar.getValue) case poi: WritableConstantHiveDecimalObjectInspector => HiveShim.toCatalystDecimal( PrimitiveObjectInspectorFactory.javaHiveDecimalObjectInspector, poi.getWritableConstantValue.getHiveDecimal) case poi: WritableConstantTimestampObjectInspector => - poi.getWritableConstantValue.getTimestamp.clone() - case poi: WritableConstantIntObjectInspector => + val t = poi.getWritableConstantValue + t.getSeconds * 10000000L + t.getNanos / 100L + case poi: WritableConstantIntObjectInspector => poi.getWritableConstantValue.get() - case poi: WritableConstantDoubleObjectInspector => + case poi: WritableConstantDoubleObjectInspector => poi.getWritableConstantValue.get() case poi: WritableConstantBooleanObjectInspector => poi.getWritableConstantValue.get() @@ -270,7 +281,7 @@ private[hive] trait HiveInspectors { System.arraycopy(writable.getBytes, 0, temp, 0, temp.length) temp case poi: WritableConstantDateObjectInspector => - DateUtils.fromJavaDate(poi.getWritableConstantValue.get()) + DateTimeUtils.fromJavaDate(poi.getWritableConstantValue.get()) case mi: StandardConstantMapObjectInspector => // take the value from the map inspector object, rather than the input data mi.getWritableConstantValue.map { case (k, v) => @@ -286,13 +297,13 @@ private[hive] trait HiveInspectors { case pi: PrimitiveObjectInspector => pi match { // We think HiveVarchar is also a String case hvoi: HiveVarcharObjectInspector if hvoi.preferWritable() => - UTF8String(hvoi.getPrimitiveWritableObject(data).getHiveVarchar.getValue) + UTF8String.fromString(hvoi.getPrimitiveWritableObject(data).getHiveVarchar.getValue) case hvoi: HiveVarcharObjectInspector => - UTF8String(hvoi.getPrimitiveJavaObject(data).getValue) + UTF8String.fromString(hvoi.getPrimitiveJavaObject(data).getValue) case x: StringObjectInspector if x.preferWritable() => - UTF8String(x.getPrimitiveWritableObject(data).toString) + UTF8String.fromString(x.getPrimitiveWritableObject(data).toString) case x: StringObjectInspector => - UTF8String(x.getPrimitiveJavaObject(data)) + UTF8String.fromString(x.getPrimitiveJavaObject(data)) case x: IntObjectInspector if x.preferWritable() => x.get(data) case x: BooleanObjectInspector if x.preferWritable() => x.get(data) case x: FloatObjectInspector if x.preferWritable() => x.get(data) @@ -306,17 +317,17 @@ private[hive] trait HiveInspectors { // In order to keep backward-compatible, we have to copy the // bytes with old apis val bw = x.getPrimitiveWritableObject(data) - val result = new Array[Byte](bw.getLength()) + val result = new Array[Byte](bw.getLength()) System.arraycopy(bw.getBytes(), 0, result, 0, bw.getLength()) result case x: DateObjectInspector if x.preferWritable() => - DateUtils.fromJavaDate(x.getPrimitiveWritableObject(data).get()) - case x: DateObjectInspector => DateUtils.fromJavaDate(x.getPrimitiveJavaObject(data)) - // org.apache.hadoop.hive.serde2.io.TimestampWritable.set will reset current time object - // if next timestamp is null, so Timestamp object is cloned + DateTimeUtils.fromJavaDate(x.getPrimitiveWritableObject(data).get()) + case x: DateObjectInspector => DateTimeUtils.fromJavaDate(x.getPrimitiveJavaObject(data)) case x: TimestampObjectInspector if x.preferWritable() => - x.getPrimitiveWritableObject(data).getTimestamp.clone() - case ti: TimestampObjectInspector => ti.getPrimitiveJavaObject(data).clone() + val t = x.getPrimitiveWritableObject(data) + t.getSeconds * 10000000L + t.getNanos / 100 + case ti: TimestampObjectInspector => + DateTimeUtils.fromJavaTimestamp(ti.getPrimitiveJavaObject(data)) case _ => pi.getPrimitiveJavaObject(data) } case li: ListObjectInspector => @@ -333,9 +344,8 @@ private[hive] trait HiveInspectors { // currently, hive doesn't provide the ConstantStructObjectInspector case si: StructObjectInspector => val allRefs = si.getAllStructFieldRefs - new GenericRow( - allRefs.map(r => - unwrap(si.getStructFieldData(data,r), r.getFieldObjectInspector)).toArray) + InternalRow.fromSeq( + allRefs.map(r => unwrap(si.getStructFieldData(data, r), r.getFieldObjectInspector))) } @@ -350,17 +360,20 @@ private[hive] trait HiveInspectors { new HiveVarchar(s, s.size) case _: JavaHiveDecimalObjectInspector => - (o: Any) => HiveShim.createDecimal(o.asInstanceOf[Decimal].toJavaBigDecimal) + (o: Any) => HiveDecimal.create(o.asInstanceOf[Decimal].toJavaBigDecimal) case _: JavaDateObjectInspector => - (o: Any) => DateUtils.toJavaDate(o.asInstanceOf[Int]) + (o: Any) => DateTimeUtils.toJavaDate(o.asInstanceOf[Int]) + + case _: JavaTimestampObjectInspector => + (o: Any) => DateTimeUtils.toJavaTimestamp(o.asInstanceOf[Long]) case soi: StandardStructObjectInspector => val wrappers = soi.getAllStructFieldRefs.map(ref => wrapperFor(ref.getFieldObjectInspector)) (o: Any) => { if (o != null) { val struct = soi.create() - (soi.getAllStructFieldRefs, wrappers, o.asInstanceOf[Row].toSeq).zipped.foreach { + (soi.getAllStructFieldRefs, wrappers, o.asInstanceOf[InternalRow].toSeq).zipped.foreach { (field, wrapper, data) => soi.setStructFieldData(struct, field, wrapper(data)) } struct @@ -394,6 +407,30 @@ private[hive] trait HiveInspectors { identity[Any] } + /** + * Builds specific unwrappers ahead of time according to object inspector + * types to avoid pattern matching and branching costs per row. + */ + def unwrapperFor(field: HiveStructField): (Any, MutableRow, Int) => Unit = + field.getFieldObjectInspector match { + case oi: BooleanObjectInspector => + (value: Any, row: MutableRow, ordinal: Int) => row.setBoolean(ordinal, oi.get(value)) + case oi: ByteObjectInspector => + (value: Any, row: MutableRow, ordinal: Int) => row.setByte(ordinal, oi.get(value)) + case oi: ShortObjectInspector => + (value: Any, row: MutableRow, ordinal: Int) => row.setShort(ordinal, oi.get(value)) + case oi: IntObjectInspector => + (value: Any, row: MutableRow, ordinal: Int) => row.setInt(ordinal, oi.get(value)) + case oi: LongObjectInspector => + (value: Any, row: MutableRow, ordinal: Int) => row.setLong(ordinal, oi.get(value)) + case oi: FloatObjectInspector => + (value: Any, row: MutableRow, ordinal: Int) => row.setFloat(ordinal, oi.get(value)) + case oi: DoubleObjectInspector => + (value: Any, row: MutableRow, ordinal: Int) => row.setDouble(ordinal, oi.get(value)) + case oi => + (value: Any, row: MutableRow, ordinal: Int) => row(ordinal) = unwrap(value, oi) + } + /** * Converts native catalyst types to the types expected by Hive * @param a the value to be wrapped @@ -415,36 +452,36 @@ private[hive] trait HiveInspectors { case _ if a == null => null case x: PrimitiveObjectInspector => x match { // TODO we don't support the HiveVarcharObjectInspector yet. - case _: StringObjectInspector if x.preferWritable() => HiveShim.getStringWritable(a) + case _: StringObjectInspector if x.preferWritable() => getStringWritable(a) case _: StringObjectInspector => a.asInstanceOf[UTF8String].toString() - case _: IntObjectInspector if x.preferWritable() => HiveShim.getIntWritable(a) + case _: IntObjectInspector if x.preferWritable() => getIntWritable(a) case _: IntObjectInspector => a.asInstanceOf[java.lang.Integer] - case _: BooleanObjectInspector if x.preferWritable() => HiveShim.getBooleanWritable(a) + case _: BooleanObjectInspector if x.preferWritable() => getBooleanWritable(a) case _: BooleanObjectInspector => a.asInstanceOf[java.lang.Boolean] - case _: FloatObjectInspector if x.preferWritable() => HiveShim.getFloatWritable(a) + case _: FloatObjectInspector if x.preferWritable() => getFloatWritable(a) case _: FloatObjectInspector => a.asInstanceOf[java.lang.Float] - case _: DoubleObjectInspector if x.preferWritable() => HiveShim.getDoubleWritable(a) + case _: DoubleObjectInspector if x.preferWritable() => getDoubleWritable(a) case _: DoubleObjectInspector => a.asInstanceOf[java.lang.Double] - case _: LongObjectInspector if x.preferWritable() => HiveShim.getLongWritable(a) + case _: LongObjectInspector if x.preferWritable() => getLongWritable(a) case _: LongObjectInspector => a.asInstanceOf[java.lang.Long] - case _: ShortObjectInspector if x.preferWritable() => HiveShim.getShortWritable(a) + case _: ShortObjectInspector if x.preferWritable() => getShortWritable(a) case _: ShortObjectInspector => a.asInstanceOf[java.lang.Short] - case _: ByteObjectInspector if x.preferWritable() => HiveShim.getByteWritable(a) + case _: ByteObjectInspector if x.preferWritable() => getByteWritable(a) case _: ByteObjectInspector => a.asInstanceOf[java.lang.Byte] case _: HiveDecimalObjectInspector if x.preferWritable() => - HiveShim.getDecimalWritable(a.asInstanceOf[Decimal]) + getDecimalWritable(a.asInstanceOf[Decimal]) case _: HiveDecimalObjectInspector => - HiveShim.createDecimal(a.asInstanceOf[Decimal].toJavaBigDecimal) - case _: BinaryObjectInspector if x.preferWritable() => HiveShim.getBinaryWritable(a) + HiveDecimal.create(a.asInstanceOf[Decimal].toJavaBigDecimal) + case _: BinaryObjectInspector if x.preferWritable() => getBinaryWritable(a) case _: BinaryObjectInspector => a.asInstanceOf[Array[Byte]] - case _: DateObjectInspector if x.preferWritable() => HiveShim.getDateWritable(a) - case _: DateObjectInspector => DateUtils.toJavaDate(a.asInstanceOf[Int]) - case _: TimestampObjectInspector if x.preferWritable() => HiveShim.getTimestampWritable(a) - case _: TimestampObjectInspector => a.asInstanceOf[java.sql.Timestamp] + case _: DateObjectInspector if x.preferWritable() => getDateWritable(a) + case _: DateObjectInspector => DateTimeUtils.toJavaDate(a.asInstanceOf[Int]) + case _: TimestampObjectInspector if x.preferWritable() => getTimestampWritable(a) + case _: TimestampObjectInspector => DateTimeUtils.toJavaTimestamp(a.asInstanceOf[Long]) } case x: SettableStructObjectInspector => val fieldRefs = x.getAllStructFieldRefs - val row = a.asInstanceOf[Row] + val row = a.asInstanceOf[InternalRow] // 1. create the pojo (most likely) object val result = x.create() var i = 0 @@ -460,7 +497,7 @@ private[hive] trait HiveInspectors { result case x: StructObjectInspector => val fieldRefs = x.getAllStructFieldRefs - val row = a.asInstanceOf[Row] + val row = a.asInstanceOf[InternalRow] val result = new java.util.ArrayList[AnyRef](fieldRefs.length) var i = 0 while (i < fieldRefs.length) { @@ -487,7 +524,7 @@ private[hive] trait HiveInspectors { } def wrap( - row: Row, + row: InternalRow, inspectors: Seq[ObjectInspector], cache: Array[AnyRef]): Array[AnyRef] = { var i = 0 @@ -537,8 +574,8 @@ private[hive] trait HiveInspectors { case DecimalType() => PrimitiveObjectInspectorFactory.javaHiveDecimalObjectInspector case StructType(fields) => ObjectInspectorFactory.getStandardStructObjectInspector( - java.util.Arrays.asList(fields.map(f => f.name) :_*), - java.util.Arrays.asList(fields.map(f => toInspector(f.dataType)) :_*)) + java.util.Arrays.asList(fields.map(f => f.name) : _*), + java.util.Arrays.asList(fields.map(f => toInspector(f.dataType)) : _*)) } /** @@ -550,31 +587,31 @@ private[hive] trait HiveInspectors { */ def toInspector(expr: Expression): ObjectInspector = expr match { case Literal(value, StringType) => - HiveShim.getStringWritableConstantObjectInspector(value) + getStringWritableConstantObjectInspector(value) case Literal(value, IntegerType) => - HiveShim.getIntWritableConstantObjectInspector(value) + getIntWritableConstantObjectInspector(value) case Literal(value, DoubleType) => - HiveShim.getDoubleWritableConstantObjectInspector(value) + getDoubleWritableConstantObjectInspector(value) case Literal(value, BooleanType) => - HiveShim.getBooleanWritableConstantObjectInspector(value) + getBooleanWritableConstantObjectInspector(value) case Literal(value, LongType) => - HiveShim.getLongWritableConstantObjectInspector(value) + getLongWritableConstantObjectInspector(value) case Literal(value, FloatType) => - HiveShim.getFloatWritableConstantObjectInspector(value) + getFloatWritableConstantObjectInspector(value) case Literal(value, ShortType) => - HiveShim.getShortWritableConstantObjectInspector(value) + getShortWritableConstantObjectInspector(value) case Literal(value, ByteType) => - HiveShim.getByteWritableConstantObjectInspector(value) + getByteWritableConstantObjectInspector(value) case Literal(value, BinaryType) => - HiveShim.getBinaryWritableConstantObjectInspector(value) + getBinaryWritableConstantObjectInspector(value) case Literal(value, DateType) => - HiveShim.getDateWritableConstantObjectInspector(value) + getDateWritableConstantObjectInspector(value) case Literal(value, TimestampType) => - HiveShim.getTimestampWritableConstantObjectInspector(value) + getTimestampWritableConstantObjectInspector(value) case Literal(value, DecimalType()) => - HiveShim.getDecimalWritableConstantObjectInspector(value) + getDecimalWritableConstantObjectInspector(value) case Literal(_, NullType) => - HiveShim.getPrimitiveNullWritableConstantObjectInspector + getPrimitiveNullWritableConstantObjectInspector case Literal(value, ArrayType(dt, _)) => val listObjectInspector = toInspector(dt) if (value == null) { @@ -634,8 +671,8 @@ private[hive] trait HiveInspectors { case _: JavaFloatObjectInspector => FloatType case _: WritableBinaryObjectInspector => BinaryType case _: JavaBinaryObjectInspector => BinaryType - case w: WritableHiveDecimalObjectInspector => HiveShim.decimalTypeInfoToCatalyst(w) - case j: JavaHiveDecimalObjectInspector => HiveShim.decimalTypeInfoToCatalyst(j) + case w: WritableHiveDecimalObjectInspector => decimalTypeInfoToCatalyst(w) + case j: JavaHiveDecimalObjectInspector => decimalTypeInfoToCatalyst(j) case _: WritableDateObjectInspector => DateType case _: JavaDateObjectInspector => DateType case _: WritableTimestampObjectInspector => TimestampType @@ -644,17 +681,143 @@ private[hive] trait HiveInspectors { case _: JavaVoidObjectInspector => NullType } + private def decimalTypeInfoToCatalyst(inspector: PrimitiveObjectInspector): DecimalType = { + val info = inspector.getTypeInfo.asInstanceOf[DecimalTypeInfo] + DecimalType(info.precision(), info.scale()) + } + + private def getStringWritableConstantObjectInspector(value: Any): ObjectInspector = + PrimitiveObjectInspectorFactory.getPrimitiveWritableConstantObjectInspector( + TypeInfoFactory.stringTypeInfo, getStringWritable(value)) + + private def getIntWritableConstantObjectInspector(value: Any): ObjectInspector = + PrimitiveObjectInspectorFactory.getPrimitiveWritableConstantObjectInspector( + TypeInfoFactory.intTypeInfo, getIntWritable(value)) + + private def getDoubleWritableConstantObjectInspector(value: Any): ObjectInspector = + PrimitiveObjectInspectorFactory.getPrimitiveWritableConstantObjectInspector( + TypeInfoFactory.doubleTypeInfo, getDoubleWritable(value)) + + private def getBooleanWritableConstantObjectInspector(value: Any): ObjectInspector = + PrimitiveObjectInspectorFactory.getPrimitiveWritableConstantObjectInspector( + TypeInfoFactory.booleanTypeInfo, getBooleanWritable(value)) + + private def getLongWritableConstantObjectInspector(value: Any): ObjectInspector = + PrimitiveObjectInspectorFactory.getPrimitiveWritableConstantObjectInspector( + TypeInfoFactory.longTypeInfo, getLongWritable(value)) + + private def getFloatWritableConstantObjectInspector(value: Any): ObjectInspector = + PrimitiveObjectInspectorFactory.getPrimitiveWritableConstantObjectInspector( + TypeInfoFactory.floatTypeInfo, getFloatWritable(value)) + + private def getShortWritableConstantObjectInspector(value: Any): ObjectInspector = + PrimitiveObjectInspectorFactory.getPrimitiveWritableConstantObjectInspector( + TypeInfoFactory.shortTypeInfo, getShortWritable(value)) + + private def getByteWritableConstantObjectInspector(value: Any): ObjectInspector = + PrimitiveObjectInspectorFactory.getPrimitiveWritableConstantObjectInspector( + TypeInfoFactory.byteTypeInfo, getByteWritable(value)) + + private def getBinaryWritableConstantObjectInspector(value: Any): ObjectInspector = + PrimitiveObjectInspectorFactory.getPrimitiveWritableConstantObjectInspector( + TypeInfoFactory.binaryTypeInfo, getBinaryWritable(value)) + + private def getDateWritableConstantObjectInspector(value: Any): ObjectInspector = + PrimitiveObjectInspectorFactory.getPrimitiveWritableConstantObjectInspector( + TypeInfoFactory.dateTypeInfo, getDateWritable(value)) + + private def getTimestampWritableConstantObjectInspector(value: Any): ObjectInspector = + PrimitiveObjectInspectorFactory.getPrimitiveWritableConstantObjectInspector( + TypeInfoFactory.timestampTypeInfo, getTimestampWritable(value)) + + private def getDecimalWritableConstantObjectInspector(value: Any): ObjectInspector = + PrimitiveObjectInspectorFactory.getPrimitiveWritableConstantObjectInspector( + TypeInfoFactory.decimalTypeInfo, getDecimalWritable(value)) + + private def getPrimitiveNullWritableConstantObjectInspector: ObjectInspector = + PrimitiveObjectInspectorFactory.getPrimitiveWritableConstantObjectInspector( + TypeInfoFactory.voidTypeInfo, null) + + private def getStringWritable(value: Any): hadoopIo.Text = + if (value == null) null else new hadoopIo.Text(value.asInstanceOf[UTF8String].getBytes) + + private def getIntWritable(value: Any): hadoopIo.IntWritable = + if (value == null) null else new hadoopIo.IntWritable(value.asInstanceOf[Int]) + + private def getDoubleWritable(value: Any): hiveIo.DoubleWritable = + if (value == null) { + null + } else { + new hiveIo.DoubleWritable(value.asInstanceOf[Double]) + } + + private def getBooleanWritable(value: Any): hadoopIo.BooleanWritable = + if (value == null) { + null + } else { + new hadoopIo.BooleanWritable(value.asInstanceOf[Boolean]) + } + + private def getLongWritable(value: Any): hadoopIo.LongWritable = + if (value == null) null else new hadoopIo.LongWritable(value.asInstanceOf[Long]) + + private def getFloatWritable(value: Any): hadoopIo.FloatWritable = + if (value == null) { + null + } else { + new hadoopIo.FloatWritable(value.asInstanceOf[Float]) + } + + private def getShortWritable(value: Any): hiveIo.ShortWritable = + if (value == null) null else new hiveIo.ShortWritable(value.asInstanceOf[Short]) + + private def getByteWritable(value: Any): hiveIo.ByteWritable = + if (value == null) null else new hiveIo.ByteWritable(value.asInstanceOf[Byte]) + + private def getBinaryWritable(value: Any): hadoopIo.BytesWritable = + if (value == null) { + null + } else { + new hadoopIo.BytesWritable(value.asInstanceOf[Array[Byte]]) + } + + private def getDateWritable(value: Any): hiveIo.DateWritable = + if (value == null) null else new hiveIo.DateWritable(value.asInstanceOf[Int]) + + private def getTimestampWritable(value: Any): hiveIo.TimestampWritable = + if (value == null) { + null + } else { + new hiveIo.TimestampWritable(DateTimeUtils.toJavaTimestamp(value.asInstanceOf[Long])) + } + + private def getDecimalWritable(value: Any): hiveIo.HiveDecimalWritable = + if (value == null) { + null + } else { + // TODO precise, scale? + new hiveIo.HiveDecimalWritable( + HiveDecimal.create(value.asInstanceOf[Decimal].toJavaBigDecimal)) + } + implicit class typeInfoConversions(dt: DataType) { + import org.apache.hadoop.hive.serde2.typeinfo.TypeInfoFactory._ import org.apache.hadoop.hive.serde2.typeinfo._ - import TypeInfoFactory._ + + private def decimalTypeInfo(decimalType: DecimalType): TypeInfo = decimalType match { + case DecimalType.Fixed(precision, scale) => new DecimalTypeInfo(precision, scale) + case _ => new DecimalTypeInfo( + HiveShim.UNLIMITED_DECIMAL_PRECISION, + HiveShim.UNLIMITED_DECIMAL_SCALE) + } def toTypeInfo: TypeInfo = dt match { case ArrayType(elemType, _) => getListTypeInfo(elemType.toTypeInfo) case StructType(fields) => getStructTypeInfo( - java.util.Arrays.asList(fields.map(_.name) :_*), - java.util.Arrays.asList(fields.map(_.dataType.toTypeInfo) :_*)) + java.util.Arrays.asList(fields.map(_.name) : _*), + java.util.Arrays.asList(fields.map(_.dataType.toTypeInfo) : _*)) case MapType(keyType, valueType, _) => getMapTypeInfo(keyType.toTypeInfo, valueType.toTypeInfo) case BinaryType => binaryTypeInfo @@ -666,7 +829,7 @@ private[hive] trait HiveInspectors { case LongType => longTypeInfo case ShortType => shortTypeInfo case StringType => stringTypeInfo - case d: DecimalType => HiveShim.decimalTypeInfo(d) + case d: DecimalType => decimalTypeInfo(d) case DateType => dateTypeInfo case TimestampType => timestampTypeInfo case NullType => voidTypeInfo diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala index 2aa80b47a97e..f35ae96ee0b5 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala @@ -19,11 +19,13 @@ package org.apache.spark.sql.hive import com.google.common.base.Objects import com.google.common.cache.{CacheBuilder, CacheLoader, LoadingCache} + import org.apache.hadoop.fs.Path +import org.apache.hadoop.hive.common.StatsSetupConst import org.apache.hadoop.hive.metastore.Warehouse import org.apache.hadoop.hive.metastore.api.FieldSchema import org.apache.hadoop.hive.ql.metadata._ -import org.apache.hadoop.hive.serde2.Deserializer +import org.apache.hadoop.hive.ql.plan.TableDesc import org.apache.spark.Logging import org.apache.spark.sql.catalyst.analysis.{Catalog, MultiInstanceRelation, OverrideCatalog} @@ -37,7 +39,6 @@ import org.apache.spark.sql.parquet.ParquetRelation2 import org.apache.spark.sql.sources.{CreateTableUsingAsSelect, LogicalRelation, Partition => ParquetPartition, PartitionSpec, ResolvedDataSource} import org.apache.spark.sql.types._ import org.apache.spark.sql.{AnalysisException, SQLContext, SaveMode, sources} -import org.apache.spark.util.Utils /* Implicit conversions */ import scala.collection.JavaConversions._ @@ -66,11 +67,11 @@ private[hive] class HiveMetastoreCatalog(val client: ClientInterface, hive: Hive def schemaStringFromParts: Option[String] = { table.properties.get("spark.sql.sources.schema.numParts").map { numParts => val parts = (0 until numParts.toInt).map { index => - val part = table.properties.get(s"spark.sql.sources.schema.part.${index}").orNull + val part = table.properties.get(s"spark.sql.sources.schema.part.$index").orNull if (part == null) { throw new AnalysisException( - s"Could not read schema from the metastore because it is corrupted " + - s"(missing part ${index} of the schema).") + "Could not read schema from the metastore because it is corrupted " + + s"(missing part $index of the schema, $numParts parts are expected).") } part @@ -89,6 +90,11 @@ private[hive] class HiveMetastoreCatalog(val client: ClientInterface, hive: Hive val userSpecifiedSchema = schemaString.map(s => DataType.fromJson(s).asInstanceOf[StructType]) + // We only need names at here since userSpecifiedSchema we loaded from the metastore + // contains partition columns. We can always get datatypes of partitioning columns + // from userSpecifiedSchema. + val partitionColumns = table.partitionColumns.map(_.name) + // It does not appear that the ql client for the metastore has a way to enumerate all the // SerDe properties directly... val options = table.serdeProperties @@ -97,7 +103,7 @@ private[hive] class HiveMetastoreCatalog(val client: ClientInterface, hive: Hive ResolvedDataSource( hive, userSpecifiedSchema, - Array.empty[String], + partitionColumns.toArray, table.properties("spark.sql.sources.provider"), options) @@ -111,8 +117,8 @@ private[hive] class HiveMetastoreCatalog(val client: ClientInterface, hive: Hive override def refreshTable(databaseName: String, tableName: String): Unit = { // refreshTable does not eagerly reload the cache. It just invalidate the cache. // Next time when we use the table, it will be populated in the cache. - // Since we also cache ParquetRealtions converted from Hive Parquet tables and - // adding converted ParquetRealtions into the cache is not defined in the load function + // Since we also cache ParquetRelations converted from Hive Parquet tables and + // adding converted ParquetRelations into the cache is not defined in the load function // of the cache (instead, we add the cache entry in convertToParquetRelation), // it is better at here to invalidate the cache to avoid confusing waring logs from the // cache loader (e.g. cannot find data source provider, which is only defined for @@ -133,12 +139,17 @@ private[hive] class HiveMetastoreCatalog(val client: ClientInterface, hive: Hive def createDataSourceTable( tableName: String, userSpecifiedSchema: Option[StructType], + partitionColumns: Array[String], provider: String, options: Map[String, String], isExternal: Boolean): Unit = { - val (dbName, tblName) = processDatabaseAndTableName("default", tableName) + val (dbName, tblName) = processDatabaseAndTableName(client.currentDatabase, tableName) val tableProperties = new scala.collection.mutable.HashMap[String, String] tableProperties.put("spark.sql.sources.provider", provider) + + // Saves optional user specified schema. Serialized JSON schema string may be too long to be + // stored into a single metastore SerDe property. In this case, we split the JSON string and + // store each part as a separate SerDe property. if (userSpecifiedSchema.isDefined) { val threshold = conf.schemaStringLengthThreshold val schemaJsonString = userSpecifiedSchema.get.json @@ -146,8 +157,29 @@ private[hive] class HiveMetastoreCatalog(val client: ClientInterface, hive: Hive val parts = schemaJsonString.grouped(threshold).toSeq tableProperties.put("spark.sql.sources.schema.numParts", parts.size.toString) parts.zipWithIndex.foreach { case (part, index) => - tableProperties.put(s"spark.sql.sources.schema.part.${index}", part) + tableProperties.put(s"spark.sql.sources.schema.part.$index", part) + } + } + + val metastorePartitionColumns = userSpecifiedSchema.map { schema => + val fields = partitionColumns.map(col => schema(col)) + fields.map { field => + HiveColumn( + name = field.name, + hiveType = HiveMetastoreTypes.toMetastoreType(field.dataType), + comment = "") + }.toSeq + }.getOrElse { + if (partitionColumns.length > 0) { + // The table does not have a specified schema, which means that the schema will be inferred + // when we load the table. So, we are not expecting partition columns and we will discover + // partitions when we load the table. However, if there are specified partition columns, + // we simplily ignore them and provide a warning message.. + logWarning( + s"The schema and partitions of table $tableName will be inferred when it is loaded. " + + s"Specified partition columns (${partitionColumns.mkString(",")}) will be ignored.") } + Seq.empty[HiveColumn] } val tableType = if (isExternal) { @@ -163,7 +195,7 @@ private[hive] class HiveMetastoreCatalog(val client: ClientInterface, hive: Hive specifiedDatabase = Option(dbName), name = tblName, schema = Seq.empty, - partitionColumns = Seq.empty, + partitionColumns = metastorePartitionColumns, tableType = tableType, properties = tableProperties.toMap, serdeProperties = options)) @@ -199,7 +231,7 @@ private[hive] class HiveMetastoreCatalog(val client: ClientInterface, hive: Hive val dataSourceTable = cachedDataSourceTables(QualifiedTableName(databaseName, tblName).toLowerCase) // Then, if alias is specified, wrap the table with a Subquery using the alias. - // Othersie, wrap the table with a Subquery using the table name. + // Otherwise, wrap the table with a Subquery using the table name. val withAlias = alias.map(a => Subquery(a, dataSourceTable)).getOrElse( Subquery(tableIdent.last, dataSourceTable)) @@ -270,7 +302,7 @@ private[hive] class HiveMetastoreCatalog(val client: ClientInterface, hive: Hive val partitionColumnDataTypes = partitionSchema.map(_.dataType) val partitions = metastoreRelation.hiveQlPartitions.map { p => val location = p.getLocation - val values = Row.fromSeq(p.getValues.zip(partitionColumnDataTypes).map { + val values = InternalRow.fromSeq(p.getValues.zip(partitionColumnDataTypes).map { case (rawValue, dataType) => Cast(Literal(rawValue), dataType).eval(null) }) ParquetPartition(values, location) @@ -485,17 +517,19 @@ private[hive] class HiveMetastoreCatalog(val client: ClientInterface, hive: Hive def castChildOutput(p: InsertIntoTable, table: MetastoreRelation, child: LogicalPlan) : LogicalPlan = { val childOutputDataTypes = child.output.map(_.dataType) + val numDynamicPartitions = p.partition.values.count(_.isEmpty) val tableOutputDataTypes = - (table.attributes ++ table.partitionKeys).take(child.output.length).map(_.dataType) + (table.attributes ++ table.partitionKeys.takeRight(numDynamicPartitions)) + .take(child.output.length).map(_.dataType) if (childOutputDataTypes == tableOutputDataTypes) { - p + InsertIntoHiveTable(table, p.partition, p.child, p.overwrite, p.ifNotExists) } else if (childOutputDataTypes.size == tableOutputDataTypes.size && childOutputDataTypes.zip(tableOutputDataTypes) .forall { case (left, right) => left.sameType(right) }) { // If both types ignoring nullability of ArrayType, MapType, StructType are the same, // use InsertIntoHiveTable instead of InsertIntoTable. - InsertIntoHiveTable(p.table, p.partition, p.child, p.overwrite, p.ifNotExists) + InsertIntoHiveTable(table, p.partition, p.child, p.overwrite, p.ifNotExists) } else { // Only do the casting when child output data types differ from table output data types. val castedChildOutput = child.output.zip(table.output).map { @@ -513,13 +547,17 @@ private[hive] class HiveMetastoreCatalog(val client: ClientInterface, hive: Hive * UNIMPLEMENTED: It needs to be decided how we will persist in-memory tables to the metastore. * For now, if this functionality is desired mix in the in-memory [[OverrideCatalog]]. */ - override def registerTable(tableIdentifier: Seq[String], plan: LogicalPlan): Unit = ??? + override def registerTable(tableIdentifier: Seq[String], plan: LogicalPlan): Unit = { + throw new UnsupportedOperationException + } /** * UNIMPLEMENTED: It needs to be decided how we will persist in-memory tables to the metastore. * For now, if this functionality is desired mix in the in-memory [[OverrideCatalog]]. */ - override def unregisterTable(tableIdentifier: Seq[String]): Unit = ??? + override def unregisterTable(tableIdentifier: Seq[String]): Unit = { + throw new UnsupportedOperationException + } override def unregisterAllTables(): Unit = {} } @@ -530,7 +568,7 @@ private[hive] class HiveMetastoreCatalog(val client: ClientInterface, hive: Hive * because Hive table doesn't have nullability for ARRAY, MAP, STRUCT types. */ private[hive] case class InsertIntoHiveTable( - table: LogicalPlan, + table: MetastoreRelation, partition: Map[String, Option[String]], child: LogicalPlan, overwrite: Boolean, @@ -540,7 +578,13 @@ private[hive] case class InsertIntoHiveTable( override def children: Seq[LogicalPlan] = child :: Nil override def output: Seq[Attribute] = child.output - override lazy val resolved: Boolean = childrenResolved && child.output.zip(table.output).forall { + val numDynamicPartitions = partition.values.count(_.isEmpty) + + // This is the expected schema of the table prepared to be inserted into, + // including dynamic partition columns. + val tableOutput = table.attributes ++ table.partitionKeys.takeRight(numDynamicPartitions) + + override lazy val resolved: Boolean = childrenResolved && child.output.zip(tableOutput).forall { case (childAttr, tableAttr) => childAttr.dataType.sameType(tableAttr.dataType) } } @@ -553,7 +597,7 @@ private[hive] case class MetastoreRelation self: Product => - override def equals(other: scala.Any): Boolean = other match { + override def equals(other: Any): Boolean = other match { case relation: MetastoreRelation => databaseName == relation.databaseName && tableName == relation.tableName && @@ -627,8 +671,8 @@ private[hive] case class MetastoreRelation @transient override lazy val statistics: Statistics = Statistics( sizeInBytes = { - val totalSize = hiveQlTable.getParameters.get(HiveShim.getStatsSetupConstTotalSize) - val rawDataSize = hiveQlTable.getParameters.get(HiveShim.getStatsSetupConstRawDataSize) + val totalSize = hiveQlTable.getParameters.get(StatsSetupConst.TOTAL_SIZE) + val rawDataSize = hiveQlTable.getParameters.get(StatsSetupConst.RAW_DATA_SIZE) // TODO: check if this estimate is valid for tables after partition pruning. // NOTE: getting `totalSize` directly from params is kind of hacky, but this should be // relatively cheap if parameters for the table are populated into the metastore. An @@ -654,11 +698,7 @@ private[hive] case class MetastoreRelation } } - val tableDesc = HiveShim.getTableDesc( - Class.forName( - hiveQlTable.getSerializationLib, - true, - Utils.getContextOrSparkClassLoader).asInstanceOf[Class[Deserializer]], + val tableDesc = new TableDesc( hiveQlTable.getInputFormatClass, // The class of table should be org.apache.hadoop.hive.ql.metadata.Table because // getOutputFormatClass will use HiveFileFormatUtils.getOutputFormatSubstitute to @@ -668,25 +708,25 @@ private[hive] case class MetastoreRelation hiveQlTable.getMetadata ) - implicit class SchemaAttribute(f: FieldSchema) { + implicit class SchemaAttribute(f: HiveColumn) { def toAttribute: AttributeReference = AttributeReference( - f.getName, - HiveMetastoreTypes.toDataType(f.getType), + f.name, + HiveMetastoreTypes.toDataType(f.hiveType), // Since data can be dumped in randomly with no validation, everything is nullable. nullable = true )(qualifiers = Seq(alias.getOrElse(tableName))) } - // Must be a stable value since new attributes are born here. - val partitionKeys = hiveQlTable.getPartitionKeys.map(_.toAttribute) + /** PartitionKey attributes */ + val partitionKeys = table.partitionColumns.map(_.toAttribute) /** Non-partitionKey attributes */ - val attributes = hiveQlTable.getCols.map(_.toAttribute) + val attributes = table.schema.map(_.toAttribute) val output = attributes ++ partitionKeys /** An attribute map that can be used to lookup original attributes based on expression id. */ - val attributeMap = AttributeMap(output.map(o => (o,o))) + val attributeMap = AttributeMap(output.map(o => (o, o))) /** An attribute map for determining the ordinal for non-partition columns. */ val columnOrdinals = AttributeMap(attributes.zipWithIndex) @@ -700,6 +740,11 @@ private[hive] case class MetastoreRelation private[hive] object HiveMetastoreTypes { def toDataType(metastoreType: String): DataType = DataTypeParser.parse(metastoreType) + def decimalMetastoreString(decimalType: DecimalType): String = decimalType match { + case DecimalType.Fixed(precision, scale) => s"decimal($precision,$scale)" + case _ => s"decimal($HiveShim.UNLIMITED_DECIMAL_PRECISION,$HiveShim.UNLIMITED_DECIMAL_SCALE)" + } + def toMetastoreType(dt: DataType): String = dt match { case ArrayType(elementType, _) => s"array<${toMetastoreType(elementType)}>" case StructType(fields) => @@ -716,7 +761,7 @@ private[hive] object HiveMetastoreTypes { case BinaryType => "binary" case BooleanType => "boolean" case DateType => "date" - case d: DecimalType => HiveShim.decimalMetastoreString(d) + case d: DecimalType => decimalMetastoreString(d) case TimestampType => "timestamp" case NullType => "void" case udt: UserDefinedType[_] => toMetastoreType(udt.sqlType) diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveQl.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveQl.scala index 2cbb5ca4d2e0..2de7a99c122f 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveQl.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveQl.scala @@ -19,8 +19,6 @@ package org.apache.spark.sql.hive import java.sql.Date -import scala.collection.mutable.ArrayBuffer - import org.apache.hadoop.hive.conf.HiveConf import org.apache.hadoop.hive.serde.serdeConstants import org.apache.hadoop.hive.ql.{ErrorMsg, Context} @@ -39,6 +37,7 @@ import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.catalyst.trees.CurrentOrigin import org.apache.spark.sql.execution.ExplainCommand import org.apache.spark.sql.sources.DescribeCommand +import org.apache.spark.sql.hive.HiveShim._ import org.apache.spark.sql.hive.client._ import org.apache.spark.sql.hive.execution.{HiveNativeCommand, DropTable, AnalyzeTable, HiveScriptIOSchema} import org.apache.spark.sql.types._ @@ -46,6 +45,7 @@ import org.apache.spark.util.random.RandomSampler /* Implicit conversions */ import scala.collection.JavaConversions._ +import scala.collection.mutable.ArrayBuffer /** * Used when we need to start parsing the AST before deciding that we are going to pass the command @@ -57,7 +57,7 @@ private[hive] case object NativePlaceholder extends LogicalPlan { override def output: Seq[Attribute] = Seq.empty } -case class CreateTableAsSelect( +private[hive] case class CreateTableAsSelect( tableDesc: HiveTable, child: LogicalPlan, allowExisting: Boolean) extends UnaryNode with Command { @@ -415,13 +415,6 @@ private[hive] object HiveQl { throw new NotImplementedError(s"No parse rules for StructField:\n ${dumpTree(a).toString} ") } - protected def nameExpressions(exprs: Seq[Expression]): Seq[NamedExpression] = { - exprs.zipWithIndex.map { - case (ne: NamedExpression, _) => ne - case (e, i) => Alias(e, s"_c$i")() - } - } - protected def extractDbNameTableName(tableNameParts: Node): (Option[String], String) = { val (db, tableName) = tableNameParts.getChildren.map { case Token(part, Nil) => cleanIdentifier(part) } match { @@ -665,7 +658,7 @@ https://cwiki.apache.org/confluence/display/Hive/Enhanced+Aggregation%2C+Cube%2C HiveColumn(field.getName, field.getType, field.getComment) }) } - case Token("TOK_TABLEROWFORMAT", Token("TOK_SERDEPROPS", child :: Nil) :: Nil)=> + case Token("TOK_TABLEROWFORMAT", Token("TOK_SERDEPROPS", child :: Nil) :: Nil) => val serdeParams = new java.util.HashMap[String, String]() child match { case Token("TOK_TABLEROWFORMATFIELD", rowChild1 :: rowChild2) => @@ -775,7 +768,7 @@ https://cwiki.apache.org/confluence/display/Hive/Enhanced+Aggregation%2C+Cube%2C // Support "TRUNCATE TABLE table_name [PARTITION partition_spec]" case Token("TOK_TRUNCATETABLE", - Token("TOK_TABLE_PARTITION",table)::Nil) => NativePlaceholder + Token("TOK_TABLE_PARTITION", table) :: Nil) => NativePlaceholder case Token("TOK_QUERY", queryArgs) if Seq("TOK_FROM", "TOK_INSERT").contains(queryArgs.head.getText) => @@ -942,7 +935,7 @@ https://cwiki.apache.org/confluence/display/Hive/Enhanced+Aggregation%2C+Cube%2C // (if there is a group by) or a script transformation. val withProject: LogicalPlan = transformation.getOrElse { val selectExpressions = - nameExpressions(select.getChildren.flatMap(selExprNodeToExpr).toSeq) + select.getChildren.flatMap(selExprNodeToExpr).map(UnresolvedAlias(_)).toSeq Seq( groupByClause.map(e => e match { case Token("TOK_GROUPBY", children) => @@ -1151,7 +1144,7 @@ https://cwiki.apache.org/confluence/display/Hive/Enhanced+Aggregation%2C+Cube%2C case Seq(false, false) => Inner }.toBuffer - val joinedTables = tables.reduceLeft(Join(_,_, Inner, None)) + val joinedTables = tables.reduceLeft(Join(_, _, Inner, None)) // Must be transform down. val joinedResult = joinedTables transform { @@ -1171,7 +1164,8 @@ https://cwiki.apache.org/confluence/display/Hive/Enhanced+Aggregation%2C+Cube%2C // worth the number of hacks that will be required to implement it. Namely, we need to add // some sort of mapped star expansion that would expand all child output row to be similarly // named output expressions where some aggregate expression has been applied (i.e. First). - ??? // Aggregate(groups, Star(None, First(_)) :: Nil, joinedResult) + // Aggregate(groups, Star(None, First(_)) :: Nil, joinedResult) + throw new UnsupportedOperationException case Token(allJoinTokens(joinToken), relation1 :: @@ -1306,16 +1300,8 @@ https://cwiki.apache.org/confluence/display/Hive/Enhanced+Aggregation%2C+Cube%2C HiveParser.DecimalLiteral) /* Case insensitive matches */ - val ARRAY = "(?i)ARRAY".r - val COALESCE = "(?i)COALESCE".r val COUNT = "(?i)COUNT".r - val AVG = "(?i)AVG".r val SUM = "(?i)SUM".r - val MAX = "(?i)MAX".r - val MIN = "(?i)MIN".r - val UPPER = "(?i)UPPER".r - val LOWER = "(?i)LOWER".r - val RAND = "(?i)RAND".r val AND = "(?i)AND".r val OR = "(?i)OR".r val NOT = "(?i)NOT".r @@ -1329,8 +1315,6 @@ https://cwiki.apache.org/confluence/display/Hive/Enhanced+Aggregation%2C+Cube%2C val BETWEEN = "(?i)BETWEEN".r val WHEN = "(?i)WHEN".r val CASE = "(?i)CASE".r - val SUBSTR = "(?i)SUBSTR(?:ING)?".r - val SQRT = "(?i)SQRT".r protected def nodeToExpr(node: Node): Expression = node match { /* Attribute References */ @@ -1352,18 +1336,9 @@ https://cwiki.apache.org/confluence/display/Hive/Enhanced+Aggregation%2C+Cube%2C UnresolvedStar(Some(name)) /* Aggregate Functions */ - case Token("TOK_FUNCTION", Token(AVG(), Nil) :: arg :: Nil) => Average(nodeToExpr(arg)) - case Token("TOK_FUNCTION", Token(COUNT(), Nil) :: arg :: Nil) => Count(nodeToExpr(arg)) case Token("TOK_FUNCTIONSTAR", Token(COUNT(), Nil) :: Nil) => Count(Literal(1)) case Token("TOK_FUNCTIONDI", Token(COUNT(), Nil) :: args) => CountDistinct(args.map(nodeToExpr)) - case Token("TOK_FUNCTION", Token(SUM(), Nil) :: arg :: Nil) => Sum(nodeToExpr(arg)) case Token("TOK_FUNCTIONDI", Token(SUM(), Nil) :: arg :: Nil) => SumDistinct(nodeToExpr(arg)) - case Token("TOK_FUNCTION", Token(MAX(), Nil) :: arg :: Nil) => Max(nodeToExpr(arg)) - case Token("TOK_FUNCTION", Token(MIN(), Nil) :: arg :: Nil) => Min(nodeToExpr(arg)) - - /* System functions about string operations */ - case Token("TOK_FUNCTION", Token(UPPER(), Nil) :: arg :: Nil) => Upper(nodeToExpr(arg)) - case Token("TOK_FUNCTION", Token(LOWER(), Nil) :: arg :: Nil) => Lower(nodeToExpr(arg)) /* Casts */ case Token("TOK_FUNCTION", Token("TOK_STRING", Nil) :: arg :: Nil) => @@ -1413,7 +1388,6 @@ https://cwiki.apache.org/confluence/display/Hive/Enhanced+Aggregation%2C+Cube%2C case Token("&", left :: right:: Nil) => BitwiseAnd(nodeToExpr(left), nodeToExpr(right)) case Token("|", left :: right:: Nil) => BitwiseOr(nodeToExpr(left), nodeToExpr(right)) case Token("^", left :: right:: Nil) => BitwiseXor(nodeToExpr(left), nodeToExpr(right)) - case Token("TOK_FUNCTION", Token(SQRT(), Nil) :: arg :: Nil) => Sqrt(nodeToExpr(arg)) /* Comparisons */ case Token("=", left :: right:: Nil) => EqualTo(nodeToExpr(left), nodeToExpr(right)) @@ -1468,17 +1442,6 @@ https://cwiki.apache.org/confluence/display/Hive/Enhanced+Aggregation%2C+Cube%2C case Token("[", child :: ordinal :: Nil) => UnresolvedExtractValue(nodeToExpr(child), nodeToExpr(ordinal)) - /* Other functions */ - case Token("TOK_FUNCTION", Token(ARRAY(), Nil) :: children) => - CreateArray(children.map(nodeToExpr)) - case Token("TOK_FUNCTION", Token(RAND(), Nil) :: Nil) => Rand() - case Token("TOK_FUNCTION", Token(RAND(), Nil) :: seed :: Nil) => Rand(seed.toString.toLong) - case Token("TOK_FUNCTION", Token(SUBSTR(), Nil) :: string :: pos :: Nil) => - Substring(nodeToExpr(string), nodeToExpr(pos), Literal.create(Integer.MAX_VALUE, IntegerType)) - case Token("TOK_FUNCTION", Token(SUBSTR(), Nil) :: string :: pos :: length :: Nil) => - Substring(nodeToExpr(string), nodeToExpr(pos), nodeToExpr(length)) - case Token("TOK_FUNCTION", Token(COALESCE(), Nil) :: list) => Coalesce(list.map(nodeToExpr)) - /* Window Functions */ case Token("TOK_FUNCTION", Token(name, Nil) +: args :+ Token("TOK_WINDOWSPEC", spec)) => val function = UnresolvedWindowFunction(name, args.map(nodeToExpr)) @@ -1560,6 +1523,10 @@ https://cwiki.apache.org/confluence/display/Hive/Enhanced+Aggregation%2C+Cube%2C """.stripMargin) } + /* Case insensitive matches for Window Specification */ + val PRECEDING = "(?i)preceding".r + val FOLLOWING = "(?i)following".r + val CURRENT = "(?i)current".r def nodesToWindowSpecification(nodes: Seq[ASTNode]): WindowSpec = nodes match { case Token(windowName, Nil) :: Nil => // Refer to a window spec defined in the window clause. @@ -1613,11 +1580,19 @@ https://cwiki.apache.org/confluence/display/Hive/Enhanced+Aggregation%2C+Cube%2C } else { val frameType = rowFrame.map(_ => RowFrame).getOrElse(RangeFrame) def nodeToBoundary(node: Node): FrameBoundary = node match { - case Token("preceding", Token(count, Nil) :: Nil) => - if (count == "unbounded") UnboundedPreceding else ValuePreceding(count.toInt) - case Token("following", Token(count, Nil) :: Nil) => - if (count == "unbounded") UnboundedFollowing else ValueFollowing(count.toInt) - case Token("current", Nil) => CurrentRow + case Token(PRECEDING(), Token(count, Nil) :: Nil) => + if (count.toLowerCase() == "unbounded") { + UnboundedPreceding + } else { + ValuePreceding(count.toInt) + } + case Token(FOLLOWING(), Token(count, Nil) :: Nil) => + if (count.toLowerCase() == "unbounded") { + UnboundedFollowing + } else { + ValueFollowing(count.toInt) + } + case Token(CURRENT(), Nil) => CurrentRow case _ => throw new NotImplementedError( s"""No parse rules for the Window Frame Boundary based on Node ${node.getName} @@ -1663,7 +1638,7 @@ https://cwiki.apache.org/confluence/display/Hive/Enhanced+Aggregation%2C+Cube%2C sys.error(s"Couldn't find function $functionName")) val functionClassName = functionInfo.getFunctionClass.getName - (HiveGenericUdtf( + (HiveGenericUDTF( new HiveFunctionWrapper(functionClassName), children.map(nodeToExpr)), attributes) diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveShim.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveShim.scala new file mode 100644 index 000000000000..d08c59415165 --- /dev/null +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveShim.scala @@ -0,0 +1,248 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.hive + +import java.io.{InputStream, OutputStream} +import java.rmi.server.UID + +/* Implicit conversions */ +import scala.collection.JavaConversions._ +import scala.language.implicitConversions +import scala.reflect.ClassTag + +import com.esotericsoftware.kryo.Kryo +import com.esotericsoftware.kryo.io.{Input, Output} +import org.apache.hadoop.conf.Configuration +import org.apache.hadoop.fs.Path +import org.apache.hadoop.hive.ql.exec.{UDF, Utilities} +import org.apache.hadoop.hive.ql.plan.{FileSinkDesc, TableDesc} +import org.apache.hadoop.hive.serde2.ColumnProjectionUtils +import org.apache.hadoop.hive.serde2.avro.AvroGenericRecordWritable +import org.apache.hadoop.hive.serde2.objectinspector.primitive.HiveDecimalObjectInspector +import org.apache.hadoop.io.Writable + +import org.apache.spark.Logging +import org.apache.spark.sql.types.Decimal +import org.apache.spark.util.Utils + +private[hive] object HiveShim { + // Precision and scale to pass for unlimited decimals; these are the same as the precision and + // scale Hive 0.13 infers for BigDecimals from sources that don't specify them (e.g. UDFs) + val UNLIMITED_DECIMAL_PRECISION = 38 + val UNLIMITED_DECIMAL_SCALE = 18 + + /* + * This function in hive-0.13 become private, but we have to do this to walkaround hive bug + */ + private def appendReadColumnNames(conf: Configuration, cols: Seq[String]) { + val old: String = conf.get(ColumnProjectionUtils.READ_COLUMN_NAMES_CONF_STR, "") + val result: StringBuilder = new StringBuilder(old) + var first: Boolean = old.isEmpty + + for (col <- cols) { + if (first) { + first = false + } else { + result.append(',') + } + result.append(col) + } + conf.set(ColumnProjectionUtils.READ_COLUMN_NAMES_CONF_STR, result.toString) + } + + /* + * Cannot use ColumnProjectionUtils.appendReadColumns directly, if ids is null or empty + */ + def appendReadColumns(conf: Configuration, ids: Seq[Integer], names: Seq[String]) { + if (ids != null && ids.nonEmpty) { + ColumnProjectionUtils.appendReadColumns(conf, ids) + } + if (names != null && names.nonEmpty) { + appendReadColumnNames(conf, names) + } + } + + /* + * Bug introduced in hive-0.13. AvroGenericRecordWritable has a member recordReaderID that + * is needed to initialize before serialization. + */ + def prepareWritable(w: Writable): Writable = { + w match { + case w: AvroGenericRecordWritable => + w.setRecordReaderID(new UID()) + case _ => + } + w + } + + def toCatalystDecimal(hdoi: HiveDecimalObjectInspector, data: Any): Decimal = { + if (hdoi.preferWritable()) { + Decimal(hdoi.getPrimitiveWritableObject(data).getHiveDecimal().bigDecimalValue, + hdoi.precision(), hdoi.scale()) + } else { + Decimal(hdoi.getPrimitiveJavaObject(data).bigDecimalValue(), hdoi.precision(), hdoi.scale()) + } + } + + /** + * This class provides the UDF creation and also the UDF instance serialization and + * de-serialization cross process boundary. + * + * Detail discussion can be found at https://github.com/apache/spark/pull/3640 + * + * @param functionClassName UDF class name + */ + private[hive] case class HiveFunctionWrapper(var functionClassName: String) + extends java.io.Externalizable { + + // for Serialization + def this() = this(null) + + @transient + def deserializeObjectByKryo[T: ClassTag]( + kryo: Kryo, + in: InputStream, + clazz: Class[_]): T = { + val inp = new Input(in) + val t: T = kryo.readObject(inp, clazz).asInstanceOf[T] + inp.close() + t + } + + @transient + def serializeObjectByKryo( + kryo: Kryo, + plan: Object, + out: OutputStream) { + val output: Output = new Output(out) + kryo.writeObject(output, plan) + output.close() + } + + def deserializePlan[UDFType](is: java.io.InputStream, clazz: Class[_]): UDFType = { + deserializeObjectByKryo(Utilities.runtimeSerializationKryo.get(), is, clazz) + .asInstanceOf[UDFType] + } + + def serializePlan(function: AnyRef, out: java.io.OutputStream): Unit = { + serializeObjectByKryo(Utilities.runtimeSerializationKryo.get(), function, out) + } + + private var instance: AnyRef = null + + def writeExternal(out: java.io.ObjectOutput) { + // output the function name + out.writeUTF(functionClassName) + + // Write a flag if instance is null or not + out.writeBoolean(instance != null) + if (instance != null) { + // Some of the UDF are serializable, but some others are not + // Hive Utilities can handle both cases + val baos = new java.io.ByteArrayOutputStream() + serializePlan(instance, baos) + val functionInBytes = baos.toByteArray + + // output the function bytes + out.writeInt(functionInBytes.length) + out.write(functionInBytes, 0, functionInBytes.length) + } + } + + def readExternal(in: java.io.ObjectInput) { + // read the function name + functionClassName = in.readUTF() + + if (in.readBoolean()) { + // if the instance is not null + // read the function in bytes + val functionInBytesLength = in.readInt() + val functionInBytes = new Array[Byte](functionInBytesLength) + in.read(functionInBytes, 0, functionInBytesLength) + + // deserialize the function object via Hive Utilities + instance = deserializePlan[AnyRef](new java.io.ByteArrayInputStream(functionInBytes), + Utils.getContextOrSparkClassLoader.loadClass(functionClassName)) + } + } + + def createFunction[UDFType <: AnyRef](): UDFType = { + if (instance != null) { + instance.asInstanceOf[UDFType] + } else { + val func = Utils.getContextOrSparkClassLoader + .loadClass(functionClassName).newInstance.asInstanceOf[UDFType] + if (!func.isInstanceOf[UDF]) { + // We cache the function if it's no the Simple UDF, + // as we always have to create new instance for Simple UDF + instance = func + } + func + } + } + } + + /* + * Bug introduced in hive-0.13. FileSinkDesc is serializable, but its member path is not. + * Fix it through wrapper. + */ + implicit def wrapperToFileSinkDesc(w: ShimFileSinkDesc): FileSinkDesc = { + val f = new FileSinkDesc(new Path(w.dir), w.tableInfo, w.compressed) + f.setCompressCodec(w.compressCodec) + f.setCompressType(w.compressType) + f.setTableInfo(w.tableInfo) + f.setDestTableId(w.destTableId) + f + } + + /* + * Bug introduced in hive-0.13. FileSinkDesc is serializable, but its member path is not. + * Fix it through wrapper. + */ + private[hive] class ShimFileSinkDesc( + var dir: String, + var tableInfo: TableDesc, + var compressed: Boolean) + extends Serializable with Logging { + var compressCodec: String = _ + var compressType: String = _ + var destTableId: Int = _ + + def setCompressed(compressed: Boolean) { + this.compressed = compressed + } + + def getDirName(): String = dir + + def setDestTableId(destTableId: Int) { + this.destTableId = destTableId + } + + def setTableInfo(tableInfo: TableDesc) { + this.tableInfo = tableInfo + } + + def setCompressCodec(intermediateCompressorCodec: String) { + compressCodec = intermediateCompressorCodec + } + + def setCompressType(intermediateCompressType: String) { + compressType = intermediateCompressType + } + } +} diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveStrategies.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveStrategies.scala index d46a127d47d3..452b7f0bcc74 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveStrategies.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveStrategies.scala @@ -24,7 +24,7 @@ import org.apache.spark.sql._ import org.apache.spark.sql.catalyst.CatalystTypeConverters import org.apache.spark.sql.catalyst.analysis.UnresolvedAttribute import org.apache.spark.sql.catalyst.expressions.codegen.GeneratePredicate -import org.apache.spark.sql.catalyst.expressions.{Row, _} +import org.apache.spark.sql.catalyst.expressions.{InternalRow, _} import org.apache.spark.sql.catalyst.planning._ import org.apache.spark.sql.catalyst.plans._ import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan @@ -137,10 +137,10 @@ private[hive] trait HiveStrategies { val partitionLocations = partitions.map(_.getLocation) if (partitionLocations.isEmpty) { - PhysicalRDD(plan.output, sparkContext.emptyRDD[Row]) :: Nil + PhysicalRDD(plan.output, sparkContext.emptyRDD[InternalRow]) :: Nil } else { hiveContext - .parquetFile(partitionLocations: _*) + .read.parquet(partitionLocations: _*) .addPartitioningAttributes(relation.partitionKeys) .lowerCase .where(unresolvedOtherPredicates) @@ -152,7 +152,7 @@ private[hive] trait HiveStrategies { } else { hiveContext - .parquetFile(relation.hiveQlTable.getDataLocation.toString) + .read.parquet(relation.hiveQlTable.getDataLocation.toString) .lowerCase .where(unresolvedOtherPredicates) .select(unresolvedProjection: _*) @@ -165,7 +165,7 @@ private[hive] trait HiveStrategies { // TODO: Remove this hack for Spark 1.3. case iae: java.lang.IllegalArgumentException if iae.getMessage.contains("Can not create a Path from an empty string") => - PhysicalRDD(plan.output, sparkContext.emptyRDD[Row]) :: Nil + PhysicalRDD(plan.output, sparkContext.emptyRDD[InternalRow]) :: Nil } case _ => Nil } diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/TableReader.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/TableReader.scala index 0b6f7a334a71..b251a9523bed 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/TableReader.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/TableReader.scala @@ -17,7 +17,6 @@ package org.apache.spark.sql.hive -import org.apache.hadoop.conf.Configuration import org.apache.hadoop.fs.{Path, PathFilter} import org.apache.hadoop.hive.conf.HiveConf import org.apache.hadoop.hive.metastore.api.hive_metastoreConstants._ @@ -25,26 +24,26 @@ import org.apache.hadoop.hive.ql.exec.Utilities import org.apache.hadoop.hive.ql.metadata.{Partition => HivePartition, Table => HiveTable} import org.apache.hadoop.hive.ql.plan.{PlanUtils, TableDesc} import org.apache.hadoop.hive.serde2.Deserializer -import org.apache.hadoop.hive.serde2.objectinspector.{ObjectInspectorConverters, StructObjectInspector} import org.apache.hadoop.hive.serde2.objectinspector.primitive._ +import org.apache.hadoop.hive.serde2.objectinspector.{ObjectInspectorConverters, StructObjectInspector} import org.apache.hadoop.io.Writable import org.apache.hadoop.mapred.{FileInputFormat, InputFormat, JobConf} -import org.apache.spark.SerializableWritable -import org.apache.spark.broadcast.Broadcast import org.apache.spark.Logging +import org.apache.spark.broadcast.Broadcast import org.apache.spark.rdd.{EmptyRDD, HadoopRDD, RDD, UnionRDD} import org.apache.spark.sql.catalyst.expressions._ -import org.apache.spark.sql.catalyst.util.DateUtils -import org.apache.spark.util.Utils +import org.apache.spark.sql.catalyst.util.DateTimeUtils +import org.apache.spark.unsafe.types.UTF8String +import org.apache.spark.util.{SerializableConfiguration, Utils} /** * A trait for subclasses that handle table scans. */ private[hive] sealed trait TableReader { - def makeRDDForTable(hiveTable: HiveTable): RDD[Row] + def makeRDDForTable(hiveTable: HiveTable): RDD[InternalRow] - def makeRDDForPartitionedTable(partitions: Seq[HivePartition]): RDD[Row] + def makeRDDForPartitionedTable(partitions: Seq[HivePartition]): RDD[InternalRow] } @@ -73,13 +72,13 @@ class HadoopTableReader( // TODO: set aws s3 credentials. private val _broadcastedHiveConf = - sc.sparkContext.broadcast(new SerializableWritable(hiveExtraConf)) + sc.sparkContext.broadcast(new SerializableConfiguration(hiveExtraConf)) - override def makeRDDForTable(hiveTable: HiveTable): RDD[Row] = + override def makeRDDForTable(hiveTable: HiveTable): RDD[InternalRow] = makeRDDForTable( hiveTable, Class.forName( - relation.tableDesc.getSerdeClassName, true, Utils.getSparkClassLoader) + relation.tableDesc.getSerdeClassName, true, Utils.getContextOrSparkClassLoader) .asInstanceOf[Class[Deserializer]], filterOpt = None) @@ -95,7 +94,7 @@ class HadoopTableReader( def makeRDDForTable( hiveTable: HiveTable, deserializerClass: Class[_ <: Deserializer], - filterOpt: Option[PathFilter]): RDD[Row] = { + filterOpt: Option[PathFilter]): RDD[InternalRow] = { assert(!hiveTable.isPartitioned, """makeRDDForTable() cannot be called on a partitioned table, since input formats may differ across partitions. Use makeRDDForTablePartitions() instead.""") @@ -126,7 +125,7 @@ class HadoopTableReader( deserializedHadoopRDD } - override def makeRDDForPartitionedTable(partitions: Seq[HivePartition]): RDD[Row] = { + override def makeRDDForPartitionedTable(partitions: Seq[HivePartition]): RDD[InternalRow] = { val partitionToDeserializer = partitions.map(part => (part, part.getDeserializer.getClass.asInstanceOf[Class[Deserializer]])).toMap makeRDDForPartitionedTable(partitionToDeserializer, filterOpt = None) @@ -145,7 +144,7 @@ class HadoopTableReader( def makeRDDForPartitionedTable( partitionToDeserializer: Map[HivePartition, Class[_ <: Deserializer]], - filterOpt: Option[PathFilter]): RDD[Row] = { + filterOpt: Option[PathFilter]): RDD[InternalRow] = { // SPARK-5068:get FileStatus and do the filtering locally when the path is not exists def verifyPartitionPath( @@ -172,7 +171,7 @@ class HadoopTableReader( path.toString + tails } - val partPath = HiveShim.getDataLocationPath(partition) + val partPath = partition.getDataLocation val partNum = Utilities.getPartitionDesc(partition).getPartSpec.size(); var pathPatternStr = getPathPatternByPath(partNum, partPath) if (!pathPatternSet.contains(pathPatternStr)) { @@ -187,7 +186,7 @@ class HadoopTableReader( val hivePartitionRDDs = verifyPartitionPath(partitionToDeserializer) .map { case (partition, partDeserializer) => val partDesc = Utilities.getPartitionDesc(partition) - val partPath = HiveShim.getDataLocationPath(partition) + val partPath = partition.getDataLocation val inputPathStr = applyFilterIfNeeded(partPath, filterOpt) val ifc = partDesc.getInputFileFormatClass .asInstanceOf[java.lang.Class[InputFormat[Writable, Writable]]] @@ -244,7 +243,7 @@ class HadoopTableReader( // Even if we don't use any partitions, we still need an empty RDD if (hivePartitionRDDs.size == 0) { - new EmptyRDD[Row](sc.sparkContext) + new EmptyRDD[InternalRow](sc.sparkContext) } else { new UnionRDD(hivePartitionRDDs(0).context, hivePartitionRDDs) } @@ -277,7 +276,7 @@ class HadoopTableReader( val rdd = new HadoopRDD( sc.sparkContext, - _broadcastedHiveConf.asInstanceOf[Broadcast[SerializableWritable[Configuration]]], + _broadcastedHiveConf.asInstanceOf[Broadcast[SerializableConfiguration]], Some(initializeJobConfFunc), inputFormatClass, classOf[Writable], @@ -320,12 +319,12 @@ private[hive] object HadoopTableReader extends HiveInspectors with Logging { rawDeser: Deserializer, nonPartitionKeyAttrs: Seq[(Attribute, Int)], mutableRow: MutableRow, - tableDeser: Deserializer): Iterator[Row] = { + tableDeser: Deserializer): Iterator[InternalRow] = { val soi = if (rawDeser.getObjectInspector.equals(tableDeser.getObjectInspector)) { rawDeser.getObjectInspector.asInstanceOf[StructObjectInspector] } else { - HiveShim.getConvertedOI( + ObjectInspectorConverters.getConvertedOI( rawDeser.getObjectInspector, tableDeser.getObjectInspector).asInstanceOf[StructObjectInspector] } @@ -358,16 +357,16 @@ private[hive] object HadoopTableReader extends HiveInspectors with Logging { (value: Any, row: MutableRow, ordinal: Int) => row.setDouble(ordinal, oi.get(value)) case oi: HiveVarcharObjectInspector => (value: Any, row: MutableRow, ordinal: Int) => - row.setString(ordinal, oi.getPrimitiveJavaObject(value).getValue) + row.update(ordinal, UTF8String.fromString(oi.getPrimitiveJavaObject(value).getValue)) case oi: HiveDecimalObjectInspector => (value: Any, row: MutableRow, ordinal: Int) => row.update(ordinal, HiveShim.toCatalystDecimal(oi, value)) case oi: TimestampObjectInspector => (value: Any, row: MutableRow, ordinal: Int) => - row.update(ordinal, oi.getPrimitiveJavaObject(value).clone()) + row.setLong(ordinal, DateTimeUtils.fromJavaTimestamp(oi.getPrimitiveJavaObject(value))) case oi: DateObjectInspector => (value: Any, row: MutableRow, ordinal: Int) => - row.update(ordinal, DateUtils.fromJavaDate(oi.getPrimitiveJavaObject(value))) + row.setInt(ordinal, DateTimeUtils.fromJavaDate(oi.getPrimitiveJavaObject(value))) case oi: BinaryObjectInspector => (value: Any, row: MutableRow, ordinal: Int) => row.update(ordinal, oi.getPrimitiveJavaObject(value)) @@ -392,7 +391,7 @@ private[hive] object HadoopTableReader extends HiveInspectors with Logging { i += 1 } - mutableRow: Row + mutableRow: InternalRow } } } diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/ClientWrapper.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/ClientWrapper.scala index 99aa0f1ded3f..cbd2bf6b5eed 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/ClientWrapper.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/ClientWrapper.scala @@ -20,6 +20,9 @@ package org.apache.spark.sql.hive.client import java.io.{BufferedReader, InputStreamReader, File, PrintStream} import java.net.URI import java.util.{ArrayList => JArrayList, Map => JMap, List => JList, Set => JSet} +import javax.annotation.concurrent.GuardedBy + +import org.apache.spark.util.CircularBuffer import scala.collection.JavaConversions._ import scala.language.reflectiveCalls @@ -27,7 +30,7 @@ import scala.language.reflectiveCalls import org.apache.hadoop.fs.Path import org.apache.hadoop.hive.metastore.api.Database import org.apache.hadoop.hive.conf.HiveConf -import org.apache.hadoop.hive.metastore.TableType +import org.apache.hadoop.hive.metastore.{TableType => HTableType} import org.apache.hadoop.hive.metastore.api import org.apache.hadoop.hive.metastore.api.FieldSchema import org.apache.hadoop.hive.ql.metadata @@ -54,49 +57,43 @@ import org.apache.spark.sql.execution.QueryExecutionException * @param version the version of hive used when pick function calls that are not compatible. * @param config a collection of configuration options that will be added to the hive conf before * opening the hive client. + * @param initClassLoader the classloader used when creating the `state` field of + * this ClientWrapper. */ private[hive] class ClientWrapper( version: HiveVersion, - config: Map[String, String]) + config: Map[String, String], + initClassLoader: ClassLoader) extends ClientInterface - with Logging - with ReflectionMagic { + with Logging { // Circular buffer to hold what hive prints to STDOUT and ERR. Only printed when failures occur. - private val outputBuffer = new java.io.OutputStream { - var pos: Int = 0 - var buffer = new Array[Int](10240) - def write(i: Int): Unit = { - buffer(pos) = i - pos = (pos + 1) % buffer.size - } - - override def toString: String = { - val (end, start) = buffer.splitAt(pos) - val input = new java.io.InputStream { - val iterator = (start ++ end).iterator - - def read(): Int = if (iterator.hasNext) iterator.next() else -1 - } - val reader = new BufferedReader(new InputStreamReader(input)) - val stringBuilder = new StringBuilder - var line = reader.readLine() - while(line != null) { - stringBuilder.append(line) - stringBuilder.append("\n") - line = reader.readLine() - } - stringBuilder.toString() - } + private val outputBuffer = new CircularBuffer() + + private val shim = version match { + case hive.v12 => new Shim_v0_12() + case hive.v13 => new Shim_v0_13() + case hive.v14 => new Shim_v0_14() + case hive.v1_0 => new Shim_v1_0() + case hive.v1_1 => new Shim_v1_1() + case hive.v1_2 => new Shim_v1_2() } + // Create an internal session state for this ClientWrapper. val state = { val original = Thread.currentThread().getContextClassLoader - Thread.currentThread().setContextClassLoader(getClass.getClassLoader) + // Switch to the initClassLoader. + Thread.currentThread().setContextClassLoader(initClassLoader) val ret = try { val oldState = SessionState.get() if (oldState == null) { val initialConf = new HiveConf(classOf[SessionState]) + // HiveConf is a Hadoop Configuration, which has a field of classLoader and + // the initial value will be the current thread's context class loader + // (i.e. initClassLoader at here). + // We call initialConf.setClassLoader(initClassLoader) at here to make + // this action explicit. + initialConf.setClassLoader(initClassLoader) config.foreach { case (k, v) => logDebug(s"Hive Config: $k=$v") initialConf.set(k, v) @@ -119,23 +116,70 @@ private[hive] class ClientWrapper( def conf: HiveConf = SessionState.get().getConf // TODO: should be a def?s - private val client = Hive.get(conf) + // When we create this val client, the HiveConf of it (conf) is the one associated with state. + @GuardedBy("this") + private var client = Hive.get(conf) + + // We use hive's conf for compatibility. + private val retryLimit = conf.getIntVar(HiveConf.ConfVars.METASTORETHRIFTFAILURERETRIES) + private val retryDelayMillis = shim.getMetastoreClientConnectRetryDelayMillis(conf) + + /** + * Runs `f` with multiple retries in case the hive metastore is temporarily unreachable. + */ + private def retryLocked[A](f: => A): A = synchronized { + // Hive sometimes retries internally, so set a deadline to avoid compounding delays. + val deadline = System.nanoTime + (retryLimit * retryDelayMillis * 1e6).toLong + var numTries = 0 + var caughtException: Exception = null + do { + numTries += 1 + try { + return f + } catch { + case e: Exception if causedByThrift(e) => + caughtException = e + logWarning( + "HiveClientWrapper got thrift exception, destroying client and retrying " + + s"(${retryLimit - numTries} tries remaining)", e) + Thread.sleep(retryDelayMillis) + try { + client = Hive.get(state.getConf, true) + } catch { + case e: Exception if causedByThrift(e) => + logWarning("Failed to refresh hive client, will retry.", e) + } + } + } while (numTries <= retryLimit && System.nanoTime < deadline) + if (System.nanoTime > deadline) { + logWarning("Deadline exceeded") + } + throw caughtException + } + + private def causedByThrift(e: Throwable): Boolean = { + var target = e + while (target != null) { + val msg = target.getMessage() + if (msg != null && msg.matches("(?s).*(TApplication|TProtocol|TTransport)Exception.*")) { + return true + } + target = target.getCause() + } + false + } /** * Runs `f` with ThreadLocal session state and classloaders configured for this version of hive. */ - private def withHiveState[A](f: => A): A = synchronized { + private def withHiveState[A](f: => A): A = retryLocked { val original = Thread.currentThread().getContextClassLoader - Thread.currentThread().setContextClassLoader(getClass.getClassLoader) + // Set the thread local metastore client to the client associated with this ClientWrapper. Hive.set(client) - version match { - case hive.v12 => - classOf[SessionState] - .callStatic[SessionState, SessionState]("start", state) - case hive.v13 => - classOf[SessionState] - .callStatic[SessionState, SessionState]("setCurrentSessionState", state) - } + // setCurrentSessionState will use the classLoader associated + // with the HiveConf in `state` to override the context class loader of the current + // thread. + shim.setCurrentSessionState(state) val ret = try f finally { Thread.currentThread().setContextClassLoader(original) } @@ -193,15 +237,12 @@ private[hive] class ClientWrapper( properties = h.getParameters.toMap, serdeProperties = h.getTTable.getSd.getSerdeInfo.getParameters.toMap, tableType = h.getTableType match { - case TableType.MANAGED_TABLE => ManagedTable - case TableType.EXTERNAL_TABLE => ExternalTable - case TableType.VIRTUAL_VIEW => VirtualView - case TableType.INDEX_TABLE => IndexTable - }, - location = version match { - case hive.v12 => Option(h.call[URI]("getDataLocation")).map(_.toString) - case hive.v13 => Option(h.call[Path]("getDataLocation")).map(_.toString) + case HTableType.MANAGED_TABLE => ManagedTable + case HTableType.EXTERNAL_TABLE => ExternalTable + case HTableType.VIRTUAL_VIEW => VirtualView + case HTableType.INDEX_TABLE => IndexTable }, + location = shim.getDataLocation(h), inputFormat = Option(h.getInputFormatClass).map(_.getName), outputFormat = Option(h.getOutputFormatClass).map(_.getName), serde = Option(h.getSerializationLib), @@ -231,14 +272,7 @@ private[hive] class ClientWrapper( // set create time qlTable.setCreateTime((System.currentTimeMillis() / 1000).asInstanceOf[Int]) - version match { - case hive.v12 => - table.location.map(new URI(_)).foreach(u => qlTable.call[URI, Unit]("setDataLocation", u)) - case hive.v13 => - table.location - .map(new org.apache.hadoop.fs.Path(_)) - .foreach(qlTable.call[Path, Unit]("setDataLocation", _)) - } + table.location.foreach { loc => shim.setDataLocation(qlTable, loc) } table.inputFormat.map(toInputFormat).foreach(qlTable.setInputFormatClass) table.outputFormat.map(toOutputFormat).foreach(qlTable.setOutputFormatClass) table.serde.foreach(qlTable.setSerializationLib) @@ -279,13 +313,7 @@ private[hive] class ClientWrapper( override def getAllPartitions(hTable: HiveTable): Seq[HivePartition] = withHiveState { val qlTable = toQlTable(hTable) - val qlPartitions = version match { - case hive.v12 => - client.call[metadata.Table, JSet[metadata.Partition]]("getAllPartitionsForPruner", qlTable) - case hive.v13 => - client.call[metadata.Table, JSet[metadata.Partition]]("getAllPartitionsOf", qlTable) - } - qlPartitions.toSeq.map(toHivePartition) + shim.getAllPartitions(client, qlTable).map(toHivePartition) } override def listTables(dbName: String): Seq[String] = withHiveState { @@ -315,15 +343,7 @@ private[hive] class ClientWrapper( val tokens: Array[String] = cmd_trimmed.split("\\s+") // The remainder of the command. val cmd_1: String = cmd_trimmed.substring(tokens(0).length()).trim() - val proc: CommandProcessor = version match { - case hive.v12 => - classOf[CommandProcessorFactory] - .callStatic[String, HiveConf, CommandProcessor]("get", tokens(0), conf) - case hive.v13 => - classOf[CommandProcessorFactory] - .callStatic[Array[String], HiveConf, CommandProcessor]("get", Array(tokens(0)), conf) - } - + val proc = shim.getCommandProcessor(tokens(0), conf) proc match { case driver: Driver => val response: CommandProcessorResponse = driver.run(cmd) @@ -334,21 +354,7 @@ private[hive] class ClientWrapper( } driver.setMaxRows(maxRows) - val results = version match { - case hive.v12 => - val res = new JArrayList[String] - driver.call[JArrayList[String], Boolean]("getResults", res) - res.toSeq - case hive.v13 => - val res = new JArrayList[Object] - driver.call[JList[Object], Boolean]("getResults", res) - res.map { r => - r match { - case s: String => s - case a: Array[Object] => a(0).asInstanceOf[String] - } - } - } + val results = shim.getDriverResults(driver) driver.close() results @@ -382,8 +388,8 @@ private[hive] class ClientWrapper( holdDDLTime: Boolean, inheritTableSpecs: Boolean, isSkewedStoreAsSubdir: Boolean): Unit = withHiveState { - - client.loadPartition( + shim.loadPartition( + client, new Path(loadPath), // TODO: Use URI tableName, partSpec, @@ -398,7 +404,8 @@ private[hive] class ClientWrapper( tableName: String, replace: Boolean, holdDDLTime: Boolean): Unit = withHiveState { - client.loadTable( + shim.loadTable( + client, new Path(loadPath), tableName, replace, @@ -413,7 +420,8 @@ private[hive] class ClientWrapper( numDP: Int, holdDDLTime: Boolean, listBucketingEnabled: Boolean): Unit = withHiveState { - client.loadDynamicPartitions( + shim.loadDynamicPartitions( + client, new Path(loadPath), tableName, partSpec, @@ -428,7 +436,7 @@ private[hive] class ClientWrapper( logDebug(s"Deleting table $t") val table = client.getTable("default", t) client.getIndexes("default", t, 255).foreach { index => - client.dropIndex("default", t, index.getIndexName, true) + shim.dropIndex(client, "default", t, index.getIndexName) } if (!table.isIndexTable) { client.dropTable("default", t) diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/HiveShim.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/HiveShim.scala new file mode 100644 index 000000000000..1fa9d278e2a5 --- /dev/null +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/HiveShim.scala @@ -0,0 +1,449 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.hive.client + +import java.lang.{Boolean => JBoolean, Integer => JInteger, Long => JLong} +import java.lang.reflect.{Method, Modifier} +import java.net.URI +import java.util.{ArrayList => JArrayList, List => JList, Map => JMap, Set => JSet} +import java.util.concurrent.TimeUnit + +import scala.collection.JavaConversions._ + +import org.apache.hadoop.fs.Path +import org.apache.hadoop.hive.conf.HiveConf +import org.apache.hadoop.hive.ql.Driver +import org.apache.hadoop.hive.ql.metadata.{Hive, Partition, Table} +import org.apache.hadoop.hive.ql.processors.{CommandProcessor, CommandProcessorFactory} +import org.apache.hadoop.hive.ql.session.SessionState + +/** + * A shim that defines the interface between ClientWrapper and the underlying Hive library used to + * talk to the metastore. Each Hive version has its own implementation of this class, defining + * version-specific version of needed functions. + * + * The guideline for writing shims is: + * - always extend from the previous version unless really not possible + * - initialize methods in lazy vals, both for quicker access for multiple invocations, and to + * avoid runtime errors due to the above guideline. + */ +private[client] sealed abstract class Shim { + + /** + * Set the current SessionState to the given SessionState. Also, set the context classloader of + * the current thread to the one set in the HiveConf of this given `state`. + * @param state + */ + def setCurrentSessionState(state: SessionState): Unit + + /** + * This shim is necessary because the return type is different on different versions of Hive. + * All parameters are the same, though. + */ + def getDataLocation(table: Table): Option[String] + + def setDataLocation(table: Table, loc: String): Unit + + def getAllPartitions(hive: Hive, table: Table): Seq[Partition] + + def getCommandProcessor(token: String, conf: HiveConf): CommandProcessor + + def getDriverResults(driver: Driver): Seq[String] + + def getMetastoreClientConnectRetryDelayMillis(conf: HiveConf): Long + + def loadPartition( + hive: Hive, + loadPath: Path, + tableName: String, + partSpec: JMap[String, String], + replace: Boolean, + holdDDLTime: Boolean, + inheritTableSpecs: Boolean, + isSkewedStoreAsSubdir: Boolean): Unit + + def loadTable( + hive: Hive, + loadPath: Path, + tableName: String, + replace: Boolean, + holdDDLTime: Boolean): Unit + + def loadDynamicPartitions( + hive: Hive, + loadPath: Path, + tableName: String, + partSpec: JMap[String, String], + replace: Boolean, + numDP: Int, + holdDDLTime: Boolean, + listBucketingEnabled: Boolean): Unit + + def dropIndex(hive: Hive, dbName: String, tableName: String, indexName: String): Unit + + protected def findStaticMethod(klass: Class[_], name: String, args: Class[_]*): Method = { + val method = findMethod(klass, name, args: _*) + require(Modifier.isStatic(method.getModifiers()), + s"Method $name of class $klass is not static.") + method + } + + protected def findMethod(klass: Class[_], name: String, args: Class[_]*): Method = { + klass.getMethod(name, args: _*) + } + +} + +private[client] class Shim_v0_12 extends Shim { + + private lazy val startMethod = + findStaticMethod( + classOf[SessionState], + "start", + classOf[SessionState]) + private lazy val getDataLocationMethod = findMethod(classOf[Table], "getDataLocation") + private lazy val setDataLocationMethod = + findMethod( + classOf[Table], + "setDataLocation", + classOf[URI]) + private lazy val getAllPartitionsMethod = + findMethod( + classOf[Hive], + "getAllPartitionsForPruner", + classOf[Table]) + private lazy val getCommandProcessorMethod = + findStaticMethod( + classOf[CommandProcessorFactory], + "get", + classOf[String], + classOf[HiveConf]) + private lazy val getDriverResultsMethod = + findMethod( + classOf[Driver], + "getResults", + classOf[JArrayList[String]]) + private lazy val loadPartitionMethod = + findMethod( + classOf[Hive], + "loadPartition", + classOf[Path], + classOf[String], + classOf[JMap[String, String]], + JBoolean.TYPE, + JBoolean.TYPE, + JBoolean.TYPE, + JBoolean.TYPE) + private lazy val loadTableMethod = + findMethod( + classOf[Hive], + "loadTable", + classOf[Path], + classOf[String], + JBoolean.TYPE, + JBoolean.TYPE) + private lazy val loadDynamicPartitionsMethod = + findMethod( + classOf[Hive], + "loadDynamicPartitions", + classOf[Path], + classOf[String], + classOf[JMap[String, String]], + JBoolean.TYPE, + JInteger.TYPE, + JBoolean.TYPE, + JBoolean.TYPE) + private lazy val dropIndexMethod = + findMethod( + classOf[Hive], + "dropIndex", + classOf[String], + classOf[String], + classOf[String], + JBoolean.TYPE) + + override def setCurrentSessionState(state: SessionState): Unit = { + // Starting from Hive 0.13, setCurrentSessionState will internally override + // the context class loader of the current thread by the class loader set in + // the conf of the SessionState. So, for this Hive 0.12 shim, we add the same + // behavior and make shim.setCurrentSessionState of all Hive versions have the + // consistent behavior. + Thread.currentThread().setContextClassLoader(state.getConf.getClassLoader) + startMethod.invoke(null, state) + } + + override def getDataLocation(table: Table): Option[String] = + Option(getDataLocationMethod.invoke(table)).map(_.toString()) + + override def setDataLocation(table: Table, loc: String): Unit = + setDataLocationMethod.invoke(table, new URI(loc)) + + override def getAllPartitions(hive: Hive, table: Table): Seq[Partition] = + getAllPartitionsMethod.invoke(hive, table).asInstanceOf[JSet[Partition]].toSeq + + override def getCommandProcessor(token: String, conf: HiveConf): CommandProcessor = + getCommandProcessorMethod.invoke(null, token, conf).asInstanceOf[CommandProcessor] + + override def getDriverResults(driver: Driver): Seq[String] = { + val res = new JArrayList[String]() + getDriverResultsMethod.invoke(driver, res) + res.toSeq + } + + override def getMetastoreClientConnectRetryDelayMillis(conf: HiveConf): Long = { + conf.getIntVar(HiveConf.ConfVars.METASTORE_CLIENT_CONNECT_RETRY_DELAY) * 1000 + } + + override def loadPartition( + hive: Hive, + loadPath: Path, + tableName: String, + partSpec: JMap[String, String], + replace: Boolean, + holdDDLTime: Boolean, + inheritTableSpecs: Boolean, + isSkewedStoreAsSubdir: Boolean): Unit = { + loadPartitionMethod.invoke(hive, loadPath, tableName, partSpec, replace: JBoolean, + holdDDLTime: JBoolean, inheritTableSpecs: JBoolean, isSkewedStoreAsSubdir: JBoolean) + } + + override def loadTable( + hive: Hive, + loadPath: Path, + tableName: String, + replace: Boolean, + holdDDLTime: Boolean): Unit = { + loadTableMethod.invoke(hive, loadPath, tableName, replace: JBoolean, holdDDLTime: JBoolean) + } + + override def loadDynamicPartitions( + hive: Hive, + loadPath: Path, + tableName: String, + partSpec: JMap[String, String], + replace: Boolean, + numDP: Int, + holdDDLTime: Boolean, + listBucketingEnabled: Boolean): Unit = { + loadDynamicPartitionsMethod.invoke(hive, loadPath, tableName, partSpec, replace: JBoolean, + numDP: JInteger, holdDDLTime: JBoolean, listBucketingEnabled: JBoolean) + } + + override def dropIndex(hive: Hive, dbName: String, tableName: String, indexName: String): Unit = { + dropIndexMethod.invoke(hive, dbName, tableName, indexName, true: JBoolean) + } + +} + +private[client] class Shim_v0_13 extends Shim_v0_12 { + + private lazy val setCurrentSessionStateMethod = + findStaticMethod( + classOf[SessionState], + "setCurrentSessionState", + classOf[SessionState]) + private lazy val setDataLocationMethod = + findMethod( + classOf[Table], + "setDataLocation", + classOf[Path]) + private lazy val getAllPartitionsMethod = + findMethod( + classOf[Hive], + "getAllPartitionsOf", + classOf[Table]) + private lazy val getCommandProcessorMethod = + findStaticMethod( + classOf[CommandProcessorFactory], + "get", + classOf[Array[String]], + classOf[HiveConf]) + private lazy val getDriverResultsMethod = + findMethod( + classOf[Driver], + "getResults", + classOf[JList[Object]]) + + override def setCurrentSessionState(state: SessionState): Unit = + setCurrentSessionStateMethod.invoke(null, state) + + override def setDataLocation(table: Table, loc: String): Unit = + setDataLocationMethod.invoke(table, new Path(loc)) + + override def getAllPartitions(hive: Hive, table: Table): Seq[Partition] = + getAllPartitionsMethod.invoke(hive, table).asInstanceOf[JSet[Partition]].toSeq + + override def getCommandProcessor(token: String, conf: HiveConf): CommandProcessor = + getCommandProcessorMethod.invoke(null, Array(token), conf).asInstanceOf[CommandProcessor] + + override def getDriverResults(driver: Driver): Seq[String] = { + val res = new JArrayList[Object]() + getDriverResultsMethod.invoke(driver, res) + res.map { r => + r match { + case s: String => s + case a: Array[Object] => a(0).asInstanceOf[String] + } + } + } + +} + +private[client] class Shim_v0_14 extends Shim_v0_13 { + + private lazy val loadPartitionMethod = + findMethod( + classOf[Hive], + "loadPartition", + classOf[Path], + classOf[String], + classOf[JMap[String, String]], + JBoolean.TYPE, + JBoolean.TYPE, + JBoolean.TYPE, + JBoolean.TYPE, + JBoolean.TYPE, + JBoolean.TYPE) + private lazy val loadTableMethod = + findMethod( + classOf[Hive], + "loadTable", + classOf[Path], + classOf[String], + JBoolean.TYPE, + JBoolean.TYPE, + JBoolean.TYPE, + JBoolean.TYPE, + JBoolean.TYPE) + private lazy val loadDynamicPartitionsMethod = + findMethod( + classOf[Hive], + "loadDynamicPartitions", + classOf[Path], + classOf[String], + classOf[JMap[String, String]], + JBoolean.TYPE, + JInteger.TYPE, + JBoolean.TYPE, + JBoolean.TYPE, + JBoolean.TYPE) + private lazy val getTimeVarMethod = + findMethod( + classOf[HiveConf], + "getTimeVar", + classOf[HiveConf.ConfVars], + classOf[TimeUnit]) + + override def loadPartition( + hive: Hive, + loadPath: Path, + tableName: String, + partSpec: JMap[String, String], + replace: Boolean, + holdDDLTime: Boolean, + inheritTableSpecs: Boolean, + isSkewedStoreAsSubdir: Boolean): Unit = { + loadPartitionMethod.invoke(hive, loadPath, tableName, partSpec, replace: JBoolean, + holdDDLTime: JBoolean, inheritTableSpecs: JBoolean, isSkewedStoreAsSubdir: JBoolean, + JBoolean.TRUE, JBoolean.FALSE) + } + + override def loadTable( + hive: Hive, + loadPath: Path, + tableName: String, + replace: Boolean, + holdDDLTime: Boolean): Unit = { + loadTableMethod.invoke(hive, loadPath, tableName, replace: JBoolean, holdDDLTime: JBoolean, + JBoolean.TRUE, JBoolean.FALSE, JBoolean.FALSE) + } + + override def loadDynamicPartitions( + hive: Hive, + loadPath: Path, + tableName: String, + partSpec: JMap[String, String], + replace: Boolean, + numDP: Int, + holdDDLTime: Boolean, + listBucketingEnabled: Boolean): Unit = { + loadDynamicPartitionsMethod.invoke(hive, loadPath, tableName, partSpec, replace: JBoolean, + numDP: JInteger, holdDDLTime: JBoolean, listBucketingEnabled: JBoolean, JBoolean.FALSE) + } + + override def getMetastoreClientConnectRetryDelayMillis(conf: HiveConf): Long = { + getTimeVarMethod.invoke( + conf, + HiveConf.ConfVars.METASTORE_CLIENT_CONNECT_RETRY_DELAY, + TimeUnit.MILLISECONDS).asInstanceOf[Long] + } +} + +private[client] class Shim_v1_0 extends Shim_v0_14 { + +} + +private[client] class Shim_v1_1 extends Shim_v1_0 { + + private lazy val dropIndexMethod = + findMethod( + classOf[Hive], + "dropIndex", + classOf[String], + classOf[String], + classOf[String], + JBoolean.TYPE, + JBoolean.TYPE) + + override def dropIndex(hive: Hive, dbName: String, tableName: String, indexName: String): Unit = { + dropIndexMethod.invoke(hive, dbName, tableName, indexName, true: JBoolean, true: JBoolean) + } + +} + +private[client] class Shim_v1_2 extends Shim_v1_1 { + + private lazy val loadDynamicPartitionsMethod = + findMethod( + classOf[Hive], + "loadDynamicPartitions", + classOf[Path], + classOf[String], + classOf[JMap[String, String]], + JBoolean.TYPE, + JInteger.TYPE, + JBoolean.TYPE, + JBoolean.TYPE, + JBoolean.TYPE, + JLong.TYPE) + + override def loadDynamicPartitions( + hive: Hive, + loadPath: Path, + tableName: String, + partSpec: JMap[String, String], + replace: Boolean, + numDP: Int, + holdDDLTime: Boolean, + listBucketingEnabled: Boolean): Unit = { + loadDynamicPartitionsMethod.invoke(hive, loadPath, tableName, partSpec, replace: JBoolean, + numDP: JInteger, holdDDLTime: JBoolean, listBucketingEnabled: JBoolean, JBoolean.FALSE, + 0: JLong) + } + +} diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/IsolatedClientLoader.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/IsolatedClientLoader.scala index 7f94c93ba49c..3d609a66f366 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/IsolatedClientLoader.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/IsolatedClientLoader.scala @@ -18,6 +18,7 @@ package org.apache.spark.sql.hive.client import java.io.File +import java.lang.reflect.InvocationTargetException import java.net.{URL, URLClassLoader} import java.util @@ -28,6 +29,7 @@ import org.apache.commons.io.{FileUtils, IOUtils} import org.apache.spark.Logging import org.apache.spark.deploy.SparkSubmitUtils +import org.apache.spark.util.Utils import org.apache.spark.sql.catalyst.util.quietly import org.apache.spark.sql.hive.HiveContext @@ -39,39 +41,41 @@ private[hive] object IsolatedClientLoader { */ def forVersion( version: String, - config: Map[String, String] = Map.empty): IsolatedClientLoader = synchronized { + config: Map[String, String] = Map.empty, + ivyPath: Option[String] = None): IsolatedClientLoader = synchronized { val resolvedVersion = hiveVersion(version) - val files = resolvedVersions.getOrElseUpdate(resolvedVersion, downloadVersion(resolvedVersion)) + val files = resolvedVersions.getOrElseUpdate(resolvedVersion, + downloadVersion(resolvedVersion, ivyPath)) new IsolatedClientLoader(hiveVersion(version), files, config) } def hiveVersion(version: String): HiveVersion = version match { case "12" | "0.12" | "0.12.0" => hive.v12 case "13" | "0.13" | "0.13.0" | "0.13.1" => hive.v13 + case "14" | "0.14" | "0.14.0" => hive.v14 + case "1.0" | "1.0.0" => hive.v1_0 + case "1.1" | "1.1.0" => hive.v1_1 + case "1.2" | "1.2.0" => hive.v1_2 } - private def downloadVersion(version: HiveVersion): Seq[URL] = { - val hiveArtifacts = - (Seq("hive-metastore", "hive-exec", "hive-common", "hive-serde") ++ - (if (version.hasBuiltinsJar) "hive-builtins" :: Nil else Nil)) - .map(a => s"org.apache.hive:$a:${version.fullVersion}") :+ - "com.google.guava:guava:14.0.1" :+ - "org.apache.hadoop:hadoop-client:2.4.0" :+ - "mysql:mysql-connector-java:5.1.12" + private def downloadVersion(version: HiveVersion, ivyPath: Option[String]): Seq[URL] = { + val hiveArtifacts = version.extraDeps ++ + Seq("hive-metastore", "hive-exec", "hive-common", "hive-serde") + .map(a => s"org.apache.hive:$a:${version.fullVersion}") ++ + Seq("com.google.guava:guava:14.0.1", + "org.apache.hadoop:hadoop-client:2.4.0") val classpath = quietly { SparkSubmitUtils.resolveMavenCoordinates( hiveArtifacts.mkString(","), Some("http://www.datanucleus.org/downloads/maven2"), - None) + ivyPath, + exclusions = version.exclusions) } val allFiles = classpath.split(",").map(new File(_)).toSet // TODO: Remove copy logic. - val tempDir = File.createTempFile("hive", "v" + version.toString) - tempDir.delete() - tempDir.mkdir() - + val tempDir = Utils.createTempDir(namePrefix = s"hive-${version}") allFiles.foreach(f => FileUtils.copyFileToDirectory(f, tempDir)) tempDir.listFiles().map(_.toURL) } @@ -91,14 +95,13 @@ private[hive] object IsolatedClientLoader { * `ClientInterface`, unless `isolationOn` is set to `false`. * * @param version The version of hive on the classpath. used to pick specific function signatures - * that are not compatibile accross versions. + * that are not compatible across versions. * @param execJars A collection of jar files that must include hive and hadoop. * @param config A set of options that will be added to the HiveConf of the constructed client. * @param isolationOn When true, custom versions of barrier classes will be constructed. Must be * true unless loading the version of hive that is on Sparks classloader. - * @param rootClassLoader The system root classloader. Must not know about hive classes. + * @param rootClassLoader The system root classloader. Must not know about Hive classes. * @param baseClassLoader The spark classloader that is used to load shared classes. - * */ private[hive] class IsolatedClientLoader( val version: HiveVersion, @@ -106,11 +109,13 @@ private[hive] class IsolatedClientLoader( val config: Map[String, String] = Map.empty, val isolationOn: Boolean = true, val rootClassLoader: ClassLoader = ClassLoader.getSystemClassLoader.getParent.getParent, - val baseClassLoader: ClassLoader = Thread.currentThread().getContextClassLoader) + val baseClassLoader: ClassLoader = Thread.currentThread().getContextClassLoader, + val sharedPrefixes: Seq[String] = Seq.empty, + val barrierPrefixes: Seq[String] = Seq.empty) extends Logging { // Check to make sure that the root classloader does not know about Hive. - assert(Try(baseClassLoader.loadClass("org.apache.hive.HiveConf")).isFailure) + assert(Try(rootClassLoader.loadClass("org.apache.hadoop.hive.conf.HiveConf")).isFailure) /** All jars used by the hive specific classloader. */ protected def allJars = execJars.toArray @@ -122,13 +127,14 @@ private[hive] class IsolatedClientLoader( name.startsWith("scala.") || name.startsWith("com.google") || name.startsWith("java.lang.") || - name.startsWith("java.net") + name.startsWith("java.net") || + sharedPrefixes.exists(name.startsWith) /** True if `name` refers to a spark class that must see specific version of Hive. */ protected def isBarrierClass(name: String): Boolean = - name.startsWith("org.apache.spark.sql.hive.execution.PairSerDe") || name.startsWith(classOf[ClientWrapper].getName) || - name.startsWith(classOf[ReflectionMagic].getName) + name.startsWith(classOf[Shim].getName) || + barrierPrefixes.exists(name.startsWith) protected def classToPath(name: String): String = name.replaceAll("\\.", "/") + ".class" @@ -143,6 +149,7 @@ private[hive] class IsolatedClientLoader( def doLoadClass(name: String, resolve: Boolean): Class[_] = { val classFileName = name.replaceAll("\\.", "/") + ".class" if (isBarrierClass(name) && isolationOn) { + // For barrier classes, we construct a new copy of the class. val bytes = IOUtils.toByteArray(baseClassLoader.getResourceAsStream(classFileName)) logDebug(s"custom defining: $name - ${util.Arrays.hashCode(bytes)}") defineClass(name, bytes, 0, bytes.length) @@ -150,6 +157,7 @@ private[hive] class IsolatedClientLoader( logDebug(s"hive class: $name - ${getResource(classToPath(name))}") super.loadClass(name, resolve) } else { + // For shared classes, we delegate to baseClassLoader. logDebug(s"shared class: $name") baseClassLoader.loadClass(name) } @@ -165,14 +173,19 @@ private[hive] class IsolatedClientLoader( classLoader .loadClass(classOf[ClientWrapper].getName) .getConstructors.head - .newInstance(version, config) + .newInstance(version, config, classLoader) .asInstanceOf[ClientInterface] } catch { - case ReflectionException(cnf: NoClassDefFoundError) => - throw new ClassNotFoundException( - s"$cnf when creating Hive client using classpath: ${execJars.mkString(", ")}\n" + - "Please make sure that jars for your version of hive and hadoop are included in the " + - s"paths passed to ${HiveContext.HIVE_METASTORE_JARS}.") + case e: InvocationTargetException => + if (e.getCause().isInstanceOf[NoClassDefFoundError]) { + val cnf = e.getCause().asInstanceOf[NoClassDefFoundError] + throw new ClassNotFoundException( + s"$cnf when creating Hive client using classpath: ${execJars.mkString(", ")}\n" + + "Please make sure that jars for your version of hive and hadoop are included in the " + + s"paths passed to ${HiveContext.HIVE_METASTORE_JARS}.") + } else { + throw e + } } finally { Thread.currentThread.setContextClassLoader(baseClassLoader) } diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/ReflectionMagic.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/ReflectionMagic.scala deleted file mode 100644 index c600b158c546..000000000000 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/ReflectionMagic.scala +++ /dev/null @@ -1,208 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.sql.hive.client - -import scala.reflect._ - -/** Unwraps reflection exceptions. */ -private[client] object ReflectionException { - def unapply(a: Throwable): Option[Throwable] = a match { - case ite: java.lang.reflect.InvocationTargetException => Option(ite.getCause) - case _ => None - } -} - -/** - * Provides implicit functions on any object for calling methods reflectively. - */ -protected trait ReflectionMagic { - /** code for InstanceMagic - println( - (1 to 22).map { n => - def repeat(str: String => String) = (1 to n).map(i => str(i.toString)).mkString(", ") - val types = repeat(n => s"A$n <: AnyRef : ClassTag") - val inArgs = repeat(n => s"a$n: A$n") - val erasure = repeat(n => s"classTag[A$n].erasure") - val outArgs = repeat(n => s"a$n") - s"""|def call[$types, R](name: String, $inArgs): R = { - | clazz.getMethod(name, $erasure).invoke(a, $outArgs).asInstanceOf[R] - |}""".stripMargin - }.mkString("\n") - ) - */ - - // scalastyle:off - protected implicit class InstanceMagic(a: Any) { - private val clazz = a.getClass - - def call[R](name: String): R = { - clazz.getMethod(name).invoke(a).asInstanceOf[R] - } - def call[A1 <: AnyRef : ClassTag, R](name: String, a1: A1): R = { - clazz.getMethod(name, classTag[A1].erasure).invoke(a, a1).asInstanceOf[R] - } - def call[A1 <: AnyRef : ClassTag, A2 <: AnyRef : ClassTag, R](name: String, a1: A1, a2: A2): R = { - clazz.getMethod(name, classTag[A1].erasure, classTag[A2].erasure).invoke(a, a1, a2).asInstanceOf[R] - } - def call[A1 <: AnyRef : ClassTag, A2 <: AnyRef : ClassTag, A3 <: AnyRef : ClassTag, R](name: String, a1: A1, a2: A2, a3: A3): R = { - clazz.getMethod(name, classTag[A1].erasure, classTag[A2].erasure, classTag[A3].erasure).invoke(a, a1, a2, a3).asInstanceOf[R] - } - def call[A1 <: AnyRef : ClassTag, A2 <: AnyRef : ClassTag, A3 <: AnyRef : ClassTag, A4 <: AnyRef : ClassTag, R](name: String, a1: A1, a2: A2, a3: A3, a4: A4): R = { - clazz.getMethod(name, classTag[A1].erasure, classTag[A2].erasure, classTag[A3].erasure, classTag[A4].erasure).invoke(a, a1, a2, a3, a4).asInstanceOf[R] - } - def call[A1 <: AnyRef : ClassTag, A2 <: AnyRef : ClassTag, A3 <: AnyRef : ClassTag, A4 <: AnyRef : ClassTag, A5 <: AnyRef : ClassTag, R](name: String, a1: A1, a2: A2, a3: A3, a4: A4, a5: A5): R = { - clazz.getMethod(name, classTag[A1].erasure, classTag[A2].erasure, classTag[A3].erasure, classTag[A4].erasure, classTag[A5].erasure).invoke(a, a1, a2, a3, a4, a5).asInstanceOf[R] - } - def call[A1 <: AnyRef : ClassTag, A2 <: AnyRef : ClassTag, A3 <: AnyRef : ClassTag, A4 <: AnyRef : ClassTag, A5 <: AnyRef : ClassTag, A6 <: AnyRef : ClassTag, R](name: String, a1: A1, a2: A2, a3: A3, a4: A4, a5: A5, a6: A6): R = { - clazz.getMethod(name, classTag[A1].erasure, classTag[A2].erasure, classTag[A3].erasure, classTag[A4].erasure, classTag[A5].erasure, classTag[A6].erasure).invoke(a, a1, a2, a3, a4, a5, a6).asInstanceOf[R] - } - def call[A1 <: AnyRef : ClassTag, A2 <: AnyRef : ClassTag, A3 <: AnyRef : ClassTag, A4 <: AnyRef : ClassTag, A5 <: AnyRef : ClassTag, A6 <: AnyRef : ClassTag, A7 <: AnyRef : ClassTag, R](name: String, a1: A1, a2: A2, a3: A3, a4: A4, a5: A5, a6: A6, a7: A7): R = { - clazz.getMethod(name, classTag[A1].erasure, classTag[A2].erasure, classTag[A3].erasure, classTag[A4].erasure, classTag[A5].erasure, classTag[A6].erasure, classTag[A7].erasure).invoke(a, a1, a2, a3, a4, a5, a6, a7).asInstanceOf[R] - } - def call[A1 <: AnyRef : ClassTag, A2 <: AnyRef : ClassTag, A3 <: AnyRef : ClassTag, A4 <: AnyRef : ClassTag, A5 <: AnyRef : ClassTag, A6 <: AnyRef : ClassTag, A7 <: AnyRef : ClassTag, A8 <: AnyRef : ClassTag, R](name: String, a1: A1, a2: A2, a3: A3, a4: A4, a5: A5, a6: A6, a7: A7, a8: A8): R = { - clazz.getMethod(name, classTag[A1].erasure, classTag[A2].erasure, classTag[A3].erasure, classTag[A4].erasure, classTag[A5].erasure, classTag[A6].erasure, classTag[A7].erasure, classTag[A8].erasure).invoke(a, a1, a2, a3, a4, a5, a6, a7, a8).asInstanceOf[R] - } - def call[A1 <: AnyRef : ClassTag, A2 <: AnyRef : ClassTag, A3 <: AnyRef : ClassTag, A4 <: AnyRef : ClassTag, A5 <: AnyRef : ClassTag, A6 <: AnyRef : ClassTag, A7 <: AnyRef : ClassTag, A8 <: AnyRef : ClassTag, A9 <: AnyRef : ClassTag, R](name: String, a1: A1, a2: A2, a3: A3, a4: A4, a5: A5, a6: A6, a7: A7, a8: A8, a9: A9): R = { - clazz.getMethod(name, classTag[A1].erasure, classTag[A2].erasure, classTag[A3].erasure, classTag[A4].erasure, classTag[A5].erasure, classTag[A6].erasure, classTag[A7].erasure, classTag[A8].erasure, classTag[A9].erasure).invoke(a, a1, a2, a3, a4, a5, a6, a7, a8, a9).asInstanceOf[R] - } - def call[A1 <: AnyRef : ClassTag, A2 <: AnyRef : ClassTag, A3 <: AnyRef : ClassTag, A4 <: AnyRef : ClassTag, A5 <: AnyRef : ClassTag, A6 <: AnyRef : ClassTag, A7 <: AnyRef : ClassTag, A8 <: AnyRef : ClassTag, A9 <: AnyRef : ClassTag, A10 <: AnyRef : ClassTag, R](name: String, a1: A1, a2: A2, a3: A3, a4: A4, a5: A5, a6: A6, a7: A7, a8: A8, a9: A9, a10: A10): R = { - clazz.getMethod(name, classTag[A1].erasure, classTag[A2].erasure, classTag[A3].erasure, classTag[A4].erasure, classTag[A5].erasure, classTag[A6].erasure, classTag[A7].erasure, classTag[A8].erasure, classTag[A9].erasure, classTag[A10].erasure).invoke(a, a1, a2, a3, a4, a5, a6, a7, a8, a9, a10).asInstanceOf[R] - } - def call[A1 <: AnyRef : ClassTag, A2 <: AnyRef : ClassTag, A3 <: AnyRef : ClassTag, A4 <: AnyRef : ClassTag, A5 <: AnyRef : ClassTag, A6 <: AnyRef : ClassTag, A7 <: AnyRef : ClassTag, A8 <: AnyRef : ClassTag, A9 <: AnyRef : ClassTag, A10 <: AnyRef : ClassTag, A11 <: AnyRef : ClassTag, R](name: String, a1: A1, a2: A2, a3: A3, a4: A4, a5: A5, a6: A6, a7: A7, a8: A8, a9: A9, a10: A10, a11: A11): R = { - clazz.getMethod(name, classTag[A1].erasure, classTag[A2].erasure, classTag[A3].erasure, classTag[A4].erasure, classTag[A5].erasure, classTag[A6].erasure, classTag[A7].erasure, classTag[A8].erasure, classTag[A9].erasure, classTag[A10].erasure, classTag[A11].erasure).invoke(a, a1, a2, a3, a4, a5, a6, a7, a8, a9, a10, a11).asInstanceOf[R] - } - def call[A1 <: AnyRef : ClassTag, A2 <: AnyRef : ClassTag, A3 <: AnyRef : ClassTag, A4 <: AnyRef : ClassTag, A5 <: AnyRef : ClassTag, A6 <: AnyRef : ClassTag, A7 <: AnyRef : ClassTag, A8 <: AnyRef : ClassTag, A9 <: AnyRef : ClassTag, A10 <: AnyRef : ClassTag, A11 <: AnyRef : ClassTag, A12 <: AnyRef : ClassTag, R](name: String, a1: A1, a2: A2, a3: A3, a4: A4, a5: A5, a6: A6, a7: A7, a8: A8, a9: A9, a10: A10, a11: A11, a12: A12): R = { - clazz.getMethod(name, classTag[A1].erasure, classTag[A2].erasure, classTag[A3].erasure, classTag[A4].erasure, classTag[A5].erasure, classTag[A6].erasure, classTag[A7].erasure, classTag[A8].erasure, classTag[A9].erasure, classTag[A10].erasure, classTag[A11].erasure, classTag[A12].erasure).invoke(a, a1, a2, a3, a4, a5, a6, a7, a8, a9, a10, a11, a12).asInstanceOf[R] - } - def call[A1 <: AnyRef : ClassTag, A2 <: AnyRef : ClassTag, A3 <: AnyRef : ClassTag, A4 <: AnyRef : ClassTag, A5 <: AnyRef : ClassTag, A6 <: AnyRef : ClassTag, A7 <: AnyRef : ClassTag, A8 <: AnyRef : ClassTag, A9 <: AnyRef : ClassTag, A10 <: AnyRef : ClassTag, A11 <: AnyRef : ClassTag, A12 <: AnyRef : ClassTag, A13 <: AnyRef : ClassTag, R](name: String, a1: A1, a2: A2, a3: A3, a4: A4, a5: A5, a6: A6, a7: A7, a8: A8, a9: A9, a10: A10, a11: A11, a12: A12, a13: A13): R = { - clazz.getMethod(name, classTag[A1].erasure, classTag[A2].erasure, classTag[A3].erasure, classTag[A4].erasure, classTag[A5].erasure, classTag[A6].erasure, classTag[A7].erasure, classTag[A8].erasure, classTag[A9].erasure, classTag[A10].erasure, classTag[A11].erasure, classTag[A12].erasure, classTag[A13].erasure).invoke(a, a1, a2, a3, a4, a5, a6, a7, a8, a9, a10, a11, a12, a13).asInstanceOf[R] - } - def call[A1 <: AnyRef : ClassTag, A2 <: AnyRef : ClassTag, A3 <: AnyRef : ClassTag, A4 <: AnyRef : ClassTag, A5 <: AnyRef : ClassTag, A6 <: AnyRef : ClassTag, A7 <: AnyRef : ClassTag, A8 <: AnyRef : ClassTag, A9 <: AnyRef : ClassTag, A10 <: AnyRef : ClassTag, A11 <: AnyRef : ClassTag, A12 <: AnyRef : ClassTag, A13 <: AnyRef : ClassTag, A14 <: AnyRef : ClassTag, R](name: String, a1: A1, a2: A2, a3: A3, a4: A4, a5: A5, a6: A6, a7: A7, a8: A8, a9: A9, a10: A10, a11: A11, a12: A12, a13: A13, a14: A14): R = { - clazz.getMethod(name, classTag[A1].erasure, classTag[A2].erasure, classTag[A3].erasure, classTag[A4].erasure, classTag[A5].erasure, classTag[A6].erasure, classTag[A7].erasure, classTag[A8].erasure, classTag[A9].erasure, classTag[A10].erasure, classTag[A11].erasure, classTag[A12].erasure, classTag[A13].erasure, classTag[A14].erasure).invoke(a, a1, a2, a3, a4, a5, a6, a7, a8, a9, a10, a11, a12, a13, a14).asInstanceOf[R] - } - def call[A1 <: AnyRef : ClassTag, A2 <: AnyRef : ClassTag, A3 <: AnyRef : ClassTag, A4 <: AnyRef : ClassTag, A5 <: AnyRef : ClassTag, A6 <: AnyRef : ClassTag, A7 <: AnyRef : ClassTag, A8 <: AnyRef : ClassTag, A9 <: AnyRef : ClassTag, A10 <: AnyRef : ClassTag, A11 <: AnyRef : ClassTag, A12 <: AnyRef : ClassTag, A13 <: AnyRef : ClassTag, A14 <: AnyRef : ClassTag, A15 <: AnyRef : ClassTag, R](name: String, a1: A1, a2: A2, a3: A3, a4: A4, a5: A5, a6: A6, a7: A7, a8: A8, a9: A9, a10: A10, a11: A11, a12: A12, a13: A13, a14: A14, a15: A15): R = { - clazz.getMethod(name, classTag[A1].erasure, classTag[A2].erasure, classTag[A3].erasure, classTag[A4].erasure, classTag[A5].erasure, classTag[A6].erasure, classTag[A7].erasure, classTag[A8].erasure, classTag[A9].erasure, classTag[A10].erasure, classTag[A11].erasure, classTag[A12].erasure, classTag[A13].erasure, classTag[A14].erasure, classTag[A15].erasure).invoke(a, a1, a2, a3, a4, a5, a6, a7, a8, a9, a10, a11, a12, a13, a14, a15).asInstanceOf[R] - } - def call[A1 <: AnyRef : ClassTag, A2 <: AnyRef : ClassTag, A3 <: AnyRef : ClassTag, A4 <: AnyRef : ClassTag, A5 <: AnyRef : ClassTag, A6 <: AnyRef : ClassTag, A7 <: AnyRef : ClassTag, A8 <: AnyRef : ClassTag, A9 <: AnyRef : ClassTag, A10 <: AnyRef : ClassTag, A11 <: AnyRef : ClassTag, A12 <: AnyRef : ClassTag, A13 <: AnyRef : ClassTag, A14 <: AnyRef : ClassTag, A15 <: AnyRef : ClassTag, A16 <: AnyRef : ClassTag, R](name: String, a1: A1, a2: A2, a3: A3, a4: A4, a5: A5, a6: A6, a7: A7, a8: A8, a9: A9, a10: A10, a11: A11, a12: A12, a13: A13, a14: A14, a15: A15, a16: A16): R = { - clazz.getMethod(name, classTag[A1].erasure, classTag[A2].erasure, classTag[A3].erasure, classTag[A4].erasure, classTag[A5].erasure, classTag[A6].erasure, classTag[A7].erasure, classTag[A8].erasure, classTag[A9].erasure, classTag[A10].erasure, classTag[A11].erasure, classTag[A12].erasure, classTag[A13].erasure, classTag[A14].erasure, classTag[A15].erasure, classTag[A16].erasure).invoke(a, a1, a2, a3, a4, a5, a6, a7, a8, a9, a10, a11, a12, a13, a14, a15, a16).asInstanceOf[R] - } - def call[A1 <: AnyRef : ClassTag, A2 <: AnyRef : ClassTag, A3 <: AnyRef : ClassTag, A4 <: AnyRef : ClassTag, A5 <: AnyRef : ClassTag, A6 <: AnyRef : ClassTag, A7 <: AnyRef : ClassTag, A8 <: AnyRef : ClassTag, A9 <: AnyRef : ClassTag, A10 <: AnyRef : ClassTag, A11 <: AnyRef : ClassTag, A12 <: AnyRef : ClassTag, A13 <: AnyRef : ClassTag, A14 <: AnyRef : ClassTag, A15 <: AnyRef : ClassTag, A16 <: AnyRef : ClassTag, A17 <: AnyRef : ClassTag, R](name: String, a1: A1, a2: A2, a3: A3, a4: A4, a5: A5, a6: A6, a7: A7, a8: A8, a9: A9, a10: A10, a11: A11, a12: A12, a13: A13, a14: A14, a15: A15, a16: A16, a17: A17): R = { - clazz.getMethod(name, classTag[A1].erasure, classTag[A2].erasure, classTag[A3].erasure, classTag[A4].erasure, classTag[A5].erasure, classTag[A6].erasure, classTag[A7].erasure, classTag[A8].erasure, classTag[A9].erasure, classTag[A10].erasure, classTag[A11].erasure, classTag[A12].erasure, classTag[A13].erasure, classTag[A14].erasure, classTag[A15].erasure, classTag[A16].erasure, classTag[A17].erasure).invoke(a, a1, a2, a3, a4, a5, a6, a7, a8, a9, a10, a11, a12, a13, a14, a15, a16, a17).asInstanceOf[R] - } - def call[A1 <: AnyRef : ClassTag, A2 <: AnyRef : ClassTag, A3 <: AnyRef : ClassTag, A4 <: AnyRef : ClassTag, A5 <: AnyRef : ClassTag, A6 <: AnyRef : ClassTag, A7 <: AnyRef : ClassTag, A8 <: AnyRef : ClassTag, A9 <: AnyRef : ClassTag, A10 <: AnyRef : ClassTag, A11 <: AnyRef : ClassTag, A12 <: AnyRef : ClassTag, A13 <: AnyRef : ClassTag, A14 <: AnyRef : ClassTag, A15 <: AnyRef : ClassTag, A16 <: AnyRef : ClassTag, A17 <: AnyRef : ClassTag, A18 <: AnyRef : ClassTag, R](name: String, a1: A1, a2: A2, a3: A3, a4: A4, a5: A5, a6: A6, a7: A7, a8: A8, a9: A9, a10: A10, a11: A11, a12: A12, a13: A13, a14: A14, a15: A15, a16: A16, a17: A17, a18: A18): R = { - clazz.getMethod(name, classTag[A1].erasure, classTag[A2].erasure, classTag[A3].erasure, classTag[A4].erasure, classTag[A5].erasure, classTag[A6].erasure, classTag[A7].erasure, classTag[A8].erasure, classTag[A9].erasure, classTag[A10].erasure, classTag[A11].erasure, classTag[A12].erasure, classTag[A13].erasure, classTag[A14].erasure, classTag[A15].erasure, classTag[A16].erasure, classTag[A17].erasure, classTag[A18].erasure).invoke(a, a1, a2, a3, a4, a5, a6, a7, a8, a9, a10, a11, a12, a13, a14, a15, a16, a17, a18).asInstanceOf[R] - } - def call[A1 <: AnyRef : ClassTag, A2 <: AnyRef : ClassTag, A3 <: AnyRef : ClassTag, A4 <: AnyRef : ClassTag, A5 <: AnyRef : ClassTag, A6 <: AnyRef : ClassTag, A7 <: AnyRef : ClassTag, A8 <: AnyRef : ClassTag, A9 <: AnyRef : ClassTag, A10 <: AnyRef : ClassTag, A11 <: AnyRef : ClassTag, A12 <: AnyRef : ClassTag, A13 <: AnyRef : ClassTag, A14 <: AnyRef : ClassTag, A15 <: AnyRef : ClassTag, A16 <: AnyRef : ClassTag, A17 <: AnyRef : ClassTag, A18 <: AnyRef : ClassTag, A19 <: AnyRef : ClassTag, R](name: String, a1: A1, a2: A2, a3: A3, a4: A4, a5: A5, a6: A6, a7: A7, a8: A8, a9: A9, a10: A10, a11: A11, a12: A12, a13: A13, a14: A14, a15: A15, a16: A16, a17: A17, a18: A18, a19: A19): R = { - clazz.getMethod(name, classTag[A1].erasure, classTag[A2].erasure, classTag[A3].erasure, classTag[A4].erasure, classTag[A5].erasure, classTag[A6].erasure, classTag[A7].erasure, classTag[A8].erasure, classTag[A9].erasure, classTag[A10].erasure, classTag[A11].erasure, classTag[A12].erasure, classTag[A13].erasure, classTag[A14].erasure, classTag[A15].erasure, classTag[A16].erasure, classTag[A17].erasure, classTag[A18].erasure, classTag[A19].erasure).invoke(a, a1, a2, a3, a4, a5, a6, a7, a8, a9, a10, a11, a12, a13, a14, a15, a16, a17, a18, a19).asInstanceOf[R] - } - def call[A1 <: AnyRef : ClassTag, A2 <: AnyRef : ClassTag, A3 <: AnyRef : ClassTag, A4 <: AnyRef : ClassTag, A5 <: AnyRef : ClassTag, A6 <: AnyRef : ClassTag, A7 <: AnyRef : ClassTag, A8 <: AnyRef : ClassTag, A9 <: AnyRef : ClassTag, A10 <: AnyRef : ClassTag, A11 <: AnyRef : ClassTag, A12 <: AnyRef : ClassTag, A13 <: AnyRef : ClassTag, A14 <: AnyRef : ClassTag, A15 <: AnyRef : ClassTag, A16 <: AnyRef : ClassTag, A17 <: AnyRef : ClassTag, A18 <: AnyRef : ClassTag, A19 <: AnyRef : ClassTag, A20 <: AnyRef : ClassTag, R](name: String, a1: A1, a2: A2, a3: A3, a4: A4, a5: A5, a6: A6, a7: A7, a8: A8, a9: A9, a10: A10, a11: A11, a12: A12, a13: A13, a14: A14, a15: A15, a16: A16, a17: A17, a18: A18, a19: A19, a20: A20): R = { - clazz.getMethod(name, classTag[A1].erasure, classTag[A2].erasure, classTag[A3].erasure, classTag[A4].erasure, classTag[A5].erasure, classTag[A6].erasure, classTag[A7].erasure, classTag[A8].erasure, classTag[A9].erasure, classTag[A10].erasure, classTag[A11].erasure, classTag[A12].erasure, classTag[A13].erasure, classTag[A14].erasure, classTag[A15].erasure, classTag[A16].erasure, classTag[A17].erasure, classTag[A18].erasure, classTag[A19].erasure, classTag[A20].erasure).invoke(a, a1, a2, a3, a4, a5, a6, a7, a8, a9, a10, a11, a12, a13, a14, a15, a16, a17, a18, a19, a20).asInstanceOf[R] - } - def call[A1 <: AnyRef : ClassTag, A2 <: AnyRef : ClassTag, A3 <: AnyRef : ClassTag, A4 <: AnyRef : ClassTag, A5 <: AnyRef : ClassTag, A6 <: AnyRef : ClassTag, A7 <: AnyRef : ClassTag, A8 <: AnyRef : ClassTag, A9 <: AnyRef : ClassTag, A10 <: AnyRef : ClassTag, A11 <: AnyRef : ClassTag, A12 <: AnyRef : ClassTag, A13 <: AnyRef : ClassTag, A14 <: AnyRef : ClassTag, A15 <: AnyRef : ClassTag, A16 <: AnyRef : ClassTag, A17 <: AnyRef : ClassTag, A18 <: AnyRef : ClassTag, A19 <: AnyRef : ClassTag, A20 <: AnyRef : ClassTag, A21 <: AnyRef : ClassTag, R](name: String, a1: A1, a2: A2, a3: A3, a4: A4, a5: A5, a6: A6, a7: A7, a8: A8, a9: A9, a10: A10, a11: A11, a12: A12, a13: A13, a14: A14, a15: A15, a16: A16, a17: A17, a18: A18, a19: A19, a20: A20, a21: A21): R = { - clazz.getMethod(name, classTag[A1].erasure, classTag[A2].erasure, classTag[A3].erasure, classTag[A4].erasure, classTag[A5].erasure, classTag[A6].erasure, classTag[A7].erasure, classTag[A8].erasure, classTag[A9].erasure, classTag[A10].erasure, classTag[A11].erasure, classTag[A12].erasure, classTag[A13].erasure, classTag[A14].erasure, classTag[A15].erasure, classTag[A16].erasure, classTag[A17].erasure, classTag[A18].erasure, classTag[A19].erasure, classTag[A20].erasure, classTag[A21].erasure).invoke(a, a1, a2, a3, a4, a5, a6, a7, a8, a9, a10, a11, a12, a13, a14, a15, a16, a17, a18, a19, a20, a21).asInstanceOf[R] - } - def call[A1 <: AnyRef : ClassTag, A2 <: AnyRef : ClassTag, A3 <: AnyRef : ClassTag, A4 <: AnyRef : ClassTag, A5 <: AnyRef : ClassTag, A6 <: AnyRef : ClassTag, A7 <: AnyRef : ClassTag, A8 <: AnyRef : ClassTag, A9 <: AnyRef : ClassTag, A10 <: AnyRef : ClassTag, A11 <: AnyRef : ClassTag, A12 <: AnyRef : ClassTag, A13 <: AnyRef : ClassTag, A14 <: AnyRef : ClassTag, A15 <: AnyRef : ClassTag, A16 <: AnyRef : ClassTag, A17 <: AnyRef : ClassTag, A18 <: AnyRef : ClassTag, A19 <: AnyRef : ClassTag, A20 <: AnyRef : ClassTag, A21 <: AnyRef : ClassTag, A22 <: AnyRef : ClassTag, R](name: String, a1: A1, a2: A2, a3: A3, a4: A4, a5: A5, a6: A6, a7: A7, a8: A8, a9: A9, a10: A10, a11: A11, a12: A12, a13: A13, a14: A14, a15: A15, a16: A16, a17: A17, a18: A18, a19: A19, a20: A20, a21: A21, a22: A22): R = { - clazz.getMethod(name, classTag[A1].erasure, classTag[A2].erasure, classTag[A3].erasure, classTag[A4].erasure, classTag[A5].erasure, classTag[A6].erasure, classTag[A7].erasure, classTag[A8].erasure, classTag[A9].erasure, classTag[A10].erasure, classTag[A11].erasure, classTag[A12].erasure, classTag[A13].erasure, classTag[A14].erasure, classTag[A15].erasure, classTag[A16].erasure, classTag[A17].erasure, classTag[A18].erasure, classTag[A19].erasure, classTag[A20].erasure, classTag[A21].erasure, classTag[A22].erasure).invoke(a, a1, a2, a3, a4, a5, a6, a7, a8, a9, a10, a11, a12, a13, a14, a15, a16, a17, a18, a19, a20, a21, a22).asInstanceOf[R] - } - } - - /** code for StaticMagic - println( - (1 to 22).map { n => - def repeat(str: String => String) = (1 to n).map(i => str(i.toString)).mkString(", ") - val types = repeat(n => s"A$n <: AnyRef : ClassTag") - val inArgs = repeat(n => s"a$n: A$n") - val erasure = repeat(n => s"classTag[A$n].erasure") - val outArgs = repeat(n => s"a$n") - s"""|def callStatic[$types, R](name: String, $inArgs): R = { - | c.getDeclaredMethod(name, $erasure).invoke(c, $outArgs).asInstanceOf[R] - |}""".stripMargin - }.mkString("\n") - ) - */ - - protected implicit class StaticMagic(c: Class[_]) { - def callStatic[A1 <: AnyRef : ClassTag, R](name: String, a1: A1): R = { - c.getDeclaredMethod(name, classTag[A1].erasure).invoke(c, a1).asInstanceOf[R] - } - def callStatic[A1 <: AnyRef : ClassTag, A2 <: AnyRef : ClassTag, R](name: String, a1: A1, a2: A2): R = { - c.getDeclaredMethod(name, classTag[A1].erasure, classTag[A2].erasure).invoke(c, a1, a2).asInstanceOf[R] - } - def callStatic[A1 <: AnyRef : ClassTag, A2 <: AnyRef : ClassTag, A3 <: AnyRef : ClassTag, R](name: String, a1: A1, a2: A2, a3: A3): R = { - c.getDeclaredMethod(name, classTag[A1].erasure, classTag[A2].erasure, classTag[A3].erasure).invoke(c, a1, a2, a3).asInstanceOf[R] - } - def callStatic[A1 <: AnyRef : ClassTag, A2 <: AnyRef : ClassTag, A3 <: AnyRef : ClassTag, A4 <: AnyRef : ClassTag, R](name: String, a1: A1, a2: A2, a3: A3, a4: A4): R = { - c.getDeclaredMethod(name, classTag[A1].erasure, classTag[A2].erasure, classTag[A3].erasure, classTag[A4].erasure).invoke(c, a1, a2, a3, a4).asInstanceOf[R] - } - def callStatic[A1 <: AnyRef : ClassTag, A2 <: AnyRef : ClassTag, A3 <: AnyRef : ClassTag, A4 <: AnyRef : ClassTag, A5 <: AnyRef : ClassTag, R](name: String, a1: A1, a2: A2, a3: A3, a4: A4, a5: A5): R = { - c.getDeclaredMethod(name, classTag[A1].erasure, classTag[A2].erasure, classTag[A3].erasure, classTag[A4].erasure, classTag[A5].erasure).invoke(c, a1, a2, a3, a4, a5).asInstanceOf[R] - } - def callStatic[A1 <: AnyRef : ClassTag, A2 <: AnyRef : ClassTag, A3 <: AnyRef : ClassTag, A4 <: AnyRef : ClassTag, A5 <: AnyRef : ClassTag, A6 <: AnyRef : ClassTag, R](name: String, a1: A1, a2: A2, a3: A3, a4: A4, a5: A5, a6: A6): R = { - c.getDeclaredMethod(name, classTag[A1].erasure, classTag[A2].erasure, classTag[A3].erasure, classTag[A4].erasure, classTag[A5].erasure, classTag[A6].erasure).invoke(c, a1, a2, a3, a4, a5, a6).asInstanceOf[R] - } - def callStatic[A1 <: AnyRef : ClassTag, A2 <: AnyRef : ClassTag, A3 <: AnyRef : ClassTag, A4 <: AnyRef : ClassTag, A5 <: AnyRef : ClassTag, A6 <: AnyRef : ClassTag, A7 <: AnyRef : ClassTag, R](name: String, a1: A1, a2: A2, a3: A3, a4: A4, a5: A5, a6: A6, a7: A7): R = { - c.getDeclaredMethod(name, classTag[A1].erasure, classTag[A2].erasure, classTag[A3].erasure, classTag[A4].erasure, classTag[A5].erasure, classTag[A6].erasure, classTag[A7].erasure).invoke(c, a1, a2, a3, a4, a5, a6, a7).asInstanceOf[R] - } - def callStatic[A1 <: AnyRef : ClassTag, A2 <: AnyRef : ClassTag, A3 <: AnyRef : ClassTag, A4 <: AnyRef : ClassTag, A5 <: AnyRef : ClassTag, A6 <: AnyRef : ClassTag, A7 <: AnyRef : ClassTag, A8 <: AnyRef : ClassTag, R](name: String, a1: A1, a2: A2, a3: A3, a4: A4, a5: A5, a6: A6, a7: A7, a8: A8): R = { - c.getDeclaredMethod(name, classTag[A1].erasure, classTag[A2].erasure, classTag[A3].erasure, classTag[A4].erasure, classTag[A5].erasure, classTag[A6].erasure, classTag[A7].erasure, classTag[A8].erasure).invoke(c, a1, a2, a3, a4, a5, a6, a7, a8).asInstanceOf[R] - } - def callStatic[A1 <: AnyRef : ClassTag, A2 <: AnyRef : ClassTag, A3 <: AnyRef : ClassTag, A4 <: AnyRef : ClassTag, A5 <: AnyRef : ClassTag, A6 <: AnyRef : ClassTag, A7 <: AnyRef : ClassTag, A8 <: AnyRef : ClassTag, A9 <: AnyRef : ClassTag, R](name: String, a1: A1, a2: A2, a3: A3, a4: A4, a5: A5, a6: A6, a7: A7, a8: A8, a9: A9): R = { - c.getDeclaredMethod(name, classTag[A1].erasure, classTag[A2].erasure, classTag[A3].erasure, classTag[A4].erasure, classTag[A5].erasure, classTag[A6].erasure, classTag[A7].erasure, classTag[A8].erasure, classTag[A9].erasure).invoke(c, a1, a2, a3, a4, a5, a6, a7, a8, a9).asInstanceOf[R] - } - def callStatic[A1 <: AnyRef : ClassTag, A2 <: AnyRef : ClassTag, A3 <: AnyRef : ClassTag, A4 <: AnyRef : ClassTag, A5 <: AnyRef : ClassTag, A6 <: AnyRef : ClassTag, A7 <: AnyRef : ClassTag, A8 <: AnyRef : ClassTag, A9 <: AnyRef : ClassTag, A10 <: AnyRef : ClassTag, R](name: String, a1: A1, a2: A2, a3: A3, a4: A4, a5: A5, a6: A6, a7: A7, a8: A8, a9: A9, a10: A10): R = { - c.getDeclaredMethod(name, classTag[A1].erasure, classTag[A2].erasure, classTag[A3].erasure, classTag[A4].erasure, classTag[A5].erasure, classTag[A6].erasure, classTag[A7].erasure, classTag[A8].erasure, classTag[A9].erasure, classTag[A10].erasure).invoke(c, a1, a2, a3, a4, a5, a6, a7, a8, a9, a10).asInstanceOf[R] - } - def callStatic[A1 <: AnyRef : ClassTag, A2 <: AnyRef : ClassTag, A3 <: AnyRef : ClassTag, A4 <: AnyRef : ClassTag, A5 <: AnyRef : ClassTag, A6 <: AnyRef : ClassTag, A7 <: AnyRef : ClassTag, A8 <: AnyRef : ClassTag, A9 <: AnyRef : ClassTag, A10 <: AnyRef : ClassTag, A11 <: AnyRef : ClassTag, R](name: String, a1: A1, a2: A2, a3: A3, a4: A4, a5: A5, a6: A6, a7: A7, a8: A8, a9: A9, a10: A10, a11: A11): R = { - c.getDeclaredMethod(name, classTag[A1].erasure, classTag[A2].erasure, classTag[A3].erasure, classTag[A4].erasure, classTag[A5].erasure, classTag[A6].erasure, classTag[A7].erasure, classTag[A8].erasure, classTag[A9].erasure, classTag[A10].erasure, classTag[A11].erasure).invoke(c, a1, a2, a3, a4, a5, a6, a7, a8, a9, a10, a11).asInstanceOf[R] - } - def callStatic[A1 <: AnyRef : ClassTag, A2 <: AnyRef : ClassTag, A3 <: AnyRef : ClassTag, A4 <: AnyRef : ClassTag, A5 <: AnyRef : ClassTag, A6 <: AnyRef : ClassTag, A7 <: AnyRef : ClassTag, A8 <: AnyRef : ClassTag, A9 <: AnyRef : ClassTag, A10 <: AnyRef : ClassTag, A11 <: AnyRef : ClassTag, A12 <: AnyRef : ClassTag, R](name: String, a1: A1, a2: A2, a3: A3, a4: A4, a5: A5, a6: A6, a7: A7, a8: A8, a9: A9, a10: A10, a11: A11, a12: A12): R = { - c.getDeclaredMethod(name, classTag[A1].erasure, classTag[A2].erasure, classTag[A3].erasure, classTag[A4].erasure, classTag[A5].erasure, classTag[A6].erasure, classTag[A7].erasure, classTag[A8].erasure, classTag[A9].erasure, classTag[A10].erasure, classTag[A11].erasure, classTag[A12].erasure).invoke(c, a1, a2, a3, a4, a5, a6, a7, a8, a9, a10, a11, a12).asInstanceOf[R] - } - def callStatic[A1 <: AnyRef : ClassTag, A2 <: AnyRef : ClassTag, A3 <: AnyRef : ClassTag, A4 <: AnyRef : ClassTag, A5 <: AnyRef : ClassTag, A6 <: AnyRef : ClassTag, A7 <: AnyRef : ClassTag, A8 <: AnyRef : ClassTag, A9 <: AnyRef : ClassTag, A10 <: AnyRef : ClassTag, A11 <: AnyRef : ClassTag, A12 <: AnyRef : ClassTag, A13 <: AnyRef : ClassTag, R](name: String, a1: A1, a2: A2, a3: A3, a4: A4, a5: A5, a6: A6, a7: A7, a8: A8, a9: A9, a10: A10, a11: A11, a12: A12, a13: A13): R = { - c.getDeclaredMethod(name, classTag[A1].erasure, classTag[A2].erasure, classTag[A3].erasure, classTag[A4].erasure, classTag[A5].erasure, classTag[A6].erasure, classTag[A7].erasure, classTag[A8].erasure, classTag[A9].erasure, classTag[A10].erasure, classTag[A11].erasure, classTag[A12].erasure, classTag[A13].erasure).invoke(c, a1, a2, a3, a4, a5, a6, a7, a8, a9, a10, a11, a12, a13).asInstanceOf[R] - } - def callStatic[A1 <: AnyRef : ClassTag, A2 <: AnyRef : ClassTag, A3 <: AnyRef : ClassTag, A4 <: AnyRef : ClassTag, A5 <: AnyRef : ClassTag, A6 <: AnyRef : ClassTag, A7 <: AnyRef : ClassTag, A8 <: AnyRef : ClassTag, A9 <: AnyRef : ClassTag, A10 <: AnyRef : ClassTag, A11 <: AnyRef : ClassTag, A12 <: AnyRef : ClassTag, A13 <: AnyRef : ClassTag, A14 <: AnyRef : ClassTag, R](name: String, a1: A1, a2: A2, a3: A3, a4: A4, a5: A5, a6: A6, a7: A7, a8: A8, a9: A9, a10: A10, a11: A11, a12: A12, a13: A13, a14: A14): R = { - c.getDeclaredMethod(name, classTag[A1].erasure, classTag[A2].erasure, classTag[A3].erasure, classTag[A4].erasure, classTag[A5].erasure, classTag[A6].erasure, classTag[A7].erasure, classTag[A8].erasure, classTag[A9].erasure, classTag[A10].erasure, classTag[A11].erasure, classTag[A12].erasure, classTag[A13].erasure, classTag[A14].erasure).invoke(c, a1, a2, a3, a4, a5, a6, a7, a8, a9, a10, a11, a12, a13, a14).asInstanceOf[R] - } - def callStatic[A1 <: AnyRef : ClassTag, A2 <: AnyRef : ClassTag, A3 <: AnyRef : ClassTag, A4 <: AnyRef : ClassTag, A5 <: AnyRef : ClassTag, A6 <: AnyRef : ClassTag, A7 <: AnyRef : ClassTag, A8 <: AnyRef : ClassTag, A9 <: AnyRef : ClassTag, A10 <: AnyRef : ClassTag, A11 <: AnyRef : ClassTag, A12 <: AnyRef : ClassTag, A13 <: AnyRef : ClassTag, A14 <: AnyRef : ClassTag, A15 <: AnyRef : ClassTag, R](name: String, a1: A1, a2: A2, a3: A3, a4: A4, a5: A5, a6: A6, a7: A7, a8: A8, a9: A9, a10: A10, a11: A11, a12: A12, a13: A13, a14: A14, a15: A15): R = { - c.getDeclaredMethod(name, classTag[A1].erasure, classTag[A2].erasure, classTag[A3].erasure, classTag[A4].erasure, classTag[A5].erasure, classTag[A6].erasure, classTag[A7].erasure, classTag[A8].erasure, classTag[A9].erasure, classTag[A10].erasure, classTag[A11].erasure, classTag[A12].erasure, classTag[A13].erasure, classTag[A14].erasure, classTag[A15].erasure).invoke(c, a1, a2, a3, a4, a5, a6, a7, a8, a9, a10, a11, a12, a13, a14, a15).asInstanceOf[R] - } - def callStatic[A1 <: AnyRef : ClassTag, A2 <: AnyRef : ClassTag, A3 <: AnyRef : ClassTag, A4 <: AnyRef : ClassTag, A5 <: AnyRef : ClassTag, A6 <: AnyRef : ClassTag, A7 <: AnyRef : ClassTag, A8 <: AnyRef : ClassTag, A9 <: AnyRef : ClassTag, A10 <: AnyRef : ClassTag, A11 <: AnyRef : ClassTag, A12 <: AnyRef : ClassTag, A13 <: AnyRef : ClassTag, A14 <: AnyRef : ClassTag, A15 <: AnyRef : ClassTag, A16 <: AnyRef : ClassTag, R](name: String, a1: A1, a2: A2, a3: A3, a4: A4, a5: A5, a6: A6, a7: A7, a8: A8, a9: A9, a10: A10, a11: A11, a12: A12, a13: A13, a14: A14, a15: A15, a16: A16): R = { - c.getDeclaredMethod(name, classTag[A1].erasure, classTag[A2].erasure, classTag[A3].erasure, classTag[A4].erasure, classTag[A5].erasure, classTag[A6].erasure, classTag[A7].erasure, classTag[A8].erasure, classTag[A9].erasure, classTag[A10].erasure, classTag[A11].erasure, classTag[A12].erasure, classTag[A13].erasure, classTag[A14].erasure, classTag[A15].erasure, classTag[A16].erasure).invoke(c, a1, a2, a3, a4, a5, a6, a7, a8, a9, a10, a11, a12, a13, a14, a15, a16).asInstanceOf[R] - } - def callStatic[A1 <: AnyRef : ClassTag, A2 <: AnyRef : ClassTag, A3 <: AnyRef : ClassTag, A4 <: AnyRef : ClassTag, A5 <: AnyRef : ClassTag, A6 <: AnyRef : ClassTag, A7 <: AnyRef : ClassTag, A8 <: AnyRef : ClassTag, A9 <: AnyRef : ClassTag, A10 <: AnyRef : ClassTag, A11 <: AnyRef : ClassTag, A12 <: AnyRef : ClassTag, A13 <: AnyRef : ClassTag, A14 <: AnyRef : ClassTag, A15 <: AnyRef : ClassTag, A16 <: AnyRef : ClassTag, A17 <: AnyRef : ClassTag, R](name: String, a1: A1, a2: A2, a3: A3, a4: A4, a5: A5, a6: A6, a7: A7, a8: A8, a9: A9, a10: A10, a11: A11, a12: A12, a13: A13, a14: A14, a15: A15, a16: A16, a17: A17): R = { - c.getDeclaredMethod(name, classTag[A1].erasure, classTag[A2].erasure, classTag[A3].erasure, classTag[A4].erasure, classTag[A5].erasure, classTag[A6].erasure, classTag[A7].erasure, classTag[A8].erasure, classTag[A9].erasure, classTag[A10].erasure, classTag[A11].erasure, classTag[A12].erasure, classTag[A13].erasure, classTag[A14].erasure, classTag[A15].erasure, classTag[A16].erasure, classTag[A17].erasure).invoke(c, a1, a2, a3, a4, a5, a6, a7, a8, a9, a10, a11, a12, a13, a14, a15, a16, a17).asInstanceOf[R] - } - def callStatic[A1 <: AnyRef : ClassTag, A2 <: AnyRef : ClassTag, A3 <: AnyRef : ClassTag, A4 <: AnyRef : ClassTag, A5 <: AnyRef : ClassTag, A6 <: AnyRef : ClassTag, A7 <: AnyRef : ClassTag, A8 <: AnyRef : ClassTag, A9 <: AnyRef : ClassTag, A10 <: AnyRef : ClassTag, A11 <: AnyRef : ClassTag, A12 <: AnyRef : ClassTag, A13 <: AnyRef : ClassTag, A14 <: AnyRef : ClassTag, A15 <: AnyRef : ClassTag, A16 <: AnyRef : ClassTag, A17 <: AnyRef : ClassTag, A18 <: AnyRef : ClassTag, R](name: String, a1: A1, a2: A2, a3: A3, a4: A4, a5: A5, a6: A6, a7: A7, a8: A8, a9: A9, a10: A10, a11: A11, a12: A12, a13: A13, a14: A14, a15: A15, a16: A16, a17: A17, a18: A18): R = { - c.getDeclaredMethod(name, classTag[A1].erasure, classTag[A2].erasure, classTag[A3].erasure, classTag[A4].erasure, classTag[A5].erasure, classTag[A6].erasure, classTag[A7].erasure, classTag[A8].erasure, classTag[A9].erasure, classTag[A10].erasure, classTag[A11].erasure, classTag[A12].erasure, classTag[A13].erasure, classTag[A14].erasure, classTag[A15].erasure, classTag[A16].erasure, classTag[A17].erasure, classTag[A18].erasure).invoke(c, a1, a2, a3, a4, a5, a6, a7, a8, a9, a10, a11, a12, a13, a14, a15, a16, a17, a18).asInstanceOf[R] - } - def callStatic[A1 <: AnyRef : ClassTag, A2 <: AnyRef : ClassTag, A3 <: AnyRef : ClassTag, A4 <: AnyRef : ClassTag, A5 <: AnyRef : ClassTag, A6 <: AnyRef : ClassTag, A7 <: AnyRef : ClassTag, A8 <: AnyRef : ClassTag, A9 <: AnyRef : ClassTag, A10 <: AnyRef : ClassTag, A11 <: AnyRef : ClassTag, A12 <: AnyRef : ClassTag, A13 <: AnyRef : ClassTag, A14 <: AnyRef : ClassTag, A15 <: AnyRef : ClassTag, A16 <: AnyRef : ClassTag, A17 <: AnyRef : ClassTag, A18 <: AnyRef : ClassTag, A19 <: AnyRef : ClassTag, R](name: String, a1: A1, a2: A2, a3: A3, a4: A4, a5: A5, a6: A6, a7: A7, a8: A8, a9: A9, a10: A10, a11: A11, a12: A12, a13: A13, a14: A14, a15: A15, a16: A16, a17: A17, a18: A18, a19: A19): R = { - c.getDeclaredMethod(name, classTag[A1].erasure, classTag[A2].erasure, classTag[A3].erasure, classTag[A4].erasure, classTag[A5].erasure, classTag[A6].erasure, classTag[A7].erasure, classTag[A8].erasure, classTag[A9].erasure, classTag[A10].erasure, classTag[A11].erasure, classTag[A12].erasure, classTag[A13].erasure, classTag[A14].erasure, classTag[A15].erasure, classTag[A16].erasure, classTag[A17].erasure, classTag[A18].erasure, classTag[A19].erasure).invoke(c, a1, a2, a3, a4, a5, a6, a7, a8, a9, a10, a11, a12, a13, a14, a15, a16, a17, a18, a19).asInstanceOf[R] - } - def callStatic[A1 <: AnyRef : ClassTag, A2 <: AnyRef : ClassTag, A3 <: AnyRef : ClassTag, A4 <: AnyRef : ClassTag, A5 <: AnyRef : ClassTag, A6 <: AnyRef : ClassTag, A7 <: AnyRef : ClassTag, A8 <: AnyRef : ClassTag, A9 <: AnyRef : ClassTag, A10 <: AnyRef : ClassTag, A11 <: AnyRef : ClassTag, A12 <: AnyRef : ClassTag, A13 <: AnyRef : ClassTag, A14 <: AnyRef : ClassTag, A15 <: AnyRef : ClassTag, A16 <: AnyRef : ClassTag, A17 <: AnyRef : ClassTag, A18 <: AnyRef : ClassTag, A19 <: AnyRef : ClassTag, A20 <: AnyRef : ClassTag, R](name: String, a1: A1, a2: A2, a3: A3, a4: A4, a5: A5, a6: A6, a7: A7, a8: A8, a9: A9, a10: A10, a11: A11, a12: A12, a13: A13, a14: A14, a15: A15, a16: A16, a17: A17, a18: A18, a19: A19, a20: A20): R = { - c.getDeclaredMethod(name, classTag[A1].erasure, classTag[A2].erasure, classTag[A3].erasure, classTag[A4].erasure, classTag[A5].erasure, classTag[A6].erasure, classTag[A7].erasure, classTag[A8].erasure, classTag[A9].erasure, classTag[A10].erasure, classTag[A11].erasure, classTag[A12].erasure, classTag[A13].erasure, classTag[A14].erasure, classTag[A15].erasure, classTag[A16].erasure, classTag[A17].erasure, classTag[A18].erasure, classTag[A19].erasure, classTag[A20].erasure).invoke(c, a1, a2, a3, a4, a5, a6, a7, a8, a9, a10, a11, a12, a13, a14, a15, a16, a17, a18, a19, a20).asInstanceOf[R] - } - def callStatic[A1 <: AnyRef : ClassTag, A2 <: AnyRef : ClassTag, A3 <: AnyRef : ClassTag, A4 <: AnyRef : ClassTag, A5 <: AnyRef : ClassTag, A6 <: AnyRef : ClassTag, A7 <: AnyRef : ClassTag, A8 <: AnyRef : ClassTag, A9 <: AnyRef : ClassTag, A10 <: AnyRef : ClassTag, A11 <: AnyRef : ClassTag, A12 <: AnyRef : ClassTag, A13 <: AnyRef : ClassTag, A14 <: AnyRef : ClassTag, A15 <: AnyRef : ClassTag, A16 <: AnyRef : ClassTag, A17 <: AnyRef : ClassTag, A18 <: AnyRef : ClassTag, A19 <: AnyRef : ClassTag, A20 <: AnyRef : ClassTag, A21 <: AnyRef : ClassTag, R](name: String, a1: A1, a2: A2, a3: A3, a4: A4, a5: A5, a6: A6, a7: A7, a8: A8, a9: A9, a10: A10, a11: A11, a12: A12, a13: A13, a14: A14, a15: A15, a16: A16, a17: A17, a18: A18, a19: A19, a20: A20, a21: A21): R = { - c.getDeclaredMethod(name, classTag[A1].erasure, classTag[A2].erasure, classTag[A3].erasure, classTag[A4].erasure, classTag[A5].erasure, classTag[A6].erasure, classTag[A7].erasure, classTag[A8].erasure, classTag[A9].erasure, classTag[A10].erasure, classTag[A11].erasure, classTag[A12].erasure, classTag[A13].erasure, classTag[A14].erasure, classTag[A15].erasure, classTag[A16].erasure, classTag[A17].erasure, classTag[A18].erasure, classTag[A19].erasure, classTag[A20].erasure, classTag[A21].erasure).invoke(c, a1, a2, a3, a4, a5, a6, a7, a8, a9, a10, a11, a12, a13, a14, a15, a16, a17, a18, a19, a20, a21).asInstanceOf[R] - } - def callStatic[A1 <: AnyRef : ClassTag, A2 <: AnyRef : ClassTag, A3 <: AnyRef : ClassTag, A4 <: AnyRef : ClassTag, A5 <: AnyRef : ClassTag, A6 <: AnyRef : ClassTag, A7 <: AnyRef : ClassTag, A8 <: AnyRef : ClassTag, A9 <: AnyRef : ClassTag, A10 <: AnyRef : ClassTag, A11 <: AnyRef : ClassTag, A12 <: AnyRef : ClassTag, A13 <: AnyRef : ClassTag, A14 <: AnyRef : ClassTag, A15 <: AnyRef : ClassTag, A16 <: AnyRef : ClassTag, A17 <: AnyRef : ClassTag, A18 <: AnyRef : ClassTag, A19 <: AnyRef : ClassTag, A20 <: AnyRef : ClassTag, A21 <: AnyRef : ClassTag, A22 <: AnyRef : ClassTag, R](name: String, a1: A1, a2: A2, a3: A3, a4: A4, a5: A5, a6: A6, a7: A7, a8: A8, a9: A9, a10: A10, a11: A11, a12: A12, a13: A13, a14: A14, a15: A15, a16: A16, a17: A17, a18: A18, a19: A19, a20: A20, a21: A21, a22: A22): R = { - c.getDeclaredMethod(name, classTag[A1].erasure, classTag[A2].erasure, classTag[A3].erasure, classTag[A4].erasure, classTag[A5].erasure, classTag[A6].erasure, classTag[A7].erasure, classTag[A8].erasure, classTag[A9].erasure, classTag[A10].erasure, classTag[A11].erasure, classTag[A12].erasure, classTag[A13].erasure, classTag[A14].erasure, classTag[A15].erasure, classTag[A16].erasure, classTag[A17].erasure, classTag[A18].erasure, classTag[A19].erasure, classTag[A20].erasure, classTag[A21].erasure, classTag[A22].erasure).invoke(c, a1, a2, a3, a4, a5, a6, a7, a8, a9, a10, a11, a12, a13, a14, a15, a16, a17, a18, a19, a20, a21, a22).asInstanceOf[R] - } - } - // scalastyle:on -} diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/package.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/package.scala index 7db9200d4744..b48082fe4b36 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/package.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/package.scala @@ -19,15 +19,50 @@ package org.apache.spark.sql.hive /** Support for interacting with different versions of the HiveMetastoreClient */ package object client { - private[client] abstract class HiveVersion(val fullVersion: String, val hasBuiltinsJar: Boolean) + private[client] abstract class HiveVersion( + val fullVersion: String, + val extraDeps: Seq[String] = Nil, + val exclusions: Seq[String] = Nil) // scalastyle:off private[client] object hive { - case object v10 extends HiveVersion("0.10.0", true) - case object v11 extends HiveVersion("0.11.0", false) - case object v12 extends HiveVersion("0.12.0", false) - case object v13 extends HiveVersion("0.13.1", false) + case object v12 extends HiveVersion("0.12.0") + case object v13 extends HiveVersion("0.13.1") + + // Hive 0.14 depends on calcite 0.9.2-incubating-SNAPSHOT which does not exist in + // maven central anymore, so override those with a version that exists. + // + // The other excluded dependencies are also nowhere to be found, so exclude them explicitly. If + // they're needed by the metastore client, users will have to dig them out of somewhere and use + // configuration to point Spark at the correct jars. + case object v14 extends HiveVersion("0.14.0", + extraDeps = Seq("org.apache.calcite:calcite-core:1.3.0-incubating", + "org.apache.calcite:calcite-avatica:1.3.0-incubating"), + exclusions = Seq("org.pentaho:pentaho-aggdesigner-algorithm")) + + case object v1_0 extends HiveVersion("1.0.0", + exclusions = Seq("eigenbase:eigenbase-properties", + "org.pentaho:pentaho-aggdesigner-algorithm", + "net.hydromatic:linq4j", + "net.hydromatic:quidem")) + + // The curator dependency was added to the exclusions here because it seems to confuse the ivy + // library. org.apache.curator:curator is a pom dependency but ivy tries to find the jar for it, + // and fails. + case object v1_1 extends HiveVersion("1.1.0", + exclusions = Seq("eigenbase:eigenbase-properties", + "org.apache.curator:*", + "org.pentaho:pentaho-aggdesigner-algorithm", + "net.hydromatic:linq4j", + "net.hydromatic:quidem")) + + case object v1_2 extends HiveVersion("1.2.0", + exclusions = Seq("eigenbase:eigenbase-properties", + "org.apache.curator:*", + "org.pentaho:pentaho-aggdesigner-algorithm", + "net.hydromatic:linq4j", + "net.hydromatic:quidem")) } // scalastyle:on - -} \ No newline at end of file + +} diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/CreateTableAsSelect.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/CreateTableAsSelect.scala index 7d3ec12c4eb0..84358cb73c9e 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/CreateTableAsSelect.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/CreateTableAsSelect.scala @@ -17,13 +17,11 @@ package org.apache.spark.sql.hive.execution -import org.apache.spark.annotation.Experimental -import org.apache.spark.sql.{AnalysisException, SQLContext} -import org.apache.spark.sql.catalyst.expressions.Row import org.apache.spark.sql.catalyst.plans.logical.{InsertIntoTable, LogicalPlan} import org.apache.spark.sql.execution.RunnableCommand -import org.apache.spark.sql.hive.client.{HiveTable, HiveColumn} -import org.apache.spark.sql.hive.{HiveContext, MetastoreRelation, HiveMetastoreTypes} +import org.apache.spark.sql.hive.client.{HiveColumn, HiveTable} +import org.apache.spark.sql.hive.{HiveContext, HiveMetastoreTypes, MetastoreRelation} +import org.apache.spark.sql.{AnalysisException, Row, SQLContext} /** * Create table and insert the query result into it. @@ -45,22 +43,30 @@ case class CreateTableAsSelect( override def run(sqlContext: SQLContext): Seq[Row] = { val hiveContext = sqlContext.asInstanceOf[HiveContext] lazy val metastoreRelation: MetastoreRelation = { - import org.apache.hadoop.hive.serde2.`lazy`.LazySimpleSerDe import org.apache.hadoop.hive.ql.io.HiveIgnoreKeyTextOutputFormat + import org.apache.hadoop.hive.serde2.`lazy`.LazySimpleSerDe import org.apache.hadoop.io.Text import org.apache.hadoop.mapred.TextInputFormat - val withSchema = + val withFormat = tableDesc.copy( - schema = - query.output.map(c => - HiveColumn(c.name, HiveMetastoreTypes.toMetastoreType(c.dataType), null)), inputFormat = tableDesc.inputFormat.orElse(Some(classOf[TextInputFormat].getName)), outputFormat = tableDesc.outputFormat .orElse(Some(classOf[HiveIgnoreKeyTextOutputFormat[Text, Text]].getName)), serde = tableDesc.serde.orElse(Some(classOf[LazySimpleSerDe].getName()))) + + val withSchema = if (withFormat.schema.isEmpty) { + // Hive doesn't support specifying the column list for target table in CTAS + // However we don't think SparkSQL should follow that. + tableDesc.copy(schema = + query.output.map(c => + HiveColumn(c.name, HiveMetastoreTypes.toMetastoreType(c.dataType), null))) + } else { + withFormat + } + hiveContext.catalog.client.createTable(withSchema) // Get the Metastore Relation diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/DescribeHiveTableCommand.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/DescribeHiveTableCommand.scala index 6fce69b58b85..5f0ed5393d19 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/DescribeHiveTableCommand.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/DescribeHiveTableCommand.scala @@ -21,12 +21,10 @@ import scala.collection.JavaConversions._ import org.apache.hadoop.hive.metastore.api.FieldSchema -import org.apache.spark.annotation.DeveloperApi -import org.apache.spark.sql.catalyst.expressions.{Attribute, Row} -import org.apache.spark.sql.execution.{SparkPlan, RunnableCommand} -import org.apache.spark.sql.hive.{HiveContext, MetastoreRelation} -import org.apache.spark.sql.hive.HiveShim -import org.apache.spark.sql.SQLContext +import org.apache.spark.sql.catalyst.expressions.Attribute +import org.apache.spark.sql.execution.RunnableCommand +import org.apache.spark.sql.hive.MetastoreRelation +import org.apache.spark.sql.{Row, SQLContext} /** * Implementation for "describe [extended] table". diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/HiveNativeCommand.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/HiveNativeCommand.scala index 60a9bb630d0d..41b645b2c9c9 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/HiveNativeCommand.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/HiveNativeCommand.scala @@ -1,34 +1,34 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.sql.hive.execution - -import org.apache.spark.sql.catalyst.expressions.{AttributeReference, Row} -import org.apache.spark.sql.execution.RunnableCommand -import org.apache.spark.sql.hive.HiveContext -import org.apache.spark.sql.SQLContext -import org.apache.spark.sql.types.StringType - -private[hive] -case class HiveNativeCommand(sql: String) extends RunnableCommand { - - override def output: Seq[AttributeReference] = - Seq(AttributeReference("result", StringType, nullable = false)()) - - override def run(sqlContext: SQLContext): Seq[Row] = - sqlContext.asInstanceOf[HiveContext].runSqlHive(sql).map(Row(_)) -} +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.hive.execution + +import org.apache.spark.sql.catalyst.expressions.AttributeReference +import org.apache.spark.sql.execution.RunnableCommand +import org.apache.spark.sql.hive.HiveContext +import org.apache.spark.sql.types.StringType +import org.apache.spark.sql.{Row, SQLContext} + +private[hive] +case class HiveNativeCommand(sql: String) extends RunnableCommand { + + override def output: Seq[AttributeReference] = + Seq(AttributeReference("result", StringType, nullable = false)()) + + override def run(sqlContext: SQLContext): Seq[Row] = + sqlContext.asInstanceOf[HiveContext].runSqlHive(sql).map(Row(_)) +} diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/HiveTableScan.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/HiveTableScan.scala index 62dc4167b78d..f4c8c9a7e8a6 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/HiveTableScan.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/HiveTableScan.scala @@ -63,7 +63,7 @@ case class HiveTableScan( BindReferences.bindReference(pred, relation.partitionKeys) } - // Create a local copy of hiveconf,so that scan specific modifications should not impact + // Create a local copy of hiveconf,so that scan specific modifications should not impact // other queries @transient private[this] val hiveExtraConf = new HiveConf(context.hiveconf) @@ -72,7 +72,7 @@ case class HiveTableScan( addColumnMetadataToConf(hiveExtraConf) @transient - private[this] val hadoopReader = + private[this] val hadoopReader = new HadoopTableReader(attributes, relation, context, hiveExtraConf) private[this] def castFromString(value: String, dataType: DataType) = { @@ -123,13 +123,13 @@ case class HiveTableScan( // Only partitioned values are needed here, since the predicate has already been bound to // partition key attribute references. - val row = new GenericRow(castedValues.toArray) + val row = InternalRow.fromSeq(castedValues) shouldKeep.eval(row).asInstanceOf[Boolean] } } } - protected override def doExecute(): RDD[Row] = if (!relation.hiveQlTable.isPartitioned) { + protected override def doExecute(): RDD[InternalRow] = if (!relation.hiveQlTable.isPartitioned) { hadoopReader.makeRDDForTable(relation.hiveQlTable) } else { hadoopReader.makeRDDForPartitionedTable(prunePartitions(relation.hiveQlPartitions)) diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/InsertIntoHiveTable.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/InsertIntoHiveTable.scala index c0b0b104e914..05f425f2b65f 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/InsertIntoHiveTable.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/InsertIntoHiveTable.scala @@ -19,26 +19,26 @@ package org.apache.spark.sql.hive.execution import java.util -import scala.collection.JavaConversions._ - import org.apache.hadoop.hive.conf.HiveConf import org.apache.hadoop.hive.conf.HiveConf.ConfVars import org.apache.hadoop.hive.metastore.MetaStoreUtils -import org.apache.hadoop.hive.ql.metadata.Hive import org.apache.hadoop.hive.ql.plan.TableDesc import org.apache.hadoop.hive.ql.{Context, ErrorMsg} import org.apache.hadoop.hive.serde2.Serializer import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspectorUtils.ObjectInspectorCopyOption import org.apache.hadoop.hive.serde2.objectinspector._ -import org.apache.hadoop.mapred.{FileOutputCommitter, FileOutputFormat, JobConf} +import org.apache.hadoop.mapred.{FileOutputFormat, JobConf} import org.apache.spark.rdd.RDD -import org.apache.spark.sql.catalyst.expressions.{Attribute, Row} +import org.apache.spark.sql.Row +import org.apache.spark.sql.catalyst.expressions.{Attribute, InternalRow} import org.apache.spark.sql.execution.{UnaryNode, SparkPlan} +import org.apache.spark.sql.hive.HiveShim.{ShimFileSinkDesc => FileSinkDesc} import org.apache.spark.sql.hive._ -import org.apache.spark.sql.hive.{ ShimFileSinkDesc => FileSinkDesc} -import org.apache.spark.sql.hive.HiveShim._ -import org.apache.spark.{SerializableWritable, SparkException, TaskContext} +import org.apache.spark.{SparkException, TaskContext} + +import scala.collection.JavaConversions._ +import org.apache.spark.util.SerializableJobConf private[hive] case class InsertIntoHiveTable( @@ -62,10 +62,10 @@ case class InsertIntoHiveTable( def output: Seq[Attribute] = child.output def saveAsHiveFile( - rdd: RDD[Row], + rdd: RDD[InternalRow], valueClass: Class[_], fileSinkConf: FileSinkDesc, - conf: SerializableWritable[JobConf], + conf: SerializableJobConf, writerContainer: SparkHiveWriterContainer): Unit = { assert(valueClass != null, "Output value class not set") conf.value.setOutputValueClass(valueClass) @@ -84,7 +84,7 @@ case class InsertIntoHiveTable( writerContainer.commitJob() // Note that this function is executed on executor side - def writeToFile(context: TaskContext, iterator: Iterator[Row]): Unit = { + def writeToFile(context: TaskContext, iterator: Iterator[InternalRow]): Unit = { val serializer = newSerializer(fileSinkConf.getTableInfo) val standardOI = ObjectInspectorUtils .getStandardObjectInspector( @@ -106,7 +106,7 @@ case class InsertIntoHiveTable( } writerContainer - .getLocalFileWriter(row) + .getLocalFileWriter(row, table.schema) .write(serializer.serialize(outputData, standardOI)) } @@ -121,12 +121,12 @@ case class InsertIntoHiveTable( * * Note: this is run once and then kept to avoid double insertions. */ - protected[sql] lazy val sideEffectResult: Seq[Row] = { + protected[sql] lazy val sideEffectResult: Seq[InternalRow] = { // Have to pass the TableDesc object to RDD.mapPartitions and then instantiate new serializer // instances within the closure, since Serializer is not serializable while TableDesc is. val tableDesc = table.tableDesc val tableLocation = table.hiveQlTable.getDataLocation - val tmpLocation = HiveShim.getExternalTmpPath(hiveContext, tableLocation) + val tmpLocation = hiveContext.getExternalTmpPath(tableLocation.toUri) val fileSinkConf = new FileSinkDesc(tmpLocation.toString, tableDesc, false) val isCompressed = sc.hiveconf.getBoolean( ConfVars.COMPRESSRESULT.varname, ConfVars.COMPRESSRESULT.defaultBoolVal) @@ -173,7 +173,7 @@ case class InsertIntoHiveTable( } val jobConf = new JobConf(sc.hiveconf) - val jobConfSer = new SerializableWritable(jobConf) + val jobConfSer = new SerializableJobConf(jobConf) val writerContainer = if (numDynamicPartitions > 0) { val dynamicPartColNames = partitionColumnNames.takeRight(numDynamicPartitions) @@ -194,12 +194,10 @@ case class InsertIntoHiveTable( if (partition.nonEmpty) { // loadPartition call orders directories created on the iteration order of the this map - val orderedPartitionSpec = new util.LinkedHashMap[String,String]() - table.hiveQlTable.getPartCols().foreach{ - entry=> - orderedPartitionSpec.put(entry.getName,partitionSpec.get(entry.getName).getOrElse("")) + val orderedPartitionSpec = new util.LinkedHashMap[String, String]() + table.hiveQlTable.getPartCols().foreach { entry => + orderedPartitionSpec.put(entry.getName, partitionSpec.get(entry.getName).getOrElse("")) } - val partVals = MetaStoreUtils.getPvals(table.hiveQlTable.getPartCols, partitionSpec) // inheritTableSpecs is set to true. It should be set to false for a IMPORT query // which is currently considered as a Hive native command. @@ -253,12 +251,13 @@ case class InsertIntoHiveTable( // however for now we return an empty list to simplify compatibility checks with hive, which // does not return anything for insert operations. // TODO: implement hive compatibility as rules. - Seq.empty[Row] + Seq.empty[InternalRow] } - override def executeCollect(): Array[Row] = sideEffectResult.toArray + override def executeCollect(): Array[Row] = + sideEffectResult.toArray - protected override def doExecute(): RDD[Row] = { + protected override def doExecute(): RDD[InternalRow] = { sqlContext.sparkContext.parallelize(sideEffectResult, 1) } } diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/ScriptTransformation.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/ScriptTransformation.scala index bfd26e0170c7..b967e191c585 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/ScriptTransformation.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/ScriptTransformation.scala @@ -18,6 +18,7 @@ package org.apache.spark.sql.hive.execution import java.io.{BufferedReader, DataInputStream, DataOutputStream, EOFException, InputStreamReader} +import java.lang.ProcessBuilder.Redirect import java.util.Properties import scala.collection.JavaConversions._ @@ -34,7 +35,7 @@ import org.apache.spark.sql.execution._ import org.apache.spark.sql.hive.HiveShim._ import org.apache.spark.sql.hive.{HiveContext, HiveInspectors} import org.apache.spark.sql.types.DataType -import org.apache.spark.util.Utils +import org.apache.spark.util.{CircularBuffer, RedirectThread, Utils} /** * Transforms the input by forking and running the specified script. @@ -54,19 +55,23 @@ case class ScriptTransformation( override def otherCopyArgs: Seq[HiveContext] = sc :: Nil - protected override def doExecute(): RDD[Row] = { + protected override def doExecute(): RDD[InternalRow] = { child.execute().mapPartitions { iter => val cmd = List("/bin/bash", "-c", script) val builder = new ProcessBuilder(cmd) + // We need to start threads connected to the process pipeline: + // 1) The error msg generated by the script process would be hidden. + // 2) If the error msg is too big to chock up the buffer, the input logic would be hung val proc = builder.start() val inputStream = proc.getInputStream val outputStream = proc.getOutputStream + val errorStream = proc.getErrorStream val reader = new BufferedReader(new InputStreamReader(inputStream)) - + val (outputSerde, outputSoi) = ioschema.initOutputSerDe(output) - val iterator: Iterator[Row] = new Iterator[Row] with HiveInspectors { - var cacheRow: Row = null + val iterator: Iterator[InternalRow] = new Iterator[InternalRow] with HiveInspectors { + var cacheRow: InternalRow = null var curLine: String = null var eof: Boolean = false @@ -83,7 +88,7 @@ case class ScriptTransformation( } } - def deserialize(): Row = { + def deserialize(): InternalRow = { if (cacheRow != null) return cacheRow val mutableRow = new SpecificMutableRow(output.map(_.dataType)) @@ -95,7 +100,7 @@ case class ScriptTransformation( val raw = outputSerde.deserialize(writable) val dataList = outputSoi.getStructFieldsDataAsList(raw) val fieldList = outputSoi.getAllStructFieldRefs() - + var i = 0 dataList.foreach( element => { if (element == null) { @@ -113,20 +118,20 @@ case class ScriptTransformation( } } - override def next(): Row = { + override def next(): InternalRow = { if (!hasNext) { throw new NoSuchElementException } - + if (outputSerde == null) { val prevLine = curLine curLine = reader.readLine() if (!ioschema.schemaLess) { - new GenericRow(CatalystTypeConverters.convertToCatalyst( + new GenericInternalRow(CatalystTypeConverters.convertToCatalyst( prevLine.split(ioschema.outputRowFormatMap("TOK_TABLEROWFORMATFIELD"))) .asInstanceOf[Array[Any]]) } else { - new GenericRow(CatalystTypeConverters.convertToCatalyst( + new GenericInternalRow(CatalystTypeConverters.convertToCatalyst( prevLine.split(ioschema.outputRowFormatMap("TOK_TABLEROWFORMATFIELD"), 2)) .asInstanceOf[Array[Any]]) } @@ -145,28 +150,43 @@ case class ScriptTransformation( val dataOutputStream = new DataOutputStream(outputStream) val outputProjection = new InterpretedProjection(input, child.output) + // TODO make the 2048 configurable? + val stderrBuffer = new CircularBuffer(2048) + // Consume the error stream from the pipeline, otherwise it will be blocked if + // the pipeline is full. + new RedirectThread(errorStream, // input stream from the pipeline + stderrBuffer, // output to a circular buffer + "Thread-ScriptTransformation-STDERR-Consumer").start() + // Put the write(output to the pipeline) into a single thread // and keep the collector as remain in the main thread. // otherwise it will causes deadlock if the data size greater than // the pipeline / buffer capacity. new Thread(new Runnable() { override def run(): Unit = { - iter - .map(outputProjection) - .foreach { row => - if (inputSerde == null) { - val data = row.mkString("", ioschema.inputRowFormatMap("TOK_TABLEROWFORMATFIELD"), - ioschema.inputRowFormatMap("TOK_TABLEROWFORMATLINES")).getBytes("utf-8") - - outputStream.write(data) - } else { - val writable = inputSerde.serialize(row.asInstanceOf[GenericRow].values, inputSoi) - prepareWritable(writable).write(dataOutputStream) + Utils.tryWithSafeFinally { + iter + .map(outputProjection) + .foreach { row => + if (inputSerde == null) { + val data = row.mkString("", ioschema.inputRowFormatMap("TOK_TABLEROWFORMATFIELD"), + ioschema.inputRowFormatMap("TOK_TABLEROWFORMATLINES")).getBytes("utf-8") + + outputStream.write(data) + } else { + val writable = inputSerde.serialize( + row.asInstanceOf[GenericInternalRow].values, inputSoi) + prepareWritable(writable).write(dataOutputStream) + } + } + outputStream.close() + } { + if (proc.waitFor() != 0) { + logError(stderrBuffer.toString) // log the stderr circular buffer } } - outputStream.close() } - }).start() + }, "Thread-ScriptTransformation-Feed").start() iterator } @@ -192,7 +212,7 @@ case class HiveScriptIOSchema ( val inputRowFormatMap = inputRowFormat.toMap.withDefault((k) => defaultFormat(k)) val outputRowFormatMap = outputRowFormat.toMap.withDefault((k) => defaultFormat(k)) - + def initInputSerDe(input: Seq[Expression]): (AbstractSerDe, ObjectInspector) = { val (columns, columnTypes) = parseAttrs(input) val serde = initSerDe(inputSerdeClass, columns, columnTypes, inputSerdeProps) @@ -206,22 +226,22 @@ case class HiveScriptIOSchema ( } def parseAttrs(attrs: Seq[Expression]): (Seq[String], Seq[DataType]) = { - + val columns = attrs.map { case aref: AttributeReference => aref.name case e: NamedExpression => e.name case _ => null } - + val columnTypes = attrs.map { case aref: AttributeReference => aref.dataType case e: NamedExpression => e.dataType - case _ => null + case _ => null } (columns, columnTypes) } - + def initSerDe(serdeClassName: String, columns: Seq[String], columnTypes: Seq[DataType], serdeProps: Seq[(String, String)]): AbstractSerDe = { @@ -240,7 +260,7 @@ case class HiveScriptIOSchema ( (kv._1.split("'")(1), kv._2.split("'")(1)) }).toMap + (serdeConstants.LIST_COLUMNS -> columns.mkString(",")) propsMap = propsMap + (serdeConstants.LIST_COLUMN_TYPES -> columnTypesNames) - + val properties = new Properties() properties.putAll(propsMap) serde.initialize(null, properties) @@ -261,7 +281,7 @@ case class HiveScriptIOSchema ( null } } - + def initOutputputSoi(outputSerde: AbstractSerDe): StructObjectInspector = { if (outputSerde != null) { outputSerde.getObjectInspector().asInstanceOf[StructObjectInspector] @@ -270,3 +290,4 @@ case class HiveScriptIOSchema ( } } } + diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/commands.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/commands.scala index 660976334375..71fa3e9c33ad 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/commands.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/commands.scala @@ -17,15 +17,14 @@ package org.apache.spark.sql.hive.execution -import org.apache.spark.sql.AnalysisException +import org.apache.spark.sql._ import org.apache.spark.sql.catalyst.analysis.EliminateSubQueries -import org.apache.spark.sql.catalyst.util._ -import org.apache.spark.sql.sources._ -import org.apache.spark.sql.{SaveMode, DataFrame, SQLContext} -import org.apache.spark.sql.catalyst.expressions.{Attribute, Row} +import org.apache.spark.sql.catalyst.expressions.Attribute import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan +import org.apache.spark.sql.catalyst.util._ import org.apache.spark.sql.execution.RunnableCommand import org.apache.spark.sql.hive.HiveContext +import org.apache.spark.sql.sources._ import org.apache.spark.sql.types._ import org.apache.spark.util.Utils @@ -91,9 +90,15 @@ case class AddJar(path: String) extends RunnableCommand { val jarURL = new java.io.File(path).toURL val newClassLoader = new java.net.URLClassLoader(Array(jarURL), currentClassLoader) Thread.currentThread.setContextClassLoader(newClassLoader) - org.apache.hadoop.hive.ql.metadata.Hive.get().getConf().setClassLoader(newClassLoader) - - // Add jar to isolated hive classloader + // We need to explicitly set the class loader associated with the conf in executionHive's + // state because this class loader will be used as the context class loader of the current + // thread to execute any Hive command. + // We cannot use `org.apache.hadoop.hive.ql.metadata.Hive.get().getConf()` because Hive.get() + // returns the value of a thread local variable and its HiveConf may not be the HiveConf + // associated with `executionHive.state` (for example, HiveContext is created in one thread + // and then add jar is called from another thread). + hiveContext.executionHive.state.getConf.setClassLoader(newClassLoader) + // Add jar to isolated hive (metadataHive) class loader. hiveContext.runSqlHive(s"ADD JAR $path") // Add jar to executors @@ -146,6 +151,7 @@ case class CreateMetastoreDataSource( hiveContext.catalog.createDataSourceTable( tableName, userSpecifiedSchema, + Array.empty[String], provider, optionsWithPath, isExternal) @@ -229,7 +235,7 @@ case class CreateMetastoreDataSourceAsSelect( val data = DataFrame(hiveContext, query) val df = existingSchema match { // If we are inserting into an existing table, just use the existing schema. - case Some(schema) => sqlContext.createDataFrame(data.queryExecution.toRdd, schema) + case Some(schema) => sqlContext.internalCreateDataFrame(data.queryExecution.toRdd, schema) case None => data } @@ -244,6 +250,7 @@ case class CreateMetastoreDataSourceAsSelect( hiveContext.catalog.createDataSourceTable( tableName, Some(resolved.relation.schema), + partitionColumns, provider, optionsWithPath, isExternal) diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveUdfs.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveUDFs.scala similarity index 85% rename from sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveUdfs.scala rename to sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveUDFs.scala index bc6b3a2d58c3..4dea561ae5f6 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveUdfs.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveUDFs.scala @@ -17,11 +17,9 @@ package org.apache.spark.sql.hive -import org.apache.hadoop.hive.ql.udf.generic.GenericUDAFEvaluator.AggregationBuffer -import org.apache.hadoop.hive.ql.udf.generic.GenericUDFUtils.ConversionHelper -import org.apache.spark.sql.AnalysisException - import scala.collection.mutable.ArrayBuffer +import scala.collection.JavaConversions._ +import scala.util.Try import org.apache.hadoop.hive.serde2.objectinspector.{ObjectInspector, ConstantObjectInspector} import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspectorFactory.ObjectInspectorOptions @@ -30,54 +28,64 @@ import org.apache.hadoop.hive.ql.exec._ import org.apache.hadoop.hive.ql.udf.{UDFType => HiveUDFType} import org.apache.hadoop.hive.ql.udf.generic._ import org.apache.hadoop.hive.ql.udf.generic.GenericUDF._ +import org.apache.hadoop.hive.ql.udf.generic.GenericUDAFEvaluator.AggregationBuffer +import org.apache.hadoop.hive.ql.udf.generic.GenericUDFUtils.ConversionHelper import org.apache.spark.Logging +import org.apache.spark.sql.AnalysisException import org.apache.spark.sql.catalyst.analysis +import org.apache.spark.sql.catalyst.analysis.FunctionRegistry.FunctionBuilder import org.apache.spark.sql.catalyst.errors.TreeNodeException import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.catalyst.rules.Rule +import org.apache.spark.sql.hive.HiveShim._ import org.apache.spark.sql.types._ -/* Implicit conversions */ -import scala.collection.JavaConversions._ -private[hive] abstract class HiveFunctionRegistry +private[hive] class HiveFunctionRegistry(underlying: analysis.FunctionRegistry) extends analysis.FunctionRegistry with HiveInspectors { def getFunctionInfo(name: String): FunctionInfo = FunctionRegistry.getFunctionInfo(name) - def lookupFunction(name: String, children: Seq[Expression]): Expression = { - // We only look it up to see if it exists, but do not include it in the HiveUDF since it is - // not always serializable. - val functionInfo: FunctionInfo = - Option(FunctionRegistry.getFunctionInfo(name.toLowerCase)).getOrElse( - sys.error(s"Couldn't find function $name")) - - val functionClassName = functionInfo.getFunctionClass.getName - - if (classOf[UDF].isAssignableFrom(functionInfo.getFunctionClass)) { - HiveSimpleUdf(new HiveFunctionWrapper(functionClassName), children) - } else if (classOf[GenericUDF].isAssignableFrom(functionInfo.getFunctionClass)) { - HiveGenericUdf(new HiveFunctionWrapper(functionClassName), children) - } else if ( - classOf[AbstractGenericUDAFResolver].isAssignableFrom(functionInfo.getFunctionClass)) { - HiveGenericUdaf(new HiveFunctionWrapper(functionClassName), children) - } else if (classOf[UDAF].isAssignableFrom(functionInfo.getFunctionClass)) { - HiveUdaf(new HiveFunctionWrapper(functionClassName), children) - } else if (classOf[GenericUDTF].isAssignableFrom(functionInfo.getFunctionClass)) { - HiveGenericUdtf(new HiveFunctionWrapper(functionClassName), children) - } else { - sys.error(s"No handler for udf ${functionInfo.getFunctionClass}") + override def lookupFunction(name: String, children: Seq[Expression]): Expression = { + Try(underlying.lookupFunction(name, children)).getOrElse { + // We only look it up to see if it exists, but do not include it in the HiveUDF since it is + // not always serializable. + val functionInfo: FunctionInfo = + Option(FunctionRegistry.getFunctionInfo(name.toLowerCase)).getOrElse( + throw new AnalysisException(s"undefined function $name")) + + val functionClassName = functionInfo.getFunctionClass.getName + + if (classOf[UDF].isAssignableFrom(functionInfo.getFunctionClass)) { + HiveSimpleUDF(new HiveFunctionWrapper(functionClassName), children) + } else if (classOf[GenericUDF].isAssignableFrom(functionInfo.getFunctionClass)) { + HiveGenericUDF(new HiveFunctionWrapper(functionClassName), children) + } else if ( + classOf[AbstractGenericUDAFResolver].isAssignableFrom(functionInfo.getFunctionClass)) { + HiveGenericUDAF(new HiveFunctionWrapper(functionClassName), children) + } else if (classOf[UDAF].isAssignableFrom(functionInfo.getFunctionClass)) { + HiveUDAF(new HiveFunctionWrapper(functionClassName), children) + } else if (classOf[GenericUDTF].isAssignableFrom(functionInfo.getFunctionClass)) { + HiveGenericUDTF(new HiveFunctionWrapper(functionClassName), children) + } else { + sys.error(s"No handler for udf ${functionInfo.getFunctionClass}") + } } } + + override def registerFunction(name: String, builder: FunctionBuilder): Unit = + throw new UnsupportedOperationException } -private[hive] case class HiveSimpleUdf(funcWrapper: HiveFunctionWrapper, children: Seq[Expression]) +private[hive] case class HiveSimpleUDF(funcWrapper: HiveFunctionWrapper, children: Seq[Expression]) extends Expression with HiveInspectors with Logging { - type EvaluatedType = Any + type UDFType = UDF + override def deterministic: Boolean = isUDFDeterministic + override def nullable: Boolean = true @transient @@ -113,7 +121,7 @@ private[hive] case class HiveSimpleUdf(funcWrapper: HiveFunctionWrapper, childre protected lazy val cached: Array[AnyRef] = new Array[AnyRef](children.length) // TODO: Finish input output types. - override def eval(input: Row): Any = { + override def eval(input: InternalRow): Any = { unwrap( FunctionRegistry.invoke(method, function, conversionHelper .convertIfNecessary(wrap(children.map(c => c.eval(input)), arguments, cached): _*): _*), @@ -136,10 +144,11 @@ private[hive] class DeferredObjectAdapter(oi: ObjectInspector) override def get(): AnyRef = wrap(func(), oi) } -private[hive] case class HiveGenericUdf(funcWrapper: HiveFunctionWrapper, children: Seq[Expression]) +private[hive] case class HiveGenericUDF(funcWrapper: HiveFunctionWrapper, children: Seq[Expression]) extends Expression with HiveInspectors with Logging { type UDFType = GenericUDF - type EvaluatedType = Any + + override def deterministic: Boolean = isUDFDeterministic override def nullable: Boolean = true @@ -169,7 +178,7 @@ private[hive] case class HiveGenericUdf(funcWrapper: HiveFunctionWrapper, childr lazy val dataType: DataType = inspectorToDataType(returnInspector) - override def eval(input: Row): Any = { + override def eval(input: InternalRow): Any = { returnInspector // Make sure initialized. var i = 0 @@ -316,7 +325,7 @@ private[hive] case class HiveWindowFunction( // The object inspector of values returned from the Hive window function. @transient - protected lazy val returnInspector = { + protected lazy val returnInspector = { evaluator.init(GenericUDAFEvaluator.Mode.COMPLETE, inputInspectors) } @@ -336,9 +345,7 @@ private[hive] case class HiveWindowFunction( def nullable: Boolean = true - override type EvaluatedType = Any - - override def eval(input: Row): Any = + override def eval(input: InternalRow): Any = throw new TreeNodeException(this, s"No function to evaluate expression. type: ${this.nodeName}") @transient @@ -362,7 +369,7 @@ private[hive] case class HiveWindowFunction( evaluator.reset(hiveEvaluatorBuffer) } - override def prepareInputParameters(input: Row): AnyRef = { + override def prepareInputParameters(input: InternalRow): AnyRef = { wrap(inputProjection(input), inputInspectors, new Array[AnyRef](children.length)) } // Add input parameters for a single row. @@ -402,7 +409,7 @@ private[hive] case class HiveWindowFunction( new HiveWindowFunction(funcWrapper, pivotResult, isUDAFBridgeRequired, children) } -private[hive] case class HiveGenericUdaf( +private[hive] case class HiveGenericUDAF( funcWrapper: HiveFunctionWrapper, children: Seq[Expression]) extends AggregateExpression with HiveInspectors { @@ -413,7 +420,7 @@ private[hive] case class HiveGenericUdaf( protected lazy val resolver: AbstractGenericUDAFResolver = funcWrapper.createFunction() @transient - protected lazy val objectInspector = { + protected lazy val objectInspector = { val parameterInfo = new SimpleGenericUDAFParameterInfo(inspectors.toArray, false, false) resolver.getEvaluator(parameterInfo) .init(GenericUDAFEvaluator.Mode.COMPLETE, inspectors.toArray) @@ -430,11 +437,11 @@ private[hive] case class HiveGenericUdaf( s"$nodeName#${funcWrapper.functionClassName}(${children.mkString(",")})" } - def newInstance(): HiveUdafFunction = new HiveUdafFunction(funcWrapper, children, this) + def newInstance(): HiveUDAFFunction = new HiveUDAFFunction(funcWrapper, children, this) } /** It is used as a wrapper for the hive functions which uses UDAF interface */ -private[hive] case class HiveUdaf( +private[hive] case class HiveUDAF( funcWrapper: HiveFunctionWrapper, children: Seq[Expression]) extends AggregateExpression with HiveInspectors { @@ -446,7 +453,7 @@ private[hive] case class HiveUdaf( new GenericUDAFBridge(funcWrapper.createFunction()) @transient - protected lazy val objectInspector = { + protected lazy val objectInspector = { val parameterInfo = new SimpleGenericUDAFParameterInfo(inspectors.toArray, false, false) resolver.getEvaluator(parameterInfo) .init(GenericUDAFEvaluator.Mode.COMPLETE, inspectors.toArray) @@ -463,7 +470,7 @@ private[hive] case class HiveUdaf( s"$nodeName#${funcWrapper.functionClassName}(${children.mkString(",")})" } - def newInstance(): HiveUdafFunction = new HiveUdafFunction(funcWrapper, children, this, true) + def newInstance(): HiveUDAFFunction = new HiveUDAFFunction(funcWrapper, children, this, true) } /** @@ -477,7 +484,7 @@ private[hive] case class HiveUdaf( * Operators that require maintaining state in between input rows should instead be implemented as * user defined aggregations, which have clean semantics even in a partitioned execution. */ -private[hive] case class HiveGenericUdtf( +private[hive] case class HiveGenericUDTF( funcWrapper: HiveFunctionWrapper, children: Seq[Expression]) extends Generator with HiveInspectors { @@ -505,7 +512,7 @@ private[hive] case class HiveGenericUdtf( field => (inspectorToDataType(field.getFieldObjectInspector), true) } - override def eval(input: Row): TraversableOnce[Row] = { + override def eval(input: InternalRow): TraversableOnce[InternalRow] = { outputInspector // Make sure initialized. val inputProjection = new InterpretedProjection(children) @@ -515,23 +522,23 @@ private[hive] case class HiveGenericUdtf( } protected class UDTFCollector extends Collector { - var collected = new ArrayBuffer[Row] + var collected = new ArrayBuffer[InternalRow] override def collect(input: java.lang.Object) { // We need to clone the input here because implementations of // GenericUDTF reuse the same object. Luckily they are always an array, so // it is easy to clone. - collected += unwrap(input, outputInspector).asInstanceOf[Row] + collected += unwrap(input, outputInspector).asInstanceOf[InternalRow] } - def collectRows(): Seq[Row] = { + def collectRows(): Seq[InternalRow] = { val toCollect = collected - collected = new ArrayBuffer[Row] + collected = new ArrayBuffer[InternalRow] toCollect } } - override def terminate(): TraversableOnce[Row] = { + override def terminate(): TraversableOnce[InternalRow] = { outputInspector // Make sure initialized. function.close() collector.collectRows() @@ -542,7 +549,7 @@ private[hive] case class HiveGenericUdtf( } } -private[hive] case class HiveUdafFunction( +private[hive] case class HiveUDAFFunction( funcWrapper: HiveFunctionWrapper, exprs: Seq[Expression], base: AggregateExpression, @@ -558,12 +565,12 @@ private[hive] case class HiveUdafFunction( } else { funcWrapper.createFunction[AbstractGenericUDAFResolver]() } - + private val inspectors = exprs.map(toInspector).toArray - - private val function = { + + private val function = { val parameterInfo = new SimpleGenericUDAFParameterInfo(inspectors, false, false) - resolver.getEvaluator(parameterInfo) + resolver.getEvaluator(parameterInfo) } private val returnInspector = function.init(GenericUDAFEvaluator.Mode.COMPLETE, inspectors) @@ -571,15 +578,15 @@ private[hive] case class HiveUdafFunction( private val buffer = function.getNewAggregationBuffer - override def eval(input: Row): Any = unwrap(function.evaluate(buffer), returnInspector) + override def eval(input: InternalRow): Any = unwrap(function.evaluate(buffer), returnInspector) @transient val inputProjection = new InterpretedProjection(exprs) @transient protected lazy val cached = new Array[AnyRef](exprs.length) - - def update(input: Row): Unit = { + + def update(input: InternalRow): Unit = { val inputs = inputProjection(input) function.iterate(buffer, wrap(inputs, inspectors, cached)) } diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveWriterContainers.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveWriterContainers.scala index cbc381cc81b5..ecc78a5f8d32 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveWriterContainers.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveWriterContainers.scala @@ -34,8 +34,10 @@ import org.apache.hadoop.hive.common.FileUtils import org.apache.spark.mapred.SparkHadoopMapRedUtil import org.apache.spark.sql.Row import org.apache.spark.{Logging, SerializableWritable, SparkHadoopWriter} -import org.apache.spark.sql.hive.{ShimFileSinkDesc => FileSinkDesc} -import org.apache.spark.sql.hive.HiveShim._ +import org.apache.spark.sql.catalyst.util.DateTimeUtils +import org.apache.spark.sql.hive.HiveShim.{ShimFileSinkDesc => FileSinkDesc} +import org.apache.spark.sql.types._ +import org.apache.spark.util.SerializableJobConf /** * Internal helper class that saves an RDD using a Hive OutputFormat. @@ -56,7 +58,7 @@ private[hive] class SparkHiveWriterContainer( PlanUtils.configureOutputJobPropertiesForStorageHandler(tableDesc) Utilities.copyTableJobPropertiesToConf(tableDesc, jobConf) } - protected val conf = new SerializableWritable(jobConf) + protected val conf = new SerializableJobConf(jobConf) private var jobID = 0 private var splitID = 0 @@ -69,7 +71,7 @@ private[hive] class SparkHiveWriterContainer( @transient protected lazy val jobContext = newJobContext(conf.value, jID.value) @transient private lazy val taskContext = newTaskAttemptContext(conf.value, taID.value) @transient private lazy val outputFormat = - conf.value.getOutputFormat.asInstanceOf[HiveOutputFormat[AnyRef,Writable]] + conf.value.getOutputFormat.asInstanceOf[HiveOutputFormat[AnyRef, Writable]] def driverSideSetup() { setIDs(0, 0, 0) @@ -92,7 +94,7 @@ private[hive] class SparkHiveWriterContainer( "part-" + numberFormat.format(splitID) + extension } - def getLocalFileWriter(row: Row): FileSinkOperator.RecordWriter = writer + def getLocalFileWriter(row: Row, schema: StructType): FileSinkOperator.RecordWriter = writer def close() { // Seems the boolean value passed into close does not matter. @@ -195,11 +197,20 @@ private[spark] class SparkHiveDynamicPartitionWriterContainer( jobConf.setBoolean(SUCCESSFUL_JOB_OUTPUT_DIR_MARKER, oldMarker) } - override def getLocalFileWriter(row: Row): FileSinkOperator.RecordWriter = { + override def getLocalFileWriter(row: Row, schema: StructType): FileSinkOperator.RecordWriter = { + def convertToHiveRawString(col: String, value: Any): String = { + val raw = String.valueOf(value) + schema(col).dataType match { + case DateType => DateTimeUtils.dateToString(raw.toInt) + case _: DecimalType => BigDecimal(raw).toString() + case _ => raw + } + } + val dynamicPartPath = dynamicPartColNames .zip(row.toSeq.takeRight(dynamicPartColNames.length)) .map { case (col, rawVal) => - val string = if (rawVal == null) null else String.valueOf(rawVal) + val string = if (rawVal == null) null else convertToHiveRawString(col, rawVal) val colString = if (string == null || string.isEmpty) { defaultPartName @@ -217,12 +228,11 @@ private[spark] class SparkHiveDynamicPartitionWriterContainer( newFileSinkDesc.setCompressCodec(fileSinkConf.getCompressCodec) newFileSinkDesc.setCompressType(fileSinkConf.getCompressType) - val path = { - val outputPath = FileOutputFormat.getOutputPath(conf.value) - assert(outputPath != null, "Undefined job output-path") - val workPath = new Path(outputPath, dynamicPartPath.stripPrefix("/")) - new Path(workPath, getOutputName) - } + // use the path like ${hive_tmp}/_temporary/${attemptId}/ + // to avoid write to the same file when `spark.speculation=true` + val path = FileOutputFormat.getTaskOutputPath( + conf.value, + dynamicPartPath.stripPrefix("/") + "/" + getOutputName) HiveFileFormatUtils.getHiveRecordWriter( conf.value, diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/orc/OrcFileOperator.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/orc/OrcFileOperator.scala new file mode 100644 index 000000000000..0f9a1a6ef3b2 --- /dev/null +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/orc/OrcFileOperator.scala @@ -0,0 +1,110 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.hive.orc + +import org.apache.hadoop.conf.Configuration +import org.apache.hadoop.fs.Path +import org.apache.hadoop.hive.ql.io.orc.{OrcFile, Reader} +import org.apache.hadoop.hive.serde2.objectinspector.StructObjectInspector + +import org.apache.spark.Logging +import org.apache.spark.deploy.SparkHadoopUtil +import org.apache.spark.sql.AnalysisException +import org.apache.spark.sql.hive.HiveMetastoreTypes +import org.apache.spark.sql.types.StructType + +private[orc] object OrcFileOperator extends Logging { + /** + * Retrieves a ORC file reader from a given path. The path can point to either a directory or a + * single ORC file. If it points to an directory, it picks any non-empty ORC file within that + * directory. + * + * The reader returned by this method is mainly used for two purposes: + * + * 1. Retrieving file metadata (schema and compression codecs, etc.) + * 2. Read the actual file content (in this case, the given path should point to the target file) + * + * @note As recorded by SPARK-8501, ORC writes an empty schema (struct<> + logInfo( + s"ORC file $path has empty schema, it probably contains no rows. " + + "Trying to read another ORC file to figure out the schema.") + false + case _ => true + } + } + + val conf = config.getOrElse(new Configuration) + val fs = { + val hdfsPath = new Path(basePath) + hdfsPath.getFileSystem(conf) + } + + listOrcFiles(basePath, conf).iterator.map { path => + path -> OrcFile.createReader(fs, path) + }.collectFirst { + case (path, reader) if isWithNonEmptySchema(path, reader) => reader + } + } + + def readSchema(path: String, conf: Option[Configuration]): StructType = { + val reader = getFileReader(path, conf).getOrElse { + throw new AnalysisException( + s"Failed to discover schema from ORC files stored in $path. " + + "Probably there are either no ORC files or only empty ORC files.") + } + val readerInspector = reader.getObjectInspector.asInstanceOf[StructObjectInspector] + val schema = readerInspector.getTypeName + logDebug(s"Reading schema from file $path, got Hive schema string: $schema") + HiveMetastoreTypes.toDataType(schema).asInstanceOf[StructType] + } + + def getObjectInspector( + path: String, conf: Option[Configuration]): Option[StructObjectInspector] = { + getFileReader(path, conf).map(_.getObjectInspector.asInstanceOf[StructObjectInspector]) + } + + def listOrcFiles(pathStr: String, conf: Configuration): Seq[Path] = { + val origPath = new Path(pathStr) + val fs = origPath.getFileSystem(conf) + val path = origPath.makeQualified(fs.getUri, fs.getWorkingDirectory) + val paths = SparkHadoopUtil.get.listLeafStatuses(fs, origPath) + .filterNot(_.isDir) + .map(_.getPath) + .filterNot(_.getName.startsWith("_")) + .filterNot(_.getName.startsWith(".")) + + if (paths == null || paths.isEmpty) { + throw new IllegalArgumentException( + s"orcFileOperator: path $path does not have valid orc files matching the pattern") + } + + paths + } +} diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/orc/OrcFilters.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/orc/OrcFilters.scala new file mode 100644 index 000000000000..250e73a4dba9 --- /dev/null +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/orc/OrcFilters.scala @@ -0,0 +1,144 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.hive.orc + +import org.apache.hadoop.hive.common.`type`.{HiveChar, HiveDecimal, HiveVarchar} +import org.apache.hadoop.hive.ql.io.sarg.SearchArgument +import org.apache.hadoop.hive.ql.io.sarg.SearchArgument.Builder +import org.apache.hadoop.hive.serde2.io.DateWritable + +import org.apache.spark.Logging +import org.apache.spark.sql.sources._ + +/** + * It may be optimized by push down partial filters. But we are conservative here. + * Because if some filters fail to be parsed, the tree may be corrupted, + * and cannot be used anymore. + */ +private[orc] object OrcFilters extends Logging { + def createFilter(expr: Array[Filter]): Option[SearchArgument] = { + expr.reduceOption(And).flatMap { conjunction => + val builder = SearchArgument.FACTORY.newBuilder() + buildSearchArgument(conjunction, builder).map(_.build()) + } + } + + private def buildSearchArgument(expression: Filter, builder: Builder): Option[Builder] = { + def newBuilder = SearchArgument.FACTORY.newBuilder() + + def isSearchableLiteral(value: Any) = value match { + // These are types recognized by the `SearchArgumentImpl.BuilderImpl.boxLiteral()` method. + case _: String | _: Long | _: Double | _: DateWritable | _: HiveDecimal | _: HiveChar | + _: HiveVarchar | _: Byte | _: Short | _: Integer | _: Float => true + case _ => false + } + + // lian: I probably missed something here, and had to end up with a pretty weird double-checking + // pattern when converting `And`/`Or`/`Not` filters. + // + // The annoying part is that, `SearchArgument` builder methods like `startAnd()` `startOr()`, + // and `startNot()` mutate internal state of the builder instance. This forces us to translate + // all convertible filters with a single builder instance. However, before actually converting a + // filter, we've no idea whether it can be recognized by ORC or not. Thus, when an inconvertible + // filter is found, we may already end up with a builder whose internal state is inconsistent. + // + // For example, to convert an `And` filter with builder `b`, we call `b.startAnd()` first, and + // then try to convert its children. Say we convert `left` child successfully, but find that + // `right` child is inconvertible. Alas, `b.startAnd()` call can't be rolled back, and `b` is + // inconsistent now. + // + // The workaround employed here is that, for `And`/`Or`/`Not`, we first try to convert their + // children with brand new builders, and only do the actual conversion with the right builder + // instance when the children are proven to be convertible. + // + // P.S.: Hive seems to use `SearchArgument` together with `ExprNodeGenericFuncDesc` only. + // Usage of builder methods mentioned above can only be found in test code, where all tested + // filters are known to be convertible. + + expression match { + case And(left, right) => + val tryLeft = buildSearchArgument(left, newBuilder) + val tryRight = buildSearchArgument(right, newBuilder) + + val conjunction = for { + _ <- tryLeft + _ <- tryRight + lhs <- buildSearchArgument(left, builder.startAnd()) + rhs <- buildSearchArgument(right, lhs) + } yield rhs.end() + + // For filter `left AND right`, we can still push down `left` even if `right` is not + // convertible, and vice versa. + conjunction + .orElse(tryLeft.flatMap(_ => buildSearchArgument(left, builder))) + .orElse(tryRight.flatMap(_ => buildSearchArgument(right, builder))) + + case Or(left, right) => + for { + _ <- buildSearchArgument(left, newBuilder) + _ <- buildSearchArgument(right, newBuilder) + lhs <- buildSearchArgument(left, builder.startOr()) + rhs <- buildSearchArgument(right, lhs) + } yield rhs.end() + + case Not(child) => + for { + _ <- buildSearchArgument(child, newBuilder) + negate <- buildSearchArgument(child, builder.startNot()) + } yield negate.end() + + case EqualTo(attribute, value) => + Option(value) + .filter(isSearchableLiteral) + .map(builder.equals(attribute, _)) + + case LessThan(attribute, value) => + Option(value) + .filter(isSearchableLiteral) + .map(builder.lessThan(attribute, _)) + + case LessThanOrEqual(attribute, value) => + Option(value) + .filter(isSearchableLiteral) + .map(builder.lessThanEquals(attribute, _)) + + case GreaterThan(attribute, value) => + Option(value) + .filter(isSearchableLiteral) + .map(builder.startNot().lessThanEquals(attribute, _).end()) + + case GreaterThanOrEqual(attribute, value) => + Option(value) + .filter(isSearchableLiteral) + .map(builder.startNot().lessThan(attribute, _).end()) + + case IsNull(attribute) => + Some(builder.isNull(attribute)) + + case IsNotNull(attribute) => + Some(builder.startNot().isNull(attribute).end()) + + case In(attribute, values) => + Option(values) + .filter(_.forall(isSearchableLiteral)) + .map(builder.in(attribute, _)) + + case _ => None + } + } +} diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/orc/OrcRelation.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/orc/OrcRelation.scala new file mode 100644 index 000000000000..9dc9fbb78e01 --- /dev/null +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/orc/OrcRelation.scala @@ -0,0 +1,323 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.hive.orc + +import java.util.Properties + +import com.google.common.base.Objects +import org.apache.hadoop.conf.Configuration +import org.apache.hadoop.fs.{FileStatus, Path} +import org.apache.hadoop.hive.conf.HiveConf.ConfVars +import org.apache.hadoop.hive.ql.io.orc.{OrcInputFormat, OrcOutputFormat, OrcSerde, OrcSplit} +import org.apache.hadoop.hive.serde2.objectinspector.StructObjectInspector +import org.apache.hadoop.hive.serde2.typeinfo.TypeInfoUtils +import org.apache.hadoop.io.{NullWritable, Writable} +import org.apache.hadoop.mapred.{InputFormat => MapRedInputFormat, JobConf, OutputFormat => MapRedOutputFormat, RecordWriter, Reporter} +import org.apache.hadoop.mapreduce.lib.input.FileInputFormat +import org.apache.hadoop.mapreduce.{Job, TaskAttemptContext} + +import org.apache.spark.Logging +import org.apache.spark.annotation.DeveloperApi +import org.apache.spark.mapred.SparkHadoopMapRedUtil +import org.apache.spark.rdd.{HadoopRDD, RDD} +import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.hive.{HiveContext, HiveInspectors, HiveMetastoreTypes, HiveShim} +import org.apache.spark.sql.sources.{Filter, _} +import org.apache.spark.sql.types.StructType +import org.apache.spark.sql.{Row, SQLContext} +import org.apache.spark.util.SerializableConfiguration + +/* Implicit conversions */ +import scala.collection.JavaConversions._ + +private[sql] class DefaultSource extends HadoopFsRelationProvider { + def createRelation( + sqlContext: SQLContext, + paths: Array[String], + dataSchema: Option[StructType], + partitionColumns: Option[StructType], + parameters: Map[String, String]): HadoopFsRelation = { + assert( + sqlContext.isInstanceOf[HiveContext], + "The ORC data source can only be used with HiveContext.") + + new OrcRelation(paths, dataSchema, None, partitionColumns, parameters)(sqlContext) + } +} + +private[orc] class OrcOutputWriter( + path: String, + dataSchema: StructType, + context: TaskAttemptContext) + extends OutputWriter with SparkHadoopMapRedUtil with HiveInspectors { + + private val serializer = { + val table = new Properties() + table.setProperty("columns", dataSchema.fieldNames.mkString(",")) + table.setProperty("columns.types", dataSchema.map { f => + HiveMetastoreTypes.toMetastoreType(f.dataType) + }.mkString(":")) + + val serde = new OrcSerde + serde.initialize(context.getConfiguration, table) + serde + } + + // Object inspector converted from the schema of the relation to be written. + private val structOI = { + val typeInfo = + TypeInfoUtils.getTypeInfoFromTypeString( + HiveMetastoreTypes.toMetastoreType(dataSchema)) + + TypeInfoUtils + .getStandardJavaObjectInspectorFromTypeInfo(typeInfo) + .asInstanceOf[StructObjectInspector] + } + + // Used to hold temporary `Writable` fields of the next row to be written. + private val reusableOutputBuffer = new Array[Any](dataSchema.length) + + // Used to convert Catalyst values into Hadoop `Writable`s. + private val wrappers = structOI.getAllStructFieldRefs.map { ref => + wrapperFor(ref.getFieldObjectInspector) + }.toArray + + // `OrcRecordWriter.close()` creates an empty file if no rows are written at all. We use this + // flag to decide whether `OrcRecordWriter.close()` needs to be called. + private var recordWriterInstantiated = false + + private lazy val recordWriter: RecordWriter[NullWritable, Writable] = { + recordWriterInstantiated = true + + val conf = context.getConfiguration + val uniqueWriteJobId = conf.get("spark.sql.sources.writeJobUUID") + val partition = context.getTaskAttemptID.getTaskID.getId + val filename = f"part-r-$partition%05d-$uniqueWriteJobId.orc" + + new OrcOutputFormat().getRecordWriter( + new Path(path, filename).getFileSystem(conf), + conf.asInstanceOf[JobConf], + new Path(path, filename).toString, + Reporter.NULL + ).asInstanceOf[RecordWriter[NullWritable, Writable]] + } + + override def write(row: Row): Unit = { + var i = 0 + while (i < row.length) { + reusableOutputBuffer(i) = wrappers(i)(row(i)) + i += 1 + } + + recordWriter.write( + NullWritable.get(), + serializer.serialize(reusableOutputBuffer, structOI)) + } + + override def close(): Unit = { + if (recordWriterInstantiated) { + recordWriter.close(Reporter.NULL) + } + } +} + +@DeveloperApi +private[sql] class OrcRelation( + override val paths: Array[String], + maybeDataSchema: Option[StructType], + maybePartitionSpec: Option[PartitionSpec], + override val userDefinedPartitionColumns: Option[StructType], + parameters: Map[String, String])( + @transient val sqlContext: SQLContext) + extends HadoopFsRelation(maybePartitionSpec) + with Logging { + + private[sql] def this( + paths: Array[String], + maybeDataSchema: Option[StructType], + maybePartitionSpec: Option[PartitionSpec], + parameters: Map[String, String])( + sqlContext: SQLContext) = { + this( + paths, + maybeDataSchema, + maybePartitionSpec, + maybePartitionSpec.map(_.partitionColumns), + parameters)(sqlContext) + } + + override val dataSchema: StructType = maybeDataSchema.getOrElse { + OrcFileOperator.readSchema( + paths.head, Some(sqlContext.sparkContext.hadoopConfiguration)) + } + + override def needConversion: Boolean = false + + override def equals(other: Any): Boolean = other match { + case that: OrcRelation => + paths.toSet == that.paths.toSet && + dataSchema == that.dataSchema && + schema == that.schema && + partitionColumns == that.partitionColumns + case _ => false + } + + override def hashCode(): Int = { + Objects.hashCode( + paths.toSet, + dataSchema, + schema, + partitionColumns) + } + + override def buildScan( + requiredColumns: Array[String], + filters: Array[Filter], + inputPaths: Array[FileStatus]): RDD[Row] = { + val output = StructType(requiredColumns.map(dataSchema(_))).toAttributes + OrcTableScan(output, this, filters, inputPaths).execute().map(_.asInstanceOf[Row]) + } + + override def prepareJobForWrite(job: Job): OutputWriterFactory = { + job.getConfiguration match { + case conf: JobConf => + conf.setOutputFormat(classOf[OrcOutputFormat]) + case conf => + conf.setClass( + "mapred.output.format.class", + classOf[OrcOutputFormat], + classOf[MapRedOutputFormat[_, _]]) + } + + new OutputWriterFactory { + override def newInstance( + path: String, + dataSchema: StructType, + context: TaskAttemptContext): OutputWriter = { + new OrcOutputWriter(path, dataSchema, context) + } + } + } +} + +private[orc] case class OrcTableScan( + attributes: Seq[Attribute], + @transient relation: OrcRelation, + filters: Array[Filter], + @transient inputPaths: Array[FileStatus]) + extends Logging + with HiveInspectors { + + @transient private val sqlContext = relation.sqlContext + + private def addColumnIds( + output: Seq[Attribute], + relation: OrcRelation, + conf: Configuration): Unit = { + val ids = output.map(a => relation.dataSchema.fieldIndex(a.name): Integer) + val (sortedIds, sortedNames) = ids.zip(attributes.map(_.name)).sorted.unzip + HiveShim.appendReadColumns(conf, sortedIds, sortedNames) + } + + // Transform all given raw `Writable`s into `InternalRow`s. + private def fillObject( + path: String, + conf: Configuration, + iterator: Iterator[Writable], + nonPartitionKeyAttrs: Seq[(Attribute, Int)], + mutableRow: MutableRow): Iterator[InternalRow] = { + val deserializer = new OrcSerde + val maybeStructOI = OrcFileOperator.getObjectInspector(path, Some(conf)) + + // SPARK-8501: ORC writes an empty schema ("struct<>") to an ORC file if the file contains zero + // rows, and thus couldn't give a proper ObjectInspector. In this case we just return an empty + // partition since we know that this file is empty. + maybeStructOI.map { soi => + val (fieldRefs, fieldOrdinals) = nonPartitionKeyAttrs.map { + case (attr, ordinal) => + soi.getStructFieldRef(attr.name.toLowerCase) -> ordinal + }.unzip + val unwrappers = fieldRefs.map(unwrapperFor) + // Map each tuple to a row object + iterator.map { value => + val raw = deserializer.deserialize(value) + var i = 0 + while (i < fieldRefs.length) { + val fieldValue = soi.getStructFieldData(raw, fieldRefs(i)) + if (fieldValue == null) { + mutableRow.setNullAt(fieldOrdinals(i)) + } else { + unwrappers(i)(fieldValue, mutableRow, fieldOrdinals(i)) + } + i += 1 + } + mutableRow: InternalRow + } + }.getOrElse { + Iterator.empty + } + } + + def execute(): RDD[InternalRow] = { + val job = new Job(sqlContext.sparkContext.hadoopConfiguration) + val conf = job.getConfiguration + + // Tries to push down filters if ORC filter push-down is enabled + if (sqlContext.conf.orcFilterPushDown) { + OrcFilters.createFilter(filters).foreach { f => + conf.set(OrcTableScan.SARG_PUSHDOWN, f.toKryo) + conf.setBoolean(ConfVars.HIVEOPTINDEXFILTER.varname, true) + } + } + + // Sets requested columns + addColumnIds(attributes, relation, conf) + + if (inputPaths.nonEmpty) { + FileInputFormat.setInputPaths(job, inputPaths.map(_.getPath): _*) + } + + val inputFormatClass = + classOf[OrcInputFormat] + .asInstanceOf[Class[_ <: MapRedInputFormat[NullWritable, Writable]]] + + val rdd = sqlContext.sparkContext.hadoopRDD( + conf.asInstanceOf[JobConf], + inputFormatClass, + classOf[NullWritable], + classOf[Writable] + ).asInstanceOf[HadoopRDD[NullWritable, Writable]] + + val wrappedConf = new SerializableConfiguration(conf) + + rdd.mapPartitionsWithInputSplit { case (split: OrcSplit, iterator) => + val mutableRow = new SpecificMutableRow(attributes.map(_.dataType)) + fillObject( + split.getPath.toString, + wrappedConf.value, + iterator.map(_._2), + attributes.zipWithIndex, + mutableRow) + } + } +} + +private[orc] object OrcTableScan { + // This constant duplicates `OrcInputFormat.SARG_PUSHDOWN`, which is unfortunately not public. + private[orc] val SARG_PUSHDOWN = "sarg.pushdown" +} diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/test/TestHive.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/test/TestHive.scala index 1598d4bd4755..7978fdacaedb 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/test/TestHive.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/test/TestHive.scala @@ -48,7 +48,14 @@ import scala.collection.JavaConversions._ // SPARK-3729: Test key required to check for initialization errors with config. object TestHive extends TestHiveContext( - new SparkContext("local[2]", "TestSQLContext", new SparkConf().set("spark.sql.test", ""))) + new SparkContext( + System.getProperty("spark.sql.test.master", "local[32]"), + "TestSQLContext", + new SparkConf() + .set("spark.sql.test", "") + .set( + "spark.sql.hive.metastore.barrierPrefixes", + "org.apache.spark.sql.hive.execution.PairSerDe"))) /** * A locally running test instance of Spark's Hive execution engine. @@ -75,9 +82,11 @@ class TestHiveContext(sc: SparkContext) extends HiveContext(sc) { lazy val warehousePath = Utils.createTempDir() + private lazy val temporaryConfig = newTemporaryConfiguration() + /** Sets up the system initially or after a RESET command */ protected override def configure(): Map[String, String] = - newTemporaryConfiguration() ++ Map("hive.metastore.warehouse.dir" -> warehousePath.toString) + temporaryConfig ++ Map("hive.metastore.warehouse.dir" -> warehousePath.toString) val testTempDir = Utils.createTempDir() @@ -103,12 +112,11 @@ class TestHiveContext(sc: SparkContext) extends HiveContext(sc) { protected[hive] class SQLSession extends super.SQLSession { /** Fewer partitions to speed up testing. */ protected[sql] override lazy val conf: SQLConf = new SQLConf { - override def numShufflePartitions: Int = getConf(SQLConf.SHUFFLE_PARTITIONS, "5").toInt + override def numShufflePartitions: Int = getConf(SQLConf.SHUFFLE_PARTITIONS, 5) // TODO as in unit test, conf.clear() probably be called, all of the value will be cleared. // The super.getConf(SQLConf.DIALECT) is "sql" by default, we need to set it as "hiveql" override def dialect: String = super.getConf(SQLConf.DIALECT, "hiveql") - override def caseSensitiveAnalysis: Boolean = - getConf(SQLConf.CASE_SENSITIVE, "false").toBoolean + override def caseSensitiveAnalysis: Boolean = getConf(SQLConf.CASE_SENSITIVE, false) } } @@ -180,7 +188,7 @@ class TestHiveContext(sc: SparkContext) extends HiveContext(sc) { } } - case class TestTable(name: String, commands: (()=>Unit)*) + case class TestTable(name: String, commands: (() => Unit)*) protected[hive] implicit class SqlCmd(sql: String) { def cmd: () => Unit = { @@ -244,8 +252,8 @@ class TestHiveContext(sc: SparkContext) extends HiveContext(sc) { | 'serialization.format'='${classOf[TBinaryProtocol].getName}' |) |STORED AS - |INPUTFORMAT '${classOf[SequenceFileInputFormat[_,_]].getName}' - |OUTPUTFORMAT '${classOf[SequenceFileOutputFormat[_,_]].getName}' + |INPUTFORMAT '${classOf[SequenceFileInputFormat[_, _]].getName}' + |OUTPUTFORMAT '${classOf[SequenceFileOutputFormat[_, _]].getName}' """.stripMargin) runSqlHive( @@ -383,7 +391,7 @@ class TestHiveContext(sc: SparkContext) extends HiveContext(sc) { * Records the UDFs present when the server starts, so we can delete ones that are created by * tests. */ - protected val originalUdfs: JavaSet[String] = FunctionRegistry.getFunctionNames + protected val originalUDFs: JavaSet[String] = FunctionRegistry.getFunctionNames /** * Resets the test instance by deleting any tables that have been created. @@ -402,7 +410,7 @@ class TestHiveContext(sc: SparkContext) extends HiveContext(sc) { catalog.client.reset() catalog.unregisterAllTables() - FunctionRegistry.getFunctionNames.filterNot(originalUdfs.contains(_)).foreach { udfName => + FunctionRegistry.getFunctionNames.filterNot(originalUDFs.contains(_)).foreach { udfName => FunctionRegistry.unregisterTemporaryUDF(udfName) } diff --git a/sql/hive/src/test/java/org/apache/spark/sql/hive/execution/UDFToListInt.java b/sql/hive/src/test/java/org/apache/spark/sql/hive/execution/UDFToListInt.java new file mode 100644 index 000000000000..67576a72f198 --- /dev/null +++ b/sql/hive/src/test/java/org/apache/spark/sql/hive/execution/UDFToListInt.java @@ -0,0 +1,29 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.hive.execution; + +import org.apache.hadoop.hive.ql.exec.UDF; + +import java.util.Arrays; +import java.util.List; + +public class UDFToListInt extends UDF { + public List evaluate(Object o) { + return Arrays.asList(1, 2, 3); + } +} diff --git a/sql/hive/src/test/java/org/apache/spark/sql/hive/execution/UDFToListString.java b/sql/hive/src/test/java/org/apache/spark/sql/hive/execution/UDFToListString.java new file mode 100644 index 000000000000..f02395cbba88 --- /dev/null +++ b/sql/hive/src/test/java/org/apache/spark/sql/hive/execution/UDFToListString.java @@ -0,0 +1,29 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.hive.execution; + +import org.apache.hadoop.hive.ql.exec.UDF; + +import java.util.Arrays; +import java.util.List; + +public class UDFToListString extends UDF { + public List evaluate(Object o) { + return Arrays.asList("data1", "data2", "data3"); + } +} diff --git a/sql/hive/src/test/java/test/org/apache/spark/sql/hive/JavaDataFrameSuite.java b/sql/hive/src/test/java/test/org/apache/spark/sql/hive/JavaDataFrameSuite.java new file mode 100644 index 000000000000..c4828c471764 --- /dev/null +++ b/sql/hive/src/test/java/test/org/apache/spark/sql/hive/JavaDataFrameSuite.java @@ -0,0 +1,78 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package test.org.apache.spark.sql.hive; + +import java.io.IOException; +import java.util.ArrayList; +import java.util.List; + +import org.junit.After; +import org.junit.Assert; +import org.junit.Before; +import org.junit.Test; + +import org.apache.spark.api.java.JavaSparkContext; +import org.apache.spark.sql.*; +import org.apache.spark.sql.expressions.Window; +import org.apache.spark.sql.hive.HiveContext; +import org.apache.spark.sql.hive.test.TestHive$; + +public class JavaDataFrameSuite { + private transient JavaSparkContext sc; + private transient HiveContext hc; + + DataFrame df; + + private void checkAnswer(DataFrame actual, List expected) { + String errorMessage = QueryTest$.MODULE$.checkAnswer(actual, expected); + if (errorMessage != null) { + Assert.fail(errorMessage); + } + } + + @Before + public void setUp() throws IOException { + hc = TestHive$.MODULE$; + sc = new JavaSparkContext(hc.sparkContext()); + + List jsonObjects = new ArrayList(10); + for (int i = 0; i < 10; i++) { + jsonObjects.add("{\"key\":" + i + ", \"value\":\"str" + i + "\"}"); + } + df = hc.jsonRDD(sc.parallelize(jsonObjects)); + df.registerTempTable("window_table"); + } + + @After + public void tearDown() throws IOException { + // Clean up tables. + hc.sql("DROP TABLE IF EXISTS window_table"); + } + + @Test + public void saveTableAndQueryIt() { + checkAnswer( + df.select(functions.avg("key").over( + Window.partitionBy("value").orderBy("key").rowsBetween(-1, 1))), + hc.sql("SELECT avg(key) " + + "OVER (PARTITION BY value " + + " ORDER BY key " + + " ROWS BETWEEN 1 preceding and 1 following) " + + "FROM window_table").collectAsList()); + } +} diff --git a/sql/hive/src/test/java/org/apache/spark/sql/hive/JavaMetastoreDataSourcesSuite.java b/sql/hive/src/test/java/test/org/apache/spark/sql/hive/JavaMetastoreDataSourcesSuite.java similarity index 89% rename from sql/hive/src/test/java/org/apache/spark/sql/hive/JavaMetastoreDataSourcesSuite.java rename to sql/hive/src/test/java/test/org/apache/spark/sql/hive/JavaMetastoreDataSourcesSuite.java index 53ddecf57958..64d1ce92931e 100644 --- a/sql/hive/src/test/java/org/apache/spark/sql/hive/JavaMetastoreDataSourcesSuite.java +++ b/sql/hive/src/test/java/test/org/apache/spark/sql/hive/JavaMetastoreDataSourcesSuite.java @@ -14,7 +14,8 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -package org.apache.spark.sql.hive; + +package test.org.apache.spark.sql.hive; import java.io.File; import java.io.IOException; @@ -36,6 +37,7 @@ import org.apache.spark.sql.DataFrame; import org.apache.spark.sql.QueryTest$; import org.apache.spark.sql.Row; +import org.apache.spark.sql.hive.HiveContext; import org.apache.spark.sql.hive.test.TestHive$; import org.apache.spark.sql.types.DataTypes; import org.apache.spark.sql.types.StructField; @@ -81,7 +83,7 @@ public void setUp() throws IOException { jsonObjects.add("{\"a\":" + i + ", \"b\":\"str" + i + "\"}"); } JavaRDD rdd = sc.parallelize(jsonObjects); - df = sqlContext.jsonRDD(rdd); + df = sqlContext.read().json(rdd); df.registerTempTable("jsonTable"); } @@ -96,7 +98,11 @@ public void tearDown() throws IOException { public void saveExternalTableAndQueryIt() { Map options = new HashMap(); options.put("path", path.toString()); - df.saveAsTable("javaSavedTable", "org.apache.spark.sql.json", SaveMode.Append, options); + df.write() + .format("org.apache.spark.sql.json") + .mode(SaveMode.Append) + .options(options) + .saveAsTable("javaSavedTable"); checkAnswer( sqlContext.sql("SELECT * FROM javaSavedTable"), @@ -115,7 +121,11 @@ public void saveExternalTableAndQueryIt() { public void saveExternalTableWithSchemaAndQueryIt() { Map options = new HashMap(); options.put("path", path.toString()); - df.saveAsTable("javaSavedTable", "org.apache.spark.sql.json", SaveMode.Append, options); + df.write() + .format("org.apache.spark.sql.json") + .mode(SaveMode.Append) + .options(options) + .saveAsTable("javaSavedTable"); checkAnswer( sqlContext.sql("SELECT * FROM javaSavedTable"), @@ -138,7 +148,11 @@ public void saveExternalTableWithSchemaAndQueryIt() { @Test public void saveTableAndQueryIt() { Map options = new HashMap(); - df.saveAsTable("javaSavedTable", "org.apache.spark.sql.json", SaveMode.Append, options); + df.write() + .format("org.apache.spark.sql.json") + .mode(SaveMode.Append) + .options(options) + .saveAsTable("javaSavedTable"); checkAnswer( sqlContext.sql("SELECT * FROM javaSavedTable"), diff --git a/sql/hive/src/test/java/org/apache/spark/sql/hive/execution/UDFIntegerToString.java b/sql/hive/src/test/java/test/org/apache/spark/sql/hive/execution/UDFIntegerToString.java similarity index 100% rename from sql/hive/src/test/java/org/apache/spark/sql/hive/execution/UDFIntegerToString.java rename to sql/hive/src/test/java/test/org/apache/spark/sql/hive/execution/UDFIntegerToString.java diff --git a/sql/hive/src/test/java/org/apache/spark/sql/hive/execution/UDFListListInt.java b/sql/hive/src/test/java/test/org/apache/spark/sql/hive/execution/UDFListListInt.java similarity index 100% rename from sql/hive/src/test/java/org/apache/spark/sql/hive/execution/UDFListListInt.java rename to sql/hive/src/test/java/test/org/apache/spark/sql/hive/execution/UDFListListInt.java diff --git a/sql/hive/src/test/java/org/apache/spark/sql/hive/execution/UDFListString.java b/sql/hive/src/test/java/test/org/apache/spark/sql/hive/execution/UDFListString.java similarity index 100% rename from sql/hive/src/test/java/org/apache/spark/sql/hive/execution/UDFListString.java rename to sql/hive/src/test/java/test/org/apache/spark/sql/hive/execution/UDFListString.java diff --git a/sql/hive/src/test/java/org/apache/spark/sql/hive/execution/UDFStringString.java b/sql/hive/src/test/java/test/org/apache/spark/sql/hive/execution/UDFStringString.java similarity index 100% rename from sql/hive/src/test/java/org/apache/spark/sql/hive/execution/UDFStringString.java rename to sql/hive/src/test/java/test/org/apache/spark/sql/hive/execution/UDFStringString.java diff --git a/sql/hive/src/test/java/org/apache/spark/sql/hive/execution/UDFTwoListList.java b/sql/hive/src/test/java/test/org/apache/spark/sql/hive/execution/UDFTwoListList.java similarity index 100% rename from sql/hive/src/test/java/org/apache/spark/sql/hive/execution/UDFTwoListList.java rename to sql/hive/src/test/java/test/org/apache/spark/sql/hive/execution/UDFTwoListList.java diff --git a/sql/hive/src/test/resources/data/files/testUdf/part-00000 b/sql/hive/src/test/resources/data/files/testUDF/part-00000 similarity index 100% rename from sql/hive/src/test/resources/data/files/testUdf/part-00000 rename to sql/hive/src/test/resources/data/files/testUDF/part-00000 diff --git a/sql/hive/src/test/resources/golden/Date cast-0-a7cd69b80c77a771a2c955db666be53d b/sql/hive/src/test/resources/golden/Date cast-0-a7cd69b80c77a771a2c955db666be53d deleted file mode 100644 index 98da82fa8938..000000000000 --- a/sql/hive/src/test/resources/golden/Date cast-0-a7cd69b80c77a771a2c955db666be53d +++ /dev/null @@ -1 +0,0 @@ -1970-01-01 1970-01-01 1969-12-31 16:00:00 1969-12-31 16:00:00 1970-01-01 00:00:00 diff --git a/sql/hive/src/test/resources/golden/Date comparison test 2-0-dc1b267f1d79d49e6675afe4fd2a34a5 b/sql/hive/src/test/resources/golden/Date comparison test 2-0-dc1b267f1d79d49e6675afe4fd2a34a5 deleted file mode 100644 index 27ba77ddaf61..000000000000 --- a/sql/hive/src/test/resources/golden/Date comparison test 2-0-dc1b267f1d79d49e6675afe4fd2a34a5 +++ /dev/null @@ -1 +0,0 @@ -true diff --git a/sql/hive/src/test/resources/golden/constant object inspector for generic udf-0-cc120a2331158f570a073599985d3f55 b/sql/hive/src/test/resources/golden/constant object inspector for generic udf-0-638f81ad9077c7d0c5c735c6e73742ad similarity index 100% rename from sql/hive/src/test/resources/golden/constant object inspector for generic udf-0-cc120a2331158f570a073599985d3f55 rename to sql/hive/src/test/resources/golden/constant object inspector for generic udf-0-638f81ad9077c7d0c5c735c6e73742ad diff --git a/sql/hive/src/test/resources/golden/timestamp cast #5-0-dbd7bcd167d322d6617b884c02c7f247 b/sql/hive/src/test/resources/golden/timestamp cast #5-0-dbd7bcd167d322d6617b884c02c7f247 index 27de46fdf22a..84a31a5a6970 100644 --- a/sql/hive/src/test/resources/golden/timestamp cast #5-0-dbd7bcd167d322d6617b884c02c7f247 +++ b/sql/hive/src/test/resources/golden/timestamp cast #5-0-dbd7bcd167d322d6617b884c02c7f247 @@ -1 +1 @@ --0.0010000000000000009 +-0.001 diff --git a/sql/hive/src/test/resources/golden/windowing.q -- 20. testSTATs-0-6dfcd7925fb267699c4bf82737d4609 b/sql/hive/src/test/resources/golden/windowing.q -- 20. testSTATs-0-6dfcd7925fb267699c4bf82737d4609 new file mode 100644 index 000000000000..7e5fceeddeee --- /dev/null +++ b/sql/hive/src/test/resources/golden/windowing.q -- 20. testSTATs-0-6dfcd7925fb267699c4bf82737d4609 @@ -0,0 +1,97 @@ +Manufacturer#1 almond antique burnished rose metallic 2 258.10677784349235 258.10677784349235 2 66619.10876874991 0.811328754177887 2801.7074999999995 +Manufacturer#1 almond antique burnished rose metallic 2 258.10677784349235 258.10677784349235 6 66619.10876874991 0.811328754177887 2801.7074999999995 +Manufacturer#1 almond antique burnished rose metallic 2 258.10677784349235 258.10677784349235 34 66619.10876874991 0.811328754177887 2801.7074999999995 +Manufacturer#1 almond antique burnished rose metallic 2 273.70217881648074 273.70217881648074 2 74912.8826888888 1.0 4128.782222222221 +Manufacturer#1 almond antique burnished rose metallic 2 273.70217881648074 273.70217881648074 34 74912.8826888888 1.0 4128.782222222221 +Manufacturer#1 almond antique chartreuse lavender yellow 34 230.90151585470358 230.90151585470358 2 53315.51002399992 0.695639377397664 2210.7864 +Manufacturer#1 almond antique chartreuse lavender yellow 34 230.90151585470358 230.90151585470358 6 53315.51002399992 0.695639377397664 2210.7864 +Manufacturer#1 almond antique chartreuse lavender yellow 34 230.90151585470358 230.90151585470358 28 53315.51002399992 0.695639377397664 2210.7864 +Manufacturer#1 almond antique chartreuse lavender yellow 34 230.90151585470358 230.90151585470358 34 53315.51002399992 0.695639377397664 2210.7864 +Manufacturer#1 almond antique salmon chartreuse burlywood 6 202.73109328368946 202.73109328368946 2 41099.896184 0.630785977101214 2009.9536000000007 +Manufacturer#1 almond antique salmon chartreuse burlywood 6 202.73109328368946 202.73109328368946 6 41099.896184 0.630785977101214 2009.9536000000007 +Manufacturer#1 almond antique salmon chartreuse burlywood 6 202.73109328368946 202.73109328368946 28 41099.896184 0.630785977101214 2009.9536000000007 +Manufacturer#1 almond antique salmon chartreuse burlywood 6 202.73109328368946 202.73109328368946 34 41099.896184 0.630785977101214 2009.9536000000007 +Manufacturer#1 almond antique salmon chartreuse burlywood 6 202.73109328368946 202.73109328368946 42 41099.896184 0.630785977101214 2009.9536000000007 +Manufacturer#1 almond aquamarine burnished black steel 28 121.6064517973862 121.6064517973862 6 14788.129118750014 0.2036684720435979 331.1337500000004 +Manufacturer#1 almond aquamarine burnished black steel 28 121.6064517973862 121.6064517973862 28 14788.129118750014 0.2036684720435979 331.1337500000004 +Manufacturer#1 almond aquamarine burnished black steel 28 121.6064517973862 121.6064517973862 34 14788.129118750014 0.2036684720435979 331.1337500000004 +Manufacturer#1 almond aquamarine burnished black steel 28 121.6064517973862 121.6064517973862 42 14788.129118750014 0.2036684720435979 331.1337500000004 +Manufacturer#1 almond aquamarine pink moccasin thistle 42 96.5751586416853 96.5751586416853 6 9326.761266666683 -1.4442181184933883E-4 -0.20666666666708502 +Manufacturer#1 almond aquamarine pink moccasin thistle 42 96.5751586416853 96.5751586416853 28 9326.761266666683 -1.4442181184933883E-4 -0.20666666666708502 +Manufacturer#1 almond aquamarine pink moccasin thistle 42 96.5751586416853 96.5751586416853 42 9326.761266666683 -1.4442181184933883E-4 -0.20666666666708502 +Manufacturer#2 almond antique violet chocolate turquoise 14 142.2363169751898 142.2363169751898 2 20231.169866666663 -0.49369526554523185 -1113.7466666666658 +Manufacturer#2 almond antique violet chocolate turquoise 14 142.2363169751898 142.2363169751898 14 20231.169866666663 -0.49369526554523185 -1113.7466666666658 +Manufacturer#2 almond antique violet chocolate turquoise 14 142.2363169751898 142.2363169751898 40 20231.169866666663 -0.49369526554523185 -1113.7466666666658 +Manufacturer#2 almond antique violet turquoise frosted 40 137.76306498840682 137.76306498840682 2 18978.662075 -0.5205630897335946 -1004.4812499999995 +Manufacturer#2 almond antique violet turquoise frosted 40 137.76306498840682 137.76306498840682 14 18978.662075 -0.5205630897335946 -1004.4812499999995 +Manufacturer#2 almond antique violet turquoise frosted 40 137.76306498840682 137.76306498840682 25 18978.662075 -0.5205630897335946 -1004.4812499999995 +Manufacturer#2 almond antique violet turquoise frosted 40 137.76306498840682 137.76306498840682 40 18978.662075 -0.5205630897335946 -1004.4812499999995 +Manufacturer#2 almond aquamarine midnight light salmon 2 130.03972279269132 130.03972279269132 2 16910.329504000005 -0.46908967495720255 -766.1791999999995 +Manufacturer#2 almond aquamarine midnight light salmon 2 130.03972279269132 130.03972279269132 14 16910.329504000005 -0.46908967495720255 -766.1791999999995 +Manufacturer#2 almond aquamarine midnight light salmon 2 130.03972279269132 130.03972279269132 18 16910.329504000005 -0.46908967495720255 -766.1791999999995 +Manufacturer#2 almond aquamarine midnight light salmon 2 130.03972279269132 130.03972279269132 25 16910.329504000005 -0.46908967495720255 -766.1791999999995 +Manufacturer#2 almond aquamarine midnight light salmon 2 130.03972279269132 130.03972279269132 40 16910.329504000005 -0.46908967495720255 -766.1791999999995 +Manufacturer#2 almond aquamarine rose maroon antique 25 135.55100986344584 135.55100986344584 2 18374.07627499999 -0.6091405874714462 -1128.1787499999987 +Manufacturer#2 almond aquamarine rose maroon antique 25 135.55100986344584 135.55100986344584 18 18374.07627499999 -0.6091405874714462 -1128.1787499999987 +Manufacturer#2 almond aquamarine rose maroon antique 25 135.55100986344584 135.55100986344584 25 18374.07627499999 -0.6091405874714462 -1128.1787499999987 +Manufacturer#2 almond aquamarine rose maroon antique 25 135.55100986344584 135.55100986344584 40 18374.07627499999 -0.6091405874714462 -1128.1787499999987 +Manufacturer#2 almond aquamarine sandy cyan gainsboro 18 156.44019460768044 156.44019460768044 2 24473.534488888927 -0.9571686373491608 -1441.4466666666676 +Manufacturer#2 almond aquamarine sandy cyan gainsboro 18 156.44019460768044 156.44019460768044 18 24473.534488888927 -0.9571686373491608 -1441.4466666666676 +Manufacturer#2 almond aquamarine sandy cyan gainsboro 18 156.44019460768044 156.44019460768044 25 24473.534488888927 -0.9571686373491608 -1441.4466666666676 +Manufacturer#3 almond antique chartreuse khaki white 17 196.7742266885805 196.7742266885805 14 38720.09628888887 0.5557168646224995 224.6944444444446 +Manufacturer#3 almond antique chartreuse khaki white 17 196.7742266885805 196.7742266885805 17 38720.09628888887 0.5557168646224995 224.6944444444446 +Manufacturer#3 almond antique chartreuse khaki white 17 196.7742266885805 196.7742266885805 19 38720.09628888887 0.5557168646224995 224.6944444444446 +Manufacturer#3 almond antique forest lavender goldenrod 14 275.14144189852607 275.14144189852607 1 75702.81305 -0.6720833036576083 -1296.9000000000003 +Manufacturer#3 almond antique forest lavender goldenrod 14 275.14144189852607 275.14144189852607 14 75702.81305 -0.6720833036576083 -1296.9000000000003 +Manufacturer#3 almond antique forest lavender goldenrod 14 275.14144189852607 275.14144189852607 17 75702.81305 -0.6720833036576083 -1296.9000000000003 +Manufacturer#3 almond antique forest lavender goldenrod 14 275.14144189852607 275.14144189852607 19 75702.81305 -0.6720833036576083 -1296.9000000000003 +Manufacturer#3 almond antique metallic orange dim 19 260.23473614412046 260.23473614412046 1 67722.117896 -0.5703526513979519 -2129.0664 +Manufacturer#3 almond antique metallic orange dim 19 260.23473614412046 260.23473614412046 14 67722.117896 -0.5703526513979519 -2129.0664 +Manufacturer#3 almond antique metallic orange dim 19 260.23473614412046 260.23473614412046 17 67722.117896 -0.5703526513979519 -2129.0664 +Manufacturer#3 almond antique metallic orange dim 19 260.23473614412046 260.23473614412046 19 67722.117896 -0.5703526513979519 -2129.0664 +Manufacturer#3 almond antique metallic orange dim 19 260.23473614412046 260.23473614412046 45 67722.117896 -0.5703526513979519 -2129.0664 +Manufacturer#3 almond antique misty red olive 1 275.9139962356932 275.9139962356932 1 76128.53331875012 -0.577476899644802 -2547.7868749999993 +Manufacturer#3 almond antique misty red olive 1 275.9139962356932 275.9139962356932 14 76128.53331875012 -0.577476899644802 -2547.7868749999993 +Manufacturer#3 almond antique misty red olive 1 275.9139962356932 275.9139962356932 19 76128.53331875012 -0.577476899644802 -2547.7868749999993 +Manufacturer#3 almond antique misty red olive 1 275.9139962356932 275.9139962356932 45 76128.53331875012 -0.577476899644802 -2547.7868749999993 +Manufacturer#3 almond antique olive coral navajo 45 260.5815918713796 260.5815918713796 1 67902.76602222225 -0.8710736366736884 -4099.731111111111 +Manufacturer#3 almond antique olive coral navajo 45 260.5815918713796 260.5815918713796 19 67902.76602222225 -0.8710736366736884 -4099.731111111111 +Manufacturer#3 almond antique olive coral navajo 45 260.5815918713796 260.5815918713796 45 67902.76602222225 -0.8710736366736884 -4099.731111111111 +Manufacturer#4 almond antique gainsboro frosted violet 10 170.13011889596618 170.13011889596618 10 28944.25735555559 -0.6656975320098423 -1347.4777777777779 +Manufacturer#4 almond antique gainsboro frosted violet 10 170.13011889596618 170.13011889596618 27 28944.25735555559 -0.6656975320098423 -1347.4777777777779 +Manufacturer#4 almond antique gainsboro frosted violet 10 170.13011889596618 170.13011889596618 39 28944.25735555559 -0.6656975320098423 -1347.4777777777779 +Manufacturer#4 almond antique violet mint lemon 39 242.26834609323197 242.26834609323197 7 58693.95151875002 -0.8051852719193339 -2537.328125 +Manufacturer#4 almond antique violet mint lemon 39 242.26834609323197 242.26834609323197 10 58693.95151875002 -0.8051852719193339 -2537.328125 +Manufacturer#4 almond antique violet mint lemon 39 242.26834609323197 242.26834609323197 27 58693.95151875002 -0.8051852719193339 -2537.328125 +Manufacturer#4 almond antique violet mint lemon 39 242.26834609323197 242.26834609323197 39 58693.95151875002 -0.8051852719193339 -2537.328125 +Manufacturer#4 almond aquamarine floral ivory bisque 27 234.10001662537326 234.10001662537326 7 54802.817784000035 -0.6046935574240581 -1719.8079999999995 +Manufacturer#4 almond aquamarine floral ivory bisque 27 234.10001662537326 234.10001662537326 10 54802.817784000035 -0.6046935574240581 -1719.8079999999995 +Manufacturer#4 almond aquamarine floral ivory bisque 27 234.10001662537326 234.10001662537326 12 54802.817784000035 -0.6046935574240581 -1719.8079999999995 +Manufacturer#4 almond aquamarine floral ivory bisque 27 234.10001662537326 234.10001662537326 27 54802.817784000035 -0.6046935574240581 -1719.8079999999995 +Manufacturer#4 almond aquamarine floral ivory bisque 27 234.10001662537326 234.10001662537326 39 54802.817784000035 -0.6046935574240581 -1719.8079999999995 +Manufacturer#4 almond aquamarine yellow dodger mint 7 247.3342714197732 247.3342714197732 7 61174.24181875003 -0.5508665654707869 -1719.0368749999975 +Manufacturer#4 almond aquamarine yellow dodger mint 7 247.3342714197732 247.3342714197732 12 61174.24181875003 -0.5508665654707869 -1719.0368749999975 +Manufacturer#4 almond aquamarine yellow dodger mint 7 247.3342714197732 247.3342714197732 27 61174.24181875003 -0.5508665654707869 -1719.0368749999975 +Manufacturer#4 almond aquamarine yellow dodger mint 7 247.3342714197732 247.3342714197732 39 61174.24181875003 -0.5508665654707869 -1719.0368749999975 +Manufacturer#4 almond azure aquamarine papaya violet 12 283.3344330566893 283.3344330566893 7 80278.40095555557 -0.7755740084632333 -1867.4888888888881 +Manufacturer#4 almond azure aquamarine papaya violet 12 283.3344330566893 283.3344330566893 12 80278.40095555557 -0.7755740084632333 -1867.4888888888881 +Manufacturer#4 almond azure aquamarine papaya violet 12 283.3344330566893 283.3344330566893 27 80278.40095555557 -0.7755740084632333 -1867.4888888888881 +Manufacturer#5 almond antique blue firebrick mint 31 83.69879024746363 83.69879024746363 2 7005.487488888913 0.39004303087285047 418.9233333333353 +Manufacturer#5 almond antique blue firebrick mint 31 83.69879024746363 83.69879024746363 6 7005.487488888913 0.39004303087285047 418.9233333333353 +Manufacturer#5 almond antique blue firebrick mint 31 83.69879024746363 83.69879024746363 31 7005.487488888913 0.39004303087285047 418.9233333333353 +Manufacturer#5 almond antique medium spring khaki 6 316.68049612345885 316.68049612345885 2 100286.53662500004 -0.713612911776183 -4090.853749999999 +Manufacturer#5 almond antique medium spring khaki 6 316.68049612345885 316.68049612345885 6 100286.53662500004 -0.713612911776183 -4090.853749999999 +Manufacturer#5 almond antique medium spring khaki 6 316.68049612345885 316.68049612345885 31 100286.53662500004 -0.713612911776183 -4090.853749999999 +Manufacturer#5 almond antique medium spring khaki 6 316.68049612345885 316.68049612345885 46 100286.53662500004 -0.713612911776183 -4090.853749999999 +Manufacturer#5 almond antique sky peru orange 2 285.40506298242155 285.40506298242155 2 81456.04997600002 -0.712858514567818 -3297.2011999999986 +Manufacturer#5 almond antique sky peru orange 2 285.40506298242155 285.40506298242155 6 81456.04997600002 -0.712858514567818 -3297.2011999999986 +Manufacturer#5 almond antique sky peru orange 2 285.40506298242155 285.40506298242155 23 81456.04997600002 -0.712858514567818 -3297.2011999999986 +Manufacturer#5 almond antique sky peru orange 2 285.40506298242155 285.40506298242155 31 81456.04997600002 -0.712858514567818 -3297.2011999999986 +Manufacturer#5 almond antique sky peru orange 2 285.40506298242155 285.40506298242155 46 81456.04997600002 -0.712858514567818 -3297.2011999999986 +Manufacturer#5 almond aquamarine dodger light gainsboro 46 285.43749038756283 285.43749038756283 2 81474.56091875004 -0.984128787153391 -4871.028125000002 +Manufacturer#5 almond aquamarine dodger light gainsboro 46 285.43749038756283 285.43749038756283 6 81474.56091875004 -0.984128787153391 -4871.028125000002 +Manufacturer#5 almond aquamarine dodger light gainsboro 46 285.43749038756283 285.43749038756283 23 81474.56091875004 -0.984128787153391 -4871.028125000002 +Manufacturer#5 almond aquamarine dodger light gainsboro 46 285.43749038756283 285.43749038756283 46 81474.56091875004 -0.984128787153391 -4871.028125000002 +Manufacturer#5 almond azure blanched chiffon midnight 23 315.9225931564038 315.9225931564038 2 99807.08486666664 -0.9978877469246936 -5664.856666666666 +Manufacturer#5 almond azure blanched chiffon midnight 23 315.9225931564038 315.9225931564038 23 99807.08486666664 -0.9978877469246936 -5664.856666666666 +Manufacturer#5 almond azure blanched chiffon midnight 23 315.9225931564038 315.9225931564038 46 99807.08486666664 -0.9978877469246936 -5664.856666666666 diff --git a/sql/hive/src/test/resources/golden/windowing.q -- 20. testSTATs-0-da0e0cca69e42118a96b8609b8fa5838 b/sql/hive/src/test/resources/golden/windowing.q -- 20. testSTATs-0-da0e0cca69e42118a96b8609b8fa5838 deleted file mode 100644 index 1f7e8a5d6703..000000000000 --- a/sql/hive/src/test/resources/golden/windowing.q -- 20. testSTATs-0-da0e0cca69e42118a96b8609b8fa5838 +++ /dev/null @@ -1,26 +0,0 @@ -Manufacturer#1 almond antique burnished rose metallic 2 273.70217881648074 273.70217881648074 [34,2] 74912.8826888888 1.0 4128.782222222221 -Manufacturer#1 almond antique burnished rose metallic 2 258.10677784349235 258.10677784349235 [34,2,6] 66619.10876874991 0.811328754177887 2801.7074999999995 -Manufacturer#1 almond antique chartreuse lavender yellow 34 230.90151585470358 230.90151585470358 [34,2,6,28] 53315.51002399992 0.695639377397664 2210.7864 -Manufacturer#1 almond antique salmon chartreuse burlywood 6 202.73109328368946 202.73109328368946 [34,2,6,42,28] 41099.896184 0.630785977101214 2009.9536000000007 -Manufacturer#1 almond aquamarine burnished black steel 28 121.6064517973862 121.6064517973862 [34,6,42,28] 14788.129118750014 0.2036684720435979 331.1337500000004 -Manufacturer#1 almond aquamarine pink moccasin thistle 42 96.5751586416853 96.5751586416853 [6,42,28] 9326.761266666683 -1.4442181184933883E-4 -0.20666666666708502 -Manufacturer#2 almond antique violet chocolate turquoise 14 142.2363169751898 142.2363169751898 [2,40,14] 20231.169866666663 -0.49369526554523185 -1113.7466666666658 -Manufacturer#2 almond antique violet turquoise frosted 40 137.76306498840682 137.76306498840682 [2,25,40,14] 18978.662075 -0.5205630897335946 -1004.4812499999995 -Manufacturer#2 almond aquamarine midnight light salmon 2 130.03972279269132 130.03972279269132 [2,18,25,40,14] 16910.329504000005 -0.46908967495720255 -766.1791999999995 -Manufacturer#2 almond aquamarine rose maroon antique 25 135.55100986344584 135.55100986344584 [2,18,25,40] 18374.07627499999 -0.6091405874714462 -1128.1787499999987 -Manufacturer#2 almond aquamarine sandy cyan gainsboro 18 156.44019460768044 156.44019460768044 [2,18,25] 24473.534488888927 -0.9571686373491608 -1441.4466666666676 -Manufacturer#3 almond antique chartreuse khaki white 17 196.7742266885805 196.7742266885805 [17,19,14] 38720.09628888887 0.5557168646224995 224.6944444444446 -Manufacturer#3 almond antique forest lavender goldenrod 14 275.14144189852607 275.14144189852607 [17,1,19,14] 75702.81305 -0.6720833036576083 -1296.9000000000003 -Manufacturer#3 almond antique metallic orange dim 19 260.23473614412046 260.23473614412046 [17,1,19,14,45] 67722.117896 -0.5703526513979519 -2129.0664 -Manufacturer#3 almond antique misty red olive 1 275.9139962356932 275.9139962356932 [1,19,14,45] 76128.53331875012 -0.577476899644802 -2547.7868749999993 -Manufacturer#3 almond antique olive coral navajo 45 260.5815918713796 260.5815918713796 [1,19,45] 67902.76602222225 -0.8710736366736884 -4099.731111111111 -Manufacturer#4 almond antique gainsboro frosted violet 10 170.13011889596618 170.13011889596618 [39,27,10] 28944.25735555559 -0.6656975320098423 -1347.4777777777779 -Manufacturer#4 almond antique violet mint lemon 39 242.26834609323197 242.26834609323197 [39,7,27,10] 58693.95151875002 -0.8051852719193339 -2537.328125 -Manufacturer#4 almond aquamarine floral ivory bisque 27 234.10001662537326 234.10001662537326 [39,7,27,10,12] 54802.817784000035 -0.6046935574240581 -1719.8079999999995 -Manufacturer#4 almond aquamarine yellow dodger mint 7 247.3342714197732 247.3342714197732 [39,7,27,12] 61174.24181875003 -0.5508665654707869 -1719.0368749999975 -Manufacturer#4 almond azure aquamarine papaya violet 12 283.3344330566893 283.3344330566893 [7,27,12] 80278.40095555557 -0.7755740084632333 -1867.4888888888881 -Manufacturer#5 almond antique blue firebrick mint 31 83.69879024746363 83.69879024746363 [2,6,31] 7005.487488888913 0.39004303087285047 418.9233333333353 -Manufacturer#5 almond antique medium spring khaki 6 316.68049612345885 316.68049612345885 [2,6,46,31] 100286.53662500004 -0.713612911776183 -4090.853749999999 -Manufacturer#5 almond antique sky peru orange 2 285.40506298242155 285.40506298242155 [2,23,6,46,31] 81456.04997600002 -0.712858514567818 -3297.2011999999986 -Manufacturer#5 almond aquamarine dodger light gainsboro 46 285.43749038756283 285.43749038756283 [2,23,6,46] 81474.56091875004 -0.984128787153391 -4871.028125000002 -Manufacturer#5 almond azure blanched chiffon midnight 23 315.9225931564038 315.9225931564038 [2,23,46] 99807.08486666664 -0.9978877469246936 -5664.856666666666 diff --git a/sql/hive/src/test/resources/hive-contrib-0.13.1.jar b/sql/hive/src/test/resources/hive-contrib-0.13.1.jar new file mode 100644 index 000000000000..ce0740d9245a Binary files /dev/null and b/sql/hive/src/test/resources/hive-contrib-0.13.1.jar differ diff --git a/sql/hive/src/test/resources/regression-test-SPARK-8489/Main.scala b/sql/hive/src/test/resources/regression-test-SPARK-8489/Main.scala new file mode 100644 index 000000000000..0e428ba1d745 --- /dev/null +++ b/sql/hive/src/test/resources/regression-test-SPARK-8489/Main.scala @@ -0,0 +1,44 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +import org.apache.spark.{SparkConf, SparkContext} +import org.apache.spark.sql.hive.HiveContext + +/** + * Entry point in test application for SPARK-8489. + * + * This file is not meant to be compiled during tests. It is already included + * in a pre-built "test.jar" located in the same directory as this file. + * This is included here for reference only and should NOT be modified without + * rebuilding the test jar itself. + * + * This is used in org.apache.spark.sql.hive.HiveSparkSubmitSuite. + */ +object Main { + def main(args: Array[String]) { + println("Running regression test for SPARK-8489.") + val sc = new SparkContext("local", "testing") + val hc = new HiveContext(sc) + // This line should not throw scala.reflect.internal.MissingRequirementError. + // See SPARK-8470 for more detail. + val df = hc.createDataFrame(Seq(MyCoolClass("1", "2", "3"))) + df.collect() + println("Regression test for SPARK-8489 success!") + sc.stop() + } +} + diff --git a/core/src/main/scala/org/apache/spark/util/collection/PairIterator.scala b/sql/hive/src/test/resources/regression-test-SPARK-8489/MyCoolClass.scala similarity index 76% rename from core/src/main/scala/org/apache/spark/util/collection/PairIterator.scala rename to sql/hive/src/test/resources/regression-test-SPARK-8489/MyCoolClass.scala index d75959f48075..b1681745c2ef 100644 --- a/core/src/main/scala/org/apache/spark/util/collection/PairIterator.scala +++ b/sql/hive/src/test/resources/regression-test-SPARK-8489/MyCoolClass.scala @@ -15,10 +15,6 @@ * limitations under the License. */ -package org.apache.spark.util.collection +/** Dummy class used in regression test SPARK-8489. */ +case class MyCoolClass(past: String, present: String, future: String) -private[spark] class PairIterator[K, V](iter: Iterator[Any]) extends Iterator[(K, V)] { - def hasNext: Boolean = iter.hasNext - - def next(): (K, V) = (iter.next().asInstanceOf[K], iter.next().asInstanceOf[V]) -} diff --git a/sql/hive/src/test/resources/regression-test-SPARK-8489/test.jar b/sql/hive/src/test/resources/regression-test-SPARK-8489/test.jar new file mode 100644 index 000000000000..5944aa6076a5 Binary files /dev/null and b/sql/hive/src/test/resources/regression-test-SPARK-8489/test.jar differ diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/CachedTableSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/CachedTableSuite.scala index fc6c3c35037b..39d315aaeab5 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/CachedTableSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/CachedTableSuite.scala @@ -57,7 +57,7 @@ class CachedTableSuite extends QueryTest { checkAnswer( sql("SELECT * FROM src s"), preCacheResults) - + uncacheTable("src") assertCached(sql("SELECT * FROM src"), 0) } @@ -162,7 +162,7 @@ class CachedTableSuite extends QueryTest { test("REFRESH TABLE also needs to recache the data (data source tables)") { val tempPath: File = Utils.createTempDir() tempPath.delete() - table("src").save(tempPath.toString, "parquet", SaveMode.Overwrite) + table("src").write.mode(SaveMode.Overwrite).parquet(tempPath.toString) sql("DROP TABLE IF EXISTS refreshTable") createExternalTable("refreshTable", tempPath.toString, "parquet") checkAnswer( @@ -172,7 +172,7 @@ class CachedTableSuite extends QueryTest { sql("CACHE TABLE refreshTable") assertCached(table("refreshTable")) // Append new data. - table("src").save(tempPath.toString, "parquet", SaveMode.Append) + table("src").write.mode(SaveMode.Append).parquet(tempPath.toString) // We are still using the old data. assertCached(table("refreshTable")) checkAnswer( diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveDataFrameAnalyticsSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveDataFrameAnalyticsSuite.scala new file mode 100644 index 000000000000..fb10f8583da9 --- /dev/null +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveDataFrameAnalyticsSuite.scala @@ -0,0 +1,65 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.hive + +import org.apache.spark.sql.{DataFrame, QueryTest} +import org.apache.spark.sql.functions._ +import org.apache.spark.sql.hive.test.TestHive +import org.apache.spark.sql.hive.test.TestHive._ +import org.apache.spark.sql.hive.test.TestHive.implicits._ +import org.scalatest.BeforeAndAfterAll + +// TODO ideally we should put the test suite into the package `sql`, as +// `hive` package is optional in compiling, however, `SQLContext.sql` doesn't +// support the `cube` or `rollup` yet. +class HiveDataFrameAnalyticsSuite extends QueryTest with BeforeAndAfterAll { + private var testData: DataFrame = _ + + override def beforeAll() { + testData = Seq((1, 2), (2, 4)).toDF("a", "b") + TestHive.registerDataFrameAsTable(testData, "mytable") + } + + override def afterAll(): Unit = { + TestHive.dropTempTable("mytable") + } + + test("rollup") { + checkAnswer( + testData.rollup($"a" + $"b", $"b").agg(sum($"a" - $"b")), + sql("select a + b, b, sum(a - b) from mytable group by a + b, b with rollup").collect() + ) + + checkAnswer( + testData.rollup("a", "b").agg(sum("b")), + sql("select a, b, sum(b) from mytable group by a, b with rollup").collect() + ) + } + + test("cube") { + checkAnswer( + testData.cube($"a" + $"b", $"b").agg(sum($"a" - $"b")), + sql("select a + b, b, sum(a - b) from mytable group by a + b, b with cube").collect() + ) + + checkAnswer( + testData.cube("a", "b").agg(sum("b")), + sql("select a, b, sum(b) from mytable group by a, b with cube").collect() + ) + } +} diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveDataFrameWindowSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveDataFrameWindowSuite.scala new file mode 100644 index 000000000000..efb3f2545db8 --- /dev/null +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveDataFrameWindowSuite.scala @@ -0,0 +1,215 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.hive + +import org.apache.spark.sql.{Row, QueryTest} +import org.apache.spark.sql.expressions.Window +import org.apache.spark.sql.functions._ +import org.apache.spark.sql.hive.test.TestHive._ +import org.apache.spark.sql.hive.test.TestHive.implicits._ + +class HiveDataFrameWindowSuite extends QueryTest { + + test("reuse window partitionBy") { + val df = Seq((1, "1"), (2, "2"), (1, "1"), (2, "2")).toDF("key", "value") + val w = Window.partitionBy("key").orderBy("value") + + checkAnswer( + df.select( + lead("key", 1).over(w), + lead("value", 1).over(w)), + Row(1, "1") :: Row(2, "2") :: Row(null, null) :: Row(null, null) :: Nil) + } + + test("reuse window orderBy") { + val df = Seq((1, "1"), (2, "2"), (1, "1"), (2, "2")).toDF("key", "value") + val w = Window.orderBy("value").partitionBy("key") + + checkAnswer( + df.select( + lead("key", 1).over(w), + lead("value", 1).over(w)), + Row(1, "1") :: Row(2, "2") :: Row(null, null) :: Row(null, null) :: Nil) + } + + test("lead") { + val df = Seq((1, "1"), (2, "2"), (1, "1"), (2, "2")).toDF("key", "value") + df.registerTempTable("window_table") + + checkAnswer( + df.select( + lead("value", 1).over(Window.partitionBy($"key").orderBy($"value"))), + sql( + """SELECT + | lead(value) OVER (PARTITION BY key ORDER BY value) + | FROM window_table""".stripMargin).collect()) + } + + test("lag") { + val df = Seq((1, "1"), (2, "2"), (1, "1"), (2, "2")).toDF("key", "value") + df.registerTempTable("window_table") + + checkAnswer( + df.select( + lag("value", 1).over(Window.partitionBy($"key").orderBy($"value"))), + sql( + """SELECT + | lag(value) OVER (PARTITION BY key ORDER BY value) + | FROM window_table""".stripMargin).collect()) + } + + test("lead with default value") { + val df = Seq((1, "1"), (1, "1"), (2, "2"), (1, "1"), + (2, "2"), (1, "1"), (2, "2")).toDF("key", "value") + df.registerTempTable("window_table") + checkAnswer( + df.select( + lead("value", 2, "n/a").over(Window.partitionBy("key").orderBy("value"))), + sql( + """SELECT + | lead(value, 2, "n/a") OVER (PARTITION BY key ORDER BY value) + | FROM window_table""".stripMargin).collect()) + } + + test("lag with default value") { + val df = Seq((1, "1"), (1, "1"), (2, "2"), (1, "1"), + (2, "2"), (1, "1"), (2, "2")).toDF("key", "value") + df.registerTempTable("window_table") + checkAnswer( + df.select( + lag("value", 2, "n/a").over(Window.partitionBy($"key").orderBy($"value"))), + sql( + """SELECT + | lag(value, 2, "n/a") OVER (PARTITION BY key ORDER BY value) + | FROM window_table""".stripMargin).collect()) + } + + test("rank functions in unspecific window") { + val df = Seq((1, "1"), (2, "2"), (1, "2"), (2, "2")).toDF("key", "value") + df.registerTempTable("window_table") + checkAnswer( + df.select( + $"key", + max("key").over(Window.partitionBy("value").orderBy("key")), + min("key").over(Window.partitionBy("value").orderBy("key")), + mean("key").over(Window.partitionBy("value").orderBy("key")), + count("key").over(Window.partitionBy("value").orderBy("key")), + sum("key").over(Window.partitionBy("value").orderBy("key")), + ntile(2).over(Window.partitionBy("value").orderBy("key")), + rowNumber().over(Window.partitionBy("value").orderBy("key")), + denseRank().over(Window.partitionBy("value").orderBy("key")), + rank().over(Window.partitionBy("value").orderBy("key")), + cumeDist().over(Window.partitionBy("value").orderBy("key")), + percentRank().over(Window.partitionBy("value").orderBy("key"))), + sql( + s"""SELECT + |key, + |max(key) over (partition by value order by key), + |min(key) over (partition by value order by key), + |avg(key) over (partition by value order by key), + |count(key) over (partition by value order by key), + |sum(key) over (partition by value order by key), + |ntile(2) over (partition by value order by key), + |row_number() over (partition by value order by key), + |dense_rank() over (partition by value order by key), + |rank() over (partition by value order by key), + |cume_dist() over (partition by value order by key), + |percent_rank() over (partition by value order by key) + |FROM window_table""".stripMargin).collect()) + } + + test("aggregation and rows between") { + val df = Seq((1, "1"), (2, "2"), (1, "1"), (2, "2")).toDF("key", "value") + df.registerTempTable("window_table") + checkAnswer( + df.select( + avg("key").over(Window.partitionBy($"value").orderBy($"key").rowsBetween(-1, 2))), + sql( + """SELECT + | avg(key) OVER + | (PARTITION BY value ORDER BY key ROWS BETWEEN 1 preceding and 2 following) + | FROM window_table""".stripMargin).collect()) + } + + test("aggregation and range betweens") { + val df = Seq((1, "1"), (2, "2"), (1, "1"), (2, "2")).toDF("key", "value") + df.registerTempTable("window_table") + checkAnswer( + df.select( + avg("key").over(Window.partitionBy($"value").orderBy($"key").rangeBetween(-1, 1))), + sql( + """SELECT + | avg(key) OVER + | (PARTITION BY value ORDER BY key RANGE BETWEEN 1 preceding and 1 following) + | FROM window_table""".stripMargin).collect()) + } + + test("aggregation and rows betweens with unbounded") { + val df = Seq((1, "1"), (2, "2"), (2, "3"), (1, "3"), (3, "2"), (4, "3")).toDF("key", "value") + df.registerTempTable("window_table") + checkAnswer( + df.select( + $"key", + last("value").over( + Window.partitionBy($"value").orderBy($"key").rowsBetween(0, Long.MaxValue)), + last("value").over( + Window.partitionBy($"value").orderBy($"key").rowsBetween(Long.MinValue, 0)), + last("value").over(Window.partitionBy($"value").orderBy($"key").rowsBetween(-1, 3))), + sql( + """SELECT + | key, + | last_value(value) OVER + | (PARTITION BY value ORDER BY key ROWS between current row and unbounded following), + | last_value(value) OVER + | (PARTITION BY value ORDER BY key ROWS between unbounded preceding and current row), + | last_value(value) OVER + | (PARTITION BY value ORDER BY key ROWS between 1 preceding and 3 following) + | FROM window_table""".stripMargin).collect()) + } + + test("aggregation and range betweens with unbounded") { + val df = Seq((1, "1"), (2, "2"), (2, "2"), (2, "2"), (1, "1"), (2, "2")).toDF("key", "value") + df.registerTempTable("window_table") + checkAnswer( + df.select( + $"key", + last("value").over( + Window.partitionBy($"value").orderBy($"key").rangeBetween(1, Long.MaxValue)) + .equalTo("2") + .as("last_v"), + avg("key").over(Window.partitionBy("value").orderBy("key").rangeBetween(Long.MinValue, 1)) + .as("avg_key1"), + avg("key").over(Window.partitionBy("value").orderBy("key").rangeBetween(0, Long.MaxValue)) + .as("avg_key2"), + avg("key").over(Window.partitionBy("value").orderBy("key").rangeBetween(-1, 0)) + .as("avg_key3") + ), + sql( + """SELECT + | key, + | last_value(value) OVER + | (PARTITION BY value ORDER BY key RANGE 1 preceding) == "2", + | avg(key) OVER + | (PARTITION BY value ORDER BY key RANGE BETWEEN unbounded preceding and 1 following), + | avg(key) OVER + | (PARTITION BY value ORDER BY key RANGE BETWEEN current row and unbounded following), + | avg(key) OVER + | (PARTITION BY value ORDER BY key RANGE BETWEEN 1 preceding and current row) + | FROM window_table""".stripMargin).collect()) + } +} diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveInspectorSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveInspectorSuite.scala index 2a7374cc172b..a93acb938d5f 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveInspectorSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveInspectorSuite.scala @@ -26,12 +26,13 @@ import org.apache.hadoop.hive.serde2.objectinspector.{ObjectInspector, ObjectIns import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspectorFactory.ObjectInspectorOptions import org.apache.hadoop.hive.serde2.objectinspector.primitive.PrimitiveObjectInspectorFactory import org.apache.hadoop.io.LongWritable -import org.scalatest.FunSuite -import org.apache.spark.sql.catalyst.expressions.{Literal, Row} +import org.apache.spark.SparkFunSuite +import org.apache.spark.sql.catalyst.expressions.{Literal, InternalRow} import org.apache.spark.sql.types._ +import org.apache.spark.sql.Row -class HiveInspectorSuite extends FunSuite with HiveInspectors { +class HiveInspectorSuite extends SparkFunSuite with HiveInspectors { test("Test wrap SettableStructObjectInspector") { val udaf = new UDAFPercentile.PercentileLongEvaluator() udaf.init() @@ -45,7 +46,7 @@ class HiveInspectorSuite extends FunSuite with HiveInspectors { classOf[UDAFPercentile.State], ObjectInspectorOptions.JAVA).asInstanceOf[StructObjectInspector] - val a = unwrap(state, soi).asInstanceOf[Row] + val a = unwrap(state, soi).asInstanceOf[InternalRow] val b = wrap(a, soi).asInstanceOf[UDAFPercentile.State] val sfCounts = soi.getStructFieldRef("counts") @@ -78,10 +79,10 @@ class HiveInspectorSuite extends FunSuite with HiveInspectors { Literal(java.sql.Date.valueOf("2014-09-23")) :: Literal(Decimal(BigDecimal(123.123))) :: Literal(new java.sql.Timestamp(123123)) :: - Literal(Array[Byte](1,2,3)) :: - Literal.create(Seq[Int](1,2,3), ArrayType(IntegerType)) :: - Literal.create(Map[Int, Int](1->2, 2->1), MapType(IntegerType, IntegerType)) :: - Literal.create(Row(1,2.0d,3.0f), + Literal(Array[Byte](1, 2, 3)) :: + Literal.create(Seq[Int](1, 2, 3), ArrayType(IntegerType)) :: + Literal.create(Map[Int, Int](1 -> 2, 2 -> 1), MapType(IntegerType, IntegerType)) :: + Literal.create(Row(1, 2.0d, 3.0f), StructType(StructField("c1", IntegerType) :: StructField("c2", DoubleType) :: StructField("c3", FloatType) :: Nil)) :: @@ -111,8 +112,8 @@ class HiveInspectorSuite extends FunSuite with HiveInspectors { case DecimalType() => PrimitiveObjectInspectorFactory.writableHiveDecimalObjectInspector case StructType(fields) => ObjectInspectorFactory.getStandardStructObjectInspector( - java.util.Arrays.asList(fields.map(f => f.name) :_*), - java.util.Arrays.asList(fields.map(f => toWritableInspector(f.dataType)) :_*)) + java.util.Arrays.asList(fields.map(f => f.name) : _*), + java.util.Arrays.asList(fields.map(f => toWritableInspector(f.dataType)) : _*)) } def checkDataType(dt1: Seq[DataType], dt2: Seq[DataType]): Unit = { @@ -127,7 +128,7 @@ class HiveInspectorSuite extends FunSuite with HiveInspectors { } } - def checkValues(row1: Seq[Any], row2: Row): Unit = { + def checkValues(row1: Seq[Any], row2: InternalRow): Unit = { row1.zip(row2.toSeq).foreach { case (r1, r2) => checkValue(r1, r2) } @@ -201,9 +202,9 @@ class HiveInspectorSuite extends FunSuite with HiveInspectors { val dt = StructType(dataTypes.zipWithIndex.map { case (t, idx) => StructField(s"c_$idx", t) }) - + val inspector = toInspector(dt) checkValues(row, - unwrap(wrap(Row.fromSeq(row), toInspector(dt)), toInspector(dt)).asInstanceOf[Row]) + unwrap(wrap(InternalRow.fromSeq(row), inspector), inspector).asInstanceOf[InternalRow]) checkValue(null, unwrap(wrap(null, toInspector(dt)), toInspector(dt))) } diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveMetastoreCatalogSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveMetastoreCatalogSuite.scala index fa8e11ffec2b..e9bb32667936 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveMetastoreCatalogSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveMetastoreCatalogSuite.scala @@ -17,13 +17,13 @@ package org.apache.spark.sql.hive +import org.apache.spark.SparkFunSuite import org.apache.spark.sql.hive.test.TestHive -import org.scalatest.FunSuite import org.apache.spark.sql.test.ExamplePointUDT import org.apache.spark.sql.types.StructType -class HiveMetastoreCatalogSuite extends FunSuite { +class HiveMetastoreCatalogSuite extends SparkFunSuite { test("struct field should accept underscore in sub-column name") { val metastr = "struct" diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveParquetSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveParquetSuite.scala index 7ff5719adb3a..af68615e8e9d 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveParquetSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveParquetSuite.scala @@ -17,10 +17,9 @@ package org.apache.spark.sql.hive -import org.apache.spark.sql.catalyst.expressions.Row import org.apache.spark.sql.hive.test.TestHive import org.apache.spark.sql.parquet.ParquetTest -import org.apache.spark.sql.{QueryTest, SQLConf} +import org.apache.spark.sql.{QueryTest, Row, SQLConf} case class Cases(lower: String, UPPER: String) @@ -55,8 +54,8 @@ class HiveParquetSuite extends QueryTest with ParquetTest { test(s"$prefix: Converting Hive to Parquet Table via saveAsParquetFile") { withTempPath { dir => - sql("SELECT * FROM src").saveAsParquetFile(dir.getCanonicalPath) - parquetFile(dir.getCanonicalPath).registerTempTable("p") + sql("SELECT * FROM src").write.parquet(dir.getCanonicalPath) + read.parquet(dir.getCanonicalPath).registerTempTable("p") withTempTable("p") { checkAnswer( sql("SELECT * FROM src ORDER BY key"), @@ -68,8 +67,8 @@ class HiveParquetSuite extends QueryTest with ParquetTest { test(s"$prefix: INSERT OVERWRITE TABLE Parquet table") { withParquetTable((1 to 10).map(i => (i, s"val_$i")), "t") { withTempPath { file => - sql("SELECT * FROM t LIMIT 1").saveAsParquetFile(file.getCanonicalPath) - parquetFile(file.getCanonicalPath).registerTempTable("p") + sql("SELECT * FROM t LIMIT 1").write.parquet(file.getCanonicalPath) + read.parquet(file.getCanonicalPath).registerTempTable("p") withTempTable("p") { // let's do three overwrites for good measure sql("INSERT OVERWRITE TABLE p SELECT * FROM t") @@ -82,11 +81,11 @@ class HiveParquetSuite extends QueryTest with ParquetTest { } } - withSQLConf(SQLConf.PARQUET_USE_DATA_SOURCE_API -> "true") { + withSQLConf(SQLConf.PARQUET_USE_DATA_SOURCE_API.key -> "true") { run("Parquet data source enabled") } - withSQLConf(SQLConf.PARQUET_USE_DATA_SOURCE_API -> "false") { + withSQLConf(SQLConf.PARQUET_USE_DATA_SOURCE_API.key -> "false") { run("Parquet data source disabled") } } diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveQlSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveQlSuite.scala index 941a2941649b..f765395e148a 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveQlSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveQlSuite.scala @@ -20,12 +20,13 @@ package org.apache.spark.sql.hive import org.apache.hadoop.hive.conf.HiveConf import org.apache.hadoop.hive.ql.session.SessionState import org.apache.hadoop.hive.serde.serdeConstants +import org.apache.spark.SparkFunSuite import org.apache.spark.sql.AnalysisException import org.apache.spark.sql.hive.client.{ManagedTable, HiveColumn, ExternalTable, HiveTable} -import org.scalatest.{BeforeAndAfterAll, FunSuite} +import org.scalatest.BeforeAndAfterAll -class HiveQlSuite extends FunSuite with BeforeAndAfterAll { +class HiveQlSuite extends SparkFunSuite with BeforeAndAfterAll { override def beforeAll() { if (SessionState.get() == null) { SessionState.start(new HiveConf()) diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveSparkSubmitSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveSparkSubmitSuite.scala new file mode 100644 index 000000000000..a38ed23b5cf9 --- /dev/null +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveSparkSubmitSuite.scala @@ -0,0 +1,212 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.hive + +import java.io.File + +import scala.sys.process.{ProcessLogger, Process} + +import org.apache.spark._ +import org.apache.spark.sql.hive.test.{TestHive, TestHiveContext} +import org.apache.spark.util.{ResetSystemProperties, Utils} +import org.scalatest.Matchers +import org.scalatest.concurrent.Timeouts +import org.scalatest.time.SpanSugar._ + +/** + * This suite tests spark-submit with applications using HiveContext. + */ +class HiveSparkSubmitSuite + extends SparkFunSuite + with Matchers + with ResetSystemProperties + with Timeouts { + + // TODO: rewrite these or mark them as slow tests to be run sparingly + + def beforeAll() { + System.setProperty("spark.testing", "true") + } + + test("SPARK-8368: includes jars passed in through --jars") { + val unusedJar = TestUtils.createJarWithClasses(Seq.empty) + val jar1 = TestUtils.createJarWithClasses(Seq("SparkSubmitClassA")) + val jar2 = TestUtils.createJarWithClasses(Seq("SparkSubmitClassB")) + val jar3 = TestHive.getHiveFile("hive-contrib-0.13.1.jar").getCanonicalPath() + val jar4 = TestHive.getHiveFile("hive-hcatalog-core-0.13.1.jar").getCanonicalPath() + val jarsString = Seq(jar1, jar2, jar3, jar4).map(j => j.toString).mkString(",") + val args = Seq( + "--class", SparkSubmitClassLoaderTest.getClass.getName.stripSuffix("$"), + "--name", "SparkSubmitClassLoaderTest", + "--master", "local-cluster[2,1,512]", + "--jars", jarsString, + unusedJar.toString, "SparkSubmitClassA", "SparkSubmitClassB") + runSparkSubmit(args) + } + + test("SPARK-8020: set sql conf in spark conf") { + val unusedJar = TestUtils.createJarWithClasses(Seq.empty) + val args = Seq( + "--class", SparkSQLConfTest.getClass.getName.stripSuffix("$"), + "--name", "SparkSQLConfTest", + "--master", "local-cluster[2,1,512]", + unusedJar.toString) + runSparkSubmit(args) + } + + test("SPARK-8489: MissingRequirementError during reflection") { + // This test uses a pre-built jar to test SPARK-8489. In a nutshell, this test creates + // a HiveContext and uses it to create a data frame from an RDD using reflection. + // Before the fix in SPARK-8470, this results in a MissingRequirementError because + // the HiveContext code mistakenly overrides the class loader that contains user classes. + // For more detail, see sql/hive/src/test/resources/regression-test-SPARK-8489/*scala. + val testJar = "sql/hive/src/test/resources/regression-test-SPARK-8489/test.jar" + val args = Seq("--class", "Main", testJar) + runSparkSubmit(args) + } + + // NOTE: This is an expensive operation in terms of time (10 seconds+). Use sparingly. + // This is copied from org.apache.spark.deploy.SparkSubmitSuite + private def runSparkSubmit(args: Seq[String]): Unit = { + val sparkHome = sys.props.getOrElse("spark.test.home", fail("spark.test.home is not set!")) + val process = Process( + Seq("./bin/spark-submit") ++ args, + new File(sparkHome), + "SPARK_TESTING" -> "1", + "SPARK_HOME" -> sparkHome + ).run(ProcessLogger( + (line: String) => { println(s"out> $line") }, + (line: String) => { println(s"err> $line") } + )) + + try { + val exitCode = failAfter(180 seconds) { process.exitValue() } + if (exitCode != 0) { + fail(s"Process returned with exit code $exitCode. See the log4j logs for more detail.") + } + } finally { + // Ensure we still kill the process in case it timed out + process.destroy() + } + } +} + +// This object is used for testing SPARK-8368: https://issues.apache.org/jira/browse/SPARK-8368. +// We test if we can load user jars in both driver and executors when HiveContext is used. +object SparkSubmitClassLoaderTest extends Logging { + def main(args: Array[String]) { + Utils.configTestLog4j("INFO") + val conf = new SparkConf() + val sc = new SparkContext(conf) + val hiveContext = new TestHiveContext(sc) + val df = hiveContext.createDataFrame((1 to 100).map(i => (i, i))).toDF("i", "j") + logInfo("Testing load classes at the driver side.") + // First, we load classes at driver side. + try { + Class.forName(args(0), true, Thread.currentThread().getContextClassLoader) + Class.forName(args(1), true, Thread.currentThread().getContextClassLoader) + } catch { + case t: Throwable => + throw new Exception("Could not load user class from jar:\n", t) + } + // Second, we load classes at the executor side. + logInfo("Testing load classes at the executor side.") + val result = df.mapPartitions { x => + var exception: String = null + try { + Class.forName(args(0), true, Thread.currentThread().getContextClassLoader) + Class.forName(args(1), true, Thread.currentThread().getContextClassLoader) + } catch { + case t: Throwable => + exception = t + "\n" + t.getStackTraceString + exception = exception.replaceAll("\n", "\n\t") + } + Option(exception).toSeq.iterator + }.collect() + if (result.nonEmpty) { + throw new Exception("Could not load user class from jar:\n" + result(0)) + } + + // Load a Hive UDF from the jar. + logInfo("Registering temporary Hive UDF provided in a jar.") + hiveContext.sql( + """ + |CREATE TEMPORARY FUNCTION example_max + |AS 'org.apache.hadoop.hive.contrib.udaf.example.UDAFExampleMax' + """.stripMargin) + val source = + hiveContext.createDataFrame((1 to 10).map(i => (i, s"str$i"))).toDF("key", "val") + source.registerTempTable("sourceTable") + // Load a Hive SerDe from the jar. + logInfo("Creating a Hive table with a SerDe provided in a jar.") + hiveContext.sql( + """ + |CREATE TABLE t1(key int, val string) + |ROW FORMAT SERDE 'org.apache.hive.hcatalog.data.JsonSerDe' + """.stripMargin) + // Actually use the loaded UDF and SerDe. + logInfo("Writing data into the table.") + hiveContext.sql( + "INSERT INTO TABLE t1 SELECT example_max(key) as key, val FROM sourceTable GROUP BY val") + logInfo("Running a simple query on the table.") + val count = hiveContext.table("t1").orderBy("key", "val").count() + if (count != 10) { + throw new Exception(s"table t1 should have 10 rows instead of $count rows") + } + logInfo("Test finishes.") + sc.stop() + } +} + +// This object is used for testing SPARK-8020: https://issues.apache.org/jira/browse/SPARK-8020. +// We test if we can correctly set spark sql configurations when HiveContext is used. +object SparkSQLConfTest extends Logging { + def main(args: Array[String]) { + Utils.configTestLog4j("INFO") + // We override the SparkConf to add spark.sql.hive.metastore.version and + // spark.sql.hive.metastore.jars to the beginning of the conf entry array. + // So, if metadataHive get initialized after we set spark.sql.hive.metastore.version but + // before spark.sql.hive.metastore.jars get set, we will see the following exception: + // Exception in thread "main" java.lang.IllegalArgumentException: Builtin jars can only + // be used when hive execution version == hive metastore version. + // Execution: 0.13.1 != Metastore: 0.12. Specify a vaild path to the correct hive jars + // using $HIVE_METASTORE_JARS or change spark.sql.hive.metastore.version to 0.13.1. + val conf = new SparkConf() { + override def getAll: Array[(String, String)] = { + def isMetastoreSetting(conf: String): Boolean = { + conf == "spark.sql.hive.metastore.version" || conf == "spark.sql.hive.metastore.jars" + } + // If there is any metastore settings, remove them. + val filteredSettings = super.getAll.filterNot(e => isMetastoreSetting(e._1)) + + // Always add these two metastore settings at the beginning. + ("spark.sql.hive.metastore.version" -> "0.12") +: + ("spark.sql.hive.metastore.jars" -> "maven") +: + filteredSettings + } + + // For this simple test, we do not really clone this object. + override def clone: SparkConf = this + } + val sc = new SparkContext(conf) + val hiveContext = new TestHiveContext(sc) + // Run a simple command to make sure all lazy vals in hiveContext get instantiated. + hiveContext.tables().collect() + sc.stop() + } +} diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/InsertIntoHiveTableSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/InsertIntoHiveTableSuite.scala index ecb990e8aac9..aa5dbe2db690 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/InsertIntoHiveTableSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/InsertIntoHiveTableSuite.scala @@ -53,7 +53,7 @@ class InsertIntoHiveTableSuite extends QueryTest with BeforeAndAfter { sql("CREATE TABLE createAndInsertTest (key int, value string)") // Add some data. - testData.insertInto("createAndInsertTest") + testData.write.mode(SaveMode.Append).insertInto("createAndInsertTest") // Make sure the table has also been updated. checkAnswer( @@ -62,7 +62,7 @@ class InsertIntoHiveTableSuite extends QueryTest with BeforeAndAfter { ) // Add more data. - testData.insertInto("createAndInsertTest") + testData.write.mode(SaveMode.Append).insertInto("createAndInsertTest") // Make sure the table has been updated. checkAnswer( @@ -71,7 +71,7 @@ class InsertIntoHiveTableSuite extends QueryTest with BeforeAndAfter { ) // Now overwrite. - testData.insertInto("createAndInsertTest", overwrite = true) + testData.write.mode(SaveMode.Overwrite).insertInto("createAndInsertTest") // Make sure the registered table has also been updated. checkAnswer( @@ -160,7 +160,7 @@ class InsertIntoHiveTableSuite extends QueryTest with BeforeAndAfter { "p1=a"::"p2=b"::"p3=c"::"p4=c"::"p5=1"::Nil , "p1=a"::"p2=b"::"p3=c"::"p4=c"::"p5=4"::Nil ) - assert(listFolders(tmpDir,List()).sortBy(_.toString()) == expected.sortBy(_.toString)) + assert(listFolders(tmpDir, List()).sortBy(_.toString()) == expected.sortBy(_.toString)) sql("DROP TABLE table_with_partition") sql("DROP TABLE tmp_table") } @@ -240,7 +240,7 @@ class InsertIntoHiveTableSuite extends QueryTest with BeforeAndAfter { checkAnswer(sql("select key,value from table_with_partition where ds='1' "), testData.collect().toSeq ) - + // test difference type of field sql("ALTER TABLE table_with_partition CHANGE COLUMN key key BIGINT") checkAnswer(sql("select key,value from table_with_partition where ds='1' "), diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/ListTablesSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/ListTablesSuite.scala index e12a6c21ccac..1c15997ea8e6 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/ListTablesSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/ListTablesSuite.scala @@ -29,7 +29,7 @@ class ListTablesSuite extends QueryTest with BeforeAndAfterAll { import org.apache.spark.sql.hive.test.TestHive.implicits._ val df = - sparkContext.parallelize((1 to 10).map(i => (i,s"str$i"))).toDF("key", "value") + sparkContext.parallelize((1 to 10).map(i => (i, s"str$i"))).toDF("key", "value") override def beforeAll(): Unit = { // The catalog in HiveContext is a case insensitive one. diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/MetastoreDataSourcesSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/MetastoreDataSourcesSuite.scala index 1bf1c1be3e3d..cc294bc3e8bc 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/MetastoreDataSourcesSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/MetastoreDataSourcesSuite.scala @@ -21,741 +21,833 @@ import java.io.File import scala.collection.mutable.ArrayBuffer +import org.scalatest.BeforeAndAfterAll + import org.apache.hadoop.fs.Path import org.apache.hadoop.mapred.InvalidInputException -import org.scalatest.BeforeAndAfterEach import org.apache.spark.sql._ import org.apache.spark.sql.hive.client.{HiveTable, ManagedTable} +import org.apache.spark.sql.hive.test.TestHive import org.apache.spark.sql.hive.test.TestHive._ import org.apache.spark.sql.hive.test.TestHive.implicits._ import org.apache.spark.sql.parquet.ParquetRelation2 import org.apache.spark.sql.sources.LogicalRelation +import org.apache.spark.sql.test.SQLTestUtils import org.apache.spark.sql.types._ import org.apache.spark.util.Utils /** * Tests for persisting tables created though the data sources API into the metastore. */ -class MetastoreDataSourcesSuite extends QueryTest with BeforeAndAfterEach { - - override def afterEach(): Unit = { - reset() - Utils.deleteRecursively(tempPath) - } +class MetastoreDataSourcesSuite extends QueryTest with SQLTestUtils with BeforeAndAfterAll { + override val sqlContext = TestHive - val filePath = Utils.getSparkClassLoader.getResource("sample.json").getFile - var tempPath: File = Utils.createTempDir() - tempPath.delete() + var jsonFilePath: String = _ - test ("persistent JSON table") { - sql( - s""" - |CREATE TABLE jsonTable - |USING org.apache.spark.sql.json.DefaultSource - |OPTIONS ( - | path '${filePath}' - |) - """.stripMargin) - - checkAnswer( - sql("SELECT * FROM jsonTable"), - jsonFile(filePath).collect().toSeq) + override def beforeAll(): Unit = { + jsonFilePath = Utils.getSparkClassLoader.getResource("sample.json").getFile } - test ("persistent JSON table with a user specified schema") { - sql( - s""" - |CREATE TABLE jsonTable ( - |a string, - |b String, - |`c_!@(3)` int, - |`` Struct<`d!`:array, `=`:array>>) - |USING org.apache.spark.sql.json.DefaultSource - |OPTIONS ( - | path '${filePath}' - |) - """.stripMargin) - - jsonFile(filePath).registerTempTable("expectedJsonTable") + test("persistent JSON table") { + withTable("jsonTable") { + sql( + s"""CREATE TABLE jsonTable + |USING org.apache.spark.sql.json.DefaultSource + |OPTIONS ( + | path '$jsonFilePath' + |) + """.stripMargin) - checkAnswer( - sql("SELECT a, b, `c_!@(3)`, ``.`d!`, ``.`=` FROM jsonTable"), - sql("SELECT a, b, `c_!@(3)`, ``.`d!`, ``.`=` FROM expectedJsonTable").collect().toSeq) + checkAnswer( + sql("SELECT * FROM jsonTable"), + read.json(jsonFilePath).collect().toSeq) + } } - test ("persistent JSON table with a user specified schema with a subset of fields") { - // This works because JSON objects are self-describing and JSONRelation can get needed - // field values based on field names. - sql( - s""" - |CREATE TABLE jsonTable (`` Struct<`=`:array>>, b String) - |USING org.apache.spark.sql.json.DefaultSource - |OPTIONS ( - | path '${filePath}' - |) - """.stripMargin) - - val innerStruct = StructType( - StructField("=", ArrayType(StructType(StructField("Dd2", BooleanType, true) :: Nil))) :: Nil) - val expectedSchema = StructType( - StructField("", innerStruct, true) :: - StructField("b", StringType, true) :: Nil) - - assert(expectedSchema === table("jsonTable").schema) - - jsonFile(filePath).registerTempTable("expectedJsonTable") + test("persistent JSON table with a user specified schema") { + withTable("jsonTable") { + sql( + s"""CREATE TABLE jsonTable ( + |a string, + |b String, + |`c_!@(3)` int, + |`` Struct<`d!`:array, `=`:array>>) + |USING org.apache.spark.sql.json.DefaultSource + |OPTIONS ( + | path '$jsonFilePath' + |) + """.stripMargin) + + withTempTable("expectedJsonTable") { + read.json(jsonFilePath).registerTempTable("expectedJsonTable") + checkAnswer( + sql("SELECT a, b, `c_!@(3)`, ``.`d!`, ``.`=` FROM jsonTable"), + sql("SELECT a, b, `c_!@(3)`, ``.`d!`, ``.`=` FROM expectedJsonTable")) + } + } + } - checkAnswer( - sql("SELECT b, ``.`=` FROM jsonTable"), - sql("SELECT b, ``.`=` FROM expectedJsonTable").collect().toSeq) + test("persistent JSON table with a user specified schema with a subset of fields") { + withTable("jsonTable") { + // This works because JSON objects are self-describing and JSONRelation can get needed + // field values based on field names. + sql( + s"""CREATE TABLE jsonTable (`` Struct<`=`:array>>, b String) + |USING org.apache.spark.sql.json.DefaultSource + |OPTIONS ( + | path '$jsonFilePath' + |) + """.stripMargin) + + val innerStruct = StructType(Seq( + StructField("=", ArrayType(StructType(StructField("Dd2", BooleanType, true) :: Nil))))) + + val expectedSchema = StructType(Seq( + StructField("", innerStruct, true), + StructField("b", StringType, true))) + + assert(expectedSchema === table("jsonTable").schema) + + withTempTable("expectedJsonTable") { + read.json(jsonFilePath).registerTempTable("expectedJsonTable") + checkAnswer( + sql("SELECT b, ``.`=` FROM jsonTable"), + sql("SELECT b, ``.`=` FROM expectedJsonTable")) + } + } } test("resolve shortened provider names") { - sql( - s""" - |CREATE TABLE jsonTable - |USING org.apache.spark.sql.json - |OPTIONS ( - | path '${filePath}' - |) - """.stripMargin) + withTable("jsonTable") { + sql( + s""" + |CREATE TABLE jsonTable + |USING org.apache.spark.sql.json + |OPTIONS ( + | path '$jsonFilePath' + |) + """.stripMargin) - checkAnswer( - sql("SELECT * FROM jsonTable"), - jsonFile(filePath).collect().toSeq) + checkAnswer( + sql("SELECT * FROM jsonTable"), + read.json(jsonFilePath).collect().toSeq) + } } test("drop table") { - sql( - s""" - |CREATE TABLE jsonTable - |USING org.apache.spark.sql.json - |OPTIONS ( - | path '${filePath}' - |) - """.stripMargin) + withTable("jsonTable") { + sql( + s""" + |CREATE TABLE jsonTable + |USING org.apache.spark.sql.json + |OPTIONS ( + | path '$jsonFilePath' + |) + """.stripMargin) - checkAnswer( - sql("SELECT * FROM jsonTable"), - jsonFile(filePath).collect().toSeq) + checkAnswer( + sql("SELECT * FROM jsonTable"), + read.json(jsonFilePath)) - sql("DROP TABLE jsonTable") + sql("DROP TABLE jsonTable") - intercept[Exception] { - sql("SELECT * FROM jsonTable").collect() - } + intercept[Exception] { + sql("SELECT * FROM jsonTable").collect() + } - assert( - (new File(filePath)).exists(), - "The table with specified path is considered as an external table, " + - "its data should not deleted after DROP TABLE.") + assert( + new File(jsonFilePath).exists(), + "The table with specified path is considered as an external table, " + + "its data should not deleted after DROP TABLE.") + } } test("check change without refresh") { - val tempDir = File.createTempFile("sparksql", "json", Utils.createTempDir()) - tempDir.delete() - sparkContext.parallelize(("a", "b") :: Nil).toDF() - .toJSON.saveAsTextFile(tempDir.getCanonicalPath) - - sql( - s""" - |CREATE TABLE jsonTable - |USING org.apache.spark.sql.json - |OPTIONS ( - | path '${tempDir.getCanonicalPath}' - |) - """.stripMargin) - - checkAnswer( - sql("SELECT * FROM jsonTable"), - Row("a", "b")) - - Utils.deleteRecursively(tempDir) - sparkContext.parallelize(("a1", "b1", "c1") :: Nil).toDF() - .toJSON.saveAsTextFile(tempDir.getCanonicalPath) - - // Schema is cached so the new column does not show. The updated values in existing columns - // will show. - checkAnswer( - sql("SELECT * FROM jsonTable"), - Row("a1", "b1")) - - sql("REFRESH TABLE jsonTable") - - // Check that the refresh worked - checkAnswer( - sql("SELECT * FROM jsonTable"), - Row("a1", "b1", "c1")) - Utils.deleteRecursively(tempDir) + withTempPath { tempDir => + withTable("jsonTable") { + (("a", "b") :: Nil).toDF().toJSON.saveAsTextFile(tempDir.getCanonicalPath) + + sql( + s"""CREATE TABLE jsonTable + |USING org.apache.spark.sql.json + |OPTIONS ( + | path '${tempDir.getCanonicalPath}' + |) + """.stripMargin) + + checkAnswer( + sql("SELECT * FROM jsonTable"), + Row("a", "b")) + + Utils.deleteRecursively(tempDir) + (("a1", "b1", "c1") :: Nil).toDF().toJSON.saveAsTextFile(tempDir.getCanonicalPath) + + // Schema is cached so the new column does not show. The updated values in existing columns + // will show. + checkAnswer( + sql("SELECT * FROM jsonTable"), + Row("a1", "b1")) + + sql("REFRESH TABLE jsonTable") + + // Check that the refresh worked + checkAnswer( + sql("SELECT * FROM jsonTable"), + Row("a1", "b1", "c1")) + } + } } test("drop, change, recreate") { - val tempDir = File.createTempFile("sparksql", "json", Utils.createTempDir()) - tempDir.delete() - sparkContext.parallelize(("a", "b") :: Nil).toDF() - .toJSON.saveAsTextFile(tempDir.getCanonicalPath) - - sql( - s""" - |CREATE TABLE jsonTable - |USING org.apache.spark.sql.json - |OPTIONS ( - | path '${tempDir.getCanonicalPath}' - |) - """.stripMargin) - - checkAnswer( - sql("SELECT * FROM jsonTable"), - Row("a", "b")) - - Utils.deleteRecursively(tempDir) - sparkContext.parallelize(("a", "b", "c") :: Nil).toDF() - .toJSON.saveAsTextFile(tempDir.getCanonicalPath) - - sql("DROP TABLE jsonTable") - - sql( - s""" - |CREATE TABLE jsonTable - |USING org.apache.spark.sql.json - |OPTIONS ( - | path '${tempDir.getCanonicalPath}' - |) - """.stripMargin) - - // New table should reflect new schema. - checkAnswer( - sql("SELECT * FROM jsonTable"), - Row("a", "b", "c")) - Utils.deleteRecursively(tempDir) + withTempPath { tempDir => + (("a", "b") :: Nil).toDF().toJSON.saveAsTextFile(tempDir.getCanonicalPath) + + withTable("jsonTable") { + sql( + s"""CREATE TABLE jsonTable + |USING org.apache.spark.sql.json + |OPTIONS ( + | path '${tempDir.getCanonicalPath}' + |) + """.stripMargin) + + checkAnswer( + sql("SELECT * FROM jsonTable"), + Row("a", "b")) + + Utils.deleteRecursively(tempDir) + (("a", "b", "c") :: Nil).toDF().toJSON.saveAsTextFile(tempDir.getCanonicalPath) + + sql("DROP TABLE jsonTable") + + sql( + s"""CREATE TABLE jsonTable + |USING org.apache.spark.sql.json + |OPTIONS ( + | path '${tempDir.getCanonicalPath}' + |) + """.stripMargin) + + // New table should reflect new schema. + checkAnswer( + sql("SELECT * FROM jsonTable"), + Row("a", "b", "c")) + } + } } test("invalidate cache and reload") { - sql( - s""" - |CREATE TABLE jsonTable (`c_!@(3)` int) - |USING org.apache.spark.sql.json.DefaultSource - |OPTIONS ( - | path '${filePath}' - |) - """.stripMargin) + withTable("jsonTable") { + sql( + s"""CREATE TABLE jsonTable (`c_!@(3)` int) + |USING org.apache.spark.sql.json.DefaultSource + |OPTIONS ( + | path '$jsonFilePath' + |) + """.stripMargin) - jsonFile(filePath).registerTempTable("expectedJsonTable") + withTempTable("expectedJsonTable") { + read.json(jsonFilePath).registerTempTable("expectedJsonTable") - checkAnswer( - sql("SELECT * FROM jsonTable"), - sql("SELECT `c_!@(3)` FROM expectedJsonTable").collect().toSeq) + checkAnswer( + sql("SELECT * FROM jsonTable"), + sql("SELECT `c_!@(3)` FROM expectedJsonTable").collect().toSeq) - // Discard the cached relation. - invalidateTable("jsonTable") + // Discard the cached relation. + invalidateTable("jsonTable") - checkAnswer( - sql("SELECT * FROM jsonTable"), - sql("SELECT `c_!@(3)` FROM expectedJsonTable").collect().toSeq) + checkAnswer( + sql("SELECT * FROM jsonTable"), + sql("SELECT `c_!@(3)` FROM expectedJsonTable").collect().toSeq) - invalidateTable("jsonTable") - val expectedSchema = StructType(StructField("c_!@(3)", IntegerType, true) :: Nil) + invalidateTable("jsonTable") + val expectedSchema = StructType(StructField("c_!@(3)", IntegerType, true) :: Nil) - assert(expectedSchema === table("jsonTable").schema) + assert(expectedSchema === table("jsonTable").schema) + } + } } test("CTAS") { - sql( - s""" - |CREATE TABLE jsonTable - |USING org.apache.spark.sql.json.DefaultSource - |OPTIONS ( - | path '${filePath}' - |) - """.stripMargin) - - sql( - s""" - |CREATE TABLE ctasJsonTable - |USING org.apache.spark.sql.json.DefaultSource - |OPTIONS ( - | path '${tempPath}' - |) AS - |SELECT * FROM jsonTable - """.stripMargin) - - assert(table("ctasJsonTable").schema === table("jsonTable").schema) - - checkAnswer( - sql("SELECT * FROM ctasJsonTable"), - sql("SELECT * FROM jsonTable").collect()) + withTempPath { tempPath => + withTable("jsonTable", "ctasJsonTable") { + sql( + s"""CREATE TABLE jsonTable + |USING org.apache.spark.sql.json.DefaultSource + |OPTIONS ( + | path '$jsonFilePath' + |) + """.stripMargin) + + sql( + s"""CREATE TABLE ctasJsonTable + |USING org.apache.spark.sql.json.DefaultSource + |OPTIONS ( + | path '$tempPath' + |) AS + |SELECT * FROM jsonTable + """.stripMargin) + + assert(table("ctasJsonTable").schema === table("jsonTable").schema) + + checkAnswer( + sql("SELECT * FROM ctasJsonTable"), + sql("SELECT * FROM jsonTable").collect()) + } + } } test("CTAS with IF NOT EXISTS") { - sql( - s""" - |CREATE TABLE jsonTable - |USING org.apache.spark.sql.json.DefaultSource - |OPTIONS ( - | path '${filePath}' - |) - """.stripMargin) - - sql( - s""" - |CREATE TABLE ctasJsonTable - |USING org.apache.spark.sql.json.DefaultSource - |OPTIONS ( - | path '${tempPath}' - |) AS - |SELECT * FROM jsonTable - """.stripMargin) - - // Create the table again should trigger a AnalysisException. - val message = intercept[AnalysisException] { - sql( - s""" - |CREATE TABLE ctasJsonTable - |USING org.apache.spark.sql.json.DefaultSource - |OPTIONS ( - | path '${tempPath}' - |) AS - |SELECT * FROM jsonTable - """.stripMargin) - }.getMessage - assert(message.contains("Table ctasJsonTable already exists."), - "We should complain that ctasJsonTable already exists") - - // The following statement should be fine if it has IF NOT EXISTS. - // It tries to create a table ctasJsonTable with a new schema. - // The actual table's schema and data should not be changed. - sql( - s""" - |CREATE TABLE IF NOT EXISTS ctasJsonTable - |USING org.apache.spark.sql.json.DefaultSource - |OPTIONS ( - | path '${tempPath}' - |) AS - |SELECT a FROM jsonTable - """.stripMargin) - - // Discard the cached relation. - invalidateTable("ctasJsonTable") - - // Schema should not be changed. - assert(table("ctasJsonTable").schema === table("jsonTable").schema) - // Table data should not be changed. - checkAnswer( - sql("SELECT * FROM ctasJsonTable"), - sql("SELECT * FROM jsonTable").collect()) + withTempPath { path => + val tempPath = path.getCanonicalPath + + withTable("jsonTable", "ctasJsonTable") { + sql( + s"""CREATE TABLE jsonTable + |USING org.apache.spark.sql.json.DefaultSource + |OPTIONS ( + | path '$jsonFilePath' + |) + """.stripMargin) + + sql( + s"""CREATE TABLE ctasJsonTable + |USING org.apache.spark.sql.json.DefaultSource + |OPTIONS ( + | path '$tempPath' + |) AS + |SELECT * FROM jsonTable + """.stripMargin) + + // Create the table again should trigger a AnalysisException. + val message = intercept[AnalysisException] { + sql( + s"""CREATE TABLE ctasJsonTable + |USING org.apache.spark.sql.json.DefaultSource + |OPTIONS ( + | path '$tempPath' + |) AS + |SELECT * FROM jsonTable + """.stripMargin) + }.getMessage + + assert( + message.contains("Table ctasJsonTable already exists."), + "We should complain that ctasJsonTable already exists") + + // The following statement should be fine if it has IF NOT EXISTS. + // It tries to create a table ctasJsonTable with a new schema. + // The actual table's schema and data should not be changed. + sql( + s"""CREATE TABLE IF NOT EXISTS ctasJsonTable + |USING org.apache.spark.sql.json.DefaultSource + |OPTIONS ( + | path '$tempPath' + |) AS + |SELECT a FROM jsonTable + """.stripMargin) + + // Discard the cached relation. + invalidateTable("ctasJsonTable") + + // Schema should not be changed. + assert(table("ctasJsonTable").schema === table("jsonTable").schema) + // Table data should not be changed. + checkAnswer( + sql("SELECT * FROM ctasJsonTable"), + sql("SELECT * FROM jsonTable").collect()) + } + } } test("CTAS a managed table") { - sql( - s""" - |CREATE TABLE jsonTable - |USING org.apache.spark.sql.json.DefaultSource - |OPTIONS ( - | path '${filePath}' - |) - """.stripMargin) - - val expectedPath = catalog.hiveDefaultTableFilePath("ctasJsonTable") - val filesystemPath = new Path(expectedPath) - val fs = filesystemPath.getFileSystem(sparkContext.hadoopConfiguration) - if (fs.exists(filesystemPath)) fs.delete(filesystemPath, true) - - // It is a managed table when we do not specify the location. - sql( - s""" - |CREATE TABLE ctasJsonTable - |USING org.apache.spark.sql.json.DefaultSource - |AS - |SELECT * FROM jsonTable - """.stripMargin) - - assert(fs.exists(filesystemPath), s"$expectedPath should exist after we create the table.") - - sql( - s""" - |CREATE TABLE loadedTable - |USING org.apache.spark.sql.json.DefaultSource - |OPTIONS ( - | path '${expectedPath}' - |) - """.stripMargin) - - assert(table("ctasJsonTable").schema === table("loadedTable").schema) + withTable("jsonTable", "ctasJsonTable", "loadedTable") { + sql( + s"""CREATE TABLE jsonTable + |USING org.apache.spark.sql.json.DefaultSource + |OPTIONS ( + | path '$jsonFilePath' + |) + """.stripMargin) + + val expectedPath = catalog.hiveDefaultTableFilePath("ctasJsonTable") + val filesystemPath = new Path(expectedPath) + val fs = filesystemPath.getFileSystem(sparkContext.hadoopConfiguration) + if (fs.exists(filesystemPath)) fs.delete(filesystemPath, true) + + // It is a managed table when we do not specify the location. + sql( + s"""CREATE TABLE ctasJsonTable + |USING org.apache.spark.sql.json.DefaultSource + |AS + |SELECT * FROM jsonTable + """.stripMargin) - checkAnswer( - sql("SELECT * FROM ctasJsonTable"), - sql("SELECT * FROM loadedTable").collect() - ) + assert(fs.exists(filesystemPath), s"$expectedPath should exist after we create the table.") + + sql( + s"""CREATE TABLE loadedTable + |USING org.apache.spark.sql.json.DefaultSource + |OPTIONS ( + | path '$expectedPath' + |) + """.stripMargin) + + assert(table("ctasJsonTable").schema === table("loadedTable").schema) + + checkAnswer( + sql("SELECT * FROM ctasJsonTable"), + sql("SELECT * FROM loadedTable")) - sql("DROP TABLE ctasJsonTable") - assert(!fs.exists(filesystemPath), s"$expectedPath should not exist after we drop the table.") + sql("DROP TABLE ctasJsonTable") + assert(!fs.exists(filesystemPath), s"$expectedPath should not exist after we drop the table.") + } } test("SPARK-5286 Fail to drop an invalid table when using the data source API") { - sql( - s""" - |CREATE TABLE jsonTable - |USING org.apache.spark.sql.json.DefaultSource - |OPTIONS ( - | path 'it is not a path at all!' - |) - """.stripMargin) - - sql("DROP TABLE jsonTable").collect().foreach(println) + withTable("jsonTable") { + sql( + s"""CREATE TABLE jsonTable + |USING org.apache.spark.sql.json.DefaultSource + |OPTIONS ( + | path 'it is not a path at all!' + |) + """.stripMargin) + + sql("DROP TABLE jsonTable").collect().foreach(println) + } } test("SPARK-5839 HiveMetastoreCatalog does not recognize table aliases of data source tables.") { - val originalDefaultSource = conf.defaultDataSourceName - - val rdd = sparkContext.parallelize((1 to 10).map(i => s"""{"a":$i, "b":"str${i}"}""")) - val df = jsonRDD(rdd) - - conf.setConf(SQLConf.DEFAULT_DATA_SOURCE_NAME, "org.apache.spark.sql.json") - // Save the df as a managed table (by not specifiying the path). - df.saveAsTable("savedJsonTable") - - checkAnswer( - sql("SELECT * FROM savedJsonTable where savedJsonTable.a < 5"), - (1 to 4).map(i => Row(i, s"str${i}"))) - - checkAnswer( - sql("SELECT * FROM savedJsonTable tmp where tmp.a > 5"), - (6 to 10).map(i => Row(i, s"str${i}"))) + withTable("savedJsonTable") { + // Save the df as a managed table (by not specifying the path). + (1 to 10) + .map(i => i -> s"str$i") + .toDF("a", "b") + .write + .format("json") + .saveAsTable("savedJsonTable") - invalidateTable("savedJsonTable") + checkAnswer( + sql("SELECT * FROM savedJsonTable where savedJsonTable.a < 5"), + (1 to 4).map(i => Row(i, s"str$i"))) - checkAnswer( - sql("SELECT * FROM savedJsonTable where savedJsonTable.a < 5"), - (1 to 4).map(i => Row(i, s"str${i}"))) + checkAnswer( + sql("SELECT * FROM savedJsonTable tmp where tmp.a > 5"), + (6 to 10).map(i => Row(i, s"str$i"))) - checkAnswer( - sql("SELECT * FROM savedJsonTable tmp where tmp.a > 5"), - (6 to 10).map(i => Row(i, s"str${i}"))) + invalidateTable("savedJsonTable") - // Drop table will also delete the data. - sql("DROP TABLE savedJsonTable") + checkAnswer( + sql("SELECT * FROM savedJsonTable where savedJsonTable.a < 5"), + (1 to 4).map(i => Row(i, s"str$i"))) - conf.setConf(SQLConf.DEFAULT_DATA_SOURCE_NAME, originalDefaultSource) + checkAnswer( + sql("SELECT * FROM savedJsonTable tmp where tmp.a > 5"), + (6 to 10).map(i => Row(i, s"str$i"))) + } } test("save table") { - val originalDefaultSource = conf.defaultDataSourceName - - val rdd = sparkContext.parallelize((1 to 10).map(i => s"""{"a":$i, "b":"str${i}"}""")) - val df = jsonRDD(rdd) - - conf.setConf(SQLConf.DEFAULT_DATA_SOURCE_NAME, "org.apache.spark.sql.json") - // Save the df as a managed table (by not specifiying the path). - df.saveAsTable("savedJsonTable") - - checkAnswer( - sql("SELECT * FROM savedJsonTable"), - df.collect()) - - // Right now, we cannot append to an existing JSON table. - intercept[RuntimeException] { - df.saveAsTable("savedJsonTable", SaveMode.Append) - } - - // We can overwrite it. - df.saveAsTable("savedJsonTable", SaveMode.Overwrite) - checkAnswer( - sql("SELECT * FROM savedJsonTable"), - df.collect()) - - // When the save mode is Ignore, we will do nothing when the table already exists. - df.select("b").saveAsTable("savedJsonTable", SaveMode.Ignore) - assert(df.schema === table("savedJsonTable").schema) - checkAnswer( - sql("SELECT * FROM savedJsonTable"), - df.collect()) - - // Drop table will also delete the data. - sql("DROP TABLE savedJsonTable") - intercept[InvalidInputException] { - jsonFile(catalog.hiveDefaultTableFilePath("savedJsonTable")) + withTempPath { path => + val tempPath = path.getCanonicalPath + + withTable("savedJsonTable") { + val df = (1 to 10).map(i => i -> s"str$i").toDF("a", "b") + + withSQLConf(SQLConf.DEFAULT_DATA_SOURCE_NAME.key -> "json") { + // Save the df as a managed table (by not specifying the path). + df.write.saveAsTable("savedJsonTable") + + checkAnswer(sql("SELECT * FROM savedJsonTable"), df) + + // Right now, we cannot append to an existing JSON table. + intercept[RuntimeException] { + df.write.mode(SaveMode.Append).saveAsTable("savedJsonTable") + } + + // We can overwrite it. + df.write.mode(SaveMode.Overwrite).saveAsTable("savedJsonTable") + checkAnswer(sql("SELECT * FROM savedJsonTable"), df) + + // When the save mode is Ignore, we will do nothing when the table already exists. + df.select("b").write.mode(SaveMode.Ignore).saveAsTable("savedJsonTable") + assert(df.schema === table("savedJsonTable").schema) + checkAnswer(sql("SELECT * FROM savedJsonTable"), df) + + // Drop table will also delete the data. + sql("DROP TABLE savedJsonTable") + intercept[InvalidInputException] { + read.json(catalog.hiveDefaultTableFilePath("savedJsonTable")) + } + } + + // Create an external table by specifying the path. + withSQLConf(SQLConf.DEFAULT_DATA_SOURCE_NAME.key -> "not a source name") { + df.write + .format("org.apache.spark.sql.json") + .mode(SaveMode.Append) + .option("path", tempPath.toString) + .saveAsTable("savedJsonTable") + + checkAnswer(sql("SELECT * FROM savedJsonTable"), df) + } + + // Data should not be deleted after we drop the table. + sql("DROP TABLE savedJsonTable") + checkAnswer(read.json(tempPath.toString), df) + } } - - // Create an external table by specifying the path. - conf.setConf(SQLConf.DEFAULT_DATA_SOURCE_NAME, "not a source name") - df.saveAsTable( - "savedJsonTable", - "org.apache.spark.sql.json", - SaveMode.Append, - Map("path" -> tempPath.toString)) - checkAnswer( - sql("SELECT * FROM savedJsonTable"), - df.collect()) - - // Data should not be deleted after we drop the table. - sql("DROP TABLE savedJsonTable") - checkAnswer( - jsonFile(tempPath.toString), - df.collect()) - - conf.setConf(SQLConf.DEFAULT_DATA_SOURCE_NAME, originalDefaultSource) } test("create external table") { - val originalDefaultSource = conf.defaultDataSourceName - - val rdd = sparkContext.parallelize((1 to 10).map(i => s"""{"a":$i, "b":"str${i}"}""")) - val df = jsonRDD(rdd) + withTempPath { tempPath => + withTable("savedJsonTable", "createdJsonTable") { + val df = read.json(sparkContext.parallelize((1 to 10).map { i => + s"""{ "a": $i, "b": "str$i" }""" + })) + + withSQLConf(SQLConf.DEFAULT_DATA_SOURCE_NAME.key -> "not a source name") { + df.write + .format("json") + .mode(SaveMode.Append) + .option("path", tempPath.toString) + .saveAsTable("savedJsonTable") + } + + withSQLConf(SQLConf.DEFAULT_DATA_SOURCE_NAME.key -> "json") { + createExternalTable("createdJsonTable", tempPath.toString) + assert(table("createdJsonTable").schema === df.schema) + checkAnswer(sql("SELECT * FROM createdJsonTable"), df) + + assert( + intercept[AnalysisException] { + createExternalTable("createdJsonTable", jsonFilePath.toString) + }.getMessage.contains("Table createdJsonTable already exists."), + "We should complain that createdJsonTable already exists") + } + + // Data should not be deleted. + sql("DROP TABLE createdJsonTable") + checkAnswer(read.json(tempPath.toString), df) + + // Try to specify the schema. + withSQLConf(SQLConf.DEFAULT_DATA_SOURCE_NAME.key -> "not a source name") { + val schema = StructType(StructField("b", StringType, true) :: Nil) + createExternalTable( + "createdJsonTable", + "org.apache.spark.sql.json", + schema, + Map("path" -> tempPath.toString)) + + checkAnswer( + sql("SELECT * FROM createdJsonTable"), + sql("SELECT b FROM savedJsonTable")) + + sql("DROP TABLE createdJsonTable") + + assert( + intercept[RuntimeException] { + createExternalTable( + "createdJsonTable", + "org.apache.spark.sql.json", + schema, + Map.empty[String, String]) + }.getMessage.contains("'path' must be specified for json data."), + "We should complain that path is not specified.") + } + } + } + } - conf.setConf(SQLConf.DEFAULT_DATA_SOURCE_NAME, "not a source name") - df.saveAsTable( - "savedJsonTable", - "org.apache.spark.sql.json", - SaveMode.Append, - Map("path" -> tempPath.toString)) + test("scan a parquet table created through a CTAS statement") { + withSQLConf( + HiveContext.CONVERT_METASTORE_PARQUET.key -> "true", + SQLConf.PARQUET_USE_DATA_SOURCE_API.key -> "true") { + + withTempTable("jt") { + (1 to 10).map(i => i -> s"str$i").toDF("a", "b").registerTempTable("jt") + + withTable("test_parquet_ctas") { + sql( + """CREATE TABLE test_parquet_ctas STORED AS PARQUET + |AS SELECT tmp.a FROM jt tmp WHERE tmp.a < 5 + """.stripMargin) + + checkAnswer( + sql(s"SELECT a FROM test_parquet_ctas WHERE a > 2 "), + Row(3) :: Row(4) :: Nil) + + table("test_parquet_ctas").queryExecution.optimizedPlan match { + case LogicalRelation(p: ParquetRelation2) => // OK + case _ => + fail(s"test_parquet_ctas should have be converted to ${classOf[ParquetRelation2]}") + } + } + } + } + } - conf.setConf(SQLConf.DEFAULT_DATA_SOURCE_NAME, "org.apache.spark.sql.json") - createExternalTable("createdJsonTable", tempPath.toString) - assert(table("createdJsonTable").schema === df.schema) - checkAnswer( - sql("SELECT * FROM createdJsonTable"), - df.collect()) + test("Pre insert nullability check (ArrayType)") { + withTable("arrayInParquet") { + { + val df = (Tuple1(Seq(Int.box(1), null: Integer)) :: Nil).toDF("a") + val expectedSchema = + StructType( + StructField( + "a", + ArrayType(IntegerType, containsNull = true), + nullable = true) :: Nil) + + assert(df.schema === expectedSchema) + + df.write + .format("parquet") + .mode(SaveMode.Overwrite) + .saveAsTable("arrayInParquet") + } - var message = intercept[AnalysisException] { - createExternalTable("createdJsonTable", filePath.toString) - }.getMessage - assert(message.contains("Table createdJsonTable already exists."), - "We should complain that ctasJsonTable already exists") + { + val df = (Tuple1(Seq(2, 3)) :: Nil).toDF("a") + val expectedSchema = + StructType( + StructField( + "a", + ArrayType(IntegerType, containsNull = false), + nullable = true) :: Nil) + + assert(df.schema === expectedSchema) + + df.write + .format("parquet") + .mode(SaveMode.Append) + .insertInto("arrayInParquet") + } - // Data should not be deleted. - sql("DROP TABLE createdJsonTable") - checkAnswer( - jsonFile(tempPath.toString), - df.collect()) - - // Try to specify the schema. - conf.setConf(SQLConf.DEFAULT_DATA_SOURCE_NAME, "not a source name") - val schema = StructType(StructField("b", StringType, true) :: Nil) - createExternalTable( - "createdJsonTable", - "org.apache.spark.sql.json", - schema, - Map("path" -> tempPath.toString)) - checkAnswer( - sql("SELECT * FROM createdJsonTable"), - sql("SELECT b FROM savedJsonTable").collect()) - - sql("DROP TABLE createdJsonTable") - - message = intercept[RuntimeException] { - createExternalTable( - "createdJsonTable", - "org.apache.spark.sql.json", - schema, - Map.empty[String, String]) - }.getMessage - assert( - message.contains("'path' must be specified for json data."), - "We should complain that path is not specified.") - - sql("DROP TABLE savedJsonTable") - conf.setConf(SQLConf.DEFAULT_DATA_SOURCE_NAME, originalDefaultSource) - } + (Tuple1(Seq(4, 5)) :: Nil).toDF("a") + .write + .mode(SaveMode.Append) + .saveAsTable("arrayInParquet") // This one internally calls df2.insertInto. - if (HiveShim.version == "0.13.1") { - test("scan a parquet table created through a CTAS statement") { - val originalConvertMetastore = getConf("spark.sql.hive.convertMetastoreParquet", "true") - val originalUseDataSource = getConf(SQLConf.PARQUET_USE_DATA_SOURCE_API, "true") - setConf("spark.sql.hive.convertMetastoreParquet", "true") - setConf(SQLConf.PARQUET_USE_DATA_SOURCE_API, "true") + (Tuple1(Seq(Int.box(6), null: Integer)) :: Nil).toDF("a") + .write + .mode(SaveMode.Append) + .saveAsTable("arrayInParquet") - val rdd = sparkContext.parallelize((1 to 10).map(i => s"""{"a":$i, "b":"str${i}"}""")) - jsonRDD(rdd).registerTempTable("jt") - sql( - """ - |create table test_parquet_ctas STORED AS parquET - |AS select tmp.a from jt tmp where tmp.a < 5 - """.stripMargin) + refreshTable("arrayInParquet") checkAnswer( - sql(s"SELECT a FROM test_parquet_ctas WHERE a > 2 "), - Row(3) :: Row(4) :: Nil - ) - - table("test_parquet_ctas").queryExecution.optimizedPlan match { - case LogicalRelation(p: ParquetRelation2) => // OK - case _ => - fail( - "test_parquet_ctas should be converted to " + - s"${classOf[ParquetRelation2].getCanonicalName}") - } - - // Clenup and reset confs. - sql("DROP TABLE IF EXISTS jt") - sql("DROP TABLE IF EXISTS test_parquet_ctas") - setConf("spark.sql.hive.convertMetastoreParquet", originalConvertMetastore) - setConf(SQLConf.PARQUET_USE_DATA_SOURCE_API, originalUseDataSource) + sql("SELECT a FROM arrayInParquet"), + Row(ArrayBuffer(1, null)) :: + Row(ArrayBuffer(2, 3)) :: + Row(ArrayBuffer(4, 5)) :: + Row(ArrayBuffer(6, null)) :: Nil) } } - test("Pre insert nullability check (ArrayType)") { - val df1 = - createDataFrame(Tuple1(Seq(Int.box(1), null.asInstanceOf[Integer])) :: Nil).toDF("a") - val expectedSchema1 = - StructType( - StructField("a", ArrayType(IntegerType, containsNull = true), nullable = true) :: Nil) - assert(df1.schema === expectedSchema1) - df1.saveAsTable("arrayInParquet", "parquet", SaveMode.Overwrite) - - val df2 = - createDataFrame(Tuple1(Seq(2, 3)) :: Nil).toDF("a") - val expectedSchema2 = - StructType( - StructField("a", ArrayType(IntegerType, containsNull = false), nullable = true) :: Nil) - assert(df2.schema === expectedSchema2) - df2.insertInto("arrayInParquet", overwrite = false) - createDataFrame(Tuple1(Seq(4, 5)) :: Nil).toDF("a") - .saveAsTable("arrayInParquet", SaveMode.Append) // This one internally calls df2.insertInto. - createDataFrame(Tuple1(Seq(Int.box(6), null.asInstanceOf[Integer])) :: Nil).toDF("a") - .saveAsTable("arrayInParquet", "parquet", SaveMode.Append) - refreshTable("arrayInParquet") + test("Pre insert nullability check (MapType)") { + withTable("mapInParquet") { + { + val df = (Tuple1(Map(1 -> (null: Integer))) :: Nil).toDF("a") + val expectedSchema = + StructType( + StructField( + "a", + MapType(IntegerType, IntegerType, valueContainsNull = true), + nullable = true) :: Nil) + + assert(df.schema === expectedSchema) + + df.write + .format("parquet") + .mode(SaveMode.Overwrite) + .saveAsTable("mapInParquet") + } - checkAnswer( - sql("SELECT a FROM arrayInParquet"), - Row(ArrayBuffer(1, null)) :: - Row(ArrayBuffer(2, 3)) :: - Row(ArrayBuffer(4, 5)) :: - Row(ArrayBuffer(6, null)) :: Nil) + { + val df = (Tuple1(Map(2 -> 3)) :: Nil).toDF("a") + val expectedSchema = + StructType( + StructField( + "a", + MapType(IntegerType, IntegerType, valueContainsNull = false), + nullable = true) :: Nil) + + assert(df.schema === expectedSchema) + + df.write + .format("parquet") + .mode(SaveMode.Append) + .insertInto("mapInParquet") + } - sql("DROP TABLE arrayInParquet") - } + (Tuple1(Map(4 -> 5)) :: Nil).toDF("a") + .write + .format("parquet") + .mode(SaveMode.Append) + .saveAsTable("mapInParquet") // This one internally calls df2.insertInto. - test("Pre insert nullability check (MapType)") { - val df1 = - createDataFrame(Tuple1(Map(1 -> null.asInstanceOf[Integer])) :: Nil).toDF("a") - val mapType1 = MapType(IntegerType, IntegerType, valueContainsNull = true) - val expectedSchema1 = - StructType( - StructField("a", mapType1, nullable = true) :: Nil) - assert(df1.schema === expectedSchema1) - df1.saveAsTable("mapInParquet", "parquet", SaveMode.Overwrite) - - val df2 = - createDataFrame(Tuple1(Map(2 -> 3)) :: Nil).toDF("a") - val mapType2 = MapType(IntegerType, IntegerType, valueContainsNull = false) - val expectedSchema2 = - StructType( - StructField("a", mapType2, nullable = true) :: Nil) - assert(df2.schema === expectedSchema2) - df2.insertInto("mapInParquet", overwrite = false) - createDataFrame(Tuple1(Map(4 -> 5)) :: Nil).toDF("a") - .saveAsTable("mapInParquet", SaveMode.Append) // This one internally calls df2.insertInto. - createDataFrame(Tuple1(Map(6 -> null.asInstanceOf[Integer])) :: Nil).toDF("a") - .saveAsTable("mapInParquet", "parquet", SaveMode.Append) - refreshTable("mapInParquet") + (Tuple1(Map(6 -> null.asInstanceOf[Integer])) :: Nil).toDF("a") + .write + .format("parquet") + .mode(SaveMode.Append) + .saveAsTable("mapInParquet") - checkAnswer( - sql("SELECT a FROM mapInParquet"), - Row(Map(1 -> null)) :: - Row(Map(2 -> 3)) :: - Row(Map(4 -> 5)) :: - Row(Map(6 -> null)) :: Nil) + refreshTable("mapInParquet") - sql("DROP TABLE mapInParquet") + checkAnswer( + sql("SELECT a FROM mapInParquet"), + Row(Map(1 -> null)) :: + Row(Map(2 -> 3)) :: + Row(Map(4 -> 5)) :: + Row(Map(6 -> null)) :: Nil) + } } test("SPARK-6024 wide schema support") { - // We will need 80 splits for this schema if the threshold is 4000. - val schema = StructType((1 to 5000).map(i => StructField(s"c_${i}", StringType, true))) - assert( - schema.json.size > conf.schemaStringLengthThreshold, - "To correctly test the fix of SPARK-6024, the value of " + - s"spark.sql.sources.schemaStringLengthThreshold needs to be less than ${schema.json.size}") - // Manually create a metastore data source table. - catalog.createDataSourceTable( - tableName = "wide_schema", - userSpecifiedSchema = Some(schema), - provider = "json", - options = Map("path" -> "just a dummy path"), - isExternal = false) - - invalidateTable("wide_schema") - - val actualSchema = table("wide_schema").schema - assert(schema === actualSchema) + withSQLConf(SQLConf.SCHEMA_STRING_LENGTH_THRESHOLD.key -> "4000") { + withTable("wide_schema") { + // We will need 80 splits for this schema if the threshold is 4000. + val schema = StructType((1 to 5000).map(i => StructField(s"c_$i", StringType, true))) + + // Manually create a metastore data source table. + catalog.createDataSourceTable( + tableName = "wide_schema", + userSpecifiedSchema = Some(schema), + partitionColumns = Array.empty[String], + provider = "json", + options = Map("path" -> "just a dummy path"), + isExternal = false) + + invalidateTable("wide_schema") + + val actualSchema = table("wide_schema").schema + assert(schema === actualSchema) + } + } } test("SPARK-6655 still support a schema stored in spark.sql.sources.schema") { val tableName = "spark6655" - val schema = StructType(StructField("int", IntegerType, true) :: Nil) - - val hiveTable = HiveTable( - specifiedDatabase = Some("default"), - name = tableName, - schema = Seq.empty, - partitionColumns = Seq.empty, - properties = Map( - "spark.sql.sources.provider" -> "json", - "spark.sql.sources.schema" -> schema.json, - "EXTERNAL" -> "FALSE"), - tableType = ManagedTable, - serdeProperties = Map( - "path" -> catalog.hiveDefaultTableFilePath(tableName))) - - catalog.client.createTable(hiveTable) - - invalidateTable(tableName) - val actualSchema = table(tableName).schema - assert(schema === actualSchema) - sql(s"drop table $tableName") + withTable(tableName) { + val schema = StructType(StructField("int", IntegerType, true) :: Nil) + val hiveTable = HiveTable( + specifiedDatabase = Some("default"), + name = tableName, + schema = Seq.empty, + partitionColumns = Seq.empty, + properties = Map( + "spark.sql.sources.provider" -> "json", + "spark.sql.sources.schema" -> schema.json, + "EXTERNAL" -> "FALSE"), + tableType = ManagedTable, + serdeProperties = Map( + "path" -> catalog.hiveDefaultTableFilePath(tableName))) + + catalog.client.createTable(hiveTable) + + invalidateTable(tableName) + val actualSchema = table(tableName).schema + assert(schema === actualSchema) + } } + test("Saving partition columns information") { + val df = (1 to 10).map(i => (i, i + 1, s"str$i", s"str${i + 1}")).toDF("a", "b", "c", "d") + val tableName = s"partitionInfo_${System.currentTimeMillis()}" + + withTable(tableName) { + df.write.format("parquet").partitionBy("d", "b").saveAsTable(tableName) + invalidateTable(tableName) + val metastoreTable = catalog.client.getTable("default", tableName) + val expectedPartitionColumns = StructType(df.schema("d") :: df.schema("b") :: Nil) + val actualPartitionColumns = + StructType( + metastoreTable.partitionColumns.map(c => + StructField(c.name, HiveMetastoreTypes.toDataType(c.hiveType)))) + // Make sure partition columns are correctly stored in metastore. + assert( + expectedPartitionColumns.sameType(actualPartitionColumns), + s"Partitions columns stored in metastore $actualPartitionColumns is not the " + + s"partition columns defined by the saveAsTable operation $expectedPartitionColumns.") + + // Check the content of the saved table. + checkAnswer( + table(tableName).select("c", "b", "d", "a"), + df.select("c", "b", "d", "a")) + } + } test("insert into a table") { - def createDF(from: Int, to: Int): DataFrame = - createDataFrame((from to to).map(i => Tuple2(i, s"str$i"))).toDF("c1", "c2") + def createDF(from: Int, to: Int): DataFrame = { + (from to to).map(i => i -> s"str$i").toDF("c1", "c2") + } - createDF(0, 9).saveAsTable("insertParquet", "parquet") - checkAnswer( - sql("SELECT p.c1, p.c2 FROM insertParquet p WHERE p.c1 > 5"), - (6 to 9).map(i => Row(i, s"str$i"))) + withTable("insertParquet") { + createDF(0, 9).write.format("parquet").saveAsTable("insertParquet") + checkAnswer( + sql("SELECT p.c1, p.c2 FROM insertParquet p WHERE p.c1 > 5"), + (6 to 9).map(i => Row(i, s"str$i"))) - intercept[AnalysisException] { - createDF(10, 19).saveAsTable("insertParquet", "parquet") - } + intercept[AnalysisException] { + createDF(10, 19).write.format("parquet").saveAsTable("insertParquet") + } - createDF(10, 19).saveAsTable("insertParquet", "parquet", SaveMode.Append) - checkAnswer( - sql("SELECT p.c1, p.c2 FROM insertParquet p WHERE p.c1 > 5"), - (6 to 19).map(i => Row(i, s"str$i"))) + createDF(10, 19).write.mode(SaveMode.Append).format("parquet").saveAsTable("insertParquet") + checkAnswer( + sql("SELECT p.c1, p.c2 FROM insertParquet p WHERE p.c1 > 5"), + (6 to 19).map(i => Row(i, s"str$i"))) - createDF(20, 29).saveAsTable("insertParquet", "parquet", SaveMode.Append) - checkAnswer( - sql("SELECT p.c1, c2 FROM insertParquet p WHERE p.c1 > 5 AND p.c1 < 25"), - (6 to 24).map(i => Row(i, s"str$i"))) + createDF(20, 29).write.mode(SaveMode.Append).format("parquet").saveAsTable("insertParquet") + checkAnswer( + sql("SELECT p.c1, c2 FROM insertParquet p WHERE p.c1 > 5 AND p.c1 < 25"), + (6 to 24).map(i => Row(i, s"str$i"))) - intercept[AnalysisException] { - createDF(30, 39).saveAsTable("insertParquet") - } + intercept[AnalysisException] { + createDF(30, 39).write.saveAsTable("insertParquet") + } - createDF(30, 39).saveAsTable("insertParquet", SaveMode.Append) - checkAnswer( - sql("SELECT p.c1, c2 FROM insertParquet p WHERE p.c1 > 5 AND p.c1 < 35"), - (6 to 34).map(i => Row(i, s"str$i"))) + createDF(30, 39).write.mode(SaveMode.Append).saveAsTable("insertParquet") + checkAnswer( + sql("SELECT p.c1, c2 FROM insertParquet p WHERE p.c1 > 5 AND p.c1 < 35"), + (6 to 34).map(i => Row(i, s"str$i"))) - createDF(40, 49).insertInto("insertParquet") - checkAnswer( - sql("SELECT p.c1, c2 FROM insertParquet p WHERE p.c1 > 5 AND p.c1 < 45"), - (6 to 44).map(i => Row(i, s"str$i"))) + createDF(40, 49).write.mode(SaveMode.Append).insertInto("insertParquet") + checkAnswer( + sql("SELECT p.c1, c2 FROM insertParquet p WHERE p.c1 > 5 AND p.c1 < 45"), + (6 to 44).map(i => Row(i, s"str$i"))) - createDF(50, 59).saveAsTable("insertParquet", SaveMode.Overwrite) - checkAnswer( - sql("SELECT p.c1, c2 FROM insertParquet p WHERE p.c1 > 51 AND p.c1 < 55"), - (52 to 54).map(i => Row(i, s"str$i"))) - createDF(60, 69).saveAsTable("insertParquet", SaveMode.Ignore) - checkAnswer( - sql("SELECT p.c1, c2 FROM insertParquet p"), - (50 to 59).map(i => Row(i, s"str$i"))) + createDF(50, 59).write.mode(SaveMode.Overwrite).saveAsTable("insertParquet") + checkAnswer( + sql("SELECT p.c1, c2 FROM insertParquet p WHERE p.c1 > 51 AND p.c1 < 55"), + (52 to 54).map(i => Row(i, s"str$i"))) + createDF(60, 69).write.mode(SaveMode.Ignore).saveAsTable("insertParquet") + checkAnswer( + sql("SELECT p.c1, c2 FROM insertParquet p"), + (50 to 59).map(i => Row(i, s"str$i"))) + + createDF(70, 79).write.mode(SaveMode.Overwrite).insertInto("insertParquet") + checkAnswer( + sql("SELECT p.c1, c2 FROM insertParquet p"), + (70 to 79).map(i => Row(i, s"str$i"))) + } + } + + test("SPARK-8156:create table to specific database by 'use dbname' ") { + + val df = (1 to 3).map(i => (i, s"val_$i", i * 2)).toDF("a", "b", "c") + sqlContext.sql("""create database if not exists testdb8156""") + sqlContext.sql("""use testdb8156""") + df.write + .format("parquet") + .mode(SaveMode.Overwrite) + .saveAsTable("ttt3") - createDF(70, 79).insertInto("insertParquet", overwrite = true) checkAnswer( - sql("SELECT p.c1, c2 FROM insertParquet p"), - (70 to 79).map(i => Row(i, s"str$i"))) + sqlContext.sql("show TABLES in testdb8156").filter("tableName = 'ttt3'"), + Row("ttt3", false)) + sqlContext.sql("""use default""") + sqlContext.sql("""drop database if exists testdb8156 CASCADE""") } } diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/QueryPartitionSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/QueryPartitionSuite.scala index 4990092df6a9..017bc2adc103 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/QueryPartitionSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/QueryPartitionSuite.scala @@ -20,16 +20,17 @@ package org.apache.spark.sql.hive import com.google.common.io.Files import org.apache.spark.sql.{QueryTest, _} -import org.apache.spark.sql.hive.test.TestHive -import org.apache.spark.sql.hive.test.TestHive._ import org.apache.spark.util.Utils class QueryPartitionSuite extends QueryTest { - import org.apache.spark.sql.hive.test.TestHive.implicits._ + + private lazy val ctx = org.apache.spark.sql.hive.test.TestHive + import ctx.implicits._ + import ctx.sql test("SPARK-5068: query data when path doesn't exist"){ - val testData = TestHive.sparkContext.parallelize( + val testData = ctx.sparkContext.parallelize( (1 to 10).map(i => TestData(i, i.toString))).toDF() testData.registerTempTable("testData") @@ -48,8 +49,8 @@ class QueryPartitionSuite extends QueryTest { // test for the exist path checkAnswer(sql("select key,value from table_with_partition"), - testData.toSchemaRDD.collect ++ testData.toSchemaRDD.collect - ++ testData.toSchemaRDD.collect ++ testData.toSchemaRDD.collect) + testData.toDF.collect ++ testData.toDF.collect + ++ testData.toDF.collect ++ testData.toDF.collect) // delete the path of one partition tmpDir.listFiles @@ -58,8 +59,7 @@ class QueryPartitionSuite extends QueryTest { // test for after delete the path checkAnswer(sql("select key,value from table_with_partition"), - testData.toSchemaRDD.collect ++ testData.toSchemaRDD.collect - ++ testData.toSchemaRDD.collect) + testData.toDF.collect ++ testData.toDF.collect ++ testData.toDF.collect) sql("DROP TABLE table_with_partition") sql("DROP TABLE createAndInsertTest") diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/SerializationSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/SerializationSuite.scala index 8afe5459d4f1..93dcb10f7a29 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/SerializationSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/SerializationSuite.scala @@ -17,16 +17,13 @@ package org.apache.spark.sql.hive -import org.scalatest.FunSuite - -import org.apache.spark.SparkConf +import org.apache.spark.{SparkConf, SparkFunSuite} import org.apache.spark.serializer.JavaSerializer -import org.apache.spark.sql.hive.test.TestHive -class SerializationSuite extends FunSuite { +class SerializationSuite extends SparkFunSuite { test("[SPARK-5840] HiveContext should be serializable") { - val hiveContext = TestHive + val hiveContext = org.apache.spark.sql.hive.test.TestHive hiveContext.hiveconf val serializer = new JavaSerializer(new SparkConf()).newInstance() val bytes = serializer.serialize(hiveContext) diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/StatisticsSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/StatisticsSuite.scala index 00a69de9e426..f067ea0d4fc7 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/StatisticsSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/StatisticsSuite.scala @@ -23,13 +23,18 @@ import scala.reflect.ClassTag import org.apache.spark.sql.{Row, SQLConf, QueryTest} import org.apache.spark.sql.execution.joins._ -import org.apache.spark.sql.hive.test.TestHive -import org.apache.spark.sql.hive.test.TestHive._ import org.apache.spark.sql.hive.execution._ class StatisticsSuite extends QueryTest with BeforeAndAfterAll { - TestHive.reset() - TestHive.cacheTables = false + + private lazy val ctx: HiveContext = { + val ctx = org.apache.spark.sql.hive.test.TestHive + ctx.reset() + ctx.cacheTables = false + ctx + } + + import ctx.sql test("parse analyze commands") { def assertAnalyzeCommand(analyzeCommand: String, c: Class[_]) { @@ -72,17 +77,13 @@ class StatisticsSuite extends QueryTest with BeforeAndAfterAll { test("analyze MetastoreRelations") { def queryTotalSize(tableName: String): BigInt = - catalog.lookupRelation(Seq(tableName)).statistics.sizeInBytes + ctx.catalog.lookupRelation(Seq(tableName)).statistics.sizeInBytes // Non-partitioned table sql("CREATE TABLE analyzeTable (key STRING, value STRING)").collect() sql("INSERT INTO TABLE analyzeTable SELECT * FROM src").collect() sql("INSERT INTO TABLE analyzeTable SELECT * FROM src").collect() - // TODO: How does it works? needs to add it back for other hive version. - if (HiveShim.version =="0.12.0") { - assert(queryTotalSize("analyzeTable") === conf.defaultSizeInBytes) - } sql("ANALYZE TABLE analyzeTable COMPUTE STATISTICS noscan") assert(queryTotalSize("analyzeTable") === BigInt(11624)) @@ -110,7 +111,7 @@ class StatisticsSuite extends QueryTest with BeforeAndAfterAll { |SELECT * FROM src """.stripMargin).collect() - assert(queryTotalSize("analyzeTable_part") === conf.defaultSizeInBytes) + assert(queryTotalSize("analyzeTable_part") === ctx.conf.defaultSizeInBytes) sql("ANALYZE TABLE analyzeTable_part COMPUTE STATISTICS noscan") @@ -121,9 +122,9 @@ class StatisticsSuite extends QueryTest with BeforeAndAfterAll { // Try to analyze a temp table sql("""SELECT * FROM src""").registerTempTable("tempTable") intercept[UnsupportedOperationException] { - analyze("tempTable") + ctx.analyze("tempTable") } - catalog.unregisterTable(Seq("tempTable")) + ctx.catalog.unregisterTable(Seq("tempTable")) } test("estimates the size of a test MetastoreRelation") { @@ -151,8 +152,8 @@ class StatisticsSuite extends QueryTest with BeforeAndAfterAll { val sizes = df.queryExecution.analyzed.collect { case r if ct.runtimeClass.isAssignableFrom(r.getClass) => r.statistics.sizeInBytes } - assert(sizes.size === 2 && sizes(0) <= conf.autoBroadcastJoinThreshold - && sizes(1) <= conf.autoBroadcastJoinThreshold, + assert(sizes.size === 2 && sizes(0) <= ctx.conf.autoBroadcastJoinThreshold + && sizes(1) <= ctx.conf.autoBroadcastJoinThreshold, s"query should contain two relations, each of which has size smaller than autoConvertSize") // Using `sparkPlan` because for relevant patterns in HashJoin to be @@ -163,10 +164,10 @@ class StatisticsSuite extends QueryTest with BeforeAndAfterAll { checkAnswer(df, expectedAnswer) // check correctness of output - TestHive.conf.settings.synchronized { - val tmp = conf.autoBroadcastJoinThreshold + ctx.conf.settings.synchronized { + val tmp = ctx.conf.autoBroadcastJoinThreshold - sql(s"""SET ${SQLConf.AUTO_BROADCASTJOIN_THRESHOLD}=-1""") + sql(s"""SET ${SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key}=-1""") df = sql(query) bhj = df.queryExecution.sparkPlan.collect { case j: BroadcastHashJoin => j } assert(bhj.isEmpty, "BroadcastHashJoin still planned even though it is switched off") @@ -175,7 +176,7 @@ class StatisticsSuite extends QueryTest with BeforeAndAfterAll { assert(shj.size === 1, "ShuffledHashJoin should be planned when BroadcastHashJoin is turned off") - sql(s"""SET ${SQLConf.AUTO_BROADCASTJOIN_THRESHOLD}=$tmp""") + sql(s"""SET ${SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key}=$tmp""") } after() @@ -207,8 +208,8 @@ class StatisticsSuite extends QueryTest with BeforeAndAfterAll { .isAssignableFrom(r.getClass) => r.statistics.sizeInBytes } - assert(sizes.size === 2 && sizes(1) <= conf.autoBroadcastJoinThreshold - && sizes(0) <= conf.autoBroadcastJoinThreshold, + assert(sizes.size === 2 && sizes(1) <= ctx.conf.autoBroadcastJoinThreshold + && sizes(0) <= ctx.conf.autoBroadcastJoinThreshold, s"query should contain two relations, each of which has size smaller than autoConvertSize") // Using `sparkPlan` because for relevant patterns in HashJoin to be @@ -221,10 +222,10 @@ class StatisticsSuite extends QueryTest with BeforeAndAfterAll { checkAnswer(df, answer) // check correctness of output - TestHive.conf.settings.synchronized { - val tmp = conf.autoBroadcastJoinThreshold + ctx.conf.settings.synchronized { + val tmp = ctx.conf.autoBroadcastJoinThreshold - sql(s"SET ${SQLConf.AUTO_BROADCASTJOIN_THRESHOLD}=-1") + sql(s"SET ${SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key}=-1") df = sql(leftSemiJoinQuery) bhj = df.queryExecution.sparkPlan.collect { case j: BroadcastLeftSemiJoinHash => j @@ -237,7 +238,7 @@ class StatisticsSuite extends QueryTest with BeforeAndAfterAll { assert(shj.size === 1, "LeftSemiJoinHash should be planned when BroadcastHashJoin is turned off") - sql(s"SET ${SQLConf.AUTO_BROADCASTJOIN_THRESHOLD}=$tmp") + sql(s"SET ${SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key}=$tmp") } } diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/UDFSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/UDFSuite.scala index 85b6bc93d712..4056dee77757 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/UDFSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/UDFSuite.scala @@ -17,20 +17,20 @@ package org.apache.spark.sql.hive -/* Implicits */ - import org.apache.spark.sql.QueryTest -import org.apache.spark.sql.hive.test.TestHive._ case class FunctionResult(f1: String, f2: String) class UDFSuite extends QueryTest { + + private lazy val ctx = org.apache.spark.sql.hive.test.TestHive + test("UDF case insensitive") { - udf.register("random0", () => { Math.random()}) - udf.register("RANDOM1", () => { Math.random()}) - udf.register("strlenScala", (_: String).length + (_:Int)) - assert(sql("SELECT RANDOM0() FROM src LIMIT 1").head().getDouble(0) >= 0.0) - assert(sql("SELECT RANDOm1() FROM src LIMIT 1").head().getDouble(0) >= 0.0) - assert(sql("SELECT strlenscala('test', 1) FROM src LIMIT 1").head().getInt(0) === 5) + ctx.udf.register("random0", () => { Math.random() }) + ctx.udf.register("RANDOM1", () => { Math.random() }) + ctx.udf.register("strlenScala", (_: String).length + (_: Int)) + assert(ctx.sql("SELECT RANDOM0() FROM src LIMIT 1").head().getDouble(0) >= 0.0) + assert(ctx.sql("SELECT RANDOm1() FROM src LIMIT 1").head().getDouble(0) >= 0.0) + assert(ctx.sql("SELECT strlenscala('test', 1) FROM src LIMIT 1").head().getInt(0) === 5) } } diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/client/VersionsSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/client/VersionsSuite.scala index 321dc8d7322b..d52e162acbd0 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/client/VersionsSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/client/VersionsSuite.scala @@ -17,18 +17,25 @@ package org.apache.spark.sql.hive.client -import org.apache.spark.Logging +import java.io.File + +import org.apache.spark.{Logging, SparkFunSuite} import org.apache.spark.sql.catalyst.util.quietly import org.apache.spark.util.Utils -import org.scalatest.FunSuite /** - * A simple set of tests that call the methods of a hive ClientInterface, loading different version - * of hive from maven central. These tests are simple in that they are mostly just testing to make - * sure that reflective calls are not throwing NoSuchMethod error, but the actually functionallity + * A simple set of tests that call the methods of a hive ClientInterface, loading different version + * of hive from maven central. These tests are simple in that they are mostly just testing to make + * sure that reflective calls are not throwing NoSuchMethod error, but the actually functionality * is not fully tested. */ -class VersionsSuite extends FunSuite with Logging { +class VersionsSuite extends SparkFunSuite with Logging { + + // Do not use a temp path here to speed up subsequent executions of the unit test during + // development. + private val ivyPath = Some( + new File(sys.props("java.io.tmpdir"), "hive-ivy-cache").getAbsolutePath()) + private def buildConf() = { lazy val warehousePath = Utils.createTempDir() lazy val metastorePath = Utils.createTempDir() @@ -39,7 +46,7 @@ class VersionsSuite extends FunSuite with Logging { } test("success sanity check") { - val badClient = IsolatedClientLoader.forVersion("13", buildConf()).client + val badClient = IsolatedClientLoader.forVersion("13", buildConf(), ivyPath).client val db = new HiveDatabase("default", "") badClient.createDatabase(db) } @@ -68,19 +75,21 @@ class VersionsSuite extends FunSuite with Logging { // TODO: currently only works on mysql where we manually create the schema... ignore("failure sanity check") { val e = intercept[Throwable] { - val badClient = quietly { IsolatedClientLoader.forVersion("13", buildConf()).client } + val badClient = quietly { + IsolatedClientLoader.forVersion("13", buildConf(), ivyPath).client + } } assert(getNestedMessages(e) contains "Unknown column 'A0.OWNER_NAME' in 'field list'") } - private val versions = Seq("12", "13") + private val versions = Seq("12", "13", "14", "1.0.0", "1.1.0", "1.2.0") private var client: ClientInterface = null versions.foreach { version => test(s"$version: create client") { client = null - client = IsolatedClientLoader.forVersion(version, buildConf()).client + client = IsolatedClientLoader.forVersion(version, buildConf(), ivyPath).client } test(s"$version: createDatabase") { @@ -171,5 +180,12 @@ class VersionsSuite extends FunSuite with Logging { false, false) } + + test(s"$version: create index and reset") { + client.runSqlHive("CREATE TABLE indexed_table (key INT)") + client.runSqlHive("CREATE INDEX index_1 ON TABLE indexed_table(key) " + + "as 'COMPACT' WITH DEFERRED REBUILD") + client.reset() + } } } diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/ConcurrentHiveSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/ConcurrentHiveSuite.scala index 23ece7e7cf6e..b0d3dd44daed 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/ConcurrentHiveSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/ConcurrentHiveSuite.scala @@ -17,11 +17,11 @@ package org.apache.spark.sql.hive.execution -import org.apache.spark.{SparkConf, SparkContext} +import org.apache.spark.{SparkConf, SparkContext, SparkFunSuite} import org.apache.spark.sql.hive.test.TestHiveContext -import org.scalatest.{BeforeAndAfterAll, FunSuite} +import org.scalatest.BeforeAndAfterAll -class ConcurrentHiveSuite extends FunSuite with BeforeAndAfterAll { +class ConcurrentHiveSuite extends SparkFunSuite with BeforeAndAfterAll { ignore("multiple instances not supported") { test("Multiple Hive Instances") { (1 to 10).map { i => diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveComparisonTest.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveComparisonTest.scala index 9c056e493bfd..c9dd4c0935a7 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveComparisonTest.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveComparisonTest.scala @@ -19,9 +19,9 @@ package org.apache.spark.sql.hive.execution import java.io._ -import org.scalatest.{BeforeAndAfterAll, FunSuite, GivenWhenThen} +import org.scalatest.{BeforeAndAfterAll, GivenWhenThen} -import org.apache.spark.Logging +import org.apache.spark.{Logging, SparkFunSuite} import org.apache.spark.sql.sources.DescribeCommand import org.apache.spark.sql.execution.{SetCommand, ExplainCommand} import org.apache.spark.sql.catalyst.planning.PhysicalOperation @@ -40,7 +40,7 @@ import org.apache.spark.sql.hive.test.TestHive * configured using system properties. */ abstract class HiveComparisonTest - extends FunSuite with BeforeAndAfterAll with GivenWhenThen with Logging { + extends SparkFunSuite with BeforeAndAfterAll with GivenWhenThen with Logging { /** * When set, any cache files that result in test failures will be deleted. Used when the test @@ -273,7 +273,7 @@ abstract class HiveComparisonTest } val hiveCacheFiles = queryList.zipWithIndex.map { - case (queryString, i) => + case (queryString, i) => val cachedAnswerName = s"$testCaseName-$i-${getMd5(queryString)}" new File(answerCache, cachedAnswerName) } @@ -304,7 +304,7 @@ abstract class HiveComparisonTest // other DDL has not been executed yet. hiveQueries.foreach(_.logical) val computedResults = (queryList.zipWithIndex, hiveQueries, hiveCacheFiles).zipped.map { - case ((queryString, i), hiveQuery, cachedAnswerFile)=> + case ((queryString, i), hiveQuery, cachedAnswerFile) => try { // Hooks often break the harness and don't really affect our test anyway, don't // even try running them. diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveQuerySuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveQuerySuite.scala index 089a57e25c08..991da2f829ae 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveQuerySuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveQuerySuite.scala @@ -20,17 +20,15 @@ package org.apache.spark.sql.hive.execution import java.io.File import java.util.{Locale, TimeZone} -import org.apache.hadoop.hive.ql.udf.generic.GenericUDTF -import org.apache.hadoop.hive.serde2.objectinspector.primitive.PrimitiveObjectInspectorFactory -import org.apache.hadoop.hive.serde2.objectinspector.{ObjectInspectorFactory, StructObjectInspector, ObjectInspector} -import org.scalatest.BeforeAndAfter - import scala.util.Try +import org.scalatest.BeforeAndAfter + import org.apache.hadoop.hive.conf.HiveConf.ConfVars import org.apache.spark.{SparkFiles, SparkException} import org.apache.spark.sql.{AnalysisException, DataFrame, Row} +import org.apache.spark.sql.catalyst.expressions.Cast import org.apache.spark.sql.catalyst.plans.logical.Project import org.apache.spark.sql.hive._ import org.apache.spark.sql.hive.test.TestHive @@ -59,7 +57,7 @@ class HiveQuerySuite extends HiveComparisonTest with BeforeAndAfter { // https://cwiki.apache.org/confluence/display/Hive/DeveloperGuide+UDTF sql( """ - |CREATE TEMPORARY FUNCTION udtf_count2 + |CREATE TEMPORARY FUNCTION udtf_count2 |AS 'org.apache.spark.sql.hive.execution.GenericUDTFCount2' """.stripMargin) } @@ -111,13 +109,13 @@ class HiveQuerySuite extends HiveComparisonTest with BeforeAndAfter { | SELECT key FROM gen_tmp ORDER BY key ASC; """.stripMargin) - test("multiple generator in projection") { + test("multiple generators in projection") { intercept[AnalysisException] { - sql("SELECT explode(map(key, value)), key FROM src").collect() + sql("SELECT explode(array(key, key)), explode(array(key, key)) FROM src").collect() } intercept[AnalysisException] { - sql("SELECT explode(map(key, value)) as k1, k2, key FROM src").collect() + sql("SELECT explode(array(key, key)) as k1, explode(array(key, key)) FROM src").collect() } } @@ -134,7 +132,7 @@ class HiveQuerySuite extends HiveComparisonTest with BeforeAndAfter { lower("AA"), "10", repeat(lower("AA"), 3), "11", lower(repeat("AA", 3)), "12", - printf("Bb%d", 12), "13", + printf("bb%d", 12), "13", repeat(printf("s%d", 14), 2), "14") FROM src LIMIT 1""") createQueryTest("NaN to Decimal", @@ -326,20 +324,6 @@ class HiveQuerySuite extends HiveComparisonTest with BeforeAndAfter { | FROM src LIMIT 1 """.stripMargin) - createQueryTest("Date comparison test 2", - "SELECT CAST(CAST(0 AS timestamp) AS date) > CAST(0 AS timestamp) FROM src LIMIT 1") - - createQueryTest("Date cast", - """ - | SELECT - | CAST(CAST(0 AS timestamp) AS date), - | CAST(CAST(CAST(0 AS timestamp) AS date) AS string), - | CAST(0 AS timestamp), - | CAST(CAST(0 AS timestamp) AS string), - | CAST(CAST(CAST('1970-01-01 23:00:00' AS timestamp) AS date) AS timestamp) - | FROM src LIMIT 1 - """.stripMargin) - createQueryTest("Simple Average", "SELECT AVG(key) FROM src") @@ -418,6 +402,25 @@ class HiveQuerySuite extends HiveComparisonTest with BeforeAndAfter { |SELECT * FROM createdtable; """.stripMargin) + test("SPARK-7270: consider dynamic partition when comparing table output") { + sql(s"CREATE TABLE test_partition (a STRING) PARTITIONED BY (b BIGINT, c STRING)") + sql(s"CREATE TABLE ptest (a STRING, b BIGINT, c STRING)") + + val analyzedPlan = sql( + """ + |INSERT OVERWRITE table test_partition PARTITION (b=1, c) + |SELECT 'a', 'c' from ptest + """.stripMargin).queryExecution.analyzed + + assertResult(false, "Incorrect cast detected\n" + analyzedPlan) { + var hasCast = false + analyzedPlan.collect { + case p: Project => p.transformExpressionsUp { case c: Cast => hasCast = true; c } + } + hasCast + } + } + createQueryTest("transform", "SELECT TRANSFORM (key) USING 'cat' AS (tKey) FROM src") @@ -857,15 +860,6 @@ class HiveQuerySuite extends HiveComparisonTest with BeforeAndAfter { |WITH serdeproperties('s1'='9') """.stripMargin) } - // Now only verify 0.12.0, and ignore other versions due to binary compatibility - // current TestSerDe.jar is from 0.12.0 - if (HiveShim.version == "0.12.0") { - sql(s"ADD JAR $testJar") - sql( - """ALTER TABLE alter1 SET SERDE 'org.apache.hadoop.hive.serde2.TestSerDe' - |WITH serdeproperties('s1'='9') - """.stripMargin) - } sql("DROP TABLE alter1") } @@ -873,15 +867,13 @@ class HiveQuerySuite extends HiveComparisonTest with BeforeAndAfter { // this is a test case from mapjoin_addjar.q val testJar = TestHive.getHiveFile("hive-hcatalog-core-0.13.1.jar").getCanonicalPath val testData = TestHive.getHiveFile("data/files/sample.json").getCanonicalPath - if (HiveShim.version == "0.13.1") { - sql(s"ADD JAR $testJar") - sql( - """CREATE TABLE t1(a string, b string) - |ROW FORMAT SERDE 'org.apache.hive.hcatalog.data.JsonSerDe'""".stripMargin) - sql(s"""LOAD DATA LOCAL INPATH "$testData" INTO TABLE t1""") - sql("select * from src join t1 on src.key = t1.a") - sql("DROP TABLE t1") - } + sql(s"ADD JAR $testJar") + sql( + """CREATE TABLE t1(a string, b string) + |ROW FORMAT SERDE 'org.apache.hive.hcatalog.data.JsonSerDe'""".stripMargin) + sql(s"""LOAD DATA LOCAL INPATH "$testData" INTO TABLE t1""") + sql("select * from src join t1 on src.key = t1.a") + sql("DROP TABLE t1") } test("ADD FILE command") { @@ -1078,14 +1070,16 @@ class HiveQuerySuite extends HiveComparisonTest with BeforeAndAfter { val testKey = "spark.sql.key.usedfortestonly" val testVal = "test.val.0" val nonexistentKey = "nonexistent" - val KV = "([^=]+)=([^=]*)".r - def collectResults(df: DataFrame): Set[(String, String)] = + def collectResults(df: DataFrame): Set[Any] = df.collect().map { case Row(key: String, value: String) => key -> value - case Row(KV(key, value)) => key -> value + case Row(key: String, defaultValue: String, doc: String) => (key, defaultValue, doc) }.toSet conf.clear() + val expectedConfs = conf.getAllDefinedConfs.toSet + assertResult(expectedConfs)(collectResults(sql("SET -v"))) + // "SET" itself returns all config variables currently specified in SQLConf. // TODO: Should we be listing the default here always? probably... assert(sql("SET").collect().size == 0) @@ -1096,16 +1090,12 @@ class HiveQuerySuite extends HiveComparisonTest with BeforeAndAfter { assert(hiveconf.get(testKey, "") == testVal) assertResult(Set(testKey -> testVal))(collectResults(sql("SET"))) - assertResult(Set(testKey -> testVal))(collectResults(sql("SET -v"))) sql(s"SET ${testKey + testKey}=${testVal + testVal}") assert(hiveconf.get(testKey + testKey, "") == testVal + testVal) assertResult(Set(testKey -> testVal, (testKey + testKey) -> (testVal + testVal))) { collectResults(sql("SET")) } - assertResult(Set(testKey -> testVal, (testKey + testKey) -> (testVal + testVal))) { - collectResults(sql("SET -v")) - } // "SET key" assertResult(Set(testKey -> testVal)) { diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveResolutionSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveResolutionSuite.scala index 8ad362750422..b08db6de2d2f 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveResolutionSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveResolutionSuite.scala @@ -18,7 +18,7 @@ package org.apache.spark.sql.hive.execution import org.apache.spark.sql.AnalysisException -import org.apache.spark.sql.hive.test.TestHive.{sparkContext, jsonRDD, sql} +import org.apache.spark.sql.hive.test.TestHive.{read, sparkContext, jsonRDD, sql} import org.apache.spark.sql.hive.test.TestHive.implicits._ case class Nested(a: Int, B: Int) @@ -31,14 +31,14 @@ case class Data(a: Int, B: Int, n: Nested, nestedArray: Seq[Nested]) class HiveResolutionSuite extends HiveComparisonTest { test("SPARK-3698: case insensitive test for nested data") { - jsonRDD(sparkContext.makeRDD( + read.json(sparkContext.makeRDD( """{"a": [{"a": {"a": 1}}]}""" :: Nil)).registerTempTable("nested") // This should be successfully analyzed sql("SELECT a[0].A.A from nested").queryExecution.analyzed } test("SPARK-5278: check ambiguous reference to fields") { - jsonRDD(sparkContext.makeRDD( + read.json(sparkContext.makeRDD( """{"a": [{"b": 1, "B": 2}]}""" :: Nil)).registerTempTable("nested") // there are 2 filed matching field name "b", we should report Ambiguous reference error @@ -77,7 +77,7 @@ class HiveResolutionSuite extends HiveComparisonTest { test("case insensitivity with scala reflection") { // Test resolution with Scala Reflection - sparkContext.parallelize(Data(1, 2, Nested(1,2), Seq(Nested(1,2))) :: Nil) + sparkContext.parallelize(Data(1, 2, Nested(1, 2), Seq(Nested(1, 2))) :: Nil) .toDF().registerTempTable("caseSensitivityTest") val query = sql("SELECT a, b, A, B, n.a, n.b, n.A, n.B FROM caseSensitivityTest") @@ -88,14 +88,14 @@ class HiveResolutionSuite extends HiveComparisonTest { ignore("case insensitivity with scala reflection joins") { // Test resolution with Scala Reflection - sparkContext.parallelize(Data(1, 2, Nested(1,2), Seq(Nested(1,2))) :: Nil) + sparkContext.parallelize(Data(1, 2, Nested(1, 2), Seq(Nested(1, 2))) :: Nil) .toDF().registerTempTable("caseSensitivityTest") sql("SELECT * FROM casesensitivitytest a JOIN casesensitivitytest b ON a.a = b.a").collect() } test("nested repeated resolution") { - sparkContext.parallelize(Data(1, 2, Nested(1,2), Seq(Nested(1,2))) :: Nil) + sparkContext.parallelize(Data(1, 2, Nested(1, 2), Seq(Nested(1, 2))) :: Nil) .toDF().registerTempTable("nestedRepeatedTest") assert(sql("SELECT nestedArray[0].a FROM nestedRepeatedTest").collect().head(0) === 1) } diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveTableScanSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveTableScanSuite.scala index ab53c6309e08..2209fc2f30a3 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveTableScanSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveTableScanSuite.scala @@ -61,7 +61,7 @@ class HiveTableScanSuite extends HiveComparisonTest { TestHive.sql("select KEY from tb where VALUE='just_for_test' limit 5").collect() TestHive.sql("drop table tb") } - + test("Spark-4077: timestamp query for null value") { TestHive.sql("DROP TABLE IF EXISTS timestamp_query_null") TestHive.sql( @@ -71,12 +71,12 @@ class HiveTableScanSuite extends HiveComparisonTest { FIELDS TERMINATED BY ',' LINES TERMINATED BY '\n' """.stripMargin) - val location = + val location = Utils.getSparkClassLoader.getResource("data/files/issue-4077-data.txt").getFile() - + TestHive.sql(s"LOAD DATA LOCAL INPATH '$location' INTO TABLE timestamp_query_null") - assert(TestHive.sql("SELECT time from timestamp_query_null limit 2").collect() - === Array(Row(java.sql.Timestamp.valueOf("2014-12-11 00:00:00")),Row(null))) + assert(TestHive.sql("SELECT time from timestamp_query_null limit 2").collect() + === Array(Row(java.sql.Timestamp.valueOf("2014-12-11 00:00:00")), Row(null))) TestHive.sql("DROP TABLE timestamp_query_null") } diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveTypeCoercionSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveTypeCoercionSuite.scala index f0f04f8c73fb..197e9bfb02c4 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveTypeCoercionSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveTypeCoercionSuite.scala @@ -59,10 +59,4 @@ class HiveTypeCoercionSuite extends HiveComparisonTest { } assert(numEquals === 1) } - - test("COALESCE with different types") { - intercept[RuntimeException] { - TestHive.sql("""SELECT COALESCE(1, true, "abc") FROM src limit 1""").collect() - } - } } diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveUdfSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveUDFSuite.scala similarity index 83% rename from sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveUdfSuite.scala rename to sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveUDFSuite.scala index 7f49eac49057..192e76ed8c39 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveUdfSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveUDFSuite.scala @@ -22,15 +22,14 @@ import java.util import java.util.Properties import org.apache.hadoop.conf.Configuration -import org.apache.hadoop.hive.ql.udf.generic.{GenericUDAFAverage, GenericUDF} import org.apache.hadoop.hive.ql.udf.generic.GenericUDF.DeferredObject +import org.apache.hadoop.hive.ql.udf.generic.{GenericUDAFAverage, GenericUDF} import org.apache.hadoop.hive.serde2.objectinspector.primitive.PrimitiveObjectInspectorFactory import org.apache.hadoop.hive.serde2.objectinspector.{ObjectInspector, ObjectInspectorFactory} import org.apache.hadoop.hive.serde2.{AbstractSerDe, SerDeStats} import org.apache.hadoop.io.Writable -import org.apache.spark.sql.{QueryTest, Row} import org.apache.spark.sql.hive.test.TestHive - +import org.apache.spark.sql.{AnalysisException, QueryTest, Row} import org.apache.spark.util.Utils import scala.collection.JavaConversions._ @@ -46,10 +45,10 @@ case class ListStringCaseClass(l: Seq[String]) /** * A test suite for Hive custom UDFs. */ -class HiveUdfSuite extends QueryTest { +class HiveUDFSuite extends QueryTest { - import TestHive.{udf, sql} - import TestHive.implicits._ + import org.apache.spark.sql.hive.test.TestHive.implicits._ + import org.apache.spark.sql.hive.test.TestHive.{sql, udf} test("spark sql udf test that returns a struct") { udf.register("getStruct", (_: Int) => Fields(1, 2, 3, 4, 5)) @@ -73,7 +72,7 @@ class HiveUdfSuite extends QueryTest { test("hive struct udf") { sql( """ - |CREATE EXTERNAL TABLE hiveUdfTestTable ( + |CREATE EXTERNAL TABLE hiveUDFTestTable ( | pair STRUCT |) |PARTITIONED BY (partition STRING) @@ -82,15 +81,15 @@ class HiveUdfSuite extends QueryTest { """. stripMargin.format(classOf[PairSerDe].getName)) - val location = Utils.getSparkClassLoader.getResource("data/files/testUdf").getFile + val location = Utils.getSparkClassLoader.getResource("data/files/testUDF").getFile sql(s""" - ALTER TABLE hiveUdfTestTable - ADD IF NOT EXISTS PARTITION(partition='testUdf') + ALTER TABLE hiveUDFTestTable + ADD IF NOT EXISTS PARTITION(partition='testUDF') LOCATION '$location'""") - sql(s"CREATE TEMPORARY FUNCTION testUdf AS '${classOf[PairUdf].getName}'") - sql("SELECT testUdf(pair) FROM hiveUdfTestTable") - sql("DROP TEMPORARY FUNCTION IF EXISTS testUdf") + sql(s"CREATE TEMPORARY FUNCTION testUDF AS '${classOf[PairUDF].getName}'") + sql("SELECT testUDF(pair) FROM hiveUDFTestTable") + sql("DROP TEMPORARY FUNCTION IF EXISTS testUDF") } test("SPARK-6409 UDAFAverage test") { @@ -101,7 +100,7 @@ class HiveUdfSuite extends QueryTest { sql("DROP TEMPORARY FUNCTION IF EXISTS test_avg") TestHive.reset() } - + test("SPARK-2693 udaf aggregates test") { checkAnswer(sql("SELECT percentile(key, 1) FROM src LIMIT 1"), sql("SELECT max(key) FROM src").collect().toSeq) @@ -133,6 +132,32 @@ class HiveUdfSuite extends QueryTest { TestHive.reset() } + test("UDFToListString") { + val testData = TestHive.sparkContext.parallelize(StringCaseClass("") :: Nil).toDF() + testData.registerTempTable("inputTable") + + sql(s"CREATE TEMPORARY FUNCTION testUDFToListString AS '${classOf[UDFToListString].getName}'") + intercept[AnalysisException] { + sql("SELECT testUDFToListString(s) FROM inputTable") + } + + sql("DROP TEMPORARY FUNCTION IF EXISTS testUDFToListString") + TestHive.reset() + } + + test("UDFToListInt") { + val testData = TestHive.sparkContext.parallelize(StringCaseClass("") :: Nil).toDF() + testData.registerTempTable("inputTable") + + sql(s"CREATE TEMPORARY FUNCTION testUDFToListInt AS '${classOf[UDFToListInt].getName}'") + intercept[AnalysisException] { + sql("SELECT testUDFToListInt(s) FROM inputTable") + } + + sql("DROP TEMPORARY FUNCTION IF EXISTS testUDFToListInt") + TestHive.reset() + } + test("UDFListListInt") { val testData = TestHive.sparkContext.parallelize( ListListIntCaseClass(Nil) :: @@ -169,11 +194,11 @@ class HiveUdfSuite extends QueryTest { StringCaseClass("world") :: StringCaseClass("goodbye") :: Nil).toDF() testData.registerTempTable("stringTable") - sql(s"CREATE TEMPORARY FUNCTION testStringStringUdf AS '${classOf[UDFStringString].getName}'") + sql(s"CREATE TEMPORARY FUNCTION testStringStringUDF AS '${classOf[UDFStringString].getName}'") checkAnswer( - sql("SELECT testStringStringUdf(\"hello\", s) FROM stringTable"), + sql("SELECT testStringStringUDF(\"hello\", s) FROM stringTable"), Seq(Row("hello world"), Row("hello goodbye"))) - sql("DROP TEMPORARY FUNCTION IF EXISTS testStringStringUdf") + sql("DROP TEMPORARY FUNCTION IF EXISTS testStringStringUDF") TestHive.reset() } @@ -244,7 +269,7 @@ class PairSerDe extends AbstractSerDe { } } -class PairUdf extends GenericUDF { +class PairUDF extends GenericUDF { override def initialize(p1: Array[ObjectInspector]): ObjectInspector = ObjectInspectorFactory.getStandardStructObjectInspector( Seq("id", "value"), diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala index dfe73c62c42b..6d645393a6da 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala @@ -17,14 +17,16 @@ package org.apache.spark.sql.hive.execution +import java.sql.{Date, Timestamp} + +import org.apache.spark.sql._ import org.apache.spark.sql.catalyst.DefaultParserDialect import org.apache.spark.sql.catalyst.analysis.EliminateSubQueries import org.apache.spark.sql.catalyst.errors.DialectException -import org.apache.spark.sql.{AnalysisException, QueryTest, Row, SQLConf} import org.apache.spark.sql.hive.test.TestHive import org.apache.spark.sql.hive.test.TestHive._ import org.apache.spark.sql.hive.test.TestHive.implicits._ -import org.apache.spark.sql.hive.{HiveQLDialect, HiveShim, MetastoreRelation} +import org.apache.spark.sql.hive.{HiveContext, HiveQLDialect, MetastoreRelation} import org.apache.spark.sql.parquet.ParquetRelation2 import org.apache.spark.sql.sources.LogicalRelation import org.apache.spark.sql.types._ @@ -191,9 +193,9 @@ class SQLQuerySuite extends QueryTest { } } - val originalConf = getConf("spark.sql.hive.convertCTAS", "false") + val originalConf = convertCTAS - setConf("spark.sql.hive.convertCTAS", "true") + setConf(HiveContext.CONVERT_CTAS, true) sql("CREATE TABLE ctas1 AS SELECT key k, value FROM src ORDER BY k, value") sql("CREATE TABLE IF NOT EXISTS ctas1 AS SELECT key k, value FROM src ORDER BY k, value") @@ -235,7 +237,7 @@ class SQLQuerySuite extends QueryTest { checkRelation("ctas1", false) sql("DROP TABLE ctas1") - setConf("spark.sql.hive.convertCTAS", originalConf) + setConf(HiveContext.CONVERT_CTAS, originalConf) } test("SQL Dialect Switching") { @@ -327,41 +329,57 @@ class SQLQuerySuite extends QueryTest { "org.apache.hadoop.hive.ql.io.RCFileInputFormat", "org.apache.hadoop.hive.ql.io.RCFileOutputFormat", "org.apache.hadoop.hive.serde2.columnar.ColumnarSerDe", - "serde_p1=p1", "serde_p2=p2", "tbl_p1=p11", "tbl_p2=p22","MANAGED_TABLE" + "serde_p1=p1", "serde_p2=p2", "tbl_p1=p11", "tbl_p2=p22", "MANAGED_TABLE" ) - if (HiveShim.version =="0.13.1") { - val origUseParquetDataSource = conf.parquetUseDataSourceApi - try { - setConf(SQLConf.PARQUET_USE_DATA_SOURCE_API, "false") - sql( - """CREATE TABLE ctas5 - | STORED AS parquet AS - | SELECT key, value - | FROM src - | ORDER BY key, value""".stripMargin).collect() - - checkExistence(sql("DESC EXTENDED ctas5"), true, - "name:key", "type:string", "name:value", "ctas5", - "org.apache.hadoop.hive.ql.io.parquet.MapredParquetInputFormat", - "org.apache.hadoop.hive.ql.io.parquet.MapredParquetOutputFormat", - "org.apache.hadoop.hive.ql.io.parquet.serde.ParquetHiveSerDe", - "MANAGED_TABLE" - ) - - val default = getConf("spark.sql.hive.convertMetastoreParquet", "true") - // use the Hive SerDe for parquet tables - sql("set spark.sql.hive.convertMetastoreParquet = false") - checkAnswer( - sql("SELECT key, value FROM ctas5 ORDER BY key, value"), - sql("SELECT key, value FROM src ORDER BY key, value").collect().toSeq) - sql(s"set spark.sql.hive.convertMetastoreParquet = $default") - } finally { - setConf(SQLConf.PARQUET_USE_DATA_SOURCE_API, origUseParquetDataSource.toString) - } + val origUseParquetDataSource = conf.parquetUseDataSourceApi + try { + setConf(SQLConf.PARQUET_USE_DATA_SOURCE_API, false) + sql( + """CREATE TABLE ctas5 + | STORED AS parquet AS + | SELECT key, value + | FROM src + | ORDER BY key, value""".stripMargin).collect() + + checkExistence(sql("DESC EXTENDED ctas5"), true, + "name:key", "type:string", "name:value", "ctas5", + "org.apache.hadoop.hive.ql.io.parquet.MapredParquetInputFormat", + "org.apache.hadoop.hive.ql.io.parquet.MapredParquetOutputFormat", + "org.apache.hadoop.hive.ql.io.parquet.serde.ParquetHiveSerDe", + "MANAGED_TABLE" + ) + + val default = convertMetastoreParquet + // use the Hive SerDe for parquet tables + sql("set spark.sql.hive.convertMetastoreParquet = false") + checkAnswer( + sql("SELECT key, value FROM ctas5 ORDER BY key, value"), + sql("SELECT key, value FROM src ORDER BY key, value").collect().toSeq) + sql(s"set spark.sql.hive.convertMetastoreParquet = $default") + } finally { + setConf(SQLConf.PARQUET_USE_DATA_SOURCE_API, origUseParquetDataSource) } } + test("specifying the column list for CTAS") { + Seq((1, "111111"), (2, "222222")).toDF("key", "value").registerTempTable("mytable1") + + sql("create table gen__tmp(a int, b string) as select key, value from mytable1") + checkAnswer( + sql("SELECT a, b from gen__tmp"), + sql("select key, value from mytable1").collect()) + sql("DROP TABLE gen__tmp") + + sql("create table gen__tmp(a double, b double) as select key, value from mytable1") + checkAnswer( + sql("SELECT a, b from gen__tmp"), + sql("select cast(key as double), cast(value as double) from mytable1").collect()) + sql("DROP TABLE gen__tmp") + + sql("drop table mytable1") + } + test("command substitution") { sql("set tbl=src") checkAnswer( @@ -425,10 +443,10 @@ class SQLQuerySuite extends QueryTest { test("SPARK-4825 save join to table") { val testData = sparkContext.parallelize(1 to 10).map(i => TestData(i, i.toString)).toDF() sql("CREATE TABLE test1 (key INT, value STRING)") - testData.insertInto("test1") + testData.write.mode(SaveMode.Append).insertInto("test1") sql("CREATE TABLE test2 (key INT, value STRING)") - testData.insertInto("test2") - testData.insertInto("test2") + testData.write.mode(SaveMode.Append).insertInto("test2") + testData.write.mode(SaveMode.Append).insertInto("test2") sql("CREATE TABLE test AS SELECT COUNT(a.value) FROM test1 a JOIN test2 b ON a.key = b.key") checkAnswer( table("test"), @@ -535,26 +553,49 @@ class SQLQuerySuite extends QueryTest { test("SPARK-4296 Grouping field with Hive UDF as sub expression") { val rdd = sparkContext.makeRDD( """{"a": "str", "b":"1", "c":"1970-01-01 00:00:00"}""" :: Nil) - jsonRDD(rdd).registerTempTable("data") + read.json(rdd).registerTempTable("data") checkAnswer( sql("SELECT concat(a, '-', b), year(c) FROM data GROUP BY concat(a, '-', b), year(c)"), Row("str-1", 1970)) dropTempTable("data") - jsonRDD(rdd).registerTempTable("data") + read.json(rdd).registerTempTable("data") checkAnswer(sql("SELECT year(c) + 1 FROM data GROUP BY year(c) + 1"), Row(1971)) dropTempTable("data") } - test("resolve udtf with single alias") { + test("resolve udtf in projection #1") { val rdd = sparkContext.makeRDD((1 to 5).map(i => s"""{"a":[$i, ${i + 1}]}""")) - jsonRDD(rdd).registerTempTable("data") + read.json(rdd).registerTempTable("data") val df = sql("SELECT explode(a) AS val FROM data") val col = df("val") } + test("resolve udtf in projection #2") { + val rdd = sparkContext.makeRDD((1 to 2).map(i => s"""{"a":[$i, ${i + 1}]}""")) + jsonRDD(rdd).registerTempTable("data") + checkAnswer(sql("SELECT explode(map(1, 1)) FROM data LIMIT 1"), Row(1, 1) :: Nil) + checkAnswer(sql("SELECT explode(map(1, 1)) as (k1, k2) FROM data LIMIT 1"), Row(1, 1) :: Nil) + intercept[AnalysisException] { + sql("SELECT explode(map(1, 1)) as k1 FROM data LIMIT 1") + } + + intercept[AnalysisException] { + sql("SELECT explode(map(1, 1)) as (k1, k2, k3) FROM data LIMIT 1") + } + } + + // TGF with non-TGF in project is allowed in Spark SQL, but not in Hive + test("TGF with non-TGF in projection") { + val rdd = sparkContext.makeRDD( """{"a": "1", "b":"1"}""" :: Nil) + jsonRDD(rdd).registerTempTable("data") + checkAnswer( + sql("SELECT explode(map(a, b)) as (k1, k2), a, b FROM data"), + Row("1", "1", "1", "1") :: Nil) + } + test("logical.Project should not be resolved if it contains aggregates or generators") { // This test is used to test the fix of SPARK-5875. // The original issue was that Project's resolved will be true when it contains @@ -563,9 +604,9 @@ class SQLQuerySuite extends QueryTest { // PreInsertionCasts will actually start to work before ImplicitGenerate and then // generates an invalid query plan. val rdd = sparkContext.makeRDD((1 to 5).map(i => s"""{"a":[$i, ${i + 1}]}""")) - jsonRDD(rdd).registerTempTable("data") - val originalConf = getConf("spark.sql.hive.convertCTAS", "false") - setConf("spark.sql.hive.convertCTAS", "false") + read.json(rdd).registerTempTable("data") + val originalConf = convertCTAS + setConf(HiveContext.CONVERT_CTAS, false) sql("CREATE TABLE explodeTest (key bigInt)") table("explodeTest").queryExecution.analyzed match { @@ -582,7 +623,7 @@ class SQLQuerySuite extends QueryTest { sql("DROP TABLE explodeTest") dropTempTable("data") - setConf("spark.sql.hive.convertCTAS", originalConf) + setConf(HiveContext.CONVERT_CTAS, originalConf) } test("sanity test for SPARK-6618") { @@ -599,19 +640,27 @@ class SQLQuerySuite extends QueryTest { test("SPARK-5203 union with different decimal precision") { Seq.empty[(Decimal, Decimal)] .toDF("d1", "d2") - .select($"d1".cast(DecimalType(10, 15)).as("d")) + .select($"d1".cast(DecimalType(10, 5)).as("d")) .registerTempTable("dn") sql("select d from dn union all select d * 2 from dn") .queryExecution.analyzed } - test("test script transform") { + test("test script transform for stdout") { val data = (1 to 100000).map { i => (i, i, i) } data.toDF("d1", "d2", "d3").registerTempTable("script_trans") assert(100000 === sql("SELECT TRANSFORM (d1, d2, d3) USING 'cat' AS (a,b,c) FROM script_trans") - .queryExecution.toRdd.count()) + .queryExecution.toRdd.count()) + } + + test("test script transform for stderr") { + val data = (1 to 100000).map { i => (i, i, i) } + data.toDF("d1", "d2", "d3").registerTempTable("script_trans") + assert(0 === + sql("SELECT TRANSFORM (d1, d2, d3) USING 'cat 1>&2' AS (a,b,c) FROM script_trans") + .queryExecution.toRdd.count()) } test("window function: udaf with aggregate expressin") { @@ -757,6 +806,42 @@ class SQLQuerySuite extends QueryTest { ).map(i => Row(i._1, i._2, i._3, i._4))) } + test("window function: multiple window expressions in a single expression") { + val nums = sparkContext.parallelize(1 to 10).map(x => (x, x % 2)).toDF("x", "y") + nums.registerTempTable("nums") + + val expected = + Row(1, 1, 1, 55, 1, 57) :: + Row(0, 2, 3, 55, 2, 60) :: + Row(1, 3, 6, 55, 4, 65) :: + Row(0, 4, 10, 55, 6, 71) :: + Row(1, 5, 15, 55, 9, 79) :: + Row(0, 6, 21, 55, 12, 88) :: + Row(1, 7, 28, 55, 16, 99) :: + Row(0, 8, 36, 55, 20, 111) :: + Row(1, 9, 45, 55, 25, 125) :: + Row(0, 10, 55, 55, 30, 140) :: Nil + + val actual = sql( + """ + |SELECT + | y, + | x, + | sum(x) OVER w1 AS running_sum, + | sum(x) OVER w2 AS total_sum, + | sum(x) OVER w3 AS running_sum_per_y, + | ((sum(x) OVER w1) + (sum(x) OVER w2) + (sum(x) OVER w3)) as combined2 + |FROM nums + |WINDOW w1 AS (ORDER BY x ROWS BETWEEN UnBOUNDED PRECEDiNG AND CuRRENT RoW), + | w2 AS (ORDER BY x ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOuNDED FoLLOWING), + | w3 AS (PARTITION BY y ORDER BY x ROWS BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW) + """.stripMargin) + + checkAnswer(actual, expected) + + dropTempTable("nums") + } + test("test case key when") { (1 to 5).map(i => (i, i.toString)).toDF("k", "v").registerTempTable("t") checkAnswer( @@ -773,4 +858,137 @@ class SQLQuerySuite extends QueryTest { | select * from v2 order by key limit 1 """.stripMargin), Row(0, 3)) } + + test("SPARK-7269 Check analysis failed in case in-sensitive") { + Seq(1, 2, 3).map { i => + (i.toString, i.toString) + }.toDF("key", "value").registerTempTable("df_analysis") + sql("SELECT kEy from df_analysis group by key").collect() + sql("SELECT kEy+3 from df_analysis group by key+3").collect() + sql("SELECT kEy+3, a.kEy, A.kEy from df_analysis A group by key").collect() + sql("SELECT cast(kEy+1 as Int) from df_analysis A group by cast(key+1 as int)").collect() + sql("SELECT cast(kEy+1 as Int) from df_analysis A group by key+1").collect() + sql("SELECT 2 from df_analysis A group by key+1").collect() + intercept[AnalysisException] { + sql("SELECT kEy+1 from df_analysis group by key+3") + } + intercept[AnalysisException] { + sql("SELECT cast(key+2 as Int) from df_analysis A group by cast(key+1 as int)") + } + } + + test("Cast STRING to BIGINT") { + checkAnswer(sql("SELECT CAST('775983671874188101' as BIGINT)"), Row(775983671874188101L)) + } + + // `Math.exp(1.0)` has different result for different jdk version, so not use createQueryTest + test("udf_java_method") { + checkAnswer(sql( + """ + |SELECT java_method("java.lang.String", "valueOf", 1), + | java_method("java.lang.String", "isEmpty"), + | java_method("java.lang.Math", "max", 2, 3), + | java_method("java.lang.Math", "min", 2, 3), + | java_method("java.lang.Math", "round", 2.5), + | java_method("java.lang.Math", "exp", 1.0), + | java_method("java.lang.Math", "floor", 1.9) + |FROM src tablesample (1 rows) + """.stripMargin), + Row( + "1", + "true", + java.lang.Math.max(2, 3).toString, + java.lang.Math.min(2, 3).toString, + java.lang.Math.round(2.5).toString, + java.lang.Math.exp(1.0).toString, + java.lang.Math.floor(1.9).toString)) + } + + test("dynamic partition value test") { + try { + sql("set hive.exec.dynamic.partition.mode=nonstrict") + // date + sql("drop table if exists dynparttest1") + sql("create table dynparttest1 (value int) partitioned by (pdate date)") + sql( + """ + |insert into table dynparttest1 partition(pdate) + | select count(*), cast('2015-05-21' as date) as pdate from src + """.stripMargin) + checkAnswer( + sql("select * from dynparttest1"), + Seq(Row(500, java.sql.Date.valueOf("2015-05-21")))) + + // decimal + sql("drop table if exists dynparttest2") + sql("create table dynparttest2 (value int) partitioned by (pdec decimal(5, 1))") + sql( + """ + |insert into table dynparttest2 partition(pdec) + | select count(*), cast('100.12' as decimal(5, 1)) as pdec from src + """.stripMargin) + checkAnswer( + sql("select * from dynparttest2"), + Seq(Row(500, new java.math.BigDecimal("100.1")))) + } finally { + sql("drop table if exists dynparttest1") + sql("drop table if exists dynparttest2") + sql("set hive.exec.dynamic.partition.mode=strict") + } + } + + test("Call add jar in a different thread (SPARK-8306)") { + @volatile var error: Option[Throwable] = None + val thread = new Thread { + override def run() { + // To make sure this test works, this jar should not be loaded in another place. + TestHive.sql( + s"ADD JAR ${TestHive.getHiveFile("hive-contrib-0.13.1.jar").getCanonicalPath()}") + try { + TestHive.sql( + """ + |CREATE TEMPORARY FUNCTION example_max + |AS 'org.apache.hadoop.hive.contrib.udaf.example.UDAFExampleMax' + """.stripMargin) + } catch { + case throwable: Throwable => + error = Some(throwable) + } + } + } + thread.start() + thread.join() + error match { + case Some(throwable) => + fail("CREATE TEMPORARY FUNCTION should not fail.", throwable) + case None => // OK + } + } + + test("SPARK-6785: HiveQuerySuite - Date comparison test 2") { + checkAnswer( + sql("SELECT CAST(CAST(0 AS timestamp) AS date) > CAST(0 AS timestamp) FROM src LIMIT 1"), + Row(false)) + } + + test("SPARK-6785: HiveQuerySuite - Date cast") { + // new Date(0) == 1970-01-01 00:00:00.0 GMT == 1969-12-31 16:00:00.0 PST + checkAnswer( + sql( + """ + | SELECT + | CAST(CAST(0 AS timestamp) AS date), + | CAST(CAST(CAST(0 AS timestamp) AS date) AS string), + | CAST(0 AS timestamp), + | CAST(CAST(0 AS timestamp) AS string), + | CAST(CAST(CAST('1970-01-01 23:00:00' AS timestamp) AS date) AS timestamp) + | FROM src LIMIT 1 + """.stripMargin), + Row( + Date.valueOf("1969-12-31"), + String.valueOf("1969-12-31"), + Timestamp.valueOf("1969-12-31 16:00:00"), + String.valueOf("1969-12-31 16:00:00"), + Timestamp.valueOf("1970-01-01 00:00:00"))) + } } diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/OrcHadoopFsRelationSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/OrcHadoopFsRelationSuite.scala new file mode 100644 index 000000000000..080af5bb23c1 --- /dev/null +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/OrcHadoopFsRelationSuite.scala @@ -0,0 +1,59 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.hive.orc + +import org.apache.hadoop.fs.Path + +import org.apache.spark.deploy.SparkHadoopUtil +import org.apache.spark.sql.sources.HadoopFsRelationTest +import org.apache.spark.sql.types._ + +class OrcHadoopFsRelationSuite extends HadoopFsRelationTest { + override val dataSourceName: String = classOf[DefaultSource].getCanonicalName + + import sqlContext._ + import sqlContext.implicits._ + + test("save()/load() - partitioned table - simple queries - partition columns in data") { + withTempDir { file => + val basePath = new Path(file.getCanonicalPath) + val fs = basePath.getFileSystem(SparkHadoopUtil.get.conf) + val qualifiedBasePath = fs.makeQualified(basePath) + + for (p1 <- 1 to 2; p2 <- Seq("foo", "bar")) { + val partitionDir = new Path(qualifiedBasePath, s"p1=$p1/p2=$p2") + sparkContext + .parallelize(for (i <- 1 to 3) yield (i, s"val_$i", p1)) + .toDF("a", "b", "p1") + .write + .format("orc") + .save(partitionDir.toString) + } + + val dataSchemaWithPartition = + StructType(dataSchema.fields :+ StructField("p1", IntegerType, nullable = true)) + + checkQueries( + load( + source = dataSourceName, + options = Map( + "path" -> file.getCanonicalPath, + "dataSchema" -> dataSchemaWithPartition.json))) + } + } +} diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/OrcPartitionDiscoverySuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/OrcPartitionDiscoverySuite.scala new file mode 100644 index 000000000000..8707f9f936be --- /dev/null +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/OrcPartitionDiscoverySuite.scala @@ -0,0 +1,257 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.hive.orc + +import java.io.File +import org.apache.hadoop.hive.conf.HiveConf.ConfVars +import org.apache.spark.SparkFunSuite +import org.apache.spark.sql._ +import org.apache.spark.sql.catalyst.expressions.InternalRow +import org.apache.spark.sql.hive.test.TestHive +import org.apache.spark.sql.hive.test.TestHive._ +import org.apache.spark.sql.hive.test.TestHive.implicits._ +import org.apache.spark.util.Utils +import org.scalatest.BeforeAndAfterAll + +import scala.reflect.ClassTag +import scala.reflect.runtime.universe.TypeTag + + +// The data where the partitioning key exists only in the directory structure. +case class OrcParData(intField: Int, stringField: String) + +// The data that also includes the partitioning key +case class OrcParDataWithKey(intField: Int, pi: Int, stringField: String, ps: String) + +// TODO This test suite duplicates ParquetPartitionDiscoverySuite a lot +class OrcPartitionDiscoverySuite extends QueryTest with BeforeAndAfterAll { + val defaultPartitionName = ConfVars.DEFAULTPARTITIONNAME.defaultVal + + def withTempDir(f: File => Unit): Unit = { + val dir = Utils.createTempDir().getCanonicalFile + try f(dir) finally Utils.deleteRecursively(dir) + } + + def makeOrcFile[T <: Product: ClassTag: TypeTag]( + data: Seq[T], path: File): Unit = { + data.toDF().write.format("orc").mode("overwrite").save(path.getCanonicalPath) + } + + + def makeOrcFile[T <: Product: ClassTag: TypeTag]( + df: DataFrame, path: File): Unit = { + df.write.format("orc").mode("overwrite").save(path.getCanonicalPath) + } + + protected def withTempTable(tableName: String)(f: => Unit): Unit = { + try f finally TestHive.dropTempTable(tableName) + } + + protected def makePartitionDir( + basePath: File, + defaultPartitionName: String, + partitionCols: (String, Any)*): File = { + val partNames = partitionCols.map { case (k, v) => + val valueString = if (v == null || v == "") defaultPartitionName else v.toString + s"$k=$valueString" + } + + val partDir = partNames.foldLeft(basePath) { (parent, child) => + new File(parent, child) + } + + assert(partDir.mkdirs(), s"Couldn't create directory $partDir") + partDir + } + + test("read partitioned table - normal case") { + withTempDir { base => + for { + pi <- Seq(1, 2) + ps <- Seq("foo", "bar") + } { + makeOrcFile( + (1 to 10).map(i => OrcParData(i, i.toString)), + makePartitionDir(base, defaultPartitionName, "pi" -> pi, "ps" -> ps)) + } + + read.format("orc").load(base.getCanonicalPath).registerTempTable("t") + + withTempTable("t") { + checkAnswer( + sql("SELECT * FROM t"), + for { + i <- 1 to 10 + pi <- Seq(1, 2) + ps <- Seq("foo", "bar") + } yield Row(i, i.toString, pi, ps)) + + checkAnswer( + sql("SELECT intField, pi FROM t"), + for { + i <- 1 to 10 + pi <- Seq(1, 2) + _ <- Seq("foo", "bar") + } yield Row(i, pi)) + + checkAnswer( + sql("SELECT * FROM t WHERE pi = 1"), + for { + i <- 1 to 10 + ps <- Seq("foo", "bar") + } yield Row(i, i.toString, 1, ps)) + + checkAnswer( + sql("SELECT * FROM t WHERE ps = 'foo'"), + for { + i <- 1 to 10 + pi <- Seq(1, 2) + } yield Row(i, i.toString, pi, "foo")) + } + } + } + + test("read partitioned table - partition key included in orc file") { + withTempDir { base => + for { + pi <- Seq(1, 2) + ps <- Seq("foo", "bar") + } { + makeOrcFile( + (1 to 10).map(i => OrcParDataWithKey(i, pi, i.toString, ps)), + makePartitionDir(base, defaultPartitionName, "pi" -> pi, "ps" -> ps)) + } + + read.format("orc").load(base.getCanonicalPath).registerTempTable("t") + + withTempTable("t") { + checkAnswer( + sql("SELECT * FROM t"), + for { + i <- 1 to 10 + pi <- Seq(1, 2) + ps <- Seq("foo", "bar") + } yield Row(i, pi, i.toString, ps)) + + checkAnswer( + sql("SELECT intField, pi FROM t"), + for { + i <- 1 to 10 + pi <- Seq(1, 2) + _ <- Seq("foo", "bar") + } yield Row(i, pi)) + + checkAnswer( + sql("SELECT * FROM t WHERE pi = 1"), + for { + i <- 1 to 10 + ps <- Seq("foo", "bar") + } yield Row(i, 1, i.toString, ps)) + + checkAnswer( + sql("SELECT * FROM t WHERE ps = 'foo'"), + for { + i <- 1 to 10 + pi <- Seq(1, 2) + } yield Row(i, pi, i.toString, "foo")) + } + } + } + + + test("read partitioned table - with nulls") { + withTempDir { base => + for { + // Must be `Integer` rather than `Int` here. `null.asInstanceOf[Int]` results in a zero... + pi <- Seq(1, null.asInstanceOf[Integer]) + ps <- Seq("foo", null.asInstanceOf[String]) + } { + makeOrcFile( + (1 to 10).map(i => OrcParData(i, i.toString)), + makePartitionDir(base, defaultPartitionName, "pi" -> pi, "ps" -> ps)) + } + + read + .format("orc") + .option(ConfVars.DEFAULTPARTITIONNAME.varname, defaultPartitionName) + .load(base.getCanonicalPath) + .registerTempTable("t") + + withTempTable("t") { + checkAnswer( + sql("SELECT * FROM t"), + for { + i <- 1 to 10 + pi <- Seq(1, null.asInstanceOf[Integer]) + ps <- Seq("foo", null.asInstanceOf[String]) + } yield Row(i, i.toString, pi, ps)) + + checkAnswer( + sql("SELECT * FROM t WHERE pi IS NULL"), + for { + i <- 1 to 10 + ps <- Seq("foo", null.asInstanceOf[String]) + } yield Row(i, i.toString, null, ps)) + + checkAnswer( + sql("SELECT * FROM t WHERE ps IS NULL"), + for { + i <- 1 to 10 + pi <- Seq(1, null.asInstanceOf[Integer]) + } yield Row(i, i.toString, pi, null)) + } + } + } + + test("read partitioned table - with nulls and partition keys are included in Orc file") { + withTempDir { base => + for { + pi <- Seq(1, 2) + ps <- Seq("foo", null.asInstanceOf[String]) + } { + makeOrcFile( + (1 to 10).map(i => OrcParDataWithKey(i, pi, i.toString, ps)), + makePartitionDir(base, defaultPartitionName, "pi" -> pi, "ps" -> ps)) + } + + read + .format("orc") + .option(ConfVars.DEFAULTPARTITIONNAME.varname, defaultPartitionName) + .load(base.getCanonicalPath) + .registerTempTable("t") + + withTempTable("t") { + checkAnswer( + sql("SELECT * FROM t"), + for { + i <- 1 to 10 + pi <- Seq(1, 2) + ps <- Seq("foo", null.asInstanceOf[String]) + } yield Row(i, pi, i.toString, ps)) + + checkAnswer( + sql("SELECT * FROM t WHERE ps IS NULL"), + for { + i <- 1 to 10 + pi <- Seq(1, 2) + } yield Row(i, pi, i.toString, null)) + } + } + } +} + diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/OrcQuerySuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/OrcQuerySuite.scala new file mode 100644 index 000000000000..ca131faaeef0 --- /dev/null +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/OrcQuerySuite.scala @@ -0,0 +1,333 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.hive.orc + +import java.io.File + +import org.apache.hadoop.hive.conf.HiveConf.ConfVars +import org.apache.hadoop.hive.ql.io.orc.CompressionKind +import org.scalatest.BeforeAndAfterAll + +import org.apache.spark.sql._ +import org.apache.spark.sql.hive.test.TestHive._ +import org.apache.spark.sql.hive.test.TestHive.implicits._ + +case class AllDataTypesWithNonPrimitiveType( + stringField: String, + intField: Int, + longField: Long, + floatField: Float, + doubleField: Double, + shortField: Short, + byteField: Byte, + booleanField: Boolean, + array: Seq[Int], + arrayContainsNull: Seq[Option[Int]], + map: Map[Int, Long], + mapValueContainsNull: Map[Int, Option[Long]], + data: (Seq[Int], (Int, String))) + +case class BinaryData(binaryData: Array[Byte]) + +case class Contact(name: String, phone: String) + +case class Person(name: String, age: Int, contacts: Seq[Contact]) + +class OrcQuerySuite extends QueryTest with BeforeAndAfterAll with OrcTest { + + def getTempFilePath(prefix: String, suffix: String = ""): File = { + val tempFile = File.createTempFile(prefix, suffix) + tempFile.delete() + tempFile + } + + test("Read/write All Types") { + val data = (0 to 255).map { i => + (s"$i", i, i.toLong, i.toFloat, i.toDouble, i.toShort, i.toByte, i % 2 == 0) + } + + withOrcFile(data) { file => + checkAnswer( + sqlContext.read.format("orc").load(file), + data.toDF().collect()) + } + } + + test("Read/write binary data") { + withOrcFile(BinaryData("test".getBytes("utf8")) :: Nil) { file => + val bytes = read.format("orc").load(file).head().getAs[Array[Byte]](0) + assert(new String(bytes, "utf8") === "test") + } + } + + test("Read/write all types with non-primitive type") { + val data = (0 to 255).map { i => + AllDataTypesWithNonPrimitiveType( + s"$i", i, i.toLong, i.toFloat, i.toDouble, i.toShort, i.toByte, i % 2 == 0, + 0 until i, + (0 until i).map(Option(_).filter(_ % 3 == 0)), + (0 until i).map(i => i -> i.toLong).toMap, + (0 until i).map(i => i -> Option(i.toLong)).toMap + (i -> None), + (0 until i, (i, s"$i"))) + } + + withOrcFile(data) { file => + checkAnswer( + read.format("orc").load(file), + data.toDF().collect()) + } + } + + test("Creating case class RDD table") { + val data = (1 to 100).map(i => (i, s"val_$i")) + sparkContext.parallelize(data).toDF().registerTempTable("t") + withTempTable("t") { + checkAnswer(sql("SELECT * FROM t"), data.toDF().collect()) + } + } + + test("Simple selection form ORC table") { + val data = (1 to 10).map { i => + Person(s"name_$i", i, (0 to 1).map { m => Contact(s"contact_$m", s"phone_$m") }) + } + + withOrcTable(data, "t") { + // ppd: + // leaf-0 = (LESS_THAN_EQUALS age 5) + // expr = leaf-0 + assert(sql("SELECT name FROM t WHERE age <= 5").count() === 5) + + // ppd: + // leaf-0 = (LESS_THAN_EQUALS age 5) + // expr = (not leaf-0) + assertResult(10) { + sql("SELECT name, contacts FROM t where age > 5") + .flatMap(_.getAs[Seq[_]]("contacts")) + .count() + } + + // ppd: + // leaf-0 = (LESS_THAN_EQUALS age 5) + // leaf-1 = (LESS_THAN age 8) + // expr = (and (not leaf-0) leaf-1) + { + val df = sql("SELECT name, contacts FROM t WHERE age > 5 AND age < 8") + assert(df.count() === 2) + assertResult(4) { + df.flatMap(_.getAs[Seq[_]]("contacts")).count() + } + } + + // ppd: + // leaf-0 = (LESS_THAN age 2) + // leaf-1 = (LESS_THAN_EQUALS age 8) + // expr = (or leaf-0 (not leaf-1)) + { + val df = sql("SELECT name, contacts FROM t WHERE age < 2 OR age > 8") + assert(df.count() === 3) + assertResult(6) { + df.flatMap(_.getAs[Seq[_]]("contacts")).count() + } + } + } + } + + test("save and load case class RDD with `None`s as orc") { + val data = ( + None: Option[Int], + None: Option[Long], + None: Option[Float], + None: Option[Double], + None: Option[Boolean] + ) :: Nil + + withOrcFile(data) { file => + checkAnswer( + read.format("orc").load(file), + Row(Seq.fill(5)(null): _*)) + } + } + + // We only support zlib in Hive 0.12.0 now + test("Default compression options for writing to an ORC file") { + withOrcFile((1 to 100).map(i => (i, s"val_$i"))) { file => + assertResult(CompressionKind.ZLIB) { + OrcFileOperator.getFileReader(file).get.getCompression + } + } + } + + // Following codec is supported in hive-0.13.1, ignore it now + ignore("Other compression options for writing to an ORC file - 0.13.1 and above") { + val data = (1 to 100).map(i => (i, s"val_$i")) + val conf = sparkContext.hadoopConfiguration + + conf.set(ConfVars.HIVE_ORC_DEFAULT_COMPRESS.varname, "SNAPPY") + withOrcFile(data) { file => + assertResult(CompressionKind.SNAPPY) { + OrcFileOperator.getFileReader(file).get.getCompression + } + } + + conf.set(ConfVars.HIVE_ORC_DEFAULT_COMPRESS.varname, "NONE") + withOrcFile(data) { file => + assertResult(CompressionKind.NONE) { + OrcFileOperator.getFileReader(file).get.getCompression + } + } + + conf.set(ConfVars.HIVE_ORC_DEFAULT_COMPRESS.varname, "LZO") + withOrcFile(data) { file => + assertResult(CompressionKind.LZO) { + OrcFileOperator.getFileReader(file).get.getCompression + } + } + } + + test("simple select queries") { + withOrcTable((0 until 10).map(i => (i, i.toString)), "t") { + checkAnswer( + sql("SELECT `_1` FROM t where t.`_1` > 5"), + (6 until 10).map(Row.apply(_))) + + checkAnswer( + sql("SELECT `_1` FROM t as tmp where tmp.`_1` < 5"), + (0 until 5).map(Row.apply(_))) + } + } + + test("appending") { + val data = (0 until 10).map(i => (i, i.toString)) + createDataFrame(data).toDF("c1", "c2").registerTempTable("tmp") + withOrcTable(data, "t") { + sql("INSERT INTO TABLE t SELECT * FROM tmp") + checkAnswer(table("t"), (data ++ data).map(Row.fromTuple)) + } + catalog.unregisterTable(Seq("tmp")) + } + + test("overwriting") { + val data = (0 until 10).map(i => (i, i.toString)) + createDataFrame(data).toDF("c1", "c2").registerTempTable("tmp") + withOrcTable(data, "t") { + sql("INSERT OVERWRITE TABLE t SELECT * FROM tmp") + checkAnswer(table("t"), data.map(Row.fromTuple)) + } + catalog.unregisterTable(Seq("tmp")) + } + + test("self-join") { + // 4 rows, cells of column 1 of row 2 and row 4 are null + val data = (1 to 4).map { i => + val maybeInt = if (i % 2 == 0) None else Some(i) + (maybeInt, i.toString) + } + + withOrcTable(data, "t") { + val selfJoin = sql("SELECT * FROM t x JOIN t y WHERE x.`_1` = y.`_1`") + val queryOutput = selfJoin.queryExecution.analyzed.output + + assertResult(4, "Field count mismatches")(queryOutput.size) + assertResult(2, "Duplicated expression ID in query plan:\n $selfJoin") { + queryOutput.filter(_.name == "_1").map(_.exprId).size + } + + checkAnswer(selfJoin, List(Row(1, "1", 1, "1"), Row(3, "3", 3, "3"))) + } + } + + test("nested data - struct with array field") { + val data = (1 to 10).map(i => Tuple1((i, Seq("val_$i")))) + withOrcTable(data, "t") { + checkAnswer(sql("SELECT `_1`.`_2`[0] FROM t"), data.map { + case Tuple1((_, Seq(string))) => Row(string) + }) + } + } + + test("nested data - array of struct") { + val data = (1 to 10).map(i => Tuple1(Seq(i -> "val_$i"))) + withOrcTable(data, "t") { + checkAnswer(sql("SELECT `_1`[0].`_2` FROM t"), data.map { + case Tuple1(Seq((_, string))) => Row(string) + }) + } + } + + test("columns only referenced by pushed down filters should remain") { + withOrcTable((1 to 10).map(Tuple1.apply), "t") { + checkAnswer(sql("SELECT `_1` FROM t WHERE `_1` < 10"), (1 to 9).map(Row.apply(_))) + } + } + + test("SPARK-5309 strings stored using dictionary compression in orc") { + withOrcTable((0 until 1000).map(i => ("same", "run_" + i / 100, 1)), "t") { + checkAnswer( + sql("SELECT `_1`, `_2`, SUM(`_3`) FROM t GROUP BY `_1`, `_2`"), + (0 until 10).map(i => Row("same", "run_" + i, 100))) + + checkAnswer( + sql("SELECT `_1`, `_2`, SUM(`_3`) FROM t WHERE `_2` = 'run_5' GROUP BY `_1`, `_2`"), + List(Row("same", "run_5", 100))) + } + } + + test("SPARK-8501: Avoids discovery schema from empty ORC files") { + withTempPath { dir => + val path = dir.getCanonicalPath + + withTable("empty_orc") { + withTempTable("empty", "single") { + sqlContext.sql( + s"""CREATE TABLE empty_orc(key INT, value STRING) + |STORED AS ORC + |LOCATION '$path' + """.stripMargin) + + val emptyDF = Seq.empty[(Int, String)].toDF("key", "value").coalesce(1) + emptyDF.registerTempTable("empty") + + // This creates 1 empty ORC file with Hive ORC SerDe. We are using this trick because + // Spark SQL ORC data source always avoids write empty ORC files. + sqlContext.sql( + s"""INSERT INTO TABLE empty_orc + |SELECT key, value FROM empty + """.stripMargin) + + val errorMessage = intercept[AnalysisException] { + sqlContext.read.format("orc").load(path) + }.getMessage + + assert(errorMessage.contains("Failed to discover schema from ORC files")) + + val singleRowDF = Seq((0, "foo")).toDF("key", "value").coalesce(1) + singleRowDF.registerTempTable("single") + + sqlContext.sql( + s"""INSERT INTO TABLE empty_orc + |SELECT key, value FROM single + """.stripMargin) + + val df = sqlContext.read.format("orc").load(path) + assert(df.schema === singleRowDF.schema.asNullable) + checkAnswer(df, singleRowDF) + } + } + } + } +} diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/OrcSourceSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/OrcSourceSuite.scala new file mode 100644 index 000000000000..82e08caf4645 --- /dev/null +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/OrcSourceSuite.scala @@ -0,0 +1,146 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.hive.orc + +import java.io.File + +import org.scalatest.BeforeAndAfterAll + +import org.apache.spark.sql.hive.test.TestHive._ +import org.apache.spark.sql.{QueryTest, Row} + +case class OrcData(intField: Int, stringField: String) + +abstract class OrcSuite extends QueryTest with BeforeAndAfterAll { + var orcTableDir: File = null + var orcTableAsDir: File = null + + override def beforeAll(): Unit = { + super.beforeAll() + + orcTableAsDir = File.createTempFile("orctests", "sparksql") + orcTableAsDir.delete() + orcTableAsDir.mkdir() + + // Hack: to prepare orc data files using hive external tables + orcTableDir = File.createTempFile("orctests", "sparksql") + orcTableDir.delete() + orcTableDir.mkdir() + import org.apache.spark.sql.hive.test.TestHive.implicits._ + + sparkContext + .makeRDD(1 to 10) + .map(i => OrcData(i, s"part-$i")) + .toDF() + .registerTempTable(s"orc_temp_table") + + sql( + s"""CREATE EXTERNAL TABLE normal_orc( + | intField INT, + | stringField STRING + |) + |STORED AS ORC + |LOCATION '${orcTableAsDir.getCanonicalPath}' + """.stripMargin) + + sql( + s"""INSERT INTO TABLE normal_orc + |SELECT intField, stringField FROM orc_temp_table + """.stripMargin) + } + + override def afterAll(): Unit = { + orcTableDir.delete() + orcTableAsDir.delete() + } + + test("create temporary orc table") { + checkAnswer(sql("SELECT COUNT(*) FROM normal_orc_source"), Row(10)) + + checkAnswer( + sql("SELECT * FROM normal_orc_source"), + (1 to 10).map(i => Row(i, s"part-$i"))) + + checkAnswer( + sql("SELECT * FROM normal_orc_source where intField > 5"), + (6 to 10).map(i => Row(i, s"part-$i"))) + + checkAnswer( + sql("SELECT COUNT(intField), stringField FROM normal_orc_source GROUP BY stringField"), + (1 to 10).map(i => Row(1, s"part-$i"))) + } + + test("create temporary orc table as") { + checkAnswer(sql("SELECT COUNT(*) FROM normal_orc_as_source"), Row(10)) + + checkAnswer( + sql("SELECT * FROM normal_orc_source"), + (1 to 10).map(i => Row(i, s"part-$i"))) + + checkAnswer( + sql("SELECT * FROM normal_orc_source WHERE intField > 5"), + (6 to 10).map(i => Row(i, s"part-$i"))) + + checkAnswer( + sql("SELECT COUNT(intField), stringField FROM normal_orc_source GROUP BY stringField"), + (1 to 10).map(i => Row(1, s"part-$i"))) + } + + test("appending insert") { + sql("INSERT INTO TABLE normal_orc_source SELECT * FROM orc_temp_table WHERE intField > 5") + + checkAnswer( + sql("SELECT * FROM normal_orc_source"), + (1 to 5).map(i => Row(i, s"part-$i")) ++ (6 to 10).flatMap { i => + Seq.fill(2)(Row(i, s"part-$i")) + }) + } + + test("overwrite insert") { + sql( + """INSERT OVERWRITE TABLE normal_orc_as_source + |SELECT * FROM orc_temp_table WHERE intField > 5 + """.stripMargin) + + checkAnswer( + sql("SELECT * FROM normal_orc_as_source"), + (6 to 10).map(i => Row(i, s"part-$i"))) + } +} + +class OrcSourceSuite extends OrcSuite { + override def beforeAll(): Unit = { + super.beforeAll() + + sql( + s"""CREATE TEMPORARY TABLE normal_orc_source + |USING org.apache.spark.sql.hive.orc + |OPTIONS ( + | PATH '${new File(orcTableAsDir.getAbsolutePath).getCanonicalPath}' + |) + """.stripMargin) + + sql( + s"""CREATE TEMPORARY TABLE normal_orc_as_source + |USING org.apache.spark.sql.hive.orc + |OPTIONS ( + | PATH '${new File(orcTableAsDir.getAbsolutePath).getCanonicalPath}' + |) + """.stripMargin) + } +} diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/OrcTest.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/OrcTest.scala new file mode 100644 index 000000000000..5daf691aa8c5 --- /dev/null +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/OrcTest.scala @@ -0,0 +1,80 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.hive.orc + +import java.io.File + +import scala.reflect.ClassTag +import scala.reflect.runtime.universe.TypeTag + +import org.apache.spark.sql.test.SQLTestUtils +import org.apache.spark.sql._ + +private[sql] trait OrcTest extends SQLTestUtils { + lazy val sqlContext = org.apache.spark.sql.hive.test.TestHive + + import sqlContext.sparkContext + import sqlContext.implicits._ + + /** + * Writes `data` to a Orc file, which is then passed to `f` and will be deleted after `f` + * returns. + */ + protected def withOrcFile[T <: Product: ClassTag: TypeTag] + (data: Seq[T]) + (f: String => Unit): Unit = { + withTempPath { file => + sparkContext.parallelize(data).toDF().write.format("orc").save(file.getCanonicalPath) + f(file.getCanonicalPath) + } + } + + /** + * Writes `data` to a Orc file and reads it back as a [[DataFrame]], + * which is then passed to `f`. The Orc file will be deleted after `f` returns. + */ + protected def withOrcDataFrame[T <: Product: ClassTag: TypeTag] + (data: Seq[T]) + (f: DataFrame => Unit): Unit = { + withOrcFile(data)(path => f(sqlContext.read.format("orc").load(path))) + } + + /** + * Writes `data` to a Orc file, reads it back as a [[DataFrame]] and registers it as a + * temporary table named `tableName`, then call `f`. The temporary table together with the + * Orc file will be dropped/deleted after `f` returns. + */ + protected def withOrcTable[T <: Product: ClassTag: TypeTag] + (data: Seq[T], tableName: String) + (f: => Unit): Unit = { + withOrcDataFrame(data) { df => + sqlContext.registerDataFrameAsTable(df, tableName) + withTempTable(tableName)(f) + } + } + + protected def makeOrcFile[T <: Product: ClassTag: TypeTag]( + data: Seq[T], path: File): Unit = { + data.toDF().write.format("orc").mode(SaveMode.Overwrite).save(path.getCanonicalPath) + } + + protected def makeOrcFile[T <: Product: ClassTag: TypeTag]( + df: DataFrame, path: File): Unit = { + df.write.format("orc").mode(SaveMode.Overwrite).save(path.getCanonicalPath) + } +} diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/parquetSuites.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/parquetSuites.scala index b6be09e2f883..c2e09800933b 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/parquetSuites.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/parquetSuites.scala @@ -21,7 +21,6 @@ import java.io.File import org.scalatest.BeforeAndAfterAll -import org.apache.spark.sql.catalyst.expressions.Row import org.apache.spark.sql.execution.{ExecutedCommand, PhysicalRDD} import org.apache.spark.sql.hive.execution.HiveTableScan import org.apache.spark.sql.hive.test.TestHive._ @@ -29,7 +28,7 @@ import org.apache.spark.sql.hive.test.TestHive.implicits._ import org.apache.spark.sql.parquet.{ParquetRelation2, ParquetTableScan} import org.apache.spark.sql.sources.{InsertIntoDataSource, InsertIntoHadoopFsRelation, LogicalRelation} import org.apache.spark.sql.types._ -import org.apache.spark.sql.{QueryTest, SQLConf, SaveMode} +import org.apache.spark.sql.{DataFrame, QueryTest, Row, SQLConf, SaveMode} import org.apache.spark.util.Utils // The data where the partitioning key exists only in the directory structure. @@ -37,7 +36,7 @@ case class ParquetData(intField: Int, stringField: String) // The data that also includes the partitioning key case class ParquetDataWithKey(p: Int, intField: Int, stringField: String) -case class StructContainer(intStructField :Int, stringStructField: String) +case class StructContainer(intStructField: Int, stringStructField: String) case class ParquetDataWithComplexTypes( intField: Int, @@ -150,11 +149,11 @@ class ParquetMetastoreSuiteBase extends ParquetPartitioningTest { } val rdd1 = sparkContext.parallelize((1 to 10).map(i => s"""{"a":$i, "b":"str$i"}""")) - jsonRDD(rdd1).registerTempTable("jt") + read.json(rdd1).registerTempTable("jt") val rdd2 = sparkContext.parallelize((1 to 10).map(i => s"""{"a":[$i, null]}""")) - jsonRDD(rdd2).registerTempTable("jt_array") + read.json(rdd2).registerTempTable("jt_array") - setConf("spark.sql.hive.convertMetastoreParquet", "true") + setConf(HiveContext.CONVERT_METASTORE_PARQUET, true) } override def afterAll(): Unit = { @@ -165,7 +164,7 @@ class ParquetMetastoreSuiteBase extends ParquetPartitioningTest { sql("DROP TABLE normal_parquet") sql("DROP TABLE IF EXISTS jt") sql("DROP TABLE IF EXISTS jt_array") - setConf("spark.sql.hive.convertMetastoreParquet", "false") + setConf(HiveContext.CONVERT_METASTORE_PARQUET, false) } test(s"conversion is working") { @@ -200,14 +199,14 @@ class ParquetDataSourceOnMetastoreSuite extends ParquetMetastoreSuiteBase { | OUTPUTFORMAT 'org.apache.hadoop.hive.ql.io.parquet.MapredParquetOutputFormat' """.stripMargin) - conf.setConf(SQLConf.PARQUET_USE_DATA_SOURCE_API, "true") + conf.setConf(SQLConf.PARQUET_USE_DATA_SOURCE_API, true) } override def afterAll(): Unit = { super.afterAll() sql("DROP TABLE IF EXISTS test_parquet") - setConf(SQLConf.PARQUET_USE_DATA_SOURCE_API, originalConf.toString) + setConf(SQLConf.PARQUET_USE_DATA_SOURCE_API, originalConf) } test("scan an empty parquet table") { @@ -315,7 +314,7 @@ class ParquetDataSourceOnMetastoreSuite extends ParquetMetastoreSuiteBase { val df = sql("INSERT INTO TABLE test_insert_parquet SELECT a FROM jt") df.queryExecution.executedPlan match { - case ExecutedCommand(InsertIntoHadoopFsRelation(_: ParquetRelation2, _, _, _)) => // OK + case ExecutedCommand(InsertIntoHadoopFsRelation(_: ParquetRelation2, _, _)) => // OK case o => fail("test_insert_parquet should be converted to a " + s"${classOf[ParquetRelation2].getCanonicalName} and " + s"${classOf[InsertIntoDataSource].getCanonicalName} is expcted as the SparkPlan. " + @@ -345,7 +344,7 @@ class ParquetDataSourceOnMetastoreSuite extends ParquetMetastoreSuiteBase { val df = sql("INSERT INTO TABLE test_insert_parquet SELECT a FROM jt_array") df.queryExecution.executedPlan match { - case ExecutedCommand(InsertIntoHadoopFsRelation(r: ParquetRelation2, _, _, _)) => // OK + case ExecutedCommand(InsertIntoHadoopFsRelation(r: ParquetRelation2, _, _)) => // OK case o => fail("test_insert_parquet should be converted to a " + s"${classOf[ParquetRelation2].getCanonicalName} and " + s"${classOf[InsertIntoDataSource].getCanonicalName} is expcted as the SparkPlan." + @@ -385,10 +384,58 @@ class ParquetDataSourceOnMetastoreSuite extends ParquetMetastoreSuiteBase { sql("DROP TABLE ms_convert") } + def collectParquetRelation(df: DataFrame): ParquetRelation2 = { + val plan = df.queryExecution.analyzed + plan.collectFirst { + case LogicalRelation(r: ParquetRelation2) => r + }.getOrElse { + fail(s"Expecting a ParquetRelation2, but got:\n$plan") + } + } + + test("SPARK-7749: non-partitioned metastore Parquet table lookup should use cached relation") { + sql( + s"""CREATE TABLE nonPartitioned ( + | key INT, + | value STRING + |) + |STORED AS PARQUET + """.stripMargin) + + // First lookup fills the cache + val r1 = collectParquetRelation(table("nonPartitioned")) + // Second lookup should reuse the cache + val r2 = collectParquetRelation(table("nonPartitioned")) + // They should be the same instance + assert(r1 eq r2) + + sql("DROP TABLE nonPartitioned") + } + + test("SPARK-7749: partitioned metastore Parquet table lookup should use cached relation") { + sql( + s"""CREATE TABLE partitioned ( + | key INT, + | value STRING + |) + |PARTITIONED BY (part INT) + |STORED AS PARQUET + """.stripMargin) + + // First lookup fills the cache + val r1 = collectParquetRelation(table("partitioned")) + // Second lookup should reuse the cache + val r2 = collectParquetRelation(table("partitioned")) + // They should be the same instance + assert(r1 eq r2) + + sql("DROP TABLE partitioned") + } + test("Caching converted data source Parquet Relations") { - def checkCached(tableIdentifer: catalog.QualifiedTableName): Unit = { + def checkCached(tableIdentifier: catalog.QualifiedTableName): Unit = { // Converted test_parquet should be cached. - catalog.cachedDataSourceTables.getIfPresent(tableIdentifer) match { + catalog.cachedDataSourceTables.getIfPresent(tableIdentifier) match { case null => fail("Converted test_parquet should be cached in the cache.") case logical @ LogicalRelation(parquetRelation: ParquetRelation2) => // OK case other => @@ -414,30 +461,30 @@ class ParquetDataSourceOnMetastoreSuite extends ParquetMetastoreSuiteBase { | OUTPUTFORMAT 'org.apache.hadoop.hive.ql.io.parquet.MapredParquetOutputFormat' """.stripMargin) - var tableIdentifer = catalog.QualifiedTableName("default", "test_insert_parquet") + var tableIdentifier = catalog.QualifiedTableName("default", "test_insert_parquet") // First, make sure the converted test_parquet is not cached. - assert(catalog.cachedDataSourceTables.getIfPresent(tableIdentifer) === null) + assert(catalog.cachedDataSourceTables.getIfPresent(tableIdentifier) === null) // Table lookup will make the table cached. table("test_insert_parquet") - checkCached(tableIdentifer) + checkCached(tableIdentifier) // For insert into non-partitioned table, we will do the conversion, // so the converted test_insert_parquet should be cached. invalidateTable("test_insert_parquet") - assert(catalog.cachedDataSourceTables.getIfPresent(tableIdentifer) === null) + assert(catalog.cachedDataSourceTables.getIfPresent(tableIdentifier) === null) sql( """ |INSERT INTO TABLE test_insert_parquet |select a, b from jt """.stripMargin) - checkCached(tableIdentifer) + checkCached(tableIdentifier) // Make sure we can read the data. checkAnswer( sql("select * from test_insert_parquet"), sql("select a, b from jt").collect()) // Invalidate the cache. invalidateTable("test_insert_parquet") - assert(catalog.cachedDataSourceTables.getIfPresent(tableIdentifer) === null) + assert(catalog.cachedDataSourceTables.getIfPresent(tableIdentifier) === null) // Create a partitioned table. sql( @@ -454,8 +501,8 @@ class ParquetDataSourceOnMetastoreSuite extends ParquetMetastoreSuiteBase { | OUTPUTFORMAT 'org.apache.hadoop.hive.ql.io.parquet.MapredParquetOutputFormat' """.stripMargin) - tableIdentifer = catalog.QualifiedTableName("default", "test_parquet_partitioned_cache_test") - assert(catalog.cachedDataSourceTables.getIfPresent(tableIdentifer) === null) + tableIdentifier = catalog.QualifiedTableName("default", "test_parquet_partitioned_cache_test") + assert(catalog.cachedDataSourceTables.getIfPresent(tableIdentifier) === null) sql( """ |INSERT INTO TABLE test_parquet_partitioned_cache_test @@ -464,18 +511,18 @@ class ParquetDataSourceOnMetastoreSuite extends ParquetMetastoreSuiteBase { """.stripMargin) // Right now, insert into a partitioned Parquet is not supported in data source Parquet. // So, we expect it is not cached. - assert(catalog.cachedDataSourceTables.getIfPresent(tableIdentifer) === null) + assert(catalog.cachedDataSourceTables.getIfPresent(tableIdentifier) === null) sql( """ |INSERT INTO TABLE test_parquet_partitioned_cache_test |PARTITION (date='2015-04-02') |select a, b from jt """.stripMargin) - assert(catalog.cachedDataSourceTables.getIfPresent(tableIdentifer) === null) + assert(catalog.cachedDataSourceTables.getIfPresent(tableIdentifier) === null) // Make sure we can cache the partitioned table. table("test_parquet_partitioned_cache_test") - checkCached(tableIdentifer) + checkCached(tableIdentifier) // Make sure we can read the data. checkAnswer( sql("select STRINGField, date, intField from test_parquet_partitioned_cache_test"), @@ -487,7 +534,7 @@ class ParquetDataSourceOnMetastoreSuite extends ParquetMetastoreSuiteBase { """.stripMargin).collect()) invalidateTable("test_parquet_partitioned_cache_test") - assert(catalog.cachedDataSourceTables.getIfPresent(tableIdentifer) === null) + assert(catalog.cachedDataSourceTables.getIfPresent(tableIdentifier) === null) sql("DROP TABLE test_insert_parquet") sql("DROP TABLE test_parquet_partitioned_cache_test") @@ -499,12 +546,12 @@ class ParquetDataSourceOffMetastoreSuite extends ParquetMetastoreSuiteBase { override def beforeAll(): Unit = { super.beforeAll() - conf.setConf(SQLConf.PARQUET_USE_DATA_SOURCE_API, "false") + conf.setConf(SQLConf.PARQUET_USE_DATA_SOURCE_API, false) } override def afterAll(): Unit = { super.afterAll() - setConf(SQLConf.PARQUET_USE_DATA_SOURCE_API, originalConf.toString) + setConf(SQLConf.PARQUET_USE_DATA_SOURCE_API, originalConf) } test("MetastoreRelation in InsertIntoTable will not be converted") { @@ -617,16 +664,16 @@ class ParquetSourceSuiteBase extends ParquetPartitioningTest { sql("drop table if exists spark_6016_fix") // Create a DataFrame with two partitions. So, the created table will have two parquet files. - val df1 = jsonRDD(sparkContext.parallelize((1 to 10).map(i => s"""{"a":$i}"""), 2)) - df1.saveAsTable("spark_6016_fix", "parquet", SaveMode.Overwrite) + val df1 = read.json(sparkContext.parallelize((1 to 10).map(i => s"""{"a":$i}"""), 2)) + df1.write.mode(SaveMode.Overwrite).format("parquet").saveAsTable("spark_6016_fix") checkAnswer( sql("select * from spark_6016_fix"), (1 to 10).map(i => Row(i)) ) // Create a DataFrame with four partitions. So, the created table will have four parquet files. - val df2 = jsonRDD(sparkContext.parallelize((1 to 10).map(i => s"""{"b":$i}"""), 4)) - df2.saveAsTable("spark_6016_fix", "parquet", SaveMode.Overwrite) + val df2 = read.json(sparkContext.parallelize((1 to 10).map(i => s"""{"b":$i}"""), 4)) + df2.write.mode(SaveMode.Overwrite).format("parquet").saveAsTable("spark_6016_fix") // For the bug of SPARK-6016, we are caching two outdated footers for df1. Then, // since the new table has four parquet files, we are trying to read new footers from two files // and then merge metadata in footers of these four (two outdated ones and two latest one), @@ -645,12 +692,12 @@ class ParquetDataSourceOnSourceSuite extends ParquetSourceSuiteBase { override def beforeAll(): Unit = { super.beforeAll() - conf.setConf(SQLConf.PARQUET_USE_DATA_SOURCE_API, "true") + conf.setConf(SQLConf.PARQUET_USE_DATA_SOURCE_API, true) } override def afterAll(): Unit = { super.afterAll() - setConf(SQLConf.PARQUET_USE_DATA_SOURCE_API, originalConf.toString) + setConf(SQLConf.PARQUET_USE_DATA_SOURCE_API, originalConf) } test("values in arrays and maps stored in parquet are always nullable") { @@ -663,7 +710,7 @@ class ParquetDataSourceOnSourceSuite extends ParquetSourceSuiteBase { StructField("a", arrayType1, nullable = true) :: Nil) assert(df.schema === expectedSchema1) - df.saveAsTable("alwaysNullable", "parquet") + df.write.format("parquet").saveAsTable("alwaysNullable") val mapType2 = MapType(IntegerType, IntegerType, valueContainsNull = true) val arrayType2 = ArrayType(IntegerType, containsNull = true) @@ -686,13 +733,13 @@ class ParquetDataSourceOnSourceSuite extends ParquetSourceSuiteBase { val filePath = new File(tempDir, "testParquet").getCanonicalPath val filePath2 = new File(tempDir, "testParquet2").getCanonicalPath - val df = Seq(1,2,3).map(i => (i, i.toString)).toDF("int", "str") + val df = Seq(1, 2, 3).map(i => (i, i.toString)).toDF("int", "str") val df2 = df.as('x).join(df.as('y), $"x.str" === $"y.str").groupBy("y.str").max("y.int") - intercept[Throwable](df2.saveAsParquetFile(filePath)) + intercept[Throwable](df2.write.parquet(filePath)) val df3 = df2.toDF("str", "max_int") - df3.saveAsParquetFile(filePath2) - val df4 = parquetFile(filePath2) + df3.write.parquet(filePath2) + val df4 = read.parquet(filePath2) checkAnswer(df4, Row("1", 1) :: Row("2", 2) :: Row("3", 3) :: Nil) assert(df4.columns === Array("str", "max_int")) } @@ -703,12 +750,12 @@ class ParquetDataSourceOffSourceSuite extends ParquetSourceSuiteBase { override def beforeAll(): Unit = { super.beforeAll() - conf.setConf(SQLConf.PARQUET_USE_DATA_SOURCE_API, "false") + conf.setConf(SQLConf.PARQUET_USE_DATA_SOURCE_API, false) } override def afterAll(): Unit = { super.afterAll() - setConf(SQLConf.PARQUET_USE_DATA_SOURCE_API, originalConf.toString) + setConf(SQLConf.PARQUET_USE_DATA_SOURCE_API, originalConf) } } @@ -731,14 +778,14 @@ abstract class ParquetPartitioningTest extends QueryTest with BeforeAndAfterAll sparkContext.makeRDD(1 to 10) .map(i => ParquetData(i, s"part-$p")) .toDF() - .saveAsParquetFile(partDir.getCanonicalPath) + .write.parquet(partDir.getCanonicalPath) } sparkContext .makeRDD(1 to 10) .map(i => ParquetData(i, s"part-1")) .toDF() - .saveAsParquetFile(new File(normalTableDir, "normal").getCanonicalPath) + .write.parquet(new File(normalTableDir, "normal").getCanonicalPath) partitionedTableDirWithKey = Utils.createTempDir() @@ -747,7 +794,7 @@ abstract class ParquetPartitioningTest extends QueryTest with BeforeAndAfterAll sparkContext.makeRDD(1 to 10) .map(i => ParquetDataWithKey(p, i, s"part-$p")) .toDF() - .saveAsParquetFile(partDir.getCanonicalPath) + .write.parquet(partDir.getCanonicalPath) } partitionedTableDirWithKeyAndComplexTypes = Utils.createTempDir() @@ -757,7 +804,7 @@ abstract class ParquetPartitioningTest extends QueryTest with BeforeAndAfterAll sparkContext.makeRDD(1 to 10).map { i => ParquetDataWithKeyAndComplexTypes( p, i, s"part-$p", StructContainer(i, f"${i}_string"), 1 to i) - }.toDF().saveAsParquetFile(partDir.getCanonicalPath) + }.toDF().write.parquet(partDir.getCanonicalPath) } partitionedTableDirWithComplexTypes = Utils.createTempDir() @@ -766,7 +813,7 @@ abstract class ParquetPartitioningTest extends QueryTest with BeforeAndAfterAll val partDir = new File(partitionedTableDirWithComplexTypes, s"p=$p") sparkContext.makeRDD(1 to 10).map { i => ParquetDataWithComplexTypes(i, s"part-$p", StructContainer(i, f"${i}_string"), 1 to i) - }.toDF().saveAsParquetFile(partDir.getCanonicalPath) + }.toDF().write.parquet(partDir.getCanonicalPath) } } diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/sources/SimpleTextRelation.scala b/sql/hive/src/test/scala/org/apache/spark/sql/sources/SimpleTextRelation.scala index 29b21586f9c2..e8141923a9b5 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/sources/SimpleTextRelation.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/sources/SimpleTextRelation.scala @@ -21,14 +21,15 @@ import java.text.NumberFormat import java.util.UUID import com.google.common.base.Objects -import org.apache.hadoop.fs.Path +import org.apache.hadoop.fs.{FileStatus, Path} import org.apache.hadoop.io.{NullWritable, Text} import org.apache.hadoop.mapreduce.lib.output.{FileOutputFormat, TextOutputFormat} import org.apache.hadoop.mapreduce.{Job, RecordWriter, TaskAttemptContext} import org.apache.spark.rdd.RDD +import org.apache.spark.sql.catalyst.CatalystTypeConverters import org.apache.spark.sql.catalyst.expressions.{Cast, Literal} -import org.apache.spark.sql.types.{DataType, StructField, StructType} +import org.apache.spark.sql.types.{DataType, StructType} import org.apache.spark.sql.{Row, SQLContext} /** @@ -52,9 +53,10 @@ class AppendingTextOutputFormat(outputFile: Path) extends TextOutputFormat[NullW numberFormat.setGroupingUsed(false) override def getDefaultWorkFile(context: TaskAttemptContext, extension: String): Path = { + val uniqueWriteJobId = context.getConfiguration.get("spark.sql.sources.writeJobUUID") val split = context.getTaskAttemptID.getTaskID.getId val name = FileOutputFormat.getOutputName(context) - new Path(outputFile, s"$name-${numberFormat.format(split)}-${UUID.randomUUID()}") + new Path(outputFile, s"$name-${numberFormat.format(split)}-$uniqueWriteJobId") } } @@ -67,7 +69,9 @@ class SimpleTextOutputWriter(path: String, context: TaskAttemptContext) extends recordWriter.write(null, new Text(serialized)) } - override def close(): Unit = recordWriter.close(context) + override def close(): Unit = { + recordWriter.close(context) + } } /** @@ -99,19 +103,24 @@ class SimpleTextRelation( } override def hashCode(): Int = - Objects.hashCode(paths, maybeDataSchema, dataSchema) + Objects.hashCode(paths, maybeDataSchema, dataSchema, partitionColumns) - override def buildScan(inputPaths: Array[String]): RDD[Row] = { + override def buildScan(inputStatuses: Array[FileStatus]): RDD[Row] = { val fields = dataSchema.map(_.dataType) - sparkContext.textFile(inputPaths.mkString(",")).map { record => + sparkContext.textFile(inputStatuses.map(_.getPath).mkString(",")).map { record => Row(record.split(",").zip(fields).map { case (value, dataType) => - Cast(Literal(value), dataType).eval() + // `Cast`ed values are always of Catalyst types (i.e. UTF8String instead of String, etc.) + val catalystValue = Cast(Literal(value), dataType).eval() + // Here we're converting Catalyst values to Scala values to test `needsConversion` + CatalystTypeConverters.convertToScala(catalystValue, dataType) }: _*) } } override def prepareJobForWrite(job: Job): OutputWriterFactory = new OutputWriterFactory { + job.setOutputFormatClass(classOf[TextOutputFormat[_, _]]) + override def newInstance( path: String, dataSchema: StructType, @@ -120,3 +129,40 @@ class SimpleTextRelation( } } } + +/** + * A simple example [[HadoopFsRelationProvider]]. + */ +class CommitFailureTestSource extends HadoopFsRelationProvider { + override def createRelation( + sqlContext: SQLContext, + paths: Array[String], + schema: Option[StructType], + partitionColumns: Option[StructType], + parameters: Map[String, String]): HadoopFsRelation = { + new CommitFailureTestRelation(paths, schema, partitionColumns, parameters)(sqlContext) + } +} + +class CommitFailureTestRelation( + override val paths: Array[String], + maybeDataSchema: Option[StructType], + override val userDefinedPartitionColumns: Option[StructType], + parameters: Map[String, String])( + @transient sqlContext: SQLContext) + extends SimpleTextRelation( + paths, maybeDataSchema, userDefinedPartitionColumns, parameters)(sqlContext) { + override def prepareJobForWrite(job: Job): OutputWriterFactory = new OutputWriterFactory { + override def newInstance( + path: String, + dataSchema: StructType, + context: TaskAttemptContext): OutputWriter = { + new SimpleTextOutputWriter(path, context) { + override def close(): Unit = { + super.close() + sys.error("Intentional task commitment failure for testing purpose.") + } + } + } + } +} diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/sources/hadoopFsRelationSuites.scala b/sql/hive/src/test/scala/org/apache/spark/sql/sources/hadoopFsRelationSuites.scala index cf6afd25ae5a..afecf9675e11 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/sources/hadoopFsRelationSuites.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/sources/hadoopFsRelationSuites.scala @@ -17,24 +17,31 @@ package org.apache.spark.sql.sources +import scala.collection.JavaConversions._ + +import java.io.File + +import com.google.common.io.Files +import org.apache.hadoop.conf.Configuration import org.apache.hadoop.fs.Path +import org.apache.hadoop.mapreduce.{JobContext, TaskAttemptContext} +import org.apache.hadoop.mapreduce.lib.output.FileOutputCommitter +import org.apache.parquet.hadoop.ParquetOutputCommitter +import org.apache.spark.{SparkException, SparkFunSuite} import org.apache.spark.deploy.SparkHadoopUtil import org.apache.spark.sql._ import org.apache.spark.sql.hive.test.TestHive -import org.apache.spark.sql.parquet.ParquetTest +import org.apache.spark.sql.test.SQLTestUtils import org.apache.spark.sql.types._ -// TODO Don't extend ParquetTest -// This test suite extends ParquetTest for some convenient utility methods. These methods should be -// moved to some more general places, maybe QueryTest. -class HadoopFsRelationTest extends QueryTest with ParquetTest { - override val sqlContext: SQLContext = TestHive +abstract class HadoopFsRelationTest extends QueryTest with SQLTestUtils { + override lazy val sqlContext: SQLContext = TestHive - import sqlContext._ + import sqlContext.sql import sqlContext.implicits._ - val dataSourceName = classOf[SimpleTextSource].getCanonicalName + val dataSourceName: String val dataSchema = StructType( @@ -42,19 +49,19 @@ class HadoopFsRelationTest extends QueryTest with ParquetTest { StructField("a", IntegerType, nullable = false), StructField("b", StringType, nullable = false))) - val testDF = (1 to 3).map(i => (i, s"val_$i")).toDF("a", "b") + lazy val testDF = (1 to 3).map(i => (i, s"val_$i")).toDF("a", "b") - val partitionedTestDF1 = (for { + lazy val partitionedTestDF1 = (for { i <- 1 to 3 p2 <- Seq("foo", "bar") } yield (i, s"val_$i", 1, p2)).toDF("a", "b", "p1", "p2") - val partitionedTestDF2 = (for { + lazy val partitionedTestDF2 = (for { i <- 1 to 3 p2 <- Seq("foo", "bar") } yield (i, s"val_$i", 2, p2)).toDF("a", "b", "p1", "p2") - val partitionedTestDF = partitionedTestDF1.unionAll(partitionedTestDF2) + lazy val partitionedTestDF = partitionedTestDF1.unionAll(partitionedTestDF2) def checkQueries(df: DataFrame): Unit = { // Selects everything @@ -77,6 +84,12 @@ class HadoopFsRelationTest extends QueryTest with ParquetTest { df.filter('a > 1 && 'p1 < 2).select('b, 'p1), for (i <- 2 to 3; _ <- Seq("foo", "bar")) yield Row(s"val_$i", 1)) + // Project many copies of columns with different types (reproduction for SPARK-7858) + checkAnswer( + df.filter('a > 1 && 'p1 < 2).select('b, 'b, 'b, 'b, 'p1, 'p1, 'p1, 'p1), + for (i <- 2 to 3; _ <- Seq("foo", "bar")) + yield Row(s"val_$i", s"val_$i", s"val_$i", s"val_$i", 1, 1, 1, 1)) + // Self-join df.registerTempTable("t") withTempTable("t") { @@ -92,44 +105,27 @@ class HadoopFsRelationTest extends QueryTest with ParquetTest { test("save()/load() - non-partitioned table - Overwrite") { withTempPath { file => - testDF.save( - path = file.getCanonicalPath, - source = dataSourceName, - mode = SaveMode.Overwrite) - - testDF.save( - path = file.getCanonicalPath, - source = dataSourceName, - mode = SaveMode.Overwrite) + testDF.write.mode(SaveMode.Overwrite).format(dataSourceName).save(file.getCanonicalPath) + testDF.write.mode(SaveMode.Overwrite).format(dataSourceName).save(file.getCanonicalPath) checkAnswer( - load( - source = dataSourceName, - options = Map( - "path" -> file.getCanonicalPath, - "dataSchema" -> dataSchema.json)), + sqlContext.read.format(dataSourceName) + .option("path", file.getCanonicalPath) + .option("dataSchema", dataSchema.json) + .load(), testDF.collect()) } } test("save()/load() - non-partitioned table - Append") { withTempPath { file => - testDF.save( - path = file.getCanonicalPath, - source = dataSourceName, - mode = SaveMode.Overwrite) - - testDF.save( - path = file.getCanonicalPath, - source = dataSourceName, - mode = SaveMode.Append) + testDF.write.mode(SaveMode.Overwrite).format(dataSourceName).save(file.getCanonicalPath) + testDF.write.mode(SaveMode.Append).format(dataSourceName).save(file.getCanonicalPath) checkAnswer( - load( - source = dataSourceName, - options = Map( - "path" -> file.getCanonicalPath, - "dataSchema" -> dataSchema.json)).orderBy("a"), + sqlContext.read.format(dataSourceName) + .option("dataSchema", dataSchema.json) + .load(file.getCanonicalPath).orderBy("a"), testDF.unionAll(testDF).orderBy("a").collect()) } } @@ -137,20 +133,14 @@ class HadoopFsRelationTest extends QueryTest with ParquetTest { test("save()/load() - non-partitioned table - ErrorIfExists") { withTempDir { file => intercept[RuntimeException] { - testDF.save( - path = file.getCanonicalPath, - source = dataSourceName, - mode = SaveMode.ErrorIfExists) + testDF.write.format(dataSourceName).mode(SaveMode.ErrorIfExists).save(file.getCanonicalPath) } } } test("save()/load() - non-partitioned table - Ignore") { withTempDir { file => - testDF.save( - path = file.getCanonicalPath, - source = dataSourceName, - mode = SaveMode.Ignore) + testDF.write.mode(SaveMode.Ignore).format(dataSourceName).save(file.getCanonicalPath) val path = new Path(file.getCanonicalPath) val fs = path.getFileSystem(sqlContext.sparkContext.hadoopConfiguration) @@ -160,89 +150,81 @@ class HadoopFsRelationTest extends QueryTest with ParquetTest { test("save()/load() - partitioned table - simple queries") { withTempPath { file => - partitionedTestDF.save( - source = dataSourceName, - mode = SaveMode.ErrorIfExists, - options = Map("path" -> file.getCanonicalPath), - partitionColumns = Seq("p1", "p2")) + partitionedTestDF.write + .format(dataSourceName) + .mode(SaveMode.ErrorIfExists) + .partitionBy("p1", "p2") + .save(file.getCanonicalPath) checkQueries( - load( - source = dataSourceName, - options = Map( - "path" -> file.getCanonicalPath, - "dataSchema" -> dataSchema.json))) + sqlContext.read.format(dataSourceName) + .option("dataSchema", dataSchema.json) + .load(file.getCanonicalPath)) } } test("save()/load() - partitioned table - Overwrite") { withTempPath { file => - partitionedTestDF.save( - source = dataSourceName, - mode = SaveMode.Overwrite, - options = Map("path" -> file.getCanonicalPath), - partitionColumns = Seq("p1", "p2")) - - partitionedTestDF.save( - source = dataSourceName, - mode = SaveMode.Overwrite, - options = Map("path" -> file.getCanonicalPath), - partitionColumns = Seq("p1", "p2")) + partitionedTestDF.write + .format(dataSourceName) + .mode(SaveMode.Overwrite) + .partitionBy("p1", "p2") + .save(file.getCanonicalPath) + + partitionedTestDF.write + .format(dataSourceName) + .mode(SaveMode.Overwrite) + .partitionBy("p1", "p2") + .save(file.getCanonicalPath) checkAnswer( - load( - source = dataSourceName, - options = Map( - "path" -> file.getCanonicalPath, - "dataSchema" -> dataSchema.json)), + sqlContext.read.format(dataSourceName) + .option("dataSchema", dataSchema.json) + .load(file.getCanonicalPath), partitionedTestDF.collect()) } } test("save()/load() - partitioned table - Append") { withTempPath { file => - partitionedTestDF.save( - source = dataSourceName, - mode = SaveMode.Overwrite, - options = Map("path" -> file.getCanonicalPath), - partitionColumns = Seq("p1", "p2")) - - partitionedTestDF.save( - source = dataSourceName, - mode = SaveMode.Append, - options = Map("path" -> file.getCanonicalPath), - partitionColumns = Seq("p1", "p2")) + partitionedTestDF.write + .format(dataSourceName) + .mode(SaveMode.Overwrite) + .partitionBy("p1", "p2") + .save(file.getCanonicalPath) + + partitionedTestDF.write + .format(dataSourceName) + .mode(SaveMode.Append) + .partitionBy("p1", "p2") + .save(file.getCanonicalPath) checkAnswer( - load( - source = dataSourceName, - options = Map( - "path" -> file.getCanonicalPath, - "dataSchema" -> dataSchema.json)), + sqlContext.read.format(dataSourceName) + .option("dataSchema", dataSchema.json) + .load(file.getCanonicalPath), partitionedTestDF.unionAll(partitionedTestDF).collect()) } } test("save()/load() - partitioned table - Append - new partition values") { withTempPath { file => - partitionedTestDF1.save( - source = dataSourceName, - mode = SaveMode.Overwrite, - options = Map("path" -> file.getCanonicalPath), - partitionColumns = Seq("p1", "p2")) - - partitionedTestDF2.save( - source = dataSourceName, - mode = SaveMode.Append, - options = Map("path" -> file.getCanonicalPath), - partitionColumns = Seq("p1", "p2")) + partitionedTestDF1.write + .format(dataSourceName) + .mode(SaveMode.Overwrite) + .partitionBy("p1", "p2") + .save(file.getCanonicalPath) + + partitionedTestDF2.write + .format(dataSourceName) + .mode(SaveMode.Append) + .partitionBy("p1", "p2") + .save(file.getCanonicalPath) checkAnswer( - load( - source = dataSourceName, - options = Map( - "path" -> file.getCanonicalPath, - "dataSchema" -> dataSchema.json)), + sqlContext.read.format(dataSourceName) + .option("dataSchema", dataSchema.json) + .load(file.getCanonicalPath), partitionedTestDF.collect()) } } @@ -250,21 +232,19 @@ class HadoopFsRelationTest extends QueryTest with ParquetTest { test("save()/load() - partitioned table - ErrorIfExists") { withTempDir { file => intercept[RuntimeException] { - partitionedTestDF.save( - source = dataSourceName, - mode = SaveMode.ErrorIfExists, - options = Map("path" -> file.getCanonicalPath), - partitionColumns = Seq("p1", "p2")) + partitionedTestDF.write + .format(dataSourceName) + .mode(SaveMode.ErrorIfExists) + .partitionBy("p1", "p2") + .save(file.getCanonicalPath) } } } test("save()/load() - partitioned table - Ignore") { withTempDir { file => - partitionedTestDF.save( - path = file.getCanonicalPath, - source = dataSourceName, - mode = SaveMode.Ignore) + partitionedTestDF.write + .format(dataSourceName).mode(SaveMode.Ignore).save(file.getCanonicalPath) val path = new Path(file.getCanonicalPath) val fs = path.getFileSystem(SparkHadoopUtil.get.conf) @@ -272,35 +252,22 @@ class HadoopFsRelationTest extends QueryTest with ParquetTest { } } - def withTable(tableName: String)(f: => Unit): Unit = { - try f finally sql(s"DROP TABLE $tableName") - } - test("saveAsTable()/load() - non-partitioned table - Overwrite") { - testDF.saveAsTable( - tableName = "t", - source = dataSourceName, - mode = SaveMode.Overwrite, - Map("dataSchema" -> dataSchema.json)) + testDF.write.format(dataSourceName).mode(SaveMode.Overwrite) + .option("dataSchema", dataSchema.json) + .saveAsTable("t") withTable("t") { - checkAnswer(table("t"), testDF.collect()) + checkAnswer(sqlContext.table("t"), testDF.collect()) } } test("saveAsTable()/load() - non-partitioned table - Append") { - testDF.saveAsTable( - tableName = "t", - source = dataSourceName, - mode = SaveMode.Overwrite) - - testDF.saveAsTable( - tableName = "t", - source = dataSourceName, - mode = SaveMode.Append) + testDF.write.format(dataSourceName).mode(SaveMode.Overwrite).saveAsTable("t") + testDF.write.format(dataSourceName).mode(SaveMode.Append).saveAsTable("t") withTable("t") { - checkAnswer(table("t"), testDF.unionAll(testDF).orderBy("a").collect()) + checkAnswer(sqlContext.table("t"), testDF.unionAll(testDF).orderBy("a").collect()) } } @@ -309,10 +276,7 @@ class HadoopFsRelationTest extends QueryTest with ParquetTest { withTempTable("t") { intercept[AnalysisException] { - testDF.saveAsTable( - tableName = "t", - source = dataSourceName, - mode = SaveMode.ErrorIfExists) + testDF.write.format(dataSourceName).mode(SaveMode.ErrorIfExists).saveAsTable("t") } } } @@ -321,113 +285,98 @@ class HadoopFsRelationTest extends QueryTest with ParquetTest { Seq.empty[(Int, String)].toDF().registerTempTable("t") withTempTable("t") { - testDF.saveAsTable( - tableName = "t", - source = dataSourceName, - mode = SaveMode.Ignore) - - assert(table("t").collect().isEmpty) + testDF.write.format(dataSourceName).mode(SaveMode.Ignore).saveAsTable("t") + assert(sqlContext.table("t").collect().isEmpty) } } test("saveAsTable()/load() - partitioned table - simple queries") { - partitionedTestDF.saveAsTable( - tableName = "t", - source = dataSourceName, - mode = SaveMode.Overwrite, - Map("dataSchema" -> dataSchema.json)) + partitionedTestDF.write.format(dataSourceName) + .mode(SaveMode.Overwrite) + .option("dataSchema", dataSchema.json) + .saveAsTable("t") withTable("t") { - checkQueries(table("t")) + checkQueries(sqlContext.table("t")) } } test("saveAsTable()/load() - partitioned table - Overwrite") { - partitionedTestDF.saveAsTable( - tableName = "t", - source = dataSourceName, - mode = SaveMode.Overwrite, - options = Map("dataSchema" -> dataSchema.json), - partitionColumns = Seq("p1", "p2")) - - partitionedTestDF.saveAsTable( - tableName = "t", - source = dataSourceName, - mode = SaveMode.Overwrite, - options = Map("dataSchema" -> dataSchema.json), - partitionColumns = Seq("p1", "p2")) + partitionedTestDF.write + .format(dataSourceName) + .mode(SaveMode.Overwrite) + .option("dataSchema", dataSchema.json) + .partitionBy("p1", "p2") + .saveAsTable("t") + + partitionedTestDF.write + .format(dataSourceName) + .mode(SaveMode.Overwrite) + .option("dataSchema", dataSchema.json) + .partitionBy("p1", "p2") + .saveAsTable("t") withTable("t") { - checkAnswer(table("t"), partitionedTestDF.collect()) + checkAnswer(sqlContext.table("t"), partitionedTestDF.collect()) } } test("saveAsTable()/load() - partitioned table - Append") { - partitionedTestDF.saveAsTable( - tableName = "t", - source = dataSourceName, - mode = SaveMode.Overwrite, - options = Map("dataSchema" -> dataSchema.json), - partitionColumns = Seq("p1", "p2")) - - partitionedTestDF.saveAsTable( - tableName = "t", - source = dataSourceName, - mode = SaveMode.Append, - options = Map("dataSchema" -> dataSchema.json), - partitionColumns = Seq("p1", "p2")) + partitionedTestDF.write + .format(dataSourceName) + .mode(SaveMode.Overwrite) + .option("dataSchema", dataSchema.json) + .partitionBy("p1", "p2") + .saveAsTable("t") + + partitionedTestDF.write + .format(dataSourceName) + .mode(SaveMode.Append) + .option("dataSchema", dataSchema.json) + .partitionBy("p1", "p2") + .saveAsTable("t") withTable("t") { - checkAnswer(table("t"), partitionedTestDF.unionAll(partitionedTestDF).collect()) + checkAnswer(sqlContext.table("t"), partitionedTestDF.unionAll(partitionedTestDF).collect()) } } test("saveAsTable()/load() - partitioned table - Append - new partition values") { - partitionedTestDF1.saveAsTable( - tableName = "t", - source = dataSourceName, - mode = SaveMode.Overwrite, - options = Map("dataSchema" -> dataSchema.json), - partitionColumns = Seq("p1", "p2")) - - partitionedTestDF2.saveAsTable( - tableName = "t", - source = dataSourceName, - mode = SaveMode.Append, - options = Map("dataSchema" -> dataSchema.json), - partitionColumns = Seq("p1", "p2")) + partitionedTestDF1.write + .format(dataSourceName) + .mode(SaveMode.Overwrite) + .option("dataSchema", dataSchema.json) + .partitionBy("p1", "p2") + .saveAsTable("t") + + partitionedTestDF2.write + .format(dataSourceName) + .mode(SaveMode.Append) + .option("dataSchema", dataSchema.json) + .partitionBy("p1", "p2") + .saveAsTable("t") withTable("t") { - checkAnswer(table("t"), partitionedTestDF.collect()) + checkAnswer(sqlContext.table("t"), partitionedTestDF.collect()) } } test("saveAsTable()/load() - partitioned table - Append - mismatched partition columns") { - partitionedTestDF1.saveAsTable( - tableName = "t", - source = dataSourceName, - mode = SaveMode.Overwrite, - options = Map("dataSchema" -> dataSchema.json), - partitionColumns = Seq("p1", "p2")) + partitionedTestDF1.write + .format(dataSourceName) + .mode(SaveMode.Overwrite) + .option("dataSchema", dataSchema.json) + .partitionBy("p1", "p2") + .saveAsTable("t") // Using only a subset of all partition columns intercept[Throwable] { - partitionedTestDF2.saveAsTable( - tableName = "t", - source = dataSourceName, - mode = SaveMode.Append, - options = Map("dataSchema" -> dataSchema.json), - partitionColumns = Seq("p1")) - } - - // Using different order of partition columns - intercept[Throwable] { - partitionedTestDF2.saveAsTable( - tableName = "t", - source = dataSourceName, - mode = SaveMode.Append, - options = Map("dataSchema" -> dataSchema.json), - partitionColumns = Seq("p2", "p1")) + partitionedTestDF2.write + .format(dataSourceName) + .mode(SaveMode.Append) + .option("dataSchema", dataSchema.json) + .partitionBy("p1") + .saveAsTable("t") } } @@ -436,12 +385,12 @@ class HadoopFsRelationTest extends QueryTest with ParquetTest { withTempTable("t") { intercept[AnalysisException] { - partitionedTestDF.saveAsTable( - tableName = "t", - source = dataSourceName, - mode = SaveMode.ErrorIfExists, - options = Map("dataSchema" -> dataSchema.json), - partitionColumns = Seq("p1", "p2")) + partitionedTestDF.write + .format(dataSourceName) + .mode(SaveMode.ErrorIfExists) + .option("dataSchema", dataSchema.json) + .partitionBy("p1", "p2") + .saveAsTable("t") } } } @@ -450,30 +399,29 @@ class HadoopFsRelationTest extends QueryTest with ParquetTest { Seq.empty[(Int, String)].toDF().registerTempTable("t") withTempTable("t") { - partitionedTestDF.saveAsTable( - tableName = "t", - source = dataSourceName, - mode = SaveMode.Ignore, - options = Map("dataSchema" -> dataSchema.json), - partitionColumns = Seq("p1", "p2")) - - assert(table("t").collect().isEmpty) + partitionedTestDF.write + .format(dataSourceName) + .mode(SaveMode.Ignore) + .option("dataSchema", dataSchema.json) + .partitionBy("p1", "p2") + .saveAsTable("t") + + assert(sqlContext.table("t").collect().isEmpty) } } test("Hadoop style globbing") { withTempPath { file => - partitionedTestDF.save( - source = dataSourceName, - mode = SaveMode.Overwrite, - options = Map("path" -> file.getCanonicalPath), - partitionColumns = Seq("p1", "p2")) - - val df = load( - source = dataSourceName, - options = Map( - "path" -> s"${file.getCanonicalPath}/p1=*/p2=???", - "dataSchema" -> dataSchema.json)) + partitionedTestDF.write + .format(dataSourceName) + .mode(SaveMode.Overwrite) + .partitionBy("p1", "p2") + .save(file.getCanonicalPath) + + val df = sqlContext.read + .format(dataSourceName) + .option("dataSchema", dataSchema.json) + .load(s"${file.getCanonicalPath}/p1=*/p2=???") val expectedPaths = Set( s"${file.getCanonicalFile}/p1=1/p2=foo", @@ -497,6 +445,139 @@ class HadoopFsRelationTest extends QueryTest with ParquetTest { checkAnswer(df, partitionedTestDF.collect()) } } + + test("Partition column type casting") { + withTempPath { file => + val input = partitionedTestDF.select('a, 'b, 'p1.cast(StringType).as('ps), 'p2) + + input + .write + .format(dataSourceName) + .mode(SaveMode.Overwrite) + .partitionBy("ps", "p2") + .saveAsTable("t") + + withTempTable("t") { + checkAnswer(sqlContext.table("t"), input.collect()) + } + } + } + + test("SPARK-7616: adjust column name order accordingly when saving partitioned table") { + val df = (1 to 3).map(i => (i, s"val_$i", i * 2)).toDF("a", "b", "c") + + df.write + .format(dataSourceName) + .mode(SaveMode.Overwrite) + .partitionBy("c", "a") + .saveAsTable("t") + + withTable("t") { + checkAnswer(sqlContext.table("t"), df.select('b, 'c, 'a).collect()) + } + } + + // NOTE: This test suite is not super deterministic. On nodes with only relatively few cores + // (4 or even 1), it's hard to reproduce the data loss issue. But on nodes with for example 8 or + // more cores, the issue can be reproduced steadily. Fortunately our Jenkins builder meets this + // requirement. We probably want to move this test case to spark-integration-tests or spark-perf + // later. + test("SPARK-8406: Avoids name collision while writing files") { + withTempPath { dir => + val path = dir.getCanonicalPath + sqlContext + .range(10000) + .repartition(250) + .write + .mode(SaveMode.Overwrite) + .format(dataSourceName) + .save(path) + + assertResult(10000) { + sqlContext + .read + .format(dataSourceName) + .option("dataSchema", StructType(StructField("id", LongType) :: Nil).json) + .load(path) + .count() + } + } + } + + test("SPARK-8578 specified custom output committer will not be used to append data") { + val clonedConf = new Configuration(configuration) + try { + val df = sqlContext.range(1, 10).toDF("i") + withTempPath { dir => + df.write.mode("append").format(dataSourceName).save(dir.getCanonicalPath) + configuration.set( + SQLConf.OUTPUT_COMMITTER_CLASS.key, + classOf[AlwaysFailOutputCommitter].getName) + // Since Parquet has its own output committer setting, also set it + // to AlwaysFailParquetOutputCommitter at here. + configuration.set("spark.sql.parquet.output.committer.class", + classOf[AlwaysFailParquetOutputCommitter].getName) + // Because there data already exists, + // this append should succeed because we will use the output committer associated + // with file format and AlwaysFailOutputCommitter will not be used. + df.write.mode("append").format(dataSourceName).save(dir.getCanonicalPath) + checkAnswer( + sqlContext.read + .format(dataSourceName) + .option("dataSchema", df.schema.json) + .load(dir.getCanonicalPath), + df.unionAll(df)) + + // This will fail because AlwaysFailOutputCommitter is used when we do append. + intercept[Exception] { + df.write.mode("overwrite").format(dataSourceName).save(dir.getCanonicalPath) + } + } + withTempPath { dir => + configuration.set( + SQLConf.OUTPUT_COMMITTER_CLASS.key, + classOf[AlwaysFailOutputCommitter].getName) + // Since Parquet has its own output committer setting, also set it + // to AlwaysFailParquetOutputCommitter at here. + configuration.set("spark.sql.parquet.output.committer.class", + classOf[AlwaysFailParquetOutputCommitter].getName) + // Because there is no existing data, + // this append will fail because AlwaysFailOutputCommitter is used when we do append + // and there is no existing data. + intercept[Exception] { + df.write.mode("append").format(dataSourceName).save(dir.getCanonicalPath) + } + } + } finally { + // Hadoop 1 doesn't have `Configuration.unset` + configuration.clear() + clonedConf.foreach(entry => configuration.set(entry.getKey, entry.getValue)) + } + } +} + +// This class is used to test SPARK-8578. We should not use any custom output committer when +// we actually append data to an existing dir. +class AlwaysFailOutputCommitter( + outputPath: Path, + context: TaskAttemptContext) + extends FileOutputCommitter(outputPath, context) { + + override def commitJob(context: JobContext): Unit = { + sys.error("Intentional job commitment failure for testing purpose.") + } +} + +// This class is used to test SPARK-8578. We should not use any custom output committer when +// we actually append data to an existing dir. +class AlwaysFailParquetOutputCommitter( + outputPath: Path, + context: TaskAttemptContext) + extends ParquetOutputCommitter(outputPath, context) { + + override def commitJob(context: JobContext): Unit = { + sys.error("Intentional job commitment failure for testing purpose.") + } } class SimpleTextHadoopFsRelationSuite extends HadoopFsRelationTest { @@ -521,11 +602,31 @@ class SimpleTextHadoopFsRelationSuite extends HadoopFsRelationTest { StructType(dataSchema.fields :+ StructField("p1", IntegerType, nullable = true)) checkQueries( - load( - source = dataSourceName, - options = Map( - "path" -> file.getCanonicalPath, - "dataSchema" -> dataSchemaWithPartition.json))) + read.format(dataSourceName) + .option("dataSchema", dataSchemaWithPartition.json) + .load(file.getCanonicalPath)) + } + } +} + +class CommitFailureTestRelationSuite extends SparkFunSuite with SQLTestUtils { + override val sqlContext = TestHive + + // When committing a task, `CommitFailureTestSource` throws an exception for testing purpose. + val dataSourceName: String = classOf[CommitFailureTestSource].getCanonicalName + + test("SPARK-7684: commitTask() failure should fallback to abortTask()") { + withTempPath { file => + // Here we coalesce partition number to 1 to ensure that only a single task is issued. This + // prevents race condition happened when FileOutputCommitter tries to remove the `_temporary` + // directory while committing/aborting the job. See SPARK-8513 for more details. + val df = sqlContext.range(0, 10).coalesce(1) + intercept[SparkException] { + df.write.format(dataSourceName).save(file.getCanonicalPath) + } + + val fs = new Path(file.getCanonicalPath).getFileSystem(SparkHadoopUtil.get.conf) + assert(!fs.exists(new Path(file.getCanonicalPath, "_temporary"))) } } } @@ -547,18 +648,96 @@ class ParquetHadoopFsRelationSuite extends HadoopFsRelationTest { sparkContext .parallelize(for (i <- 1 to 3) yield (i, s"val_$i", p1)) .toDF("a", "b", "p1") - .saveAsParquetFile(partitionDir.toString) + .write.parquet(partitionDir.toString) } val dataSchemaWithPartition = StructType(dataSchema.fields :+ StructField("p1", IntegerType, nullable = true)) checkQueries( - load( - source = dataSourceName, - options = Map( - "path" -> file.getCanonicalPath, - "dataSchema" -> dataSchemaWithPartition.json))) + read.format(dataSourceName) + .option("dataSchema", dataSchemaWithPartition.json) + .load(file.getCanonicalPath)) + } + } + + test("SPARK-7868: _temporary directories should be ignored") { + withTempPath { dir => + val df = Seq("a", "b", "c").zipWithIndex.toDF() + + df.write + .format("parquet") + .save(dir.getCanonicalPath) + + df.write + .format("parquet") + .save(s"${dir.getCanonicalPath}/_temporary") + + checkAnswer(read.format("parquet").load(dir.getCanonicalPath), df.collect()) + } + } + + test("SPARK-8014: Avoid scanning output directory when SaveMode isn't SaveMode.Append") { + withTempDir { dir => + val path = dir.getCanonicalPath + val df = Seq(1 -> "a").toDF() + + // Creates an arbitrary file. If this directory gets scanned, ParquetRelation2 will throw + // since it's not a valid Parquet file. + val emptyFile = new File(path, "empty") + Files.createParentDirs(emptyFile) + Files.touch(emptyFile) + + // This shouldn't throw anything. + df.write.format("parquet").mode(SaveMode.Ignore).save(path) + + // This should only complain that the destination directory already exists, rather than file + // "empty" is not a Parquet file. + assert { + intercept[RuntimeException] { + df.write.format("parquet").mode(SaveMode.ErrorIfExists).save(path) + }.getMessage.contains("already exists") + } + + // This shouldn't throw anything. + df.write.format("parquet").mode(SaveMode.Overwrite).save(path) + checkAnswer(read.format("parquet").load(path), df) + } + } + + test("SPARK-8079: Avoid NPE thrown from BaseWriterContainer.abortJob") { + withTempPath { dir => + intercept[AnalysisException] { + // Parquet doesn't allow field names with spaces. Here we are intentionally making an + // exception thrown from the `ParquetRelation2.prepareForWriteJob()` method to trigger + // the bug. Please refer to spark-8079 for more details. + range(1, 10) + .withColumnRenamed("id", "a b") + .write + .format("parquet") + .save(dir.getCanonicalPath) + } + } + } + + test("SPARK-8604: Parquet data source should write summary file while doing appending") { + withTempPath { dir => + val path = dir.getCanonicalPath + val df = sqlContext.range(0, 5) + df.write.mode(SaveMode.Overwrite).parquet(path) + + val summaryPath = new Path(path, "_metadata") + val commonSummaryPath = new Path(path, "_common_metadata") + + val fs = summaryPath.getFileSystem(configuration) + fs.delete(summaryPath, true) + fs.delete(commonSummaryPath, true) + + df.write.mode(SaveMode.Append).parquet(path) + checkAnswer(sqlContext.read.parquet(path), df.unionAll(df)) + + assert(fs.exists(summaryPath)) + assert(fs.exists(commonSummaryPath)) } } } diff --git a/sql/hive/v0.12.0/src/main/scala/org/apache/spark/sql/hive/Shim12.scala b/sql/hive/v0.12.0/src/main/scala/org/apache/spark/sql/hive/Shim12.scala deleted file mode 100644 index 33e96eaabfbf..000000000000 --- a/sql/hive/v0.12.0/src/main/scala/org/apache/spark/sql/hive/Shim12.scala +++ /dev/null @@ -1,265 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.sql.hive - -import java.net.URI -import java.util.{ArrayList => JArrayList, Properties} - -import scala.collection.JavaConversions._ -import scala.language.implicitConversions - -import org.apache.hadoop.{io => hadoopIo} -import org.apache.hadoop.conf.Configuration -import org.apache.hadoop.fs.Path -import org.apache.hadoop.hive.common.`type`.HiveDecimal -import org.apache.hadoop.hive.conf.HiveConf -import org.apache.hadoop.hive.ql.Context -import org.apache.hadoop.hive.ql.metadata.{Hive, Partition, Table} -import org.apache.hadoop.hive.ql.plan.{CreateTableDesc, FileSinkDesc, TableDesc} -import org.apache.hadoop.hive.ql.processors._ -import org.apache.hadoop.hive.ql.stats.StatsSetupConst -import org.apache.hadoop.hive.serde2.{ColumnProjectionUtils, Deserializer, io => hiveIo} -import org.apache.hadoop.hive.serde2.objectinspector.{ObjectInspectorConverters, ObjectInspector, PrimitiveObjectInspector} -import org.apache.hadoop.hive.serde2.objectinspector.PrimitiveObjectInspector.PrimitiveCategory -import org.apache.hadoop.hive.serde2.objectinspector.primitive.{HiveDecimalObjectInspector, PrimitiveObjectInspectorFactory} -import org.apache.hadoop.hive.serde2.typeinfo.{TypeInfo, TypeInfoFactory} -import org.apache.hadoop.io.{NullWritable, Writable} -import org.apache.hadoop.mapred.InputFormat - -import org.apache.spark.sql.types.{UTF8String, Decimal, DecimalType} - -private[hive] case class HiveFunctionWrapper(functionClassName: String) - extends java.io.Serializable { - - // for Serialization - def this() = this(null) - - import org.apache.spark.util.Utils._ - def createFunction[UDFType <: AnyRef](): UDFType = { - getContextOrSparkClassLoader - .loadClass(functionClassName).newInstance.asInstanceOf[UDFType] - } -} - -/** - * A compatibility layer for interacting with Hive version 0.12.0. - */ -private[hive] object HiveShim { - val version = "0.12.0" - - def getTableDesc( - serdeClass: Class[_ <: Deserializer], - inputFormatClass: Class[_ <: InputFormat[_, _]], - outputFormatClass: Class[_], - properties: Properties) = { - new TableDesc(serdeClass, inputFormatClass, outputFormatClass, properties) - } - - def getStringWritableConstantObjectInspector(value: Any): ObjectInspector = - PrimitiveObjectInspectorFactory.getPrimitiveWritableConstantObjectInspector( - PrimitiveCategory.STRING, - getStringWritable(value)) - - def getIntWritableConstantObjectInspector(value: Any): ObjectInspector = - PrimitiveObjectInspectorFactory.getPrimitiveWritableConstantObjectInspector( - PrimitiveCategory.INT, - getIntWritable(value)) - - def getDoubleWritableConstantObjectInspector(value: Any): ObjectInspector = - PrimitiveObjectInspectorFactory.getPrimitiveWritableConstantObjectInspector( - PrimitiveCategory.DOUBLE, - getDoubleWritable(value)) - - def getBooleanWritableConstantObjectInspector(value: Any): ObjectInspector = - PrimitiveObjectInspectorFactory.getPrimitiveWritableConstantObjectInspector( - PrimitiveCategory.BOOLEAN, - getBooleanWritable(value)) - - def getLongWritableConstantObjectInspector(value: Any): ObjectInspector = - PrimitiveObjectInspectorFactory.getPrimitiveWritableConstantObjectInspector( - PrimitiveCategory.LONG, - getLongWritable(value)) - - def getFloatWritableConstantObjectInspector(value: Any): ObjectInspector = - PrimitiveObjectInspectorFactory.getPrimitiveWritableConstantObjectInspector( - PrimitiveCategory.FLOAT, - getFloatWritable(value)) - - def getShortWritableConstantObjectInspector(value: Any): ObjectInspector = - PrimitiveObjectInspectorFactory.getPrimitiveWritableConstantObjectInspector( - PrimitiveCategory.SHORT, - getShortWritable(value)) - - def getByteWritableConstantObjectInspector(value: Any): ObjectInspector = - PrimitiveObjectInspectorFactory.getPrimitiveWritableConstantObjectInspector( - PrimitiveCategory.BYTE, - getByteWritable(value)) - - def getBinaryWritableConstantObjectInspector(value: Any): ObjectInspector = - PrimitiveObjectInspectorFactory.getPrimitiveWritableConstantObjectInspector( - PrimitiveCategory.BINARY, - getBinaryWritable(value)) - - def getDateWritableConstantObjectInspector(value: Any): ObjectInspector = - PrimitiveObjectInspectorFactory.getPrimitiveWritableConstantObjectInspector( - PrimitiveCategory.DATE, - getDateWritable(value)) - - def getTimestampWritableConstantObjectInspector(value: Any): ObjectInspector = - PrimitiveObjectInspectorFactory.getPrimitiveWritableConstantObjectInspector( - PrimitiveCategory.TIMESTAMP, - getTimestampWritable(value)) - - def getDecimalWritableConstantObjectInspector(value: Any): ObjectInspector = - PrimitiveObjectInspectorFactory.getPrimitiveWritableConstantObjectInspector( - PrimitiveCategory.DECIMAL, - getDecimalWritable(value)) - - def getPrimitiveNullWritableConstantObjectInspector: ObjectInspector = - PrimitiveObjectInspectorFactory.getPrimitiveWritableConstantObjectInspector( - PrimitiveCategory.VOID, null) - - def getStringWritable(value: Any): hadoopIo.Text = - if (value == null) null else new hadoopIo.Text(value.asInstanceOf[UTF8String].toString) - - def getIntWritable(value: Any): hadoopIo.IntWritable = - if (value == null) null else new hadoopIo.IntWritable(value.asInstanceOf[Int]) - - def getDoubleWritable(value: Any): hiveIo.DoubleWritable = - if (value == null) null else new hiveIo.DoubleWritable(value.asInstanceOf[Double]) - - def getBooleanWritable(value: Any): hadoopIo.BooleanWritable = - if (value == null) null else new hadoopIo.BooleanWritable(value.asInstanceOf[Boolean]) - - def getLongWritable(value: Any): hadoopIo.LongWritable = - if (value == null) null else new hadoopIo.LongWritable(value.asInstanceOf[Long]) - - def getFloatWritable(value: Any): hadoopIo.FloatWritable = - if (value == null) null else new hadoopIo.FloatWritable(value.asInstanceOf[Float]) - - def getShortWritable(value: Any): hiveIo.ShortWritable = - if (value == null) null else new hiveIo.ShortWritable(value.asInstanceOf[Short]) - - def getByteWritable(value: Any): hiveIo.ByteWritable = - if (value == null) null else new hiveIo.ByteWritable(value.asInstanceOf[Byte]) - - def getBinaryWritable(value: Any): hadoopIo.BytesWritable = - if (value == null) null else new hadoopIo.BytesWritable(value.asInstanceOf[Array[Byte]]) - - def getDateWritable(value: Any): hiveIo.DateWritable = - if (value == null) null else new hiveIo.DateWritable(value.asInstanceOf[Int]) - - def getTimestampWritable(value: Any): hiveIo.TimestampWritable = - if (value == null) { - null - } else { - new hiveIo.TimestampWritable(value.asInstanceOf[java.sql.Timestamp]) - } - - def getDecimalWritable(value: Any): hiveIo.HiveDecimalWritable = - if (value == null) { - null - } else { - new hiveIo.HiveDecimalWritable( - HiveShim.createDecimal(value.asInstanceOf[Decimal].toJavaBigDecimal)) - } - - def getPrimitiveNullWritable: NullWritable = NullWritable.get() - - def createDriverResultsArray = new JArrayList[String] - - def processResults(results: JArrayList[String]) = results - - def getStatsSetupConstTotalSize = StatsSetupConst.TOTAL_SIZE - - def getStatsSetupConstRawDataSize = StatsSetupConst.RAW_DATA_SIZE - - def createDefaultDBIfNeeded(context: HiveContext) = { } - - def getCommandProcessor(cmd: Array[String], conf: HiveConf) = { - CommandProcessorFactory.get(cmd(0), conf) - } - - def createDecimal(bd: java.math.BigDecimal): HiveDecimal = { - new HiveDecimal(bd) - } - - def appendReadColumns(conf: Configuration, ids: Seq[Integer], names: Seq[String]) { - ColumnProjectionUtils.appendReadColumnIDs(conf, ids) - ColumnProjectionUtils.appendReadColumnNames(conf, names) - } - - def getExternalTmpPath(context: Context, uri: URI) = { - context.getExternalTmpFileURI(uri) - } - - def getDataLocationPath(p: Partition) = p.getPartitionPath - - def getAllPartitionsOf(client: Hive, tbl: Table) = client.getAllPartitionsForPruner(tbl) - - def compatibilityBlackList = Seq( - "decimal_.*", - "udf7", - "drop_partitions_filter2", - "show_.*", - "serde_regex", - "udf_to_date", - "udaf_collect_set", - "udf_concat" - ) - - def setLocation(tbl: Table, crtTbl: CreateTableDesc): Unit = { - tbl.setDataLocation(new Path(crtTbl.getLocation()).toUri()) - } - - def decimalMetastoreString(decimalType: DecimalType): String = "decimal" - - def decimalTypeInfo(decimalType: DecimalType): TypeInfo = - TypeInfoFactory.decimalTypeInfo - - def decimalTypeInfoToCatalyst(inspector: PrimitiveObjectInspector): DecimalType = { - DecimalType.Unlimited - } - - def toCatalystDecimal(hdoi: HiveDecimalObjectInspector, data: Any): Decimal = { - if (hdoi.preferWritable()) { - Decimal(hdoi.getPrimitiveWritableObject(data).getHiveDecimal().bigDecimalValue) - } else { - Decimal(hdoi.getPrimitiveJavaObject(data).bigDecimalValue()) - } - } - - def getConvertedOI( - inputOI: ObjectInspector, - outputOI: ObjectInspector): ObjectInspector = { - ObjectInspectorConverters.getConvertedOI(inputOI, outputOI, true) - } - - def prepareWritable(w: Writable): Writable = { - w - } - - def setTblNullFormat(crtTbl: CreateTableDesc, tbl: Table) = {} -} - -private[hive] class ShimFileSinkDesc( - var dir: String, - var tableInfo: TableDesc, - var compressed: Boolean) - extends FileSinkDesc(dir, tableInfo, compressed) { -} diff --git a/sql/hive/v0.13.1/src/main/scala/org/apache/spark/sql/hive/Shim13.scala b/sql/hive/v0.13.1/src/main/scala/org/apache/spark/sql/hive/Shim13.scala deleted file mode 100644 index dbc5e029e204..000000000000 --- a/sql/hive/v0.13.1/src/main/scala/org/apache/spark/sql/hive/Shim13.scala +++ /dev/null @@ -1,457 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.sql.hive - -import java.rmi.server.UID -import java.util.{Properties, ArrayList => JArrayList} -import java.io.{OutputStream, InputStream} - -import scala.collection.JavaConversions._ -import scala.language.implicitConversions -import scala.reflect.ClassTag - -import com.esotericsoftware.kryo.Kryo -import com.esotericsoftware.kryo.io.Input -import com.esotericsoftware.kryo.io.Output -import org.apache.hadoop.conf.Configuration -import org.apache.hadoop.fs.Path -import org.apache.hadoop.hive.common.StatsSetupConst -import org.apache.hadoop.hive.common.`type`.HiveDecimal -import org.apache.hadoop.hive.conf.HiveConf -import org.apache.hadoop.hive.ql.Context -import org.apache.hadoop.hive.ql.exec.{UDF, Utilities} -import org.apache.hadoop.hive.ql.metadata.{Hive, Partition, Table} -import org.apache.hadoop.hive.ql.plan.{CreateTableDesc, FileSinkDesc, TableDesc} -import org.apache.hadoop.hive.ql.processors.CommandProcessorFactory -import org.apache.hadoop.hive.serde.serdeConstants -import org.apache.hadoop.hive.serde2.avro.AvroGenericRecordWritable -import org.apache.hadoop.hive.serde2.objectinspector.primitive.{HiveDecimalObjectInspector, PrimitiveObjectInspectorFactory} -import org.apache.hadoop.hive.serde2.objectinspector.{ObjectInspector, ObjectInspectorConverters, PrimitiveObjectInspector} -import org.apache.hadoop.hive.serde2.typeinfo.{DecimalTypeInfo, TypeInfo, TypeInfoFactory} -import org.apache.hadoop.hive.serde2.{ColumnProjectionUtils, Deserializer, io => hiveIo} -import org.apache.hadoop.io.{NullWritable, Writable} -import org.apache.hadoop.mapred.InputFormat -import org.apache.hadoop.{io => hadoopIo} - -import org.apache.spark.Logging -import org.apache.spark.sql.types.{Decimal, DecimalType, UTF8String} -import org.apache.spark.util.Utils._ - -/** - * This class provides the UDF creation and also the UDF instance serialization and - * de-serialization cross process boundary. - * - * Detail discussion can be found at https://github.com/apache/spark/pull/3640 - * - * @param functionClassName UDF class name - */ -private[hive] case class HiveFunctionWrapper(var functionClassName: String) - extends java.io.Externalizable { - - // for Serialization - def this() = this(null) - - @transient - def deserializeObjectByKryo[T: ClassTag]( - kryo: Kryo, - in: InputStream, - clazz: Class[_]): T = { - val inp = new Input(in) - val t: T = kryo.readObject(inp,clazz).asInstanceOf[T] - inp.close() - t - } - - @transient - def serializeObjectByKryo( - kryo: Kryo, - plan: Object, - out: OutputStream ) { - val output: Output = new Output(out) - kryo.writeObject(output, plan) - output.close() - } - - def deserializePlan[UDFType](is: java.io.InputStream, clazz: Class[_]): UDFType = { - deserializeObjectByKryo(Utilities.runtimeSerializationKryo.get(), is, clazz) - .asInstanceOf[UDFType] - } - - def serializePlan(function: AnyRef, out: java.io.OutputStream): Unit = { - serializeObjectByKryo(Utilities.runtimeSerializationKryo.get(), function, out) - } - - private var instance: AnyRef = null - - def writeExternal(out: java.io.ObjectOutput) { - // output the function name - out.writeUTF(functionClassName) - - // Write a flag if instance is null or not - out.writeBoolean(instance != null) - if (instance != null) { - // Some of the UDF are serializable, but some others are not - // Hive Utilities can handle both cases - val baos = new java.io.ByteArrayOutputStream() - serializePlan(instance, baos) - val functionInBytes = baos.toByteArray - - // output the function bytes - out.writeInt(functionInBytes.length) - out.write(functionInBytes, 0, functionInBytes.length) - } - } - - def readExternal(in: java.io.ObjectInput) { - // read the function name - functionClassName = in.readUTF() - - if (in.readBoolean()) { - // if the instance is not null - // read the function in bytes - val functionInBytesLength = in.readInt() - val functionInBytes = new Array[Byte](functionInBytesLength) - in.read(functionInBytes, 0, functionInBytesLength) - - // deserialize the function object via Hive Utilities - instance = deserializePlan[AnyRef](new java.io.ByteArrayInputStream(functionInBytes), - getContextOrSparkClassLoader.loadClass(functionClassName)) - } - } - - def createFunction[UDFType <: AnyRef](): UDFType = { - if (instance != null) { - instance.asInstanceOf[UDFType] - } else { - val func = getContextOrSparkClassLoader - .loadClass(functionClassName).newInstance.asInstanceOf[UDFType] - if (!func.isInstanceOf[UDF]) { - // We cache the function if it's no the Simple UDF, - // as we always have to create new instance for Simple UDF - instance = func - } - func - } - } -} - -/** - * A compatibility layer for interacting with Hive version 0.13.1. - */ -private[hive] object HiveShim { - val version = "0.13.1" - - def getTableDesc( - serdeClass: Class[_ <: Deserializer], - inputFormatClass: Class[_ <: InputFormat[_, _]], - outputFormatClass: Class[_], - properties: Properties) = { - new TableDesc(inputFormatClass, outputFormatClass, properties) - } - - - def getStringWritableConstantObjectInspector(value: Any): ObjectInspector = - PrimitiveObjectInspectorFactory.getPrimitiveWritableConstantObjectInspector( - TypeInfoFactory.stringTypeInfo, getStringWritable(value)) - - def getIntWritableConstantObjectInspector(value: Any): ObjectInspector = - PrimitiveObjectInspectorFactory.getPrimitiveWritableConstantObjectInspector( - TypeInfoFactory.intTypeInfo, getIntWritable(value)) - - def getDoubleWritableConstantObjectInspector(value: Any): ObjectInspector = - PrimitiveObjectInspectorFactory.getPrimitiveWritableConstantObjectInspector( - TypeInfoFactory.doubleTypeInfo, getDoubleWritable(value)) - - def getBooleanWritableConstantObjectInspector(value: Any): ObjectInspector = - PrimitiveObjectInspectorFactory.getPrimitiveWritableConstantObjectInspector( - TypeInfoFactory.booleanTypeInfo, getBooleanWritable(value)) - - def getLongWritableConstantObjectInspector(value: Any): ObjectInspector = - PrimitiveObjectInspectorFactory.getPrimitiveWritableConstantObjectInspector( - TypeInfoFactory.longTypeInfo, getLongWritable(value)) - - def getFloatWritableConstantObjectInspector(value: Any): ObjectInspector = - PrimitiveObjectInspectorFactory.getPrimitiveWritableConstantObjectInspector( - TypeInfoFactory.floatTypeInfo, getFloatWritable(value)) - - def getShortWritableConstantObjectInspector(value: Any): ObjectInspector = - PrimitiveObjectInspectorFactory.getPrimitiveWritableConstantObjectInspector( - TypeInfoFactory.shortTypeInfo, getShortWritable(value)) - - def getByteWritableConstantObjectInspector(value: Any): ObjectInspector = - PrimitiveObjectInspectorFactory.getPrimitiveWritableConstantObjectInspector( - TypeInfoFactory.byteTypeInfo, getByteWritable(value)) - - def getBinaryWritableConstantObjectInspector(value: Any): ObjectInspector = - PrimitiveObjectInspectorFactory.getPrimitiveWritableConstantObjectInspector( - TypeInfoFactory.binaryTypeInfo, getBinaryWritable(value)) - - def getDateWritableConstantObjectInspector(value: Any): ObjectInspector = - PrimitiveObjectInspectorFactory.getPrimitiveWritableConstantObjectInspector( - TypeInfoFactory.dateTypeInfo, getDateWritable(value)) - - def getTimestampWritableConstantObjectInspector(value: Any): ObjectInspector = - PrimitiveObjectInspectorFactory.getPrimitiveWritableConstantObjectInspector( - TypeInfoFactory.timestampTypeInfo, getTimestampWritable(value)) - - def getDecimalWritableConstantObjectInspector(value: Any): ObjectInspector = - PrimitiveObjectInspectorFactory.getPrimitiveWritableConstantObjectInspector( - TypeInfoFactory.decimalTypeInfo, getDecimalWritable(value)) - - def getPrimitiveNullWritableConstantObjectInspector: ObjectInspector = - PrimitiveObjectInspectorFactory.getPrimitiveWritableConstantObjectInspector( - TypeInfoFactory.voidTypeInfo, null) - - def getStringWritable(value: Any): hadoopIo.Text = - if (value == null) null else new hadoopIo.Text(value.asInstanceOf[UTF8String].toString) - - def getIntWritable(value: Any): hadoopIo.IntWritable = - if (value == null) null else new hadoopIo.IntWritable(value.asInstanceOf[Int]) - - def getDoubleWritable(value: Any): hiveIo.DoubleWritable = - if (value == null) { - null - } else { - new hiveIo.DoubleWritable(value.asInstanceOf[Double]) - } - - def getBooleanWritable(value: Any): hadoopIo.BooleanWritable = - if (value == null) { - null - } else { - new hadoopIo.BooleanWritable(value.asInstanceOf[Boolean]) - } - - def getLongWritable(value: Any): hadoopIo.LongWritable = - if (value == null) null else new hadoopIo.LongWritable(value.asInstanceOf[Long]) - - def getFloatWritable(value: Any): hadoopIo.FloatWritable = - if (value == null) { - null - } else { - new hadoopIo.FloatWritable(value.asInstanceOf[Float]) - } - - def getShortWritable(value: Any): hiveIo.ShortWritable = - if (value == null) null else new hiveIo.ShortWritable(value.asInstanceOf[Short]) - - def getByteWritable(value: Any): hiveIo.ByteWritable = - if (value == null) null else new hiveIo.ByteWritable(value.asInstanceOf[Byte]) - - def getBinaryWritable(value: Any): hadoopIo.BytesWritable = - if (value == null) { - null - } else { - new hadoopIo.BytesWritable(value.asInstanceOf[Array[Byte]]) - } - - def getDateWritable(value: Any): hiveIo.DateWritable = - if (value == null) null else new hiveIo.DateWritable(value.asInstanceOf[Int]) - - def getTimestampWritable(value: Any): hiveIo.TimestampWritable = - if (value == null) { - null - } else { - new hiveIo.TimestampWritable(value.asInstanceOf[java.sql.Timestamp]) - } - - def getDecimalWritable(value: Any): hiveIo.HiveDecimalWritable = - if (value == null) { - null - } else { - // TODO precise, scale? - new hiveIo.HiveDecimalWritable( - HiveShim.createDecimal(value.asInstanceOf[Decimal].toJavaBigDecimal)) - } - - def getPrimitiveNullWritable: NullWritable = NullWritable.get() - - def createDriverResultsArray = new JArrayList[Object] - - def processResults(results: JArrayList[Object]) = { - results.map { r => - r match { - case s: String => s - case a: Array[Object] => a(0).asInstanceOf[String] - } - } - } - - def getStatsSetupConstTotalSize = StatsSetupConst.TOTAL_SIZE - - def getStatsSetupConstRawDataSize = StatsSetupConst.RAW_DATA_SIZE - - def createDefaultDBIfNeeded(context: HiveContext) = { - context.runSqlHive("CREATE DATABASE default") - context.runSqlHive("USE default") - } - - def getCommandProcessor(cmd: Array[String], conf: HiveConf) = { - CommandProcessorFactory.get(cmd, conf) - } - - def createDecimal(bd: java.math.BigDecimal): HiveDecimal = { - HiveDecimal.create(bd) - } - - /* - * This function in hive-0.13 become private, but we have to do this to walkaround hive bug - */ - private def appendReadColumnNames(conf: Configuration, cols: Seq[String]) { - val old: String = conf.get(ColumnProjectionUtils.READ_COLUMN_NAMES_CONF_STR, "") - val result: StringBuilder = new StringBuilder(old) - var first: Boolean = old.isEmpty - - for (col <- cols) { - if (first) { - first = false - } else { - result.append(',') - } - result.append(col) - } - conf.set(ColumnProjectionUtils.READ_COLUMN_NAMES_CONF_STR, result.toString) - } - - /* - * Cannot use ColumnProjectionUtils.appendReadColumns directly, if ids is null or empty - */ - def appendReadColumns(conf: Configuration, ids: Seq[Integer], names: Seq[String]) { - if (ids != null && ids.size > 0) { - ColumnProjectionUtils.appendReadColumns(conf, ids) - } - if (names != null && names.size > 0) { - appendReadColumnNames(conf, names) - } - } - - def getExternalTmpPath(context: Context, path: Path) = { - context.getExternalTmpPath(path.toUri) - } - - def getDataLocationPath(p: Partition) = p.getDataLocation - - def getAllPartitionsOf(client: Hive, tbl: Table) = client.getAllPartitionsOf(tbl) - - def compatibilityBlackList = Seq() - - def setLocation(tbl: Table, crtTbl: CreateTableDesc): Unit = { - tbl.setDataLocation(new Path(crtTbl.getLocation())) - } - - /* - * Bug introdiced in hive-0.13. FileSinkDesc is serializable, but its member path is not. - * Fix it through wrapper. - * */ - implicit def wrapperToFileSinkDesc(w: ShimFileSinkDesc): FileSinkDesc = { - var f = new FileSinkDesc(new Path(w.dir), w.tableInfo, w.compressed) - f.setCompressCodec(w.compressCodec) - f.setCompressType(w.compressType) - f.setTableInfo(w.tableInfo) - f.setDestTableId(w.destTableId) - f - } - - // Precision and scale to pass for unlimited decimals; these are the same as the precision and - // scale Hive 0.13 infers for BigDecimals from sources that don't specify them (e.g. UDFs) - private val UNLIMITED_DECIMAL_PRECISION = 38 - private val UNLIMITED_DECIMAL_SCALE = 18 - - def decimalMetastoreString(decimalType: DecimalType): String = decimalType match { - case DecimalType.Fixed(precision, scale) => s"decimal($precision,$scale)" - case _ => s"decimal($UNLIMITED_DECIMAL_PRECISION,$UNLIMITED_DECIMAL_SCALE)" - } - - def decimalTypeInfo(decimalType: DecimalType): TypeInfo = decimalType match { - case DecimalType.Fixed(precision, scale) => new DecimalTypeInfo(precision, scale) - case _ => new DecimalTypeInfo(UNLIMITED_DECIMAL_PRECISION, UNLIMITED_DECIMAL_SCALE) - } - - def decimalTypeInfoToCatalyst(inspector: PrimitiveObjectInspector): DecimalType = { - val info = inspector.getTypeInfo.asInstanceOf[DecimalTypeInfo] - DecimalType(info.precision(), info.scale()) - } - - def toCatalystDecimal(hdoi: HiveDecimalObjectInspector, data: Any): Decimal = { - if (hdoi.preferWritable()) { - Decimal(hdoi.getPrimitiveWritableObject(data).getHiveDecimal().bigDecimalValue, - hdoi.precision(), hdoi.scale()) - } else { - Decimal(hdoi.getPrimitiveJavaObject(data).bigDecimalValue(), hdoi.precision(), hdoi.scale()) - } - } - - def getConvertedOI(inputOI: ObjectInspector, outputOI: ObjectInspector): ObjectInspector = { - ObjectInspectorConverters.getConvertedOI(inputOI, outputOI) - } - - /* - * Bug introduced in hive-0.13. AvroGenericRecordWritable has a member recordReaderID that - * is needed to initialize before serialization. - */ - def prepareWritable(w: Writable): Writable = { - w match { - case w: AvroGenericRecordWritable => - w.setRecordReaderID(new UID()) - case _ => - } - w - } - - def setTblNullFormat(crtTbl: CreateTableDesc, tbl: Table) = { - if (crtTbl != null && crtTbl.getNullFormat() != null) { - tbl.setSerdeParam(serdeConstants.SERIALIZATION_NULL_FORMAT, crtTbl.getNullFormat()) - } - } -} - -/* - * Bug introduced in hive-0.13. FileSinkDesc is serilizable, but its member path is not. - * Fix it through wrapper. - */ -private[hive] class ShimFileSinkDesc( - var dir: String, - var tableInfo: TableDesc, - var compressed: Boolean) - extends Serializable with Logging { - var compressCodec: String = _ - var compressType: String = _ - var destTableId: Int = _ - - def setCompressed(compressed: Boolean) { - this.compressed = compressed - } - - def getDirName = dir - - def setDestTableId(destTableId: Int) { - this.destTableId = destTableId - } - - def setTableInfo(tableInfo: TableDesc) { - this.tableInfo = tableInfo - } - - def setCompressCodec(intermediateCompressorCodec: String) { - compressCodec = intermediateCompressorCodec - } - - def setCompressType(intermediateCompressType: String) { - compressType = intermediateCompressType - } -} diff --git a/streaming/pom.xml b/streaming/pom.xml index 5ab7f4472c38..697895e72fe5 100644 --- a/streaming/pom.xml +++ b/streaming/pom.xml @@ -21,7 +21,7 @@ org.apache.spark spark-parent_2.10 - 1.4.0-SNAPSHOT + 1.5.0-SNAPSHOT ../pom.xml @@ -40,6 +40,13 @@ spark-core_${scala.binary.version} ${project.version} + + org.apache.spark + spark-core_${scala.binary.version} + ${project.version} + test-jar + test + diff --git a/streaming/src/main/resources/org/apache/spark/streaming/ui/static/streaming-page.css b/streaming/src/main/resources/org/apache/spark/streaming/ui/static/streaming-page.css index 19abe889ad3c..ec12616b58d8 100644 --- a/streaming/src/main/resources/org/apache/spark/streaming/ui/static/streaming-page.css +++ b/streaming/src/main/resources/org/apache/spark/streaming/ui/static/streaming-page.css @@ -31,7 +31,7 @@ } .tooltip-inner { - max-width: 500px !important; // Make sure we only have one line tooltip + max-width: 500px !important; /* Make sure we only have one line tooltip */ } .line { @@ -60,3 +60,7 @@ span.expand-input-rate { cursor: pointer; } + +tr.batch-table-cell-highlight > td { + background-color: #D6FFE4 !important; +} diff --git a/streaming/src/main/resources/org/apache/spark/streaming/ui/static/streaming-page.js b/streaming/src/main/resources/org/apache/spark/streaming/ui/static/streaming-page.js index 0ee6752b29e9..4886b68eeaf7 100644 --- a/streaming/src/main/resources/org/apache/spark/streaming/ui/static/streaming-page.js +++ b/streaming/src/main/resources/org/apache/spark/streaming/ui/static/streaming-page.js @@ -31,6 +31,8 @@ var maxXForHistogram = 0; var histogramBinCount = 10; var yValueFormat = d3.format(",.2f"); +var unitLabelYOffset = -10; + // Show a tooltip "text" for "node" function showBootstrapTooltip(node, text) { $(node).tooltip({title: text, trigger: "manual", container: "body"}); @@ -133,7 +135,7 @@ function drawTimeline(id, data, minX, maxX, minY, maxY, unitY, batchInterval) { .attr("class", "y axis") .call(yAxis) .append("text") - .attr("transform", "translate(0," + (-3) + ")") + .attr("transform", "translate(0," + unitLabelYOffset + ")") .text(unitY); @@ -146,6 +148,12 @@ function drawTimeline(id, data, minX, maxX, minY, maxY, unitY, batchInterval) { .attr("class", "line") .attr("d", line); + // If the user click one point in the graphs, jump to the batch row and highlight it. And + // recovery the batch row after 3 seconds if necessary. + // We need to remember the last clicked batch so that we can recovery it. + var lastClickedBatch = null; + var lastTimeout = null; + // Add points to the line. However, we make it invisible at first. But when the user moves mouse // over a point, it will be displayed with its detail. svg.selectAll(".point") @@ -154,6 +162,7 @@ function drawTimeline(id, data, minX, maxX, minY, maxY, unitY, batchInterval) { .attr("stroke", "white") // white and opacity = 0 make it invisible .attr("fill", "white") .attr("opacity", "0") + .style("cursor", "pointer") .attr("cx", function(d) { return x(d.x); }) .attr("cy", function(d) { return y(d.y); }) .attr("r", function(d) { return 3; }) @@ -175,7 +184,29 @@ function drawTimeline(id, data, minX, maxX, minY, maxY, unitY, batchInterval) { .attr("opacity", "0"); }) .on("click", function(d) { - window.location.href = "batch/?id=" + d.x; + if (lastTimeout != null) { + window.clearTimeout(lastTimeout); + } + if (lastClickedBatch != null) { + clearBatchRow(lastClickedBatch); + lastClickedBatch = null; + } + lastClickedBatch = d.x; + highlightBatchRow(lastClickedBatch) + lastTimeout = window.setTimeout(function () { + lastTimeout = null; + if (lastClickedBatch != null) { + clearBatchRow(lastClickedBatch); + lastClickedBatch = null; + } + }, 3000); // Clean up after 3 seconds + + var batchSelector = $("#batch-" + d.x); + var topOffset = batchSelector.offset().top - 15; + if (topOffset < 0) { + topOffset = 0; + } + $('html,body').animate({scrollTop: topOffset}, 200); }); } @@ -194,10 +225,10 @@ function drawHistogram(id, values, minY, maxY, unitY, batchInterval) { .style("border-left", "0px solid white"); var margin = {top: 20, right: 30, bottom: 30, left: 10}; - var width = 300 - margin.left - margin.right; + var width = 350 - margin.left - margin.right; var height = 150 - margin.top - margin.bottom; - var x = d3.scale.linear().domain([0, maxXForHistogram]).range([0, width]); + var x = d3.scale.linear().domain([0, maxXForHistogram]).range([0, width - 50]); var y = d3.scale.linear().domain([minY, maxY]).range([height, 0]); var xAxis = d3.svg.axis().scale(x).orient("top").ticks(5); @@ -218,6 +249,9 @@ function drawHistogram(id, values, minY, maxY, unitY, batchInterval) { svg.append("g") .attr("class", "x axis") .call(xAxis) + .append("text") + .attr("transform", "translate(" + (margin.left + width - 45) + ", " + unitLabelYOffset + ")") + .text("#batches"); svg.append("g") .attr("class", "y axis") @@ -279,3 +313,11 @@ $(function() { $(this).find('.expand-input-rate-arrow').toggleClass('arrow-open').toggleClass('arrow-closed'); } }); + +function highlightBatchRow(batch) { + $("#batch-" + batch).parent().addClass("batch-table-cell-highlight"); +} + +function clearBatchRow(batch) { + $("#batch-" + batch).parent().removeClass("batch-table-cell-highlight"); +} diff --git a/streaming/src/main/scala/org/apache/spark/streaming/Checkpoint.scala b/streaming/src/main/scala/org/apache/spark/streaming/Checkpoint.scala index 7bfae253c3a0..5279331c9e12 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/Checkpoint.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/Checkpoint.scala @@ -44,11 +44,23 @@ class Checkpoint(@transient ssc: StreamingContext, val checkpointTime: Time) val sparkConfPairs = ssc.conf.getAll def createSparkConf(): SparkConf = { + + // Reload properties for the checkpoint application since user wants to set a reload property + // or spark had changed its value and user wants to set it back. + val propertiesToReload = List( + "spark.master", + "spark.yarn.keytab", + "spark.yarn.principal") + val newSparkConf = new SparkConf(loadDefaults = false).setAll(sparkConfPairs) .remove("spark.driver.host") .remove("spark.driver.port") - val newMasterOption = new SparkConf(loadDefaults = true).getOption("spark.master") - newMasterOption.foreach { newMaster => newSparkConf.setMaster(newMaster) } + val newReloadConf = new SparkConf(loadDefaults = true) + propertiesToReload.foreach { prop => + newReloadConf.getOption(prop).foreach { value => + newSparkConf.set(prop, value) + } + } newSparkConf } @@ -102,6 +114,44 @@ object Checkpoint extends Logging { Seq.empty } } + + /** Serialize the checkpoint, or throw any exception that occurs */ + def serialize(checkpoint: Checkpoint, conf: SparkConf): Array[Byte] = { + val compressionCodec = CompressionCodec.createCodec(conf) + val bos = new ByteArrayOutputStream() + val zos = compressionCodec.compressedOutputStream(bos) + val oos = new ObjectOutputStream(zos) + Utils.tryWithSafeFinally { + oos.writeObject(checkpoint) + } { + oos.close() + } + bos.toByteArray + } + + /** Deserialize a checkpoint from the input stream, or throw any exception that occurs */ + def deserialize(inputStream: InputStream, conf: SparkConf): Checkpoint = { + val compressionCodec = CompressionCodec.createCodec(conf) + var ois: ObjectInputStreamWithLoader = null + Utils.tryWithSafeFinally { + + // ObjectInputStream uses the last defined user-defined class loader in the stack + // to find classes, which maybe the wrong class loader. Hence, a inherited version + // of ObjectInputStream is used to explicitly use the current thread's default class + // loader to find and load classes. This is a well know Java issue and has popped up + // in other places (e.g., http://jira.codehaus.org/browse/GROOVY-1627) + val zis = compressionCodec.compressedInputStream(inputStream) + ois = new ObjectInputStreamWithLoader(zis, + Thread.currentThread().getContextClassLoader) + val cp = ois.readObject.asInstanceOf[Checkpoint] + cp.validate() + cp + } { + if (ois != null) { + ois.close() + } + } + } } @@ -189,17 +239,10 @@ class CheckpointWriter( } def write(checkpoint: Checkpoint, clearCheckpointDataLater: Boolean) { - val bos = new ByteArrayOutputStream() - val zos = compressionCodec.compressedOutputStream(bos) - val oos = new ObjectOutputStream(zos) - Utils.tryWithSafeFinally { - oos.writeObject(checkpoint) - } { - oos.close() - } try { + val bytes = Checkpoint.serialize(checkpoint, conf) executor.execute(new CheckpointWriteHandler( - checkpoint.checkpointTime, bos.toByteArray, clearCheckpointDataLater)) + checkpoint.checkpointTime, bytes, clearCheckpointDataLater)) logDebug("Submitted checkpoint of time " + checkpoint.checkpointTime + " writer queue") } catch { case rej: RejectedExecutionException => @@ -264,25 +307,8 @@ object CheckpointReader extends Logging { checkpointFiles.foreach(file => { logInfo("Attempting to load checkpoint from file " + file) try { - var ois: ObjectInputStreamWithLoader = null - var cp: Checkpoint = null - Utils.tryWithSafeFinally { - val fis = fs.open(file) - // ObjectInputStream uses the last defined user-defined class loader in the stack - // to find classes, which maybe the wrong class loader. Hence, a inherited version - // of ObjectInputStream is used to explicitly use the current thread's default class - // loader to find and load classes. This is a well know Java issue and has popped up - // in other places (e.g., http://jira.codehaus.org/browse/GROOVY-1627) - val zis = compressionCodec.compressedInputStream(fis) - ois = new ObjectInputStreamWithLoader(zis, - Thread.currentThread().getContextClassLoader) - cp = ois.readObject.asInstanceOf[Checkpoint] - } { - if (ois != null) { - ois.close() - } - } - cp.validate() + val fis = fs.open(file) + val cp = Checkpoint.deserialize(fis, conf) logInfo("Checkpoint successfully loaded from file " + file) logInfo("Checkpoint was generated at time " + cp.checkpointTime) return Some(cp) diff --git a/streaming/src/main/scala/org/apache/spark/streaming/DStreamGraph.scala b/streaming/src/main/scala/org/apache/spark/streaming/DStreamGraph.scala index 85b354ff4aa0..40789c66f399 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/DStreamGraph.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/DStreamGraph.scala @@ -157,10 +157,10 @@ final private[streaming] class DStreamGraph extends Serializable with Logging { def validate() { this.synchronized { - assert(batchDuration != null, "Batch duration has not been set") + require(batchDuration != null, "Batch duration has not been set") // assert(batchDuration >= Milliseconds(100), "Batch duration of " + batchDuration + // " is very low") - assert(getOutputStreams().size > 0, "No output streams registered, so nothing to execute") + require(getOutputStreams().size > 0, "No output operations registered, so nothing to execute") } } diff --git a/streaming/src/main/scala/org/apache/spark/streaming/StreamingContext.scala b/streaming/src/main/scala/org/apache/spark/streaming/StreamingContext.scala index 1d2ecdd34181..ec49d0f42d12 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/StreamingContext.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/StreamingContext.scala @@ -17,12 +17,13 @@ package org.apache.spark.streaming -import java.io.InputStream +import java.io.{InputStream, NotSerializableException} import java.util.concurrent.atomic.{AtomicInteger, AtomicReference} import scala.collection.Map import scala.collection.mutable.Queue import scala.reflect.ClassTag +import scala.util.control.NonFatal import akka.actor.{Props, SupervisorStrategy} import org.apache.hadoop.conf.Configuration @@ -34,14 +35,15 @@ import org.apache.hadoop.mapreduce.{InputFormat => NewInputFormat} import org.apache.spark._ import org.apache.spark.annotation.{DeveloperApi, Experimental} import org.apache.spark.input.FixedLengthBinaryInputFormat -import org.apache.spark.rdd.RDD +import org.apache.spark.rdd.{RDD, RDDOperationScope} +import org.apache.spark.serializer.SerializationDebugger import org.apache.spark.storage.StorageLevel import org.apache.spark.streaming.StreamingContextState._ import org.apache.spark.streaming.dstream._ import org.apache.spark.streaming.receiver.{ActorReceiver, ActorSupervisorStrategy, Receiver} import org.apache.spark.streaming.scheduler.{JobScheduler, StreamingListener} import org.apache.spark.streaming.ui.{StreamingJobProgressListener, StreamingTab} -import org.apache.spark.util.CallSite +import org.apache.spark.util.{CallSite, Utils} /** * Main entry point for Spark Streaming functionality. It provides methods used to create @@ -155,7 +157,7 @@ class StreamingContext private[streaming] ( cp_.graph.restoreCheckpointData() cp_.graph } else { - assert(batchDur_ != null, "Batch duration for streaming context cannot be null") + require(batchDur_ != null, "Batch duration for StreamingContext cannot be null") val newGraph = new DStreamGraph() newGraph.setBatchDuration(batchDur_) newGraph @@ -200,6 +202,8 @@ class StreamingContext private[streaming] ( private val startSite = new AtomicReference[CallSite](null) + private var shutdownHookRef: AnyRef = _ + /** * Return the associated Spark context */ @@ -235,21 +239,46 @@ class StreamingContext private[streaming] ( } } + private[streaming] def isCheckpointingEnabled: Boolean = { + checkpointDir != null + } + private[streaming] def initialCheckpoint: Checkpoint = { if (isCheckpointPresent) cp_ else null } private[streaming] def getNewInputStreamId() = nextInputStreamId.getAndIncrement() + /** + * Execute a block of code in a scope such that all new DStreams created in this body will + * be part of the same scope. For more detail, see the comments in `doCompute`. + * + * Note: Return statements are NOT allowed in the given body. + */ + private[streaming] def withScope[U](body: => U): U = sparkContext.withScope(body) + + /** + * Execute a block of code in a scope such that all new DStreams created in this body will + * be part of the same scope. For more detail, see the comments in `doCompute`. + * + * Note: Return statements are NOT allowed in the given body. + */ + private[streaming] def withNamedScope[U](name: String)(body: => U): U = { + RDDOperationScope.withScope(sc, name, allowNesting = false, ignoreParent = false)(body) + } + /** * Create an input stream with any arbitrary user implemented receiver. * Find more details at: http://spark.apache.org/docs/latest/streaming-custom-receivers.html * @param receiver Custom implementation of Receiver + * + * @deprecated As of 1.0.0", replaced by `receiverStream`. */ @deprecated("Use receiverStream", "1.0.0") - def networkStream[T: ClassTag]( - receiver: Receiver[T]): ReceiverInputDStream[T] = { - receiverStream(receiver) + def networkStream[T: ClassTag](receiver: Receiver[T]): ReceiverInputDStream[T] = { + withNamedScope("network stream") { + receiverStream(receiver) + } } /** @@ -257,9 +286,10 @@ class StreamingContext private[streaming] ( * Find more details at: http://spark.apache.org/docs/latest/streaming-custom-receivers.html * @param receiver Custom implementation of Receiver */ - def receiverStream[T: ClassTag]( - receiver: Receiver[T]): ReceiverInputDStream[T] = { - new PluggableInputDStream[T](this, receiver) + def receiverStream[T: ClassTag](receiver: Receiver[T]): ReceiverInputDStream[T] = { + withNamedScope("receiver stream") { + new PluggableInputDStream[T](this, receiver) + } } /** @@ -279,7 +309,7 @@ class StreamingContext private[streaming] ( name: String, storageLevel: StorageLevel = StorageLevel.MEMORY_AND_DISK_SER_2, supervisorStrategy: SupervisorStrategy = ActorSupervisorStrategy.defaultStrategy - ): ReceiverInputDStream[T] = { + ): ReceiverInputDStream[T] = withNamedScope("actor stream") { receiverStream(new ActorReceiver[T](props, name, storageLevel, supervisorStrategy)) } @@ -296,7 +326,7 @@ class StreamingContext private[streaming] ( hostname: String, port: Int, storageLevel: StorageLevel = StorageLevel.MEMORY_AND_DISK_SER_2 - ): ReceiverInputDStream[String] = { + ): ReceiverInputDStream[String] = withNamedScope("socket text stream") { socketStream[String](hostname, port, SocketReceiver.bytesToLines, storageLevel) } @@ -334,7 +364,7 @@ class StreamingContext private[streaming] ( hostname: String, port: Int, storageLevel: StorageLevel = StorageLevel.MEMORY_AND_DISK_SER_2 - ): ReceiverInputDStream[T] = { + ): ReceiverInputDStream[T] = withNamedScope("raw socket stream") { new RawInputDStream[T](this, hostname, port, storageLevel) } @@ -408,7 +438,7 @@ class StreamingContext private[streaming] ( * file system. File names starting with . are ignored. * @param directory HDFS directory to monitor for new file */ - def textFileStream(directory: String): DStream[String] = { + def textFileStream(directory: String): DStream[String] = withNamedScope("text file stream") { fileStream[LongWritable, Text, TextInputFormat](directory).map(_._2.toString) } @@ -430,14 +460,15 @@ class StreamingContext private[streaming] ( @Experimental def binaryRecordsStream( directory: String, - recordLength: Int): DStream[Array[Byte]] = { + recordLength: Int): DStream[Array[Byte]] = withNamedScope("binary records stream") { val conf = sc_.hadoopConfiguration conf.setInt(FixedLengthBinaryInputFormat.RECORD_LENGTH_PROPERTY, recordLength) val br = fileStream[LongWritable, BytesWritable, FixedLengthBinaryInputFormat]( - directory, FileInputDStream.defaultFilter : Path => Boolean, newFilesOnly=true, conf) + directory, FileInputDStream.defaultFilter: Path => Boolean, newFilesOnly = true, conf) val data = br.map { case (k, v) => val bytes = v.getBytes - assert(bytes.length == recordLength, "Byte array does not have correct length") + require(bytes.length == recordLength, "Byte array does not have correct length. " + + s"${bytes.length} did not equal recordLength: $recordLength") bytes } data @@ -446,6 +477,10 @@ class StreamingContext private[streaming] ( /** * Create an input stream from a queue of RDDs. In each batch, * it will process either one or all of the RDDs returned by the queue. + * + * NOTE: Arbitrary RDDs can be added to `queueStream`, there is no way to recover data of + * those RDDs, so `queueStream` doesn't support checkpointing. + * * @param queue Queue of RDDs * @param oneAtATime Whether only one RDD should be consumed from the queue in every interval * @tparam T Type of objects in the RDD @@ -460,6 +495,10 @@ class StreamingContext private[streaming] ( /** * Create an input stream from a queue of RDDs. In each batch, * it will process either one or all of the RDDs returned by the queue. + * + * NOTE: Arbitrary RDDs can be added to `queueStream`, there is no way to recover data of + * those RDDs, so `queueStream` doesn't support checkpointing. + * * @param queue Queue of RDDs * @param oneAtATime Whether only one RDD should be consumed from the queue in every interval * @param defaultRDD Default RDD is returned by the DStream when the queue is empty. @@ -477,7 +516,7 @@ class StreamingContext private[streaming] ( /** * Create a unified DStream from multiple DStreams of the same type and same slide duration. */ - def union[T: ClassTag](streams: Seq[DStream[T]]): DStream[T] = { + def union[T: ClassTag](streams: Seq[DStream[T]]): DStream[T] = withScope { new UnionDStream[T](streams.toArray) } @@ -488,7 +527,7 @@ class StreamingContext private[streaming] ( def transform[T: ClassTag]( dstreams: Seq[DStream[_]], transformFunc: (Seq[RDD[_]], Time) => RDD[T] - ): DStream[T] = { + ): DStream[T] = withScope { new TransformedDStream[T](dstreams, sparkContext.clean(transformFunc)) } @@ -503,11 +542,26 @@ class StreamingContext private[streaming] ( assert(graph != null, "Graph is null") graph.validate() - assert( - checkpointDir == null || checkpointDuration != null, + require( + !isCheckpointingEnabled || checkpointDuration != null, "Checkpoint directory has been set, but the graph checkpointing interval has " + "not been set. Please use StreamingContext.checkpoint() to set the interval." ) + + // Verify whether the DStream checkpoint is serializable + if (isCheckpointingEnabled) { + val checkpoint = new Checkpoint(this, Time.apply(0)) + try { + Checkpoint.serialize(checkpoint, conf) + } catch { + case e: NotSerializableException => + throw new NotSerializableException( + "DStream checkpointing has been enabled but the DStreams with their functions " + + "are not serializable\n" + + SerializationDebugger.improveException(checkpoint, e).getMessage() + ) + } + } } /** @@ -528,26 +582,36 @@ class StreamingContext private[streaming] ( /** * Start the execution of the streams. * - * @throws SparkException if the StreamingContext is already stopped. + * @throws IllegalStateException if the StreamingContext is already stopped. */ def start(): Unit = synchronized { state match { case INITIALIZED => - validate() startSite.set(DStream.getCreationSite()) sparkContext.setCallSite(startSite.get) StreamingContext.ACTIVATION_LOCK.synchronized { StreamingContext.assertNoOtherContextIsActive() - scheduler.start() - uiTab.foreach(_.attach()) - state = StreamingContextState.ACTIVE + try { + validate() + scheduler.start() + state = StreamingContextState.ACTIVE + } catch { + case NonFatal(e) => + logError("Error starting the context, marking it as stopped", e) + scheduler.stop(false) + state = StreamingContextState.STOPPED + throw e + } StreamingContext.setActiveContext(this) } + shutdownHookRef = Utils.addShutdownHook( + StreamingContext.SHUTDOWN_HOOK_PRIORITY)(stopOnShutdown) + uiTab.foreach(_.attach()) logInfo("StreamingContext started") case ACTIVE => logWarning("StreamingContext has already been started") case STOPPED => - throw new SparkException("StreamingContext has already been stopped") + throw new IllegalStateException("StreamingContext has already been stopped") } } @@ -563,6 +627,8 @@ class StreamingContext private[streaming] ( * Wait for the execution to stop. Any exceptions that occurs during the execution * will be thrown in this thread. * @param timeout time to wait in milliseconds + * + * @deprecated As of 1.3.0, replaced by `awaitTerminationOrTimeout(Long)`. */ @deprecated("Use awaitTerminationOrTimeout(Long) instead", "1.3.0") def awaitTermination(timeout: Long) { @@ -619,6 +685,9 @@ class StreamingContext private[streaming] ( uiTab.foreach(_.detach()) StreamingContext.setActiveContext(null) waiter.notifyStop() + if (shutdownHookRef != null) { + Utils.removeShutdownHook(shutdownHookRef) + } logInfo("StreamingContext stopped successfully") } // Even if we have already stopped, we still need to attempt to stop the SparkContext because @@ -629,6 +698,13 @@ class StreamingContext private[streaming] ( state = STOPPED } } + + private def stopOnShutdown(): Unit = { + val stopGracefully = conf.getBoolean("spark.streaming.stopGracefullyOnShutdown", false) + logInfo(s"Invoking stop(stopGracefully=$stopGracefully) from shutdown hook") + // Do not stop SparkContext, let its own shutdown hook stop it + stop(stopSparkContext = false, stopGracefully = stopGracefully) + } } /** @@ -644,12 +720,14 @@ object StreamingContext extends Logging { */ private val ACTIVATION_LOCK = new Object() + private val SHUTDOWN_HOOK_PRIORITY = Utils.SPARK_CONTEXT_SHUTDOWN_PRIORITY + 1 + private val activeContext = new AtomicReference[StreamingContext](null) private def assertNoOtherContextIsActive(): Unit = { ACTIVATION_LOCK.synchronized { if (activeContext.get() != null) { - throw new SparkException( + throw new IllegalStateException( "Only one StreamingContext may be started in this JVM. " + "Currently running StreamingContext was started at" + activeContext.get.startSite.get.longForm) @@ -675,6 +753,10 @@ object StreamingContext extends Logging { } } + /** + * @deprecated As of 1.3.0, replaced by implicit functions in the DStream companion object. + * This is kept here only for backward compatibility. + */ @deprecated("Replaced by implicit functions in the DStream companion object. This is " + "kept here only for backward compatibility.", "1.3.0") def toPairDStreamFunctions[K, V](stream: DStream[(K, V)]) diff --git a/streaming/src/main/scala/org/apache/spark/streaming/api/java/JavaPairDStream.scala b/streaming/src/main/scala/org/apache/spark/streaming/api/java/JavaPairDStream.scala index 93baad19e3ee..959ac9c177f8 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/api/java/JavaPairDStream.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/api/java/JavaPairDStream.scala @@ -227,7 +227,7 @@ class JavaPairDStream[K, V](val dstream: DStream[(K, V)])( * @param numPartitions Number of partitions of each RDD in the new DStream. */ def groupByKeyAndWindow(windowDuration: Duration, slideDuration: Duration, numPartitions: Int) - :JavaPairDStream[K, JIterable[V]] = { + : JavaPairDStream[K, JIterable[V]] = { dstream.groupByKeyAndWindow(windowDuration, slideDuration, numPartitions) .mapValues(asJavaIterable _) } @@ -247,7 +247,7 @@ class JavaPairDStream[K, V](val dstream: DStream[(K, V)])( windowDuration: Duration, slideDuration: Duration, partitioner: Partitioner - ):JavaPairDStream[K, JIterable[V]] = { + ): JavaPairDStream[K, JIterable[V]] = { dstream.groupByKeyAndWindow(windowDuration, slideDuration, partitioner) .mapValues(asJavaIterable _) } @@ -262,7 +262,7 @@ class JavaPairDStream[K, V](val dstream: DStream[(K, V)])( * batching interval */ def reduceByKeyAndWindow(reduceFunc: JFunction2[V, V, V], windowDuration: Duration) - :JavaPairDStream[K, V] = { + : JavaPairDStream[K, V] = { dstream.reduceByKeyAndWindow(reduceFunc, windowDuration) } @@ -281,7 +281,7 @@ class JavaPairDStream[K, V](val dstream: DStream[(K, V)])( reduceFunc: JFunction2[V, V, V], windowDuration: Duration, slideDuration: Duration - ):JavaPairDStream[K, V] = { + ): JavaPairDStream[K, V] = { dstream.reduceByKeyAndWindow(reduceFunc, windowDuration, slideDuration) } diff --git a/streaming/src/main/scala/org/apache/spark/streaming/api/java/JavaStreamingContext.scala b/streaming/src/main/scala/org/apache/spark/streaming/api/java/JavaStreamingContext.scala index b639b94d5ca4..40deb6d7ea79 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/api/java/JavaStreamingContext.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/api/java/JavaStreamingContext.scala @@ -148,6 +148,9 @@ class JavaStreamingContext(val ssc: StreamingContext) extends Closeable { /** The underlying SparkContext */ val sparkContext = new JavaSparkContext(ssc.sc) + /** + * @deprecated As of 0.9.0, replaced by `sparkContext` + */ @deprecated("use sparkContext", "0.9.0") val sc: JavaSparkContext = sparkContext @@ -416,7 +419,11 @@ class JavaStreamingContext(val ssc: StreamingContext) extends Closeable { * Create an input stream from an queue of RDDs. In each batch, * it will process either one or all of the RDDs returned by the queue. * - * NOTE: changes to the queue after the stream is created will not be recognized. + * NOTE: + * 1. Changes to the queue after the stream is created will not be recognized. + * 2. Arbitrary RDDs can be added to `queueStream`, there is no way to recover data of + * those RDDs, so `queueStream` doesn't support checkpointing. + * * @param queue Queue of RDDs * @tparam T Type of objects in the RDD */ @@ -432,7 +439,11 @@ class JavaStreamingContext(val ssc: StreamingContext) extends Closeable { * Create an input stream from an queue of RDDs. In each batch, * it will process either one or all of the RDDs returned by the queue. * - * NOTE: changes to the queue after the stream is created will not be recognized. + * NOTE: + * 1. Changes to the queue after the stream is created will not be recognized. + * 2. Arbitrary RDDs can be added to `queueStream`, there is no way to recover data of + * those RDDs, so `queueStream` doesn't support checkpointing. + * * @param queue Queue of RDDs * @param oneAtATime Whether only one RDD should be consumed from the queue in every interval * @tparam T Type of objects in the RDD @@ -452,7 +463,11 @@ class JavaStreamingContext(val ssc: StreamingContext) extends Closeable { * Create an input stream from an queue of RDDs. In each batch, * it will process either one or all of the RDDs returned by the queue. * - * NOTE: changes to the queue after the stream is created will not be recognized. + * NOTE: + * 1. Changes to the queue after the stream is created will not be recognized. + * 2. Arbitrary RDDs can be added to `queueStream`, there is no way to recover data of + * those RDDs, so `queueStream` doesn't support checkpointing. + * * @param queue Queue of RDDs * @param oneAtATime Whether only one RDD should be consumed from the queue in every interval * @param defaultRDD Default RDD is returned by the DStream when the queue is empty @@ -619,6 +634,7 @@ class JavaStreamingContext(val ssc: StreamingContext) extends Closeable { * Wait for the execution to stop. Any exceptions that occurs during the execution * will be thrown in this thread. * @param timeout time to wait in milliseconds + * @deprecated As of 1.3.0, replaced by `awaitTerminationOrTimeout(Long)`. */ @deprecated("Use awaitTerminationOrTimeout(Long) instead", "1.3.0") def awaitTermination(timeout: Long): Unit = { @@ -677,6 +693,7 @@ object JavaStreamingContext { * * @param checkpointPath Checkpoint directory used in an earlier JavaStreamingContext program * @param factory JavaStreamingContextFactory object to create a new JavaStreamingContext + * @deprecated As of 1.4.0, replaced by `getOrCreate` without JavaStreamingContextFactor. */ @deprecated("use getOrCreate without JavaStreamingContextFactor", "1.4.0") def getOrCreate( @@ -699,6 +716,7 @@ object JavaStreamingContext { * @param factory JavaStreamingContextFactory object to create a new JavaStreamingContext * @param hadoopConf Hadoop configuration if necessary for reading from any HDFS compatible * file system + * @deprecated As of 1.4.0, replaced by `getOrCreate` without JavaStreamingContextFactor. */ @deprecated("use getOrCreate without JavaStreamingContextFactory", "1.4.0") def getOrCreate( @@ -724,6 +742,7 @@ object JavaStreamingContext { * file system * @param createOnError Whether to create a new JavaStreamingContext if there is an * error in reading checkpoint data. + * @deprecated As of 1.4.0, replaced by `getOrCreate` without JavaStreamingContextFactor. */ @deprecated("use getOrCreate without JavaStreamingContextFactory", "1.4.0") def getOrCreate( diff --git a/streaming/src/main/scala/org/apache/spark/streaming/api/python/PythonDStream.scala b/streaming/src/main/scala/org/apache/spark/streaming/api/python/PythonDStream.scala index 4c28654ef641..d06401245ff1 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/api/python/PythonDStream.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/api/python/PythonDStream.scala @@ -109,7 +109,7 @@ private[python] object PythonTransformFunctionSerializer { } def serialize(func: PythonTransformFunction): Array[Byte] = { - assert(serializer != null, "Serializer has not been registered!") + require(serializer != null, "Serializer has not been registered!") // get the id of PythonTransformFunction in py4j val h = Proxy.getInvocationHandler(func.asInstanceOf[Proxy]) val f = h.getClass().getDeclaredField("id") @@ -119,7 +119,7 @@ private[python] object PythonTransformFunctionSerializer { } def deserialize(bytes: Array[Byte]): PythonTransformFunction = { - assert(serializer != null, "Serializer has not been registered!") + require(serializer != null, "Serializer has not been registered!") serializer.loads(bytes) } } diff --git a/streaming/src/main/scala/org/apache/spark/streaming/dstream/DStream.scala b/streaming/src/main/scala/org/apache/spark/streaming/dstream/DStream.scala index 64de7526a6a3..192aa6a139bc 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/dstream/DStream.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/dstream/DStream.scala @@ -25,12 +25,13 @@ import scala.language.implicitConversions import scala.reflect.ClassTag import scala.util.matching.Regex -import org.apache.spark.{Logging, SparkException} -import org.apache.spark.rdd.{BlockRDD, PairRDDFunctions, RDD} +import org.apache.spark.{Logging, SparkContext, SparkException} +import org.apache.spark.rdd.{BlockRDD, PairRDDFunctions, RDD, RDDOperationScope} import org.apache.spark.storage.StorageLevel import org.apache.spark.streaming._ import org.apache.spark.streaming.StreamingContext.rddToFileName import org.apache.spark.streaming.scheduler.Job +import org.apache.spark.streaming.ui.UIUtils import org.apache.spark.util.{CallSite, MetadataCleaner, Utils} /** @@ -73,7 +74,7 @@ abstract class DStream[T: ClassTag] ( def dependencies: List[DStream[_]] /** Method that generates a RDD for the given time */ - def compute (validTime: Time): Option[RDD[T]] + def compute(validTime: Time): Option[RDD[T]] // ======================================================================= // Methods and fields available on all DStreams @@ -111,6 +112,44 @@ abstract class DStream[T: ClassTag] ( /* Set the creation call site */ private[streaming] val creationSite = DStream.getCreationSite() + /** + * The base scope associated with the operation that created this DStream. + * + * This is the medium through which we pass the DStream operation name (e.g. updatedStateByKey) + * to the RDDs created by this DStream. Note that we never use this scope directly in RDDs. + * Instead, we instantiate a new scope during each call to `compute` based on this one. + * + * This is not defined if the DStream is created outside of one of the public DStream operations. + */ + protected[streaming] val baseScope: Option[String] = { + Option(ssc.sc.getLocalProperty(SparkContext.RDD_SCOPE_KEY)) + } + + /** + * Make a scope that groups RDDs created in the same DStream operation in the same batch. + * + * Each DStream produces many scopes and each scope may be shared by other DStreams created + * in the same operation. Separate calls to the same DStream operation create separate scopes. + * For instance, `dstream.map(...).map(...)` creates two separate scopes per batch. + */ + private def makeScope(time: Time): Option[RDDOperationScope] = { + baseScope.map { bsJson => + val formattedBatchTime = UIUtils.formatBatchTime( + time.milliseconds, ssc.graph.batchDuration.milliseconds, showYYYYMMSS = false) + val bs = RDDOperationScope.fromJson(bsJson) + val baseName = bs.name // e.g. countByWindow, "kafka stream [0]" + val scopeName = + if (baseName.length > 10) { + // If the operation name is too long, wrap the line + s"$baseName\n@ $formattedBatchTime" + } else { + s"$baseName @ $formattedBatchTime" + } + val scopeId = s"${bs.id}_${time.milliseconds}" + new RDDOperationScope(scopeName, id = scopeId) + } + } + /** Persist the RDDs of this DStream with the given storage level */ def persist(level: StorageLevel): DStream[T] = { if (this.isInitialized) { @@ -178,53 +217,52 @@ abstract class DStream[T: ClassTag] ( case StreamingContextState.INITIALIZED => // good to go case StreamingContextState.ACTIVE => - throw new SparkException( + throw new IllegalStateException( "Adding new inputs, transformations, and output operations after " + "starting a context is not supported") case StreamingContextState.STOPPED => - throw new SparkException( + throw new IllegalStateException( "Adding new inputs, transformations, and output operations after " + "stopping a context is not supported") } } private[streaming] def validateAtStart() { - assert(rememberDuration != null, "Remember duration is set to null") + require(rememberDuration != null, "Remember duration is set to null") - assert( + require( !mustCheckpoint || checkpointDuration != null, "The checkpoint interval for " + this.getClass.getSimpleName + " has not been set." + " Please use DStream.checkpoint() to set the interval." ) - assert( + require( checkpointDuration == null || context.sparkContext.checkpointDir.isDefined, - "The checkpoint directory has not been set. Please use StreamingContext.checkpoint()" + - " or SparkContext.checkpoint() to set the checkpoint directory." + "The checkpoint directory has not been set. Please set it by StreamingContext.checkpoint()." ) - assert( + require( checkpointDuration == null || checkpointDuration >= slideDuration, "The checkpoint interval for " + this.getClass.getSimpleName + " has been set to " + checkpointDuration + " which is lower than its slide time (" + slideDuration + "). " + "Please set it to at least " + slideDuration + "." ) - assert( + require( checkpointDuration == null || checkpointDuration.isMultipleOf(slideDuration), "The checkpoint interval for " + this.getClass.getSimpleName + " has been set to " + checkpointDuration + " which not a multiple of its slide time (" + slideDuration + "). " + - "Please set it to a multiple " + slideDuration + "." + "Please set it to a multiple of " + slideDuration + "." ) - assert( + require( checkpointDuration == null || storageLevel != StorageLevel.NONE, "" + this.getClass.getSimpleName + " has been marked for checkpointing but the storage " + "level has not been set to enable persisting. Please use DStream.persist() to set the " + "storage level to use memory for better checkpointing performance." ) - assert( + require( checkpointDuration == null || rememberDuration > checkpointDuration, "The remember duration for " + this.getClass.getSimpleName + " has been set to " + rememberDuration + " which is not more than the checkpoint interval (" + @@ -233,7 +271,7 @@ abstract class DStream[T: ClassTag] ( val metadataCleanerDelay = MetadataCleaner.getDelaySeconds(ssc.conf) logInfo("metadataCleanupDelay = " + metadataCleanerDelay) - assert( + require( metadataCleanerDelay < 0 || rememberDuration.milliseconds < metadataCleanerDelay * 1000, "It seems you are doing some DStream window operation or setting a checkpoint interval " + "which requires " + this.getClass.getSimpleName + " to remember generated RDDs for more " + @@ -295,28 +333,23 @@ abstract class DStream[T: ClassTag] ( * Get the RDD corresponding to the given time; either retrieve it from cache * or compute-and-cache it. */ - private[streaming] def getOrCompute(time: Time): Option[RDD[T]] = { + private[streaming] final def getOrCompute(time: Time): Option[RDD[T]] = { // If RDD was already generated, then retrieve it from HashMap, // or else compute the RDD generatedRDDs.get(time).orElse { // Compute the RDD if time is valid (e.g. correct time in a sliding window) // of RDD generation, else generate nothing. if (isTimeValid(time)) { - // Set the thread-local property for call sites to this DStream's creation site - // such that RDDs generated by compute gets that as their creation site. - // Note that this `getOrCompute` may get called from another DStream which may have - // set its own call site. So we store its call site in a temporary variable, - // set this DStream's creation site, generate RDDs and then restore the previous call site. - val prevCallSite = ssc.sparkContext.getCallSite() - ssc.sparkContext.setCallSite(creationSite) - // Disable checks for existing output directories in jobs launched by the streaming - // scheduler, since we may need to write output to an existing directory during checkpoint - // recovery; see SPARK-4835 for more details. We need to have this call here because - // compute() might cause Spark jobs to be launched. - val rddOption = PairRDDFunctions.disableOutputSpecValidation.withValue(true) { - compute(time) + + val rddOption = createRDDWithLocalProperties(time) { + // Disable checks for existing output directories in jobs launched by the streaming + // scheduler, since we may need to write output to an existing directory during checkpoint + // recovery; see SPARK-4835 for more details. We need to have this call here because + // compute() might cause Spark jobs to be launched. + PairRDDFunctions.disableOutputSpecValidation.withValue(true) { + compute(time) + } } - ssc.sparkContext.setCallSite(prevCallSite) rddOption.foreach { case newRDD => // Register the generated RDD for caching and checkpointing @@ -337,6 +370,41 @@ abstract class DStream[T: ClassTag] ( } } + /** + * Wrap a body of code such that the call site and operation scope + * information are passed to the RDDs created in this body properly. + */ + protected def createRDDWithLocalProperties[U](time: Time)(body: => U): U = { + val scopeKey = SparkContext.RDD_SCOPE_KEY + val scopeNoOverrideKey = SparkContext.RDD_SCOPE_NO_OVERRIDE_KEY + // Pass this DStream's operation scope and creation site information to RDDs through + // thread-local properties in our SparkContext. Since this method may be called from another + // DStream, we need to temporarily store any old scope and creation site information to + // restore them later after setting our own. + val prevCallSite = ssc.sparkContext.getCallSite() + val prevScope = ssc.sparkContext.getLocalProperty(scopeKey) + val prevScopeNoOverride = ssc.sparkContext.getLocalProperty(scopeNoOverrideKey) + + try { + ssc.sparkContext.setCallSite(creationSite) + // Use the DStream's base scope for this RDD so we can (1) preserve the higher level + // DStream operation name, and (2) share this scope with other DStreams created in the + // same operation. Disallow nesting so that low-level Spark primitives do not show up. + // TODO: merge callsites with scopes so we can just reuse the code there + makeScope(time).foreach { s => + ssc.sparkContext.setLocalProperty(scopeKey, s.toJson) + ssc.sparkContext.setLocalProperty(scopeNoOverrideKey, "true") + } + + body + } finally { + // Restore any state that was modified before returning + ssc.sparkContext.setCallSite(prevCallSite) + ssc.sparkContext.setLocalProperty(scopeKey, prevScope) + ssc.sparkContext.setLocalProperty(scopeNoOverrideKey, prevScopeNoOverride) + } + } + /** * Generate a SparkStreaming job for the given time. This is an internal method that * should not be called directly. This default implementation creates a job @@ -456,7 +524,7 @@ abstract class DStream[T: ClassTag] ( // ======================================================================= /** Return a new DStream by applying a function to all elements of this DStream. */ - def map[U: ClassTag](mapFunc: T => U): DStream[U] = { + def map[U: ClassTag](mapFunc: T => U): DStream[U] = ssc.withScope { new MappedDStream(this, context.sparkContext.clean(mapFunc)) } @@ -464,26 +532,31 @@ abstract class DStream[T: ClassTag] ( * Return a new DStream by applying a function to all elements of this DStream, * and then flattening the results */ - def flatMap[U: ClassTag](flatMapFunc: T => Traversable[U]): DStream[U] = { + def flatMap[U: ClassTag](flatMapFunc: T => Traversable[U]): DStream[U] = ssc.withScope { new FlatMappedDStream(this, context.sparkContext.clean(flatMapFunc)) } /** Return a new DStream containing only the elements that satisfy a predicate. */ - def filter(filterFunc: T => Boolean): DStream[T] = new FilteredDStream(this, filterFunc) + def filter(filterFunc: T => Boolean): DStream[T] = ssc.withScope { + new FilteredDStream(this, context.sparkContext.clean(filterFunc)) + } /** * Return a new DStream in which each RDD is generated by applying glom() to each RDD of * this DStream. Applying glom() to an RDD coalesces all elements within each partition into * an array. */ - def glom(): DStream[Array[T]] = new GlommedDStream(this) - + def glom(): DStream[Array[T]] = ssc.withScope { + new GlommedDStream(this) + } /** * Return a new DStream with an increased or decreased level of parallelism. Each RDD in the * returned DStream has exactly numPartitions partitions. */ - def repartition(numPartitions: Int): DStream[T] = this.transform(_.repartition(numPartitions)) + def repartition(numPartitions: Int): DStream[T] = ssc.withScope { + this.transform(_.repartition(numPartitions)) + } /** * Return a new DStream in which each RDD is generated by applying mapPartitions() to each RDDs @@ -493,7 +566,7 @@ abstract class DStream[T: ClassTag] ( def mapPartitions[U: ClassTag]( mapPartFunc: Iterator[T] => Iterator[U], preservePartitioning: Boolean = false - ): DStream[U] = { + ): DStream[U] = ssc.withScope { new MapPartitionedDStream(this, context.sparkContext.clean(mapPartFunc), preservePartitioning) } @@ -501,14 +574,15 @@ abstract class DStream[T: ClassTag] ( * Return a new DStream in which each RDD has a single element generated by reducing each RDD * of this DStream. */ - def reduce(reduceFunc: (T, T) => T): DStream[T] = + def reduce(reduceFunc: (T, T) => T): DStream[T] = ssc.withScope { this.map(x => (null, x)).reduceByKey(reduceFunc, 1).map(_._2) + } /** * Return a new DStream in which each RDD has a single element generated by counting each RDD * of this DStream. */ - def count(): DStream[Long] = { + def count(): DStream[Long] = ssc.withScope { this.map(_ => (null, 1L)) .transform(_.union(context.sparkContext.makeRDD(Seq((null, 0L)), 1))) .reduceByKey(_ + _) @@ -522,24 +596,29 @@ abstract class DStream[T: ClassTag] ( * `numPartitions` not specified). */ def countByValue(numPartitions: Int = ssc.sc.defaultParallelism)(implicit ord: Ordering[T] = null) - : DStream[(T, Long)] = + : DStream[(T, Long)] = ssc.withScope { this.map(x => (x, 1L)).reduceByKey((x: Long, y: Long) => x + y, numPartitions) + } /** * Apply a function to each RDD in this DStream. This is an output operator, so * 'this' DStream will be registered as an output stream and therefore materialized. + * + * @deprecated As of 0.9.0, replaced by `foreachRDD`. */ @deprecated("use foreachRDD", "0.9.0") - def foreach(foreachFunc: RDD[T] => Unit): Unit = { + def foreach(foreachFunc: RDD[T] => Unit): Unit = ssc.withScope { this.foreachRDD(foreachFunc) } /** * Apply a function to each RDD in this DStream. This is an output operator, so * 'this' DStream will be registered as an output stream and therefore materialized. + * + * @deprecated As of 0.9.0, replaced by `foreachRDD`. */ @deprecated("use foreachRDD", "0.9.0") - def foreach(foreachFunc: (RDD[T], Time) => Unit): Unit = { + def foreach(foreachFunc: (RDD[T], Time) => Unit): Unit = ssc.withScope { this.foreachRDD(foreachFunc) } @@ -547,17 +626,18 @@ abstract class DStream[T: ClassTag] ( * Apply a function to each RDD in this DStream. This is an output operator, so * 'this' DStream will be registered as an output stream and therefore materialized. */ - def foreachRDD(foreachFunc: RDD[T] => Unit) { - this.foreachRDD((r: RDD[T], t: Time) => foreachFunc(r)) + def foreachRDD(foreachFunc: RDD[T] => Unit): Unit = ssc.withScope { + val cleanedF = context.sparkContext.clean(foreachFunc, false) + this.foreachRDD((r: RDD[T], t: Time) => cleanedF(r)) } /** * Apply a function to each RDD in this DStream. This is an output operator, so * 'this' DStream will be registered as an output stream and therefore materialized. */ - def foreachRDD(foreachFunc: (RDD[T], Time) => Unit) { - // because the DStream is reachable from the outer object here, and because - // DStreams can't be serialized with closures, we can't proactively check + def foreachRDD(foreachFunc: (RDD[T], Time) => Unit): Unit = ssc.withScope { + // because the DStream is reachable from the outer object here, and because + // DStreams can't be serialized with closures, we can't proactively check // it for serializability and so we pass the optional false to SparkContext.clean new ForEachDStream(this, context.sparkContext.clean(foreachFunc, false)).register() } @@ -566,9 +646,9 @@ abstract class DStream[T: ClassTag] ( * Return a new DStream in which each RDD is generated by applying a function * on each RDD of 'this' DStream. */ - def transform[U: ClassTag](transformFunc: RDD[T] => RDD[U]): DStream[U] = { - // because the DStream is reachable from the outer object here, and because - // DStreams can't be serialized with closures, we can't proactively check + def transform[U: ClassTag](transformFunc: RDD[T] => RDD[U]): DStream[U] = ssc.withScope { + // because the DStream is reachable from the outer object here, and because + // DStreams can't be serialized with closures, we can't proactively check // it for serializability and so we pass the optional false to SparkContext.clean val cleanedF = context.sparkContext.clean(transformFunc, false) transform((r: RDD[T], t: Time) => cleanedF(r)) @@ -578,12 +658,12 @@ abstract class DStream[T: ClassTag] ( * Return a new DStream in which each RDD is generated by applying a function * on each RDD of 'this' DStream. */ - def transform[U: ClassTag](transformFunc: (RDD[T], Time) => RDD[U]): DStream[U] = { - // because the DStream is reachable from the outer object here, and because - // DStreams can't be serialized with closures, we can't proactively check + def transform[U: ClassTag](transformFunc: (RDD[T], Time) => RDD[U]): DStream[U] = ssc.withScope { + // because the DStream is reachable from the outer object here, and because + // DStreams can't be serialized with closures, we can't proactively check // it for serializability and so we pass the optional false to SparkContext.clean val cleanedF = context.sparkContext.clean(transformFunc, false) - val realTransformFunc = (rdds: Seq[RDD[_]], time: Time) => { + val realTransformFunc = (rdds: Seq[RDD[_]], time: Time) => { assert(rdds.length == 1) cleanedF(rdds.head.asInstanceOf[RDD[T]], time) } @@ -596,9 +676,9 @@ abstract class DStream[T: ClassTag] ( */ def transformWith[U: ClassTag, V: ClassTag]( other: DStream[U], transformFunc: (RDD[T], RDD[U]) => RDD[V] - ): DStream[V] = { - // because the DStream is reachable from the outer object here, and because - // DStreams can't be serialized with closures, we can't proactively check + ): DStream[V] = ssc.withScope { + // because the DStream is reachable from the outer object here, and because + // DStreams can't be serialized with closures, we can't proactively check // it for serializability and so we pass the optional false to SparkContext.clean val cleanedF = ssc.sparkContext.clean(transformFunc, false) transformWith(other, (rdd1: RDD[T], rdd2: RDD[U], time: Time) => cleanedF(rdd1, rdd2)) @@ -610,9 +690,9 @@ abstract class DStream[T: ClassTag] ( */ def transformWith[U: ClassTag, V: ClassTag]( other: DStream[U], transformFunc: (RDD[T], RDD[U], Time) => RDD[V] - ): DStream[V] = { - // because the DStream is reachable from the outer object here, and because - // DStreams can't be serialized with closures, we can't proactively check + ): DStream[V] = ssc.withScope { + // because the DStream is reachable from the outer object here, and because + // DStreams can't be serialized with closures, we can't proactively check // it for serializability and so we pass the optional false to SparkContext.clean val cleanedF = ssc.sparkContext.clean(transformFunc, false) val realTransformFunc = (rdds: Seq[RDD[_]], time: Time) => { @@ -628,7 +708,7 @@ abstract class DStream[T: ClassTag] ( * Print the first ten elements of each RDD generated in this DStream. This is an output * operator, so this DStream will be registered as an output stream and there materialized. */ - def print() { + def print(): Unit = ssc.withScope { print(10) } @@ -636,7 +716,7 @@ abstract class DStream[T: ClassTag] ( * Print the first num elements of each RDD generated in this DStream. This is an output * operator, so this DStream will be registered as an output stream and there materialized. */ - def print(num: Int) { + def print(num: Int): Unit = ssc.withScope { def foreachFunc: (RDD[T], Time) => Unit = { (rdd: RDD[T], time: Time) => { val firstNum = rdd.take(num + 1) @@ -668,7 +748,7 @@ abstract class DStream[T: ClassTag] ( * the new DStream will generate RDDs); must be a multiple of this * DStream's batching interval */ - def window(windowDuration: Duration, slideDuration: Duration): DStream[T] = { + def window(windowDuration: Duration, slideDuration: Duration): DStream[T] = ssc.withScope { new WindowedDStream(this, windowDuration, slideDuration) } @@ -686,7 +766,7 @@ abstract class DStream[T: ClassTag] ( reduceFunc: (T, T) => T, windowDuration: Duration, slideDuration: Duration - ): DStream[T] = { + ): DStream[T] = ssc.withScope { this.reduce(reduceFunc).window(windowDuration, slideDuration).reduce(reduceFunc) } @@ -711,7 +791,7 @@ abstract class DStream[T: ClassTag] ( invReduceFunc: (T, T) => T, windowDuration: Duration, slideDuration: Duration - ): DStream[T] = { + ): DStream[T] = ssc.withScope { this.map(x => (1, x)) .reduceByKeyAndWindow(reduceFunc, invReduceFunc, windowDuration, slideDuration, 1) .map(_._2) @@ -727,7 +807,9 @@ abstract class DStream[T: ClassTag] ( * the new DStream will generate RDDs); must be a multiple of this * DStream's batching interval */ - def countByWindow(windowDuration: Duration, slideDuration: Duration): DStream[Long] = { + def countByWindow( + windowDuration: Duration, + slideDuration: Duration): DStream[Long] = ssc.withScope { this.map(_ => 1L).reduceByWindow(_ + _, _ - _, windowDuration, slideDuration) } @@ -748,8 +830,7 @@ abstract class DStream[T: ClassTag] ( slideDuration: Duration, numPartitions: Int = ssc.sc.defaultParallelism) (implicit ord: Ordering[T] = null) - : DStream[(T, Long)] = - { + : DStream[(T, Long)] = ssc.withScope { this.map(x => (x, 1L)).reduceByKeyAndWindow( (x: Long, y: Long) => x + y, (x: Long, y: Long) => x - y, @@ -764,19 +845,21 @@ abstract class DStream[T: ClassTag] ( * Return a new DStream by unifying data of another DStream with this DStream. * @param that Another DStream having the same slideDuration as this DStream. */ - def union(that: DStream[T]): DStream[T] = new UnionDStream[T](Array(this, that)) + def union(that: DStream[T]): DStream[T] = ssc.withScope { + new UnionDStream[T](Array(this, that)) + } /** * Return all the RDDs defined by the Interval object (both end times included) */ - def slice(interval: Interval): Seq[RDD[T]] = { + def slice(interval: Interval): Seq[RDD[T]] = ssc.withScope { slice(interval.beginTime, interval.endTime) } /** * Return all the RDDs between 'fromTime' to 'toTime' (both included) */ - def slice(fromTime: Time, toTime: Time): Seq[RDD[T]] = { + def slice(fromTime: Time, toTime: Time): Seq[RDD[T]] = ssc.withScope { if (!isInitialized) { throw new SparkException(this + " has not been initialized") } @@ -810,7 +893,7 @@ abstract class DStream[T: ClassTag] ( * The file name at each batch interval is generated based on `prefix` and * `suffix`: "prefix-TIME_IN_MS.suffix". */ - def saveAsObjectFiles(prefix: String, suffix: String = "") { + def saveAsObjectFiles(prefix: String, suffix: String = ""): Unit = ssc.withScope { val saveFunc = (rdd: RDD[T], time: Time) => { val file = rddToFileName(prefix, suffix, time) rdd.saveAsObjectFile(file) @@ -823,7 +906,7 @@ abstract class DStream[T: ClassTag] ( * of elements. The file name at each batch interval is generated based on * `prefix` and `suffix`: "prefix-TIME_IN_MS.suffix". */ - def saveAsTextFiles(prefix: String, suffix: String = "") { + def saveAsTextFiles(prefix: String, suffix: String = ""): Unit = ssc.withScope { val saveFunc = (rdd: RDD[T], time: Time) => { val file = rddToFileName(prefix, suffix, time) rdd.saveAsTextFile(file) diff --git a/streaming/src/main/scala/org/apache/spark/streaming/dstream/FileInputDStream.scala b/streaming/src/main/scala/org/apache/spark/streaming/dstream/FileInputDStream.scala index eca69f00188e..86a8e2beff57 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/dstream/FileInputDStream.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/dstream/FileInputDStream.scala @@ -26,10 +26,9 @@ import org.apache.hadoop.conf.Configuration import org.apache.hadoop.fs.{FileSystem, Path, PathFilter} import org.apache.hadoop.mapreduce.{InputFormat => NewInputFormat} -import org.apache.spark.{SparkConf, SerializableWritable} import org.apache.spark.rdd.{RDD, UnionRDD} import org.apache.spark.streaming._ -import org.apache.spark.util.{TimeStampedHashMap, Utils} +import org.apache.spark.util.{SerializableConfiguration, TimeStampedHashMap, Utils} /** * This class represents an input stream that monitors a Hadoop-compatible filesystem for new @@ -69,7 +68,7 @@ import org.apache.spark.util.{TimeStampedHashMap, Utils} * processing semantics are undefined. */ private[streaming] -class FileInputDStream[K, V, F <: NewInputFormat[K,V]]( +class FileInputDStream[K, V, F <: NewInputFormat[K, V]]( @transient ssc_ : StreamingContext, directory: String, filter: Path => Boolean = FileInputDStream.defaultFilter, @@ -78,7 +77,7 @@ class FileInputDStream[K, V, F <: NewInputFormat[K,V]]( (implicit km: ClassTag[K], vm: ClassTag[V], fm: ClassTag[F]) extends InputDStream[(K, V)](ssc_) { - private val serializableConfOpt = conf.map(new SerializableWritable(_)) + private val serializableConfOpt = conf.map(new SerializableConfiguration(_)) /** * Minimum duration of remembering the information of selected files. Defaults to 60 seconds. @@ -251,7 +250,7 @@ class FileInputDStream[K, V, F <: NewInputFormat[K,V]]( /** Generate one RDD from an array of files */ private def filesToRDD(files: Seq[String]): RDD[(K, V)] = { - val fileRDDs = files.map(file =>{ + val fileRDDs = files.map { file => val rdd = serializableConfOpt.map(_.value) match { case Some(config) => context.sparkContext.newAPIHadoopFile( file, @@ -267,7 +266,7 @@ class FileInputDStream[K, V, F <: NewInputFormat[K,V]]( "Refer to the streaming programming guide for more details.") } rdd - }) + } new UnionRDD(context.sparkContext, fileRDDs) } @@ -294,7 +293,7 @@ class FileInputDStream[K, V, F <: NewInputFormat[K,V]]( private def readObject(ois: ObjectInputStream): Unit = Utils.tryOrIOException { logDebug(this.getClass().getSimpleName + ".readObject used") ois.defaultReadObject() - generatedRDDs = new mutable.HashMap[Time, RDD[(K,V)]] () + generatedRDDs = new mutable.HashMap[Time, RDD[(K, V)]]() batchTimeToSelectedFiles = new mutable.HashMap[Time, Array[String]] with mutable.SynchronizedMap[Time, Array[String]] recentlySelectedFiles = new mutable.HashSet[String]() diff --git a/streaming/src/main/scala/org/apache/spark/streaming/dstream/ForEachDStream.scala b/streaming/src/main/scala/org/apache/spark/streaming/dstream/ForEachDStream.scala index 685a32e1d280..c109ceccc698 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/dstream/ForEachDStream.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/dstream/ForEachDStream.scala @@ -37,7 +37,7 @@ class ForEachDStream[T: ClassTag] ( override def generateJob(time: Time): Option[Job] = { parent.getOrCompute(time) match { case Some(rdd) => - val jobFunc = () => { + val jobFunc = () => createRDDWithLocalProperties(time) { ssc.sparkContext.setCallSite(creationSite) foreachFunc(rdd, time) } diff --git a/streaming/src/main/scala/org/apache/spark/streaming/dstream/InputDStream.scala b/streaming/src/main/scala/org/apache/spark/streaming/dstream/InputDStream.scala index 9716adb62817..d58c99a8ff32 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/dstream/InputDStream.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/dstream/InputDStream.scala @@ -17,10 +17,13 @@ package org.apache.spark.streaming.dstream -import org.apache.spark.streaming.{Time, Duration, StreamingContext} - import scala.reflect.ClassTag +import org.apache.spark.SparkContext +import org.apache.spark.rdd.RDDOperationScope +import org.apache.spark.streaming.{Time, Duration, StreamingContext} +import org.apache.spark.util.Utils + /** * This is the abstract base class for all input streams. This class provides methods * start() and stop() which is called by Spark Streaming system to start and stop receiving data. @@ -44,10 +47,31 @@ abstract class InputDStream[T: ClassTag] (@transient ssc_ : StreamingContext) /** This is an unique identifier for the input stream. */ val id = ssc.getNewInputStreamId() + /** A human-readable name of this InputDStream */ + private[streaming] def name: String = { + // e.g. FlumePollingDStream -> "Flume polling stream" + val newName = Utils.getFormattedClassName(this) + .replaceAll("InputDStream", "Stream") + .split("(?=[A-Z])") + .filter(_.nonEmpty) + .mkString(" ") + .toLowerCase + .capitalize + s"$newName [$id]" + } + /** - * The name of this InputDStream. By default, it's the class name with its id. + * The base scope associated with the operation that created this DStream. + * + * For InputDStreams, we use the name of this DStream as the scope name. + * If an outer scope is given, we assume that it includes an alternative name for this stream. */ - private[streaming] def name: String = s"${getClass.getSimpleName}-$id" + protected[streaming] override val baseScope: Option[String] = { + val scopeName = Option(ssc.sc.getLocalProperty(SparkContext.RDD_SCOPE_KEY)) + .map { json => RDDOperationScope.fromJson(json).name + s" [$id]" } + .getOrElse(name.toLowerCase) + Some(new RDDOperationScope(scopeName).toJson) + } /** * Checks whether the 'time' is valid wrt slideDuration for generating RDD. diff --git a/streaming/src/main/scala/org/apache/spark/streaming/dstream/PairDStreamFunctions.scala b/streaming/src/main/scala/org/apache/spark/streaming/dstream/PairDStreamFunctions.scala index 8a5857163244..71bec96d46c8 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/dstream/PairDStreamFunctions.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/dstream/PairDStreamFunctions.scala @@ -24,20 +24,23 @@ import org.apache.hadoop.conf.Configuration import org.apache.hadoop.mapred.{JobConf, OutputFormat} import org.apache.hadoop.mapreduce.{OutputFormat => NewOutputFormat} -import org.apache.spark.{HashPartitioner, Partitioner, SerializableWritable} +import org.apache.spark.{HashPartitioner, Partitioner} import org.apache.spark.rdd.RDD import org.apache.spark.streaming.{Duration, Time} import org.apache.spark.streaming.StreamingContext.rddToFileName +import org.apache.spark.util.{SerializableConfiguration, SerializableJobConf} /** * Extra functions available on DStream of (key, value) pairs through an implicit conversion. */ -class PairDStreamFunctions[K, V](self: DStream[(K,V)]) +class PairDStreamFunctions[K, V](self: DStream[(K, V)]) (implicit kt: ClassTag[K], vt: ClassTag[V], ord: Ordering[K]) extends Serializable { private[streaming] def ssc = self.ssc + private[streaming] def sparkContext = self.context.sparkContext + private[streaming] def defaultPartitioner(numPartitions: Int = self.ssc.sc.defaultParallelism) = { new HashPartitioner(numPartitions) } @@ -46,7 +49,7 @@ class PairDStreamFunctions[K, V](self: DStream[(K,V)]) * Return a new DStream by applying `groupByKey` to each RDD. Hash partitioning is used to * generate the RDDs with Spark's default number of partitions. */ - def groupByKey(): DStream[(K, Iterable[V])] = { + def groupByKey(): DStream[(K, Iterable[V])] = ssc.withScope { groupByKey(defaultPartitioner()) } @@ -54,7 +57,7 @@ class PairDStreamFunctions[K, V](self: DStream[(K,V)]) * Return a new DStream by applying `groupByKey` to each RDD. Hash partitioning is used to * generate the RDDs with `numPartitions` partitions. */ - def groupByKey(numPartitions: Int): DStream[(K, Iterable[V])] = { + def groupByKey(numPartitions: Int): DStream[(K, Iterable[V])] = ssc.withScope { groupByKey(defaultPartitioner(numPartitions)) } @@ -62,7 +65,7 @@ class PairDStreamFunctions[K, V](self: DStream[(K,V)]) * Return a new DStream by applying `groupByKey` on each RDD. The supplied * org.apache.spark.Partitioner is used to control the partitioning of each RDD. */ - def groupByKey(partitioner: Partitioner): DStream[(K, Iterable[V])] = { + def groupByKey(partitioner: Partitioner): DStream[(K, Iterable[V])] = ssc.withScope { val createCombiner = (v: V) => ArrayBuffer[V](v) val mergeValue = (c: ArrayBuffer[V], v: V) => (c += v) val mergeCombiner = (c1: ArrayBuffer[V], c2: ArrayBuffer[V]) => (c1 ++ c2) @@ -75,7 +78,7 @@ class PairDStreamFunctions[K, V](self: DStream[(K,V)]) * merged using the associative reduce function. Hash partitioning is used to generate the RDDs * with Spark's default number of partitions. */ - def reduceByKey(reduceFunc: (V, V) => V): DStream[(K, V)] = { + def reduceByKey(reduceFunc: (V, V) => V): DStream[(K, V)] = ssc.withScope { reduceByKey(reduceFunc, defaultPartitioner()) } @@ -84,7 +87,9 @@ class PairDStreamFunctions[K, V](self: DStream[(K,V)]) * merged using the supplied reduce function. Hash partitioning is used to generate the RDDs * with `numPartitions` partitions. */ - def reduceByKey(reduceFunc: (V, V) => V, numPartitions: Int): DStream[(K, V)] = { + def reduceByKey( + reduceFunc: (V, V) => V, + numPartitions: Int): DStream[(K, V)] = ssc.withScope { reduceByKey(reduceFunc, defaultPartitioner(numPartitions)) } @@ -93,9 +98,10 @@ class PairDStreamFunctions[K, V](self: DStream[(K,V)]) * merged using the supplied reduce function. org.apache.spark.Partitioner is used to control * the partitioning of each RDD. */ - def reduceByKey(reduceFunc: (V, V) => V, partitioner: Partitioner): DStream[(K, V)] = { - val cleanedReduceFunc = ssc.sc.clean(reduceFunc) - combineByKey((v: V) => v, cleanedReduceFunc, cleanedReduceFunc, partitioner) + def reduceByKey( + reduceFunc: (V, V) => V, + partitioner: Partitioner): DStream[(K, V)] = ssc.withScope { + combineByKey((v: V) => v, reduceFunc, reduceFunc, partitioner) } /** @@ -104,12 +110,20 @@ class PairDStreamFunctions[K, V](self: DStream[(K,V)]) * org.apache.spark.rdd.PairRDDFunctions in the Spark core documentation for more information. */ def combineByKey[C: ClassTag]( - createCombiner: V => C, - mergeValue: (C, V) => C, - mergeCombiner: (C, C) => C, - partitioner: Partitioner, - mapSideCombine: Boolean = true): DStream[(K, C)] = { - new ShuffledDStream[K, V, C](self, createCombiner, mergeValue, mergeCombiner, partitioner, + createCombiner: V => C, + mergeValue: (C, V) => C, + mergeCombiner: (C, C) => C, + partitioner: Partitioner, + mapSideCombine: Boolean = true): DStream[(K, C)] = ssc.withScope { + val cleanedCreateCombiner = sparkContext.clean(createCombiner) + val cleanedMergeValue = sparkContext.clean(mergeValue) + val cleanedMergeCombiner = sparkContext.clean(mergeCombiner) + new ShuffledDStream[K, V, C]( + self, + cleanedCreateCombiner, + cleanedMergeValue, + cleanedMergeCombiner, + partitioner, mapSideCombine) } @@ -121,7 +135,7 @@ class PairDStreamFunctions[K, V](self: DStream[(K,V)]) * @param windowDuration width of the window; must be a multiple of this DStream's * batching interval */ - def groupByKeyAndWindow(windowDuration: Duration): DStream[(K, Iterable[V])] = { + def groupByKeyAndWindow(windowDuration: Duration): DStream[(K, Iterable[V])] = ssc.withScope { groupByKeyAndWindow(windowDuration, self.slideDuration, defaultPartitioner()) } @@ -136,8 +150,7 @@ class PairDStreamFunctions[K, V](self: DStream[(K,V)]) * DStream's batching interval */ def groupByKeyAndWindow(windowDuration: Duration, slideDuration: Duration) - : DStream[(K, Iterable[V])] = - { + : DStream[(K, Iterable[V])] = ssc.withScope { groupByKeyAndWindow(windowDuration, slideDuration, defaultPartitioner()) } @@ -157,7 +170,7 @@ class PairDStreamFunctions[K, V](self: DStream[(K,V)]) windowDuration: Duration, slideDuration: Duration, numPartitions: Int - ): DStream[(K, Iterable[V])] = { + ): DStream[(K, Iterable[V])] = ssc.withScope { groupByKeyAndWindow(windowDuration, slideDuration, defaultPartitioner(numPartitions)) } @@ -176,7 +189,7 @@ class PairDStreamFunctions[K, V](self: DStream[(K,V)]) windowDuration: Duration, slideDuration: Duration, partitioner: Partitioner - ): DStream[(K, Iterable[V])] = { + ): DStream[(K, Iterable[V])] = ssc.withScope { val createCombiner = (v: Iterable[V]) => new ArrayBuffer[V] ++= v val mergeValue = (buf: ArrayBuffer[V], v: Iterable[V]) => buf ++= v val mergeCombiner = (buf1: ArrayBuffer[V], buf2: ArrayBuffer[V]) => buf1 ++= buf2 @@ -198,7 +211,7 @@ class PairDStreamFunctions[K, V](self: DStream[(K,V)]) def reduceByKeyAndWindow( reduceFunc: (V, V) => V, windowDuration: Duration - ): DStream[(K, V)] = { + ): DStream[(K, V)] = ssc.withScope { reduceByKeyAndWindow(reduceFunc, windowDuration, self.slideDuration, defaultPartitioner()) } @@ -217,7 +230,7 @@ class PairDStreamFunctions[K, V](self: DStream[(K,V)]) reduceFunc: (V, V) => V, windowDuration: Duration, slideDuration: Duration - ): DStream[(K, V)] = { + ): DStream[(K, V)] = ssc.withScope { reduceByKeyAndWindow(reduceFunc, windowDuration, slideDuration, defaultPartitioner()) } @@ -238,7 +251,7 @@ class PairDStreamFunctions[K, V](self: DStream[(K,V)]) windowDuration: Duration, slideDuration: Duration, numPartitions: Int - ): DStream[(K, V)] = { + ): DStream[(K, V)] = ssc.withScope { reduceByKeyAndWindow(reduceFunc, windowDuration, slideDuration, defaultPartitioner(numPartitions)) } @@ -260,11 +273,10 @@ class PairDStreamFunctions[K, V](self: DStream[(K,V)]) windowDuration: Duration, slideDuration: Duration, partitioner: Partitioner - ): DStream[(K, V)] = { - val cleanedReduceFunc = ssc.sc.clean(reduceFunc) - self.reduceByKey(cleanedReduceFunc, partitioner) + ): DStream[(K, V)] = ssc.withScope { + self.reduceByKey(reduceFunc, partitioner) .window(windowDuration, slideDuration) - .reduceByKey(cleanedReduceFunc, partitioner) + .reduceByKey(reduceFunc, partitioner) } /** @@ -294,8 +306,7 @@ class PairDStreamFunctions[K, V](self: DStream[(K,V)]) slideDuration: Duration = self.slideDuration, numPartitions: Int = ssc.sc.defaultParallelism, filterFunc: ((K, V)) => Boolean = null - ): DStream[(K, V)] = { - + ): DStream[(K, V)] = ssc.withScope { reduceByKeyAndWindow( reduceFunc, invReduceFunc, windowDuration, slideDuration, defaultPartitioner(numPartitions), filterFunc @@ -328,7 +339,7 @@ class PairDStreamFunctions[K, V](self: DStream[(K,V)]) slideDuration: Duration, partitioner: Partitioner, filterFunc: ((K, V)) => Boolean - ): DStream[(K, V)] = { + ): DStream[(K, V)] = ssc.withScope { val cleanedReduceFunc = ssc.sc.clean(reduceFunc) val cleanedInvReduceFunc = ssc.sc.clean(invReduceFunc) @@ -349,7 +360,7 @@ class PairDStreamFunctions[K, V](self: DStream[(K,V)]) */ def updateStateByKey[S: ClassTag]( updateFunc: (Seq[V], Option[S]) => Option[S] - ): DStream[(K, S)] = { + ): DStream[(K, S)] = ssc.withScope { updateStateByKey(updateFunc, defaultPartitioner()) } @@ -365,7 +376,7 @@ class PairDStreamFunctions[K, V](self: DStream[(K,V)]) def updateStateByKey[S: ClassTag]( updateFunc: (Seq[V], Option[S]) => Option[S], numPartitions: Int - ): DStream[(K, S)] = { + ): DStream[(K, S)] = ssc.withScope { updateStateByKey(updateFunc, defaultPartitioner(numPartitions)) } @@ -382,9 +393,10 @@ class PairDStreamFunctions[K, V](self: DStream[(K,V)]) def updateStateByKey[S: ClassTag]( updateFunc: (Seq[V], Option[S]) => Option[S], partitioner: Partitioner - ): DStream[(K, S)] = { + ): DStream[(K, S)] = ssc.withScope { + val cleanedUpdateF = sparkContext.clean(updateFunc) val newUpdateFunc = (iterator: Iterator[(K, Seq[V], Option[S])]) => { - iterator.flatMap(t => updateFunc(t._2, t._3).map(s => (t._1, s))) + iterator.flatMap(t => cleanedUpdateF(t._2, t._3).map(s => (t._1, s))) } updateStateByKey(newUpdateFunc, partitioner, true) } @@ -406,7 +418,7 @@ class PairDStreamFunctions[K, V](self: DStream[(K,V)]) updateFunc: (Iterator[(K, Seq[V], Option[S])]) => Iterator[(K, S)], partitioner: Partitioner, rememberPartitioner: Boolean - ): DStream[(K, S)] = { + ): DStream[(K, S)] = ssc.withScope { new StateDStream(self, ssc.sc.clean(updateFunc), partitioner, rememberPartitioner, None) } @@ -425,9 +437,10 @@ class PairDStreamFunctions[K, V](self: DStream[(K,V)]) updateFunc: (Seq[V], Option[S]) => Option[S], partitioner: Partitioner, initialRDD: RDD[(K, S)] - ): DStream[(K, S)] = { + ): DStream[(K, S)] = ssc.withScope { + val cleanedUpdateF = sparkContext.clean(updateFunc) val newUpdateFunc = (iterator: Iterator[(K, Seq[V], Option[S])]) => { - iterator.flatMap(t => updateFunc(t._2, t._3).map(s => (t._1, s))) + iterator.flatMap(t => cleanedUpdateF(t._2, t._3).map(s => (t._1, s))) } updateStateByKey(newUpdateFunc, partitioner, true, initialRDD) } @@ -451,7 +464,7 @@ class PairDStreamFunctions[K, V](self: DStream[(K,V)]) partitioner: Partitioner, rememberPartitioner: Boolean, initialRDD: RDD[(K, S)] - ): DStream[(K, S)] = { + ): DStream[(K, S)] = ssc.withScope { new StateDStream(self, ssc.sc.clean(updateFunc), partitioner, rememberPartitioner, Some(initialRDD)) } @@ -460,8 +473,8 @@ class PairDStreamFunctions[K, V](self: DStream[(K,V)]) * Return a new DStream by applying a map function to the value of each key-value pairs in * 'this' DStream without changing the key. */ - def mapValues[U: ClassTag](mapValuesFunc: V => U): DStream[(K, U)] = { - new MapValuedDStream[K, V, U](self, mapValuesFunc) + def mapValues[U: ClassTag](mapValuesFunc: V => U): DStream[(K, U)] = ssc.withScope { + new MapValuedDStream[K, V, U](self, sparkContext.clean(mapValuesFunc)) } /** @@ -470,8 +483,8 @@ class PairDStreamFunctions[K, V](self: DStream[(K,V)]) */ def flatMapValues[U: ClassTag]( flatMapValuesFunc: V => TraversableOnce[U] - ): DStream[(K, U)] = { - new FlatMapValuedDStream[K, V, U](self, flatMapValuesFunc) + ): DStream[(K, U)] = ssc.withScope { + new FlatMapValuedDStream[K, V, U](self, sparkContext.clean(flatMapValuesFunc)) } /** @@ -479,7 +492,8 @@ class PairDStreamFunctions[K, V](self: DStream[(K,V)]) * Hash partitioning is used to generate the RDDs with Spark's default number * of partitions. */ - def cogroup[W: ClassTag](other: DStream[(K, W)]): DStream[(K, (Iterable[V], Iterable[W]))] = { + def cogroup[W: ClassTag]( + other: DStream[(K, W)]): DStream[(K, (Iterable[V], Iterable[W]))] = ssc.withScope { cogroup(other, defaultPartitioner()) } @@ -487,8 +501,9 @@ class PairDStreamFunctions[K, V](self: DStream[(K,V)]) * Return a new DStream by applying 'cogroup' between RDDs of `this` DStream and `other` DStream. * Hash partitioning is used to generate the RDDs with `numPartitions` partitions. */ - def cogroup[W: ClassTag](other: DStream[(K, W)], numPartitions: Int) - : DStream[(K, (Iterable[V], Iterable[W]))] = { + def cogroup[W: ClassTag]( + other: DStream[(K, W)], + numPartitions: Int): DStream[(K, (Iterable[V], Iterable[W]))] = ssc.withScope { cogroup(other, defaultPartitioner(numPartitions)) } @@ -499,7 +514,7 @@ class PairDStreamFunctions[K, V](self: DStream[(K,V)]) def cogroup[W: ClassTag]( other: DStream[(K, W)], partitioner: Partitioner - ): DStream[(K, (Iterable[V], Iterable[W]))] = { + ): DStream[(K, (Iterable[V], Iterable[W]))] = ssc.withScope { self.transformWith( other, (rdd1: RDD[(K, V)], rdd2: RDD[(K, W)]) => rdd1.cogroup(rdd2, partitioner) @@ -510,7 +525,7 @@ class PairDStreamFunctions[K, V](self: DStream[(K,V)]) * Return a new DStream by applying 'join' between RDDs of `this` DStream and `other` DStream. * Hash partitioning is used to generate the RDDs with Spark's default number of partitions. */ - def join[W: ClassTag](other: DStream[(K, W)]): DStream[(K, (V, W))] = { + def join[W: ClassTag](other: DStream[(K, W)]): DStream[(K, (V, W))] = ssc.withScope { join[W](other, defaultPartitioner()) } @@ -518,7 +533,9 @@ class PairDStreamFunctions[K, V](self: DStream[(K,V)]) * Return a new DStream by applying 'join' between RDDs of `this` DStream and `other` DStream. * Hash partitioning is used to generate the RDDs with `numPartitions` partitions. */ - def join[W: ClassTag](other: DStream[(K, W)], numPartitions: Int): DStream[(K, (V, W))] = { + def join[W: ClassTag]( + other: DStream[(K, W)], + numPartitions: Int): DStream[(K, (V, W))] = ssc.withScope { join[W](other, defaultPartitioner(numPartitions)) } @@ -529,7 +546,7 @@ class PairDStreamFunctions[K, V](self: DStream[(K,V)]) def join[W: ClassTag]( other: DStream[(K, W)], partitioner: Partitioner - ): DStream[(K, (V, W))] = { + ): DStream[(K, (V, W))] = ssc.withScope { self.transformWith( other, (rdd1: RDD[(K, V)], rdd2: RDD[(K, W)]) => rdd1.join(rdd2, partitioner) @@ -541,7 +558,8 @@ class PairDStreamFunctions[K, V](self: DStream[(K,V)]) * `other` DStream. Hash partitioning is used to generate the RDDs with Spark's default * number of partitions. */ - def leftOuterJoin[W: ClassTag](other: DStream[(K, W)]): DStream[(K, (V, Option[W]))] = { + def leftOuterJoin[W: ClassTag]( + other: DStream[(K, W)]): DStream[(K, (V, Option[W]))] = ssc.withScope { leftOuterJoin[W](other, defaultPartitioner()) } @@ -553,7 +571,7 @@ class PairDStreamFunctions[K, V](self: DStream[(K,V)]) def leftOuterJoin[W: ClassTag]( other: DStream[(K, W)], numPartitions: Int - ): DStream[(K, (V, Option[W]))] = { + ): DStream[(K, (V, Option[W]))] = ssc.withScope { leftOuterJoin[W](other, defaultPartitioner(numPartitions)) } @@ -565,7 +583,7 @@ class PairDStreamFunctions[K, V](self: DStream[(K,V)]) def leftOuterJoin[W: ClassTag]( other: DStream[(K, W)], partitioner: Partitioner - ): DStream[(K, (V, Option[W]))] = { + ): DStream[(K, (V, Option[W]))] = ssc.withScope { self.transformWith( other, (rdd1: RDD[(K, V)], rdd2: RDD[(K, W)]) => rdd1.leftOuterJoin(rdd2, partitioner) @@ -577,7 +595,8 @@ class PairDStreamFunctions[K, V](self: DStream[(K,V)]) * `other` DStream. Hash partitioning is used to generate the RDDs with Spark's default * number of partitions. */ - def rightOuterJoin[W: ClassTag](other: DStream[(K, W)]): DStream[(K, (Option[V], W))] = { + def rightOuterJoin[W: ClassTag]( + other: DStream[(K, W)]): DStream[(K, (Option[V], W))] = ssc.withScope { rightOuterJoin[W](other, defaultPartitioner()) } @@ -589,7 +608,7 @@ class PairDStreamFunctions[K, V](self: DStream[(K,V)]) def rightOuterJoin[W: ClassTag]( other: DStream[(K, W)], numPartitions: Int - ): DStream[(K, (Option[V], W))] = { + ): DStream[(K, (Option[V], W))] = ssc.withScope { rightOuterJoin[W](other, defaultPartitioner(numPartitions)) } @@ -601,7 +620,7 @@ class PairDStreamFunctions[K, V](self: DStream[(K,V)]) def rightOuterJoin[W: ClassTag]( other: DStream[(K, W)], partitioner: Partitioner - ): DStream[(K, (Option[V], W))] = { + ): DStream[(K, (Option[V], W))] = ssc.withScope { self.transformWith( other, (rdd1: RDD[(K, V)], rdd2: RDD[(K, W)]) => rdd1.rightOuterJoin(rdd2, partitioner) @@ -613,7 +632,8 @@ class PairDStreamFunctions[K, V](self: DStream[(K,V)]) * `other` DStream. Hash partitioning is used to generate the RDDs with Spark's default * number of partitions. */ - def fullOuterJoin[W: ClassTag](other: DStream[(K, W)]): DStream[(K, (Option[V], Option[W]))] = { + def fullOuterJoin[W: ClassTag]( + other: DStream[(K, W)]): DStream[(K, (Option[V], Option[W]))] = ssc.withScope { fullOuterJoin[W](other, defaultPartitioner()) } @@ -625,7 +645,7 @@ class PairDStreamFunctions[K, V](self: DStream[(K,V)]) def fullOuterJoin[W: ClassTag]( other: DStream[(K, W)], numPartitions: Int - ): DStream[(K, (Option[V], Option[W]))] = { + ): DStream[(K, (Option[V], Option[W]))] = ssc.withScope { fullOuterJoin[W](other, defaultPartitioner(numPartitions)) } @@ -637,7 +657,7 @@ class PairDStreamFunctions[K, V](self: DStream[(K,V)]) def fullOuterJoin[W: ClassTag]( other: DStream[(K, W)], partitioner: Partitioner - ): DStream[(K, (Option[V], Option[W]))] = { + ): DStream[(K, (Option[V], Option[W]))] = ssc.withScope { self.transformWith( other, (rdd1: RDD[(K, V)], rdd2: RDD[(K, W)]) => rdd1.fullOuterJoin(rdd2, partitioner) @@ -651,7 +671,7 @@ class PairDStreamFunctions[K, V](self: DStream[(K,V)]) def saveAsHadoopFiles[F <: OutputFormat[K, V]]( prefix: String, suffix: String - )(implicit fm: ClassTag[F]) { + )(implicit fm: ClassTag[F]): Unit = ssc.withScope { saveAsHadoopFiles(prefix, suffix, keyClass, valueClass, fm.runtimeClass.asInstanceOf[Class[F]]) } @@ -667,9 +687,9 @@ class PairDStreamFunctions[K, V](self: DStream[(K,V)]) valueClass: Class[_], outputFormatClass: Class[_ <: OutputFormat[_, _]], conf: JobConf = new JobConf(ssc.sparkContext.hadoopConfiguration) - ) { + ): Unit = ssc.withScope { // Wrap conf in SerializableWritable so that ForeachDStream can be serialized for checkpoints - val serializableConf = new SerializableWritable(conf) + val serializableConf = new SerializableJobConf(conf) val saveFunc = (rdd: RDD[(K, V)], time: Time) => { val file = rddToFileName(prefix, suffix, time) rdd.saveAsHadoopFile(file, keyClass, valueClass, outputFormatClass, serializableConf.value) @@ -684,7 +704,7 @@ class PairDStreamFunctions[K, V](self: DStream[(K,V)]) def saveAsNewAPIHadoopFiles[F <: NewOutputFormat[K, V]]( prefix: String, suffix: String - )(implicit fm: ClassTag[F]) { + )(implicit fm: ClassTag[F]): Unit = ssc.withScope { saveAsNewAPIHadoopFiles(prefix, suffix, keyClass, valueClass, fm.runtimeClass.asInstanceOf[Class[F]]) } @@ -700,9 +720,9 @@ class PairDStreamFunctions[K, V](self: DStream[(K,V)]) valueClass: Class[_], outputFormatClass: Class[_ <: NewOutputFormat[_, _]], conf: Configuration = ssc.sparkContext.hadoopConfiguration - ) { + ): Unit = ssc.withScope { // Wrap conf in SerializableWritable so that ForeachDStream can be serialized for checkpoints - val serializableConf = new SerializableWritable(conf) + val serializableConf = new SerializableConfiguration(conf) val saveFunc = (rdd: RDD[(K, V)], time: Time) => { val file = rddToFileName(prefix, suffix, time) rdd.saveAsNewAPIHadoopFile( diff --git a/streaming/src/main/scala/org/apache/spark/streaming/dstream/QueueInputDStream.scala b/streaming/src/main/scala/org/apache/spark/streaming/dstream/QueueInputDStream.scala index ed7da6dc1315..a2f5d82a79bd 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/dstream/QueueInputDStream.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/dstream/QueueInputDStream.scala @@ -17,13 +17,14 @@ package org.apache.spark.streaming.dstream -import org.apache.spark.rdd.RDD -import org.apache.spark.rdd.UnionRDD -import scala.collection.mutable.Queue -import scala.collection.mutable.ArrayBuffer -import org.apache.spark.streaming.{Time, StreamingContext} +import java.io.{NotSerializableException, ObjectOutputStream} + +import scala.collection.mutable.{ArrayBuffer, Queue} import scala.reflect.ClassTag +import org.apache.spark.rdd.{RDD, UnionRDD} +import org.apache.spark.streaming.{Time, StreamingContext} + private[streaming] class QueueInputDStream[T: ClassTag]( @transient ssc: StreamingContext, @@ -36,6 +37,10 @@ class QueueInputDStream[T: ClassTag]( override def stop() { } + private def writeObject(oos: ObjectOutputStream): Unit = { + throw new NotSerializableException("queueStream doesn't support checkpointing") + } + override def compute(validTime: Time): Option[RDD[T]] = { val buffer = new ArrayBuffer[RDD[T]]() if (oneAtATime && queue.size > 0) { diff --git a/streaming/src/main/scala/org/apache/spark/streaming/dstream/ReceiverInputDStream.scala b/streaming/src/main/scala/org/apache/spark/streaming/dstream/ReceiverInputDStream.scala index 5cfe43a1ce72..e76e7eb0dea1 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/dstream/ReceiverInputDStream.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/dstream/ReceiverInputDStream.scala @@ -70,30 +70,41 @@ abstract class ReceiverInputDStream[T: ClassTag](@transient ssc_ : StreamingCont val blockIds = blockInfos.map { _.blockId.asInstanceOf[BlockId] }.toArray // Register the input blocks information into InputInfoTracker - val inputInfo = InputInfo(id, blockInfos.map(_.numRecords).sum) + val inputInfo = InputInfo(id, blockInfos.flatMap(_.numRecords).sum) ssc.scheduler.inputInfoTracker.reportInfo(validTime, inputInfo) - // Are WAL record handles present with all the blocks - val areWALRecordHandlesPresent = blockInfos.forall { _.walRecordHandleOption.nonEmpty } + if (blockInfos.nonEmpty) { + // Are WAL record handles present with all the blocks + val areWALRecordHandlesPresent = blockInfos.forall { _.walRecordHandleOption.nonEmpty } - if (areWALRecordHandlesPresent) { - // If all the blocks have WAL record handle, then create a WALBackedBlockRDD - val isBlockIdValid = blockInfos.map { _.isBlockIdValid() }.toArray - val walRecordHandles = blockInfos.map { _.walRecordHandleOption.get }.toArray - new WriteAheadLogBackedBlockRDD[T]( - ssc.sparkContext, blockIds, walRecordHandles, isBlockIdValid) - } else { - // Else, create a BlockRDD. However, if there are some blocks with WAL info but not others - // then that is unexpected and log a warning accordingly. - if (blockInfos.find(_.walRecordHandleOption.nonEmpty).nonEmpty) { - if (WriteAheadLogUtils.enableReceiverLog(ssc.conf)) { - logError("Some blocks do not have Write Ahead Log information; " + - "this is unexpected and data may not be recoverable after driver failures") - } else { - logWarning("Some blocks have Write Ahead Log information; this is unexpected") + if (areWALRecordHandlesPresent) { + // If all the blocks have WAL record handle, then create a WALBackedBlockRDD + val isBlockIdValid = blockInfos.map { _.isBlockIdValid() }.toArray + val walRecordHandles = blockInfos.map { _.walRecordHandleOption.get }.toArray + new WriteAheadLogBackedBlockRDD[T]( + ssc.sparkContext, blockIds, walRecordHandles, isBlockIdValid) + } else { + // Else, create a BlockRDD. However, if there are some blocks with WAL info but not + // others then that is unexpected and log a warning accordingly. + if (blockInfos.find(_.walRecordHandleOption.nonEmpty).nonEmpty) { + if (WriteAheadLogUtils.enableReceiverLog(ssc.conf)) { + logError("Some blocks do not have Write Ahead Log information; " + + "this is unexpected and data may not be recoverable after driver failures") + } else { + logWarning("Some blocks have Write Ahead Log information; this is unexpected") + } } + new BlockRDD[T](ssc.sc, blockIds) + } + } else { + // If no block is ready now, creating WriteAheadLogBackedBlockRDD or BlockRDD + // according to the configuration + if (WriteAheadLogUtils.enableReceiverLog(ssc.conf)) { + new WriteAheadLogBackedBlockRDD[T]( + ssc.sparkContext, Array.empty, Array.empty, Array.empty) + } else { + new BlockRDD[T](ssc.sc, Array.empty) } - new BlockRDD[T](ssc.sc, blockIds) } } } diff --git a/streaming/src/main/scala/org/apache/spark/streaming/dstream/ReducedWindowedDStream.scala b/streaming/src/main/scala/org/apache/spark/streaming/dstream/ReducedWindowedDStream.scala index 1385ccbf56ee..6a583bf2a362 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/dstream/ReducedWindowedDStream.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/dstream/ReducedWindowedDStream.scala @@ -38,14 +38,14 @@ class ReducedWindowedDStream[K: ClassTag, V: ClassTag]( _windowDuration: Duration, _slideDuration: Duration, partitioner: Partitioner - ) extends DStream[(K,V)](parent.ssc) { + ) extends DStream[(K, V)](parent.ssc) { - assert(_windowDuration.isMultipleOf(parent.slideDuration), + require(_windowDuration.isMultipleOf(parent.slideDuration), "The window duration of ReducedWindowedDStream (" + _windowDuration + ") " + "must be multiple of the slide duration of parent DStream (" + parent.slideDuration + ")" ) - assert(_slideDuration.isMultipleOf(parent.slideDuration), + require(_slideDuration.isMultipleOf(parent.slideDuration), "The slide duration of ReducedWindowedDStream (" + _slideDuration + ") " + "must be multiple of the slide duration of parent DStream (" + parent.slideDuration + ")" ) @@ -58,7 +58,7 @@ class ReducedWindowedDStream[K: ClassTag, V: ClassTag]( super.persist(StorageLevel.MEMORY_ONLY_SER) reducedStream.persist(StorageLevel.MEMORY_ONLY_SER) - def windowDuration: Duration = _windowDuration + def windowDuration: Duration = _windowDuration override def dependencies: List[DStream[_]] = List(reducedStream) @@ -68,7 +68,7 @@ class ReducedWindowedDStream[K: ClassTag, V: ClassTag]( override def parentRememberDuration: Duration = rememberDuration + windowDuration - override def persist(storageLevel: StorageLevel): DStream[(K,V)] = { + override def persist(storageLevel: StorageLevel): DStream[(K, V)] = { super.persist(storageLevel) reducedStream.persist(storageLevel) this @@ -118,7 +118,7 @@ class ReducedWindowedDStream[K: ClassTag, V: ClassTag]( // Get the RDD of the reduced value of the previous window val previousWindowRDD = - getOrCompute(previousWindow.endTime).getOrElse(ssc.sc.makeRDD(Seq[(K,V)]())) + getOrCompute(previousWindow.endTime).getOrElse(ssc.sc.makeRDD(Seq[(K, V)]())) // Make the list of RDDs that needs to cogrouped together for reducing their reduced values val allRDDs = new ArrayBuffer[RDD[(K, V)]]() += previousWindowRDD ++= oldRDDs ++= newRDDs diff --git a/streaming/src/main/scala/org/apache/spark/streaming/dstream/ShuffledDStream.scala b/streaming/src/main/scala/org/apache/spark/streaming/dstream/ShuffledDStream.scala index 7757ccac09a5..e0ffd5d86b43 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/dstream/ShuffledDStream.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/dstream/ShuffledDStream.scala @@ -25,19 +25,19 @@ import scala.reflect.ClassTag private[streaming] class ShuffledDStream[K: ClassTag, V: ClassTag, C: ClassTag]( - parent: DStream[(K,V)], + parent: DStream[(K, V)], createCombiner: V => C, mergeValue: (C, V) => C, mergeCombiner: (C, C) => C, partitioner: Partitioner, mapSideCombine: Boolean = true - ) extends DStream[(K,C)] (parent.ssc) { + ) extends DStream[(K, C)] (parent.ssc) { override def dependencies: List[DStream[_]] = List(parent) override def slideDuration: Duration = parent.slideDuration - override def compute(validTime: Time): Option[RDD[(K,C)]] = { + override def compute(validTime: Time): Option[RDD[(K, C)]] = { parent.getOrCompute(validTime) match { case Some(rdd) => Some(rdd.combineByKey[C]( createCombiner, mergeValue, mergeCombiner, partitioner, mapSideCombine)) diff --git a/streaming/src/main/scala/org/apache/spark/streaming/dstream/SocketInputDStream.scala b/streaming/src/main/scala/org/apache/spark/streaming/dstream/SocketInputDStream.scala index 8b72bcf20653..5ce5b7aae6e6 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/dstream/SocketInputDStream.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/dstream/SocketInputDStream.scala @@ -17,6 +17,8 @@ package org.apache.spark.streaming.dstream +import scala.util.control.NonFatal + import org.apache.spark.streaming.StreamingContext import org.apache.spark.storage.StorageLevel import org.apache.spark.util.NextIterator @@ -74,13 +76,17 @@ class SocketReceiver[T: ClassTag]( while(!isStopped && iterator.hasNext) { store(iterator.next) } - logInfo("Stopped receiving") - restart("Retrying connecting to " + host + ":" + port) + if (!isStopped()) { + restart("Socket data stream had no more data") + } else { + logInfo("Stopped receiving") + } } catch { case e: java.net.ConnectException => restart("Error connecting to " + host + ":" + port, e) - case t: Throwable => - restart("Error receiving data", t) + case NonFatal(e) => + logWarning("Error receiving data", e) + restart("Error receiving data", e) } finally { if (socket != null) { socket.close() diff --git a/streaming/src/main/scala/org/apache/spark/streaming/dstream/StateDStream.scala b/streaming/src/main/scala/org/apache/spark/streaming/dstream/StateDStream.scala index de8718d0a80f..621d6dff788f 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/dstream/StateDStream.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/dstream/StateDStream.scala @@ -51,7 +51,7 @@ class StateDStream[K: ClassTag, V: ClassTag, S: ClassTag]( val finalFunc = (iterator: Iterator[(K, (Iterable[V], Iterable[S]))]) => { val i = iterator.map(t => { val itr = t._2._2.iterator - val headOption = if(itr.hasNext) Some(itr.next) else None + val headOption = if (itr.hasNext) Some(itr.next()) else None (t._1, t._2._1.toSeq, headOption) }) updateFuncLocal(i) diff --git a/streaming/src/main/scala/org/apache/spark/streaming/dstream/WindowedDStream.scala b/streaming/src/main/scala/org/apache/spark/streaming/dstream/WindowedDStream.scala index 899865a906c2..4efba039f895 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/dstream/WindowedDStream.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/dstream/WindowedDStream.scala @@ -44,7 +44,7 @@ class WindowedDStream[T: ClassTag]( // Persist parent level by default, as those RDDs are going to be obviously reused. parent.persist(StorageLevel.MEMORY_ONLY_SER) - def windowDuration: Duration = _windowDuration + def windowDuration: Duration = _windowDuration override def dependencies: List[DStream[_]] = List(parent) @@ -68,7 +68,7 @@ class WindowedDStream[T: ClassTag]( new PartitionerAwareUnionRDD(ssc.sc, rddsInWindow) } else { logDebug("Using normal union for windowing at " + validTime) - new UnionRDD(ssc.sc,rddsInWindow) + new UnionRDD(ssc.sc, rddsInWindow) } Some(windowRDD) } diff --git a/streaming/src/main/scala/org/apache/spark/streaming/rdd/WriteAheadLogBackedBlockRDD.scala b/streaming/src/main/scala/org/apache/spark/streaming/rdd/WriteAheadLogBackedBlockRDD.scala index ffce6a4c3c74..31ce8e1ec14d 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/rdd/WriteAheadLogBackedBlockRDD.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/rdd/WriteAheadLogBackedBlockRDD.scala @@ -23,12 +23,11 @@ import java.util.UUID import scala.reflect.ClassTag import scala.util.control.NonFatal -import org.apache.commons.io.FileUtils - import org.apache.spark._ import org.apache.spark.rdd.BlockRDD import org.apache.spark.storage.{BlockId, StorageLevel} import org.apache.spark.streaming.util._ +import org.apache.spark.util.SerializableConfiguration /** * Partition class for [[org.apache.spark.streaming.rdd.WriteAheadLogBackedBlockRDD]]. @@ -94,7 +93,7 @@ class WriteAheadLogBackedBlockRDD[T: ClassTag]( // Hadoop configuration is not serializable, so broadcast it as a serializable. @transient private val hadoopConfig = sc.hadoopConfiguration - private val broadcastedHadoopConf = new SerializableWritable(hadoopConfig) + private val broadcastedHadoopConf = new SerializableConfiguration(hadoopConfig) override def isValid(): Boolean = true diff --git a/streaming/src/main/scala/org/apache/spark/streaming/receiver/BlockGenerator.scala b/streaming/src/main/scala/org/apache/spark/streaming/receiver/BlockGenerator.scala index 4bebcc5aa7ca..92b51ce39234 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/receiver/BlockGenerator.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/receiver/BlockGenerator.scala @@ -24,7 +24,7 @@ import scala.collection.mutable.ArrayBuffer import org.apache.spark.{Logging, SparkConf} import org.apache.spark.storage.StreamBlockId import org.apache.spark.streaming.util.RecurringTimer -import org.apache.spark.util.{SystemClock, Utils} +import org.apache.spark.util.SystemClock /** Listener object for BlockGenerator events */ private[streaming] trait BlockGeneratorListener { @@ -80,6 +80,8 @@ private[streaming] class BlockGenerator( private val clock = new SystemClock() private val blockIntervalMs = conf.getTimeAsMs("spark.streaming.blockInterval", "200ms") + require(blockIntervalMs > 0, s"'spark.streaming.blockInterval' should be a positive value") + private val blockIntervalTimer = new RecurringTimer(clock, blockIntervalMs, updateCurrentBuffer, "BlockGenerator") private val blockQueueSize = conf.getInt("spark.streaming.blockQueueSize", 10) @@ -164,7 +166,7 @@ private[streaming] class BlockGenerator( private def keepPushingBlocks() { logInfo("Started block pushing thread") try { - while(!stopped) { + while (!stopped) { Option(blocksForPushing.poll(100, TimeUnit.MILLISECONDS)) match { case Some(block) => pushBlock(block) case None => @@ -191,7 +193,7 @@ private[streaming] class BlockGenerator( logError(message, t) listener.onError(message, t) } - + private def pushBlock(block: Block) { listener.onPushBlock(block.id, block.buffer) logInfo("Pushed block " + block.id) diff --git a/streaming/src/main/scala/org/apache/spark/streaming/receiver/RateLimiter.scala b/streaming/src/main/scala/org/apache/spark/streaming/receiver/RateLimiter.scala index 97db9ded8336..8df542b367d2 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/receiver/RateLimiter.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/receiver/RateLimiter.scala @@ -17,8 +17,9 @@ package org.apache.spark.streaming.receiver +import com.google.common.util.concurrent.{RateLimiter => GuavaRateLimiter} + import org.apache.spark.{Logging, SparkConf} -import com.google.common.util.concurrent.{RateLimiter=>GuavaRateLimiter} /** Provides waitToPush() method to limit the rate at which receivers consume data. * diff --git a/streaming/src/main/scala/org/apache/spark/streaming/receiver/ReceivedBlockHandler.scala b/streaming/src/main/scala/org/apache/spark/streaming/receiver/ReceivedBlockHandler.scala index 651b534ac190..c8dd6e06812d 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/receiver/ReceivedBlockHandler.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/receiver/ReceivedBlockHandler.scala @@ -32,7 +32,10 @@ import org.apache.spark.{Logging, SparkConf, SparkException} /** Trait that represents the metadata related to storage of blocks */ private[streaming] trait ReceivedBlockStoreResult { - def blockId: StreamBlockId // Any implementation of this trait will store a block id + // Any implementation of this trait will store a block id + def blockId: StreamBlockId + // Any implementation of this trait will have to return the number of records + def numRecords: Option[Long] } /** Trait that represents a class that handles the storage of blocks received by receiver */ @@ -51,7 +54,8 @@ private[streaming] trait ReceivedBlockHandler { * that stores the metadata related to storage of blocks using * [[org.apache.spark.streaming.receiver.BlockManagerBasedBlockHandler]] */ -private[streaming] case class BlockManagerBasedStoreResult(blockId: StreamBlockId) +private[streaming] case class BlockManagerBasedStoreResult( + blockId: StreamBlockId, numRecords: Option[Long]) extends ReceivedBlockStoreResult @@ -62,13 +66,22 @@ private[streaming] case class BlockManagerBasedStoreResult(blockId: StreamBlockI private[streaming] class BlockManagerBasedBlockHandler( blockManager: BlockManager, storageLevel: StorageLevel) extends ReceivedBlockHandler with Logging { - + def storeBlock(blockId: StreamBlockId, block: ReceivedBlock): ReceivedBlockStoreResult = { + + var numRecords = None: Option[Long] + val putResult: Seq[(BlockId, BlockStatus)] = block match { case ArrayBufferBlock(arrayBuffer) => - blockManager.putIterator(blockId, arrayBuffer.iterator, storageLevel, tellMaster = true) + numRecords = Some(arrayBuffer.size.toLong) + blockManager.putIterator(blockId, arrayBuffer.iterator, storageLevel, + tellMaster = true) case IteratorBlock(iterator) => - blockManager.putIterator(blockId, iterator, storageLevel, tellMaster = true) + val countIterator = new CountingIterator(iterator) + val putResult = blockManager.putIterator(blockId, countIterator, storageLevel, + tellMaster = true) + numRecords = countIterator.count + putResult case ByteBufferBlock(byteBuffer) => blockManager.putBytes(blockId, byteBuffer, storageLevel, tellMaster = true) case o => @@ -79,7 +92,7 @@ private[streaming] class BlockManagerBasedBlockHandler( throw new SparkException( s"Could not store $blockId to block manager with storage level $storageLevel") } - BlockManagerBasedStoreResult(blockId) + BlockManagerBasedStoreResult(blockId, numRecords) } def cleanupOldBlocks(threshTime: Long) { @@ -96,6 +109,7 @@ private[streaming] class BlockManagerBasedBlockHandler( */ private[streaming] case class WriteAheadLogBasedStoreResult( blockId: StreamBlockId, + numRecords: Option[Long], walRecordHandle: WriteAheadLogRecordHandle ) extends ReceivedBlockStoreResult @@ -151,12 +165,17 @@ private[streaming] class WriteAheadLogBasedBlockHandler( */ def storeBlock(blockId: StreamBlockId, block: ReceivedBlock): ReceivedBlockStoreResult = { + var numRecords = None: Option[Long] // Serialize the block so that it can be inserted into both val serializedBlock = block match { case ArrayBufferBlock(arrayBuffer) => + numRecords = Some(arrayBuffer.size.toLong) blockManager.dataSerialize(blockId, arrayBuffer.iterator) case IteratorBlock(iterator) => - blockManager.dataSerialize(blockId, iterator) + val countIterator = new CountingIterator(iterator) + val serializedBlock = blockManager.dataSerialize(blockId, countIterator) + numRecords = countIterator.count + serializedBlock case ByteBufferBlock(byteBuffer) => byteBuffer case _ => @@ -181,7 +200,7 @@ private[streaming] class WriteAheadLogBasedBlockHandler( // Combine the futures, wait for both to complete, and return the write ahead log record handle val combinedFuture = storeInBlockManagerFuture.zip(storeInWriteAheadLogFuture).map(_._2) val walRecordHandle = Await.result(combinedFuture, blockStoreTimeout) - WriteAheadLogBasedStoreResult(blockId, walRecordHandle) + WriteAheadLogBasedStoreResult(blockId, numRecords, walRecordHandle) } def cleanupOldBlocks(threshTime: Long) { @@ -199,3 +218,23 @@ private[streaming] object WriteAheadLogBasedBlockHandler { new Path(checkpointDir, new Path("receivedData", streamId.toString)).toString } } + +/** + * A utility that will wrap the Iterator to get the count + */ +private class CountingIterator[T](iterator: Iterator[T]) extends Iterator[T] { + private var _count = 0 + + private def isFullyConsumed: Boolean = !iterator.hasNext + + def hasNext(): Boolean = iterator.hasNext + + def count(): Option[Long] = { + if (isFullyConsumed) Some(_count) else None + } + + def next(): T = { + _count += 1 + iterator.next() + } +} diff --git a/streaming/src/main/scala/org/apache/spark/streaming/receiver/ReceiverSupervisor.scala b/streaming/src/main/scala/org/apache/spark/streaming/receiver/ReceiverSupervisor.scala index 4943f29395d1..33be067ebdaf 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/receiver/ReceiverSupervisor.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/receiver/ReceiverSupervisor.scala @@ -18,14 +18,14 @@ package org.apache.spark.streaming.receiver import java.nio.ByteBuffer +import java.util.concurrent.CountDownLatch import scala.collection.mutable.ArrayBuffer +import scala.concurrent._ import org.apache.spark.{Logging, SparkConf} import org.apache.spark.storage.StreamBlockId -import java.util.concurrent.CountDownLatch -import scala.concurrent._ -import ExecutionContext.Implicits.global +import org.apache.spark.util.ThreadUtils /** * Abstract class that is responsible for supervising a Receiver in the worker. @@ -46,6 +46,9 @@ private[streaming] abstract class ReceiverSupervisor( // Attach the executor to the receiver receiver.attachExecutor(this) + private val futureExecutionContext = ExecutionContext.fromExecutorService( + ThreadUtils.newDaemonCachedThreadPool("receiver-supervisor-future", 128)) + /** Receiver id */ protected val streamId = receiver.streamId @@ -111,6 +114,7 @@ private[streaming] abstract class ReceiverSupervisor( stoppingError = error.orNull stopReceiver(message, error) onStop(message, error) + futureExecutionContext.shutdownNow() stopLatch.countDown() } @@ -150,6 +154,8 @@ private[streaming] abstract class ReceiverSupervisor( /** Restart receiver with delay */ def restartReceiver(message: String, error: Option[Throwable], delay: Int) { Future { + // This is a blocking action so we should use "futureExecutionContext" which is a cached + // thread pool. logWarning("Restarting receiver with delay " + delay + " ms: " + message, error.getOrElse(null)) stopReceiver("Restarting receiver with delay " + delay + "ms: " + message, error) @@ -158,7 +164,7 @@ private[streaming] abstract class ReceiverSupervisor( logInfo("Starting receiver again") startReceiver() logInfo("Receiver started again") - } + }(futureExecutionContext) } /** Check if receiver has been marked for stopping */ diff --git a/streaming/src/main/scala/org/apache/spark/streaming/receiver/ReceiverSupervisorImpl.scala b/streaming/src/main/scala/org/apache/spark/streaming/receiver/ReceiverSupervisorImpl.scala index 92938379b9c1..6078cdf8f879 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/receiver/ReceiverSupervisorImpl.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/receiver/ReceiverSupervisorImpl.scala @@ -137,15 +137,10 @@ private[streaming] class ReceiverSupervisorImpl( blockIdOption: Option[StreamBlockId] ) { val blockId = blockIdOption.getOrElse(nextBlockId) - val numRecords = receivedBlock match { - case ArrayBufferBlock(arrayBuffer) => arrayBuffer.size - case _ => -1 - } - val time = System.currentTimeMillis val blockStoreResult = receivedBlockHandler.storeBlock(blockId, receivedBlock) logDebug(s"Pushed block $blockId in ${(System.currentTimeMillis - time)} ms") - + val numRecords = blockStoreResult.numRecords val blockInfo = ReceivedBlockInfo(streamId, numRecords, metadataOption, blockStoreResult) trackerEndpoint.askWithRetry[Boolean](AddBlock(blockInfo)) logDebug(s"Reported block $blockId") diff --git a/streaming/src/main/scala/org/apache/spark/streaming/scheduler/InputInfoTracker.scala b/streaming/src/main/scala/org/apache/spark/streaming/scheduler/InputInfoTracker.scala index a72efccf2f99..7c0db8a863c6 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/scheduler/InputInfoTracker.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/scheduler/InputInfoTracker.scala @@ -23,7 +23,9 @@ import org.apache.spark.Logging import org.apache.spark.streaming.{Time, StreamingContext} /** To track the information of input stream at specified batch time. */ -private[streaming] case class InputInfo(inputStreamId: Int, numRecords: Long) +private[streaming] case class InputInfo(inputStreamId: Int, numRecords: Long) { + require(numRecords >= 0, "numRecords must not be negative") +} /** * This class manages all the input streams as well as their input data statistics. The information diff --git a/streaming/src/main/scala/org/apache/spark/streaming/scheduler/JobScheduler.scala b/streaming/src/main/scala/org/apache/spark/streaming/scheduler/JobScheduler.scala index 1d1ddaaccf21..4af9b6d3b56a 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/scheduler/JobScheduler.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/scheduler/JobScheduler.scala @@ -126,6 +126,10 @@ class JobScheduler(val ssc: StreamingContext) extends Logging { eventLoop.post(ErrorReported(msg, e)) } + def isStarted(): Boolean = synchronized { + eventLoop != null + } + private def processEvent(event: JobSchedulerEvent) { try { event match { diff --git a/streaming/src/main/scala/org/apache/spark/streaming/scheduler/ReceivedBlockInfo.scala b/streaming/src/main/scala/org/apache/spark/streaming/scheduler/ReceivedBlockInfo.scala index dc11e84f2996..656ac80df897 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/scheduler/ReceivedBlockInfo.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/scheduler/ReceivedBlockInfo.scala @@ -24,11 +24,13 @@ import org.apache.spark.streaming.util.WriteAheadLogRecordHandle /** Information about blocks received by the receiver */ private[streaming] case class ReceivedBlockInfo( streamId: Int, - numRecords: Long, + numRecords: Option[Long], metadataOption: Option[Any], blockStoreResult: ReceivedBlockStoreResult ) { + require(numRecords.isEmpty || numRecords.get >= 0, "numRecords must not be negative") + @volatile private var _isBlockIdValid = true def blockId: StreamBlockId = blockStoreResult.blockId diff --git a/streaming/src/main/scala/org/apache/spark/streaming/scheduler/ReceivedBlockTracker.scala b/streaming/src/main/scala/org/apache/spark/streaming/scheduler/ReceivedBlockTracker.scala index a9f4147a5f02..7720259a5d79 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/scheduler/ReceivedBlockTracker.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/scheduler/ReceivedBlockTracker.scala @@ -153,7 +153,7 @@ private[streaming] class ReceivedBlockTracker( * returns only after the files are cleaned up. */ def cleanupOldBatches(cleanupThreshTime: Time, waitForCompletion: Boolean): Unit = synchronized { - assert(cleanupThreshTime.milliseconds < clock.getTimeMillis()) + require(cleanupThreshTime.milliseconds < clock.getTimeMillis()) val timesToCleanup = timeToAllocatedBlocks.keys.filter { _ < cleanupThreshTime }.toSeq logInfo("Deleting batches " + timesToCleanup) writeToLog(BatchCleanupEvent(timesToCleanup)) diff --git a/streaming/src/main/scala/org/apache/spark/streaming/scheduler/ReceiverTracker.scala b/streaming/src/main/scala/org/apache/spark/streaming/scheduler/ReceiverTracker.scala index f73f7e705ee0..644e581cd827 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/scheduler/ReceiverTracker.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/scheduler/ReceiverTracker.scala @@ -17,14 +17,18 @@ package org.apache.spark.streaming.scheduler -import scala.collection.mutable.{HashMap, SynchronizedMap} +import scala.collection.mutable.{ArrayBuffer, HashMap, SynchronizedMap} import scala.language.existentials +import scala.math.max +import org.apache.spark.rdd._ import org.apache.spark.streaming.util.WriteAheadLogUtils -import org.apache.spark.{Logging, SerializableWritable, SparkEnv, SparkException} +import org.apache.spark.{Logging, SparkEnv, SparkException} import org.apache.spark.rpc._ import org.apache.spark.streaming.{StreamingContext, Time} -import org.apache.spark.streaming.receiver.{CleanupOldBlocks, Receiver, ReceiverSupervisorImpl, StopReceiver} +import org.apache.spark.streaming.receiver.{CleanupOldBlocks, Receiver, ReceiverSupervisorImpl, + StopReceiver} +import org.apache.spark.util.SerializableConfiguration /** * Messages used by the NetworkReceiver and the ReceiverTracker to communicate @@ -230,7 +234,7 @@ class ReceiverTracker(ssc: StreamingContext, skipReceiverLaunch: Boolean = false class ReceiverLauncher { @transient val env = ssc.env @volatile @transient private var running = false - @transient val thread = new Thread() { + @transient val thread = new Thread() { override def run() { try { SparkEnv.set(env) @@ -270,6 +274,41 @@ class ReceiverTracker(ssc: StreamingContext, skipReceiverLaunch: Boolean = false } } + /** + * Get the list of executors excluding driver + */ + private def getExecutors(ssc: StreamingContext): List[String] = { + val executors = ssc.sparkContext.getExecutorMemoryStatus.map(_._1.split(":")(0)).toList + val driver = ssc.sparkContext.getConf.get("spark.driver.host") + executors.diff(List(driver)) + } + + /** Set host location(s) for each receiver so as to distribute them over + * executors in a round-robin fashion taking into account preferredLocation if set + */ + private[streaming] def scheduleReceivers(receivers: Seq[Receiver[_]], + executors: List[String]): Array[ArrayBuffer[String]] = { + val locations = new Array[ArrayBuffer[String]](receivers.length) + var i = 0 + for (i <- 0 until receivers.length) { + locations(i) = new ArrayBuffer[String]() + if (receivers(i).preferredLocation.isDefined) { + locations(i) += receivers(i).preferredLocation.get + } + } + var count = 0 + for (i <- 0 until max(receivers.length, executors.length)) { + if (!receivers(i % receivers.length).preferredLocation.isDefined) { + locations(i % receivers.length) += executors(count) + count += 1 + if (count == executors.length) { + count = 0 + } + } + } + locations + } + /** * Get the receivers from the ReceiverInputDStreams, distributes them to the * worker nodes as a parallel collection, and runs them. @@ -281,20 +320,9 @@ class ReceiverTracker(ssc: StreamingContext, skipReceiverLaunch: Boolean = false rcvr }) - // Right now, we only honor preferences if all receivers have them - val hasLocationPreferences = receivers.map(_.preferredLocation.isDefined).reduce(_ && _) - - // Create the parallel collection of receivers to distributed them on the worker nodes - val tempRDD = - if (hasLocationPreferences) { - val receiversWithPreferences = receivers.map(r => (r, Seq(r.preferredLocation.get))) - ssc.sc.makeRDD[Receiver[_]](receiversWithPreferences) - } else { - ssc.sc.makeRDD(receivers, receivers.size) - } - val checkpointDirOption = Option(ssc.checkpointDir) - val serializableHadoopConf = new SerializableWritable(ssc.sparkContext.hadoopConfiguration) + val serializableHadoopConf = + new SerializableConfiguration(ssc.sparkContext.hadoopConfiguration) // Function to start the receiver on the worker node val startReceiver = (iterator: Iterator[Receiver[_]]) => { @@ -308,12 +336,25 @@ class ReceiverTracker(ssc: StreamingContext, skipReceiverLaunch: Boolean = false supervisor.start() supervisor.awaitTermination() } + // Run the dummy Spark job to ensure that all slaves have registered. // This avoids all the receivers to be scheduled on the same node. if (!ssc.sparkContext.isLocal) { ssc.sparkContext.makeRDD(1 to 50, 50).map(x => (x, 1)).reduceByKey(_ + _, 20).collect() } + // Get the list of executors and schedule receivers + val executors = getExecutors(ssc) + val tempRDD = + if (!executors.isEmpty) { + val locations = scheduleReceivers(receivers, executors) + val roundRobinReceivers = (0 until receivers.length).map(i => + (receivers(i), locations(i))) + ssc.sc.makeRDD[Receiver[_]](roundRobinReceivers) + } else { + ssc.sc.makeRDD(receivers, receivers.size) + } + // Distribute the receivers and start them logInfo("Starting " + receivers.length + " receivers") running = true diff --git a/streaming/src/main/scala/org/apache/spark/streaming/ui/AllBatchesTable.scala b/streaming/src/main/scala/org/apache/spark/streaming/ui/AllBatchesTable.scala index 00cc47d6a3ca..f702bd5bc946 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/ui/AllBatchesTable.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/ui/AllBatchesTable.scala @@ -44,8 +44,9 @@ private[ui] abstract class BatchTableBase(tableId: String, batchInterval: Long) val formattedSchedulingDelay = schedulingDelay.map(SparkUIUtils.formatDuration).getOrElse("-") val processingTime = batch.processingDelay val formattedProcessingTime = processingTime.map(SparkUIUtils.formatDuration).getOrElse("-") + val batchTimeId = s"batch-$batchTime" - + {formattedBatchTime} diff --git a/streaming/src/main/scala/org/apache/spark/streaming/ui/StreamingPage.scala b/streaming/src/main/scala/org/apache/spark/streaming/ui/StreamingPage.scala index 4ee7a486e370..87af902428ec 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/ui/StreamingPage.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/ui/StreamingPage.scala @@ -310,7 +310,7 @@ private[ui] class StreamingPage(parent: StreamingTab) Timelines (Last {batchTimes.length} batches, {numActiveBatches} active, {numCompletedBatches} completed) - Histograms + Histograms @@ -456,7 +456,7 @@ private[ui] class StreamingPage(parent: StreamingTab) {receiverActive} {receiverLocation} {receiverLastErrorTime} -
    {receiverLastError}
    +
    {receiverLastError}
    diff --git a/streaming/src/main/scala/org/apache/spark/streaming/util/FileBasedWriteAheadLog.scala b/streaming/src/main/scala/org/apache/spark/streaming/util/FileBasedWriteAheadLog.scala index 87ba4f84a9ce..fe6328b1ce72 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/util/FileBasedWriteAheadLog.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/util/FileBasedWriteAheadLog.scala @@ -200,7 +200,7 @@ private[streaming] class FileBasedWriteAheadLog( /** Initialize the log directory or recover existing logs inside the directory */ private def initializeOrRecover(): Unit = synchronized { val logDirectoryPath = new Path(logDirectory) - val fileSystem = HdfsUtils.getFileSystemForPath(logDirectoryPath, hadoopConf) + val fileSystem = HdfsUtils.getFileSystemForPath(logDirectoryPath, hadoopConf) if (fileSystem.exists(logDirectoryPath) && fileSystem.getFileStatus(logDirectoryPath).isDir) { val logFileInfo = logFilesTologInfo(fileSystem.listStatus(logDirectoryPath).map { _.getPath }) diff --git a/streaming/src/main/scala/org/apache/spark/streaming/util/RawTextHelper.scala b/streaming/src/main/scala/org/apache/spark/streaming/util/RawTextHelper.scala index 4d968f8bfa7a..408936653c79 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/util/RawTextHelper.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/util/RawTextHelper.scala @@ -27,7 +27,7 @@ object RawTextHelper { * Splits lines and counts the words. */ def splitAndCountPartitions(iter: Iterator[String]): Iterator[(String, Long)] = { - val map = new OpenHashMap[String,Long] + val map = new OpenHashMap[String, Long] var i = 0 var j = 0 while (iter.hasNext) { @@ -98,7 +98,7 @@ object RawTextHelper { * before real workload starts. */ def warmUp(sc: SparkContext) { - for(i <- 0 to 1) { + for (i <- 0 to 1) { sc.parallelize(1 to 200000, 1000) .map(_ % 1331).map(_.toString) .mapPartitions(splitAndCountPartitions).reduceByKey(_ + _, 10) diff --git a/streaming/src/test/java/org/apache/spark/streaming/JavaAPISuite.java b/streaming/src/test/java/org/apache/spark/streaming/JavaAPISuite.java index 1077b1b2cb7e..a34f23475804 100644 --- a/streaming/src/test/java/org/apache/spark/streaming/JavaAPISuite.java +++ b/streaming/src/test/java/org/apache/spark/streaming/JavaAPISuite.java @@ -364,6 +364,14 @@ private void testReduceByWindow(boolean withInverse) { @SuppressWarnings("unchecked") @Test public void testQueueStream() { + ssc.stop(); + // Create a new JavaStreamingContext without checkpointing + SparkConf conf = new SparkConf() + .setMaster("local[2]") + .setAppName("test") + .set("spark.streaming.clock", "org.apache.spark.util.ManualClock"); + ssc = new JavaStreamingContext(conf, new Duration(1000)); + List> expected = Arrays.asList( Arrays.asList(1,2,3), Arrays.asList(4,5,6), diff --git a/streaming/src/test/scala/org/apache/spark/streaming/BasicOperationsSuite.scala b/streaming/src/test/scala/org/apache/spark/streaming/BasicOperationsSuite.scala index 87bc20f79c3c..08faeaa58f41 100644 --- a/streaming/src/test/scala/org/apache/spark/streaming/BasicOperationsSuite.scala +++ b/streaming/src/test/scala/org/apache/spark/streaming/BasicOperationsSuite.scala @@ -255,7 +255,7 @@ class BasicOperationsSuite extends TestSuiteBase { Seq( ) ) val operation = (s1: DStream[String], s2: DStream[String]) => { - s1.map(x => (x,1)).cogroup(s2.map(x => (x, "x"))).mapValues(x => (x._1.toSeq, x._2.toSeq)) + s1.map(x => (x, 1)).cogroup(s2.map(x => (x, "x"))).mapValues(x => (x._1.toSeq, x._2.toSeq)) } testOperation(inputData1, inputData2, operation, outputData, true) } @@ -427,9 +427,9 @@ class BasicOperationsSuite extends TestSuiteBase { test("updateStateByKey - object lifecycle") { val inputData = Seq( - Seq("a","b"), + Seq("a", "b"), null, - Seq("a","c","a"), + Seq("a", "c", "a"), Seq("c"), null, null @@ -557,6 +557,9 @@ class BasicOperationsSuite extends TestSuiteBase { withTestServer(new TestServer()) { testServer => withStreamingContext(new StreamingContext(conf, batchDuration)) { ssc => testServer.start() + + val batchCounter = new BatchCounter(ssc) + // Set up the streaming context and input streams val networkStream = ssc.socketTextStream("localhost", testServer.port, StorageLevel.MEMORY_AND_DISK) @@ -587,7 +590,11 @@ class BasicOperationsSuite extends TestSuiteBase { for (i <- 0 until input.size) { testServer.send(input(i).toString + "\n") Thread.sleep(200) + val numCompletedBatches = batchCounter.getNumCompletedBatches clock.advance(batchDuration.milliseconds) + if (!batchCounter.waitUntilBatchesCompleted(numCompletedBatches + 1, 5000)) { + fail("Batch took more than 5 seconds to complete") + } collectRddInfo() } diff --git a/streaming/src/test/scala/org/apache/spark/streaming/DStreamClosureSuite.scala b/streaming/src/test/scala/org/apache/spark/streaming/DStreamClosureSuite.scala new file mode 100644 index 000000000000..9b5e4dc819a2 --- /dev/null +++ b/streaming/src/test/scala/org/apache/spark/streaming/DStreamClosureSuite.scala @@ -0,0 +1,196 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.streaming + +import java.io.NotSerializableException + +import org.scalatest.BeforeAndAfterAll + +import org.apache.spark.{HashPartitioner, SparkContext, SparkException, SparkFunSuite} +import org.apache.spark.rdd.RDD +import org.apache.spark.streaming.dstream.DStream +import org.apache.spark.util.ReturnStatementInClosureException + +/** + * Test that closures passed to DStream operations are actually cleaned. + */ +class DStreamClosureSuite extends SparkFunSuite with BeforeAndAfterAll { + private var ssc: StreamingContext = null + + override def beforeAll(): Unit = { + val sc = new SparkContext("local", "test") + ssc = new StreamingContext(sc, Seconds(1)) + } + + override def afterAll(): Unit = { + ssc.stop(stopSparkContext = true) + ssc = null + } + + test("user provided closures are actually cleaned") { + val dstream = new DummyInputDStream(ssc) + val pairDstream = dstream.map { i => (i, i) } + // DStream + testMap(dstream) + testFlatMap(dstream) + testFilter(dstream) + testMapPartitions(dstream) + testReduce(dstream) + testForeach(dstream) + testForeachRDD(dstream) + testTransform(dstream) + testTransformWith(dstream) + testReduceByWindow(dstream) + // PairDStreamFunctions + testReduceByKey(pairDstream) + testCombineByKey(pairDstream) + testReduceByKeyAndWindow(pairDstream) + testUpdateStateByKey(pairDstream) + testMapValues(pairDstream) + testFlatMapValues(pairDstream) + // StreamingContext + testTransform2(ssc, dstream) + } + + /** + * Verify that the expected exception is thrown. + * + * We use return statements as an indication that a closure is actually being cleaned. + * We expect closure cleaner to find the return statements in the user provided closures. + */ + private def expectCorrectException(body: => Unit): Unit = { + try { + body + } catch { + case rse: ReturnStatementInClosureException => // Success! + case e @ (_: NotSerializableException | _: SparkException) => + throw new TestException( + s"Expected ReturnStatementInClosureException, but got $e.\n" + + "This means the closure provided by user is not actually cleaned.") + } + } + + // DStream operations + private def testMap(ds: DStream[Int]): Unit = expectCorrectException { + ds.map { _ => return; 1 } + } + private def testFlatMap(ds: DStream[Int]): Unit = expectCorrectException { + ds.flatMap { _ => return; Seq.empty } + } + private def testFilter(ds: DStream[Int]): Unit = expectCorrectException { + ds.filter { _ => return; true } + } + private def testMapPartitions(ds: DStream[Int]): Unit = expectCorrectException { + ds.mapPartitions { _ => return; Seq.empty.toIterator } + } + private def testReduce(ds: DStream[Int]): Unit = expectCorrectException { + ds.reduce { case (_, _) => return; 1 } + } + private def testForeach(ds: DStream[Int]): Unit = { + val foreachF1 = (rdd: RDD[Int], t: Time) => return + val foreachF2 = (rdd: RDD[Int]) => return + expectCorrectException { ds.foreach(foreachF1) } + expectCorrectException { ds.foreach(foreachF2) } + } + private def testForeachRDD(ds: DStream[Int]): Unit = { + val foreachRDDF1 = (rdd: RDD[Int], t: Time) => return + val foreachRDDF2 = (rdd: RDD[Int]) => return + expectCorrectException { ds.foreachRDD(foreachRDDF1) } + expectCorrectException { ds.foreachRDD(foreachRDDF2) } + } + private def testTransform(ds: DStream[Int]): Unit = { + val transformF1 = (rdd: RDD[Int]) => { return; rdd } + val transformF2 = (rdd: RDD[Int], time: Time) => { return; rdd } + expectCorrectException { ds.transform(transformF1) } + expectCorrectException { ds.transform(transformF2) } + } + private def testTransformWith(ds: DStream[Int]): Unit = { + val transformF1 = (rdd1: RDD[Int], rdd2: RDD[Int]) => { return; rdd1 } + val transformF2 = (rdd1: RDD[Int], rdd2: RDD[Int], time: Time) => { return; rdd2 } + expectCorrectException { ds.transformWith(ds, transformF1) } + expectCorrectException { ds.transformWith(ds, transformF2) } + } + private def testReduceByWindow(ds: DStream[Int]): Unit = { + val reduceF = (_: Int, _: Int) => { return; 1 } + expectCorrectException { ds.reduceByWindow(reduceF, Seconds(1), Seconds(2)) } + expectCorrectException { ds.reduceByWindow(reduceF, reduceF, Seconds(1), Seconds(2)) } + } + + // PairDStreamFunctions operations + private def testReduceByKey(ds: DStream[(Int, Int)]): Unit = { + val reduceF = (_: Int, _: Int) => { return; 1 } + expectCorrectException { ds.reduceByKey(reduceF) } + expectCorrectException { ds.reduceByKey(reduceF, 5) } + expectCorrectException { ds.reduceByKey(reduceF, new HashPartitioner(5)) } + } + private def testCombineByKey(ds: DStream[(Int, Int)]): Unit = { + expectCorrectException { + ds.combineByKey[Int]( + { _: Int => return; 1 }, + { case (_: Int, _: Int) => return; 1 }, + { case (_: Int, _: Int) => return; 1 }, + new HashPartitioner(5) + ) + } + } + private def testReduceByKeyAndWindow(ds: DStream[(Int, Int)]): Unit = { + val reduceF = (_: Int, _: Int) => { return; 1 } + val filterF = (_: (Int, Int)) => { return; false } + expectCorrectException { ds.reduceByKeyAndWindow(reduceF, Seconds(1)) } + expectCorrectException { ds.reduceByKeyAndWindow(reduceF, Seconds(1), Seconds(2)) } + expectCorrectException { ds.reduceByKeyAndWindow(reduceF, Seconds(1), Seconds(2), 5) } + expectCorrectException { + ds.reduceByKeyAndWindow(reduceF, Seconds(1), Seconds(2), new HashPartitioner(5)) + } + expectCorrectException { ds.reduceByKeyAndWindow(reduceF, reduceF, Seconds(2)) } + expectCorrectException { + ds.reduceByKeyAndWindow( + reduceF, reduceF, Seconds(2), Seconds(3), new HashPartitioner(5), filterF) + } + } + private def testUpdateStateByKey(ds: DStream[(Int, Int)]): Unit = { + val updateF1 = (_: Seq[Int], _: Option[Int]) => { return; Some(1) } + val updateF2 = (_: Iterator[(Int, Seq[Int], Option[Int])]) => { return; Seq((1, 1)).toIterator } + val initialRDD = ds.ssc.sparkContext.emptyRDD[Int].map { i => (i, i) } + expectCorrectException { ds.updateStateByKey(updateF1) } + expectCorrectException { ds.updateStateByKey(updateF1, 5) } + expectCorrectException { ds.updateStateByKey(updateF1, new HashPartitioner(5)) } + expectCorrectException { + ds.updateStateByKey(updateF1, new HashPartitioner(5), initialRDD) + } + expectCorrectException { + ds.updateStateByKey(updateF2, new HashPartitioner(5), true) + } + expectCorrectException { + ds.updateStateByKey(updateF2, new HashPartitioner(5), true, initialRDD) + } + } + private def testMapValues(ds: DStream[(Int, Int)]): Unit = expectCorrectException { + ds.mapValues { _ => return; 1 } + } + private def testFlatMapValues(ds: DStream[(Int, Int)]): Unit = expectCorrectException { + ds.flatMapValues { _ => return; Seq.empty } + } + + // StreamingContext operations + private def testTransform2(ssc: StreamingContext, ds: DStream[Int]): Unit = { + val transformF = (rdds: Seq[RDD[_]], time: Time) => { return; ssc.sparkContext.emptyRDD[Int] } + expectCorrectException { ssc.transform(Seq(ds), transformF) } + } + +} diff --git a/streaming/src/test/scala/org/apache/spark/streaming/DStreamScopeSuite.scala b/streaming/src/test/scala/org/apache/spark/streaming/DStreamScopeSuite.scala new file mode 100644 index 000000000000..8844c9d74b93 --- /dev/null +++ b/streaming/src/test/scala/org/apache/spark/streaming/DStreamScopeSuite.scala @@ -0,0 +1,172 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.streaming + +import org.scalatest.{BeforeAndAfter, BeforeAndAfterAll} + +import org.apache.spark.{SparkContext, SparkFunSuite} +import org.apache.spark.rdd.RDDOperationScope +import org.apache.spark.streaming.dstream.DStream +import org.apache.spark.streaming.ui.UIUtils + +/** + * Tests whether scope information is passed from DStream operations to RDDs correctly. + */ +class DStreamScopeSuite extends SparkFunSuite with BeforeAndAfter with BeforeAndAfterAll { + private var ssc: StreamingContext = null + private val batchDuration: Duration = Seconds(1) + + override def beforeAll(): Unit = { + ssc = new StreamingContext(new SparkContext("local", "test"), batchDuration) + } + + override def afterAll(): Unit = { + ssc.stop(stopSparkContext = true) + } + + before { assertPropertiesNotSet() } + after { assertPropertiesNotSet() } + + test("dstream without scope") { + val dummyStream = new DummyDStream(ssc) + dummyStream.initialize(Time(0)) + + // This DStream is not instantiated in any scope, so all RDDs + // created by this stream should similarly not have a scope + assert(dummyStream.baseScope === None) + assert(dummyStream.getOrCompute(Time(1000)).get.scope === None) + assert(dummyStream.getOrCompute(Time(2000)).get.scope === None) + assert(dummyStream.getOrCompute(Time(3000)).get.scope === None) + } + + test("input dstream without scope") { + val inputStream = new DummyInputDStream(ssc) + inputStream.initialize(Time(0)) + + val baseScope = inputStream.baseScope.map(RDDOperationScope.fromJson) + val scope1 = inputStream.getOrCompute(Time(1000)).get.scope + val scope2 = inputStream.getOrCompute(Time(2000)).get.scope + val scope3 = inputStream.getOrCompute(Time(3000)).get.scope + + // This DStream is not instantiated in any scope, so all RDDs + assertDefined(baseScope, scope1, scope2, scope3) + assert(baseScope.get.name.startsWith("dummy stream")) + assertScopeCorrect(baseScope.get, scope1.get, 1000) + assertScopeCorrect(baseScope.get, scope2.get, 2000) + assertScopeCorrect(baseScope.get, scope3.get, 3000) + } + + test("scoping simple operations") { + val inputStream = new DummyInputDStream(ssc) + val mappedStream = inputStream.map { i => i + 1 } + val filteredStream = mappedStream.filter { i => i % 2 == 0 } + filteredStream.initialize(Time(0)) + + val mappedScopeBase = mappedStream.baseScope.map(RDDOperationScope.fromJson) + val mappedScope1 = mappedStream.getOrCompute(Time(1000)).get.scope + val mappedScope2 = mappedStream.getOrCompute(Time(2000)).get.scope + val mappedScope3 = mappedStream.getOrCompute(Time(3000)).get.scope + val filteredScopeBase = filteredStream.baseScope.map(RDDOperationScope.fromJson) + val filteredScope1 = filteredStream.getOrCompute(Time(1000)).get.scope + val filteredScope2 = filteredStream.getOrCompute(Time(2000)).get.scope + val filteredScope3 = filteredStream.getOrCompute(Time(3000)).get.scope + + // These streams are defined in their respective scopes "map" and "filter", so all + // RDDs created by these streams should inherit the IDs and names of their parent + // DStream's base scopes + assertDefined(mappedScopeBase, mappedScope1, mappedScope2, mappedScope3) + assertDefined(filteredScopeBase, filteredScope1, filteredScope2, filteredScope3) + assert(mappedScopeBase.get.name === "map") + assert(filteredScopeBase.get.name === "filter") + assertScopeCorrect(mappedScopeBase.get, mappedScope1.get, 1000) + assertScopeCorrect(mappedScopeBase.get, mappedScope2.get, 2000) + assertScopeCorrect(mappedScopeBase.get, mappedScope3.get, 3000) + assertScopeCorrect(filteredScopeBase.get, filteredScope1.get, 1000) + assertScopeCorrect(filteredScopeBase.get, filteredScope2.get, 2000) + assertScopeCorrect(filteredScopeBase.get, filteredScope3.get, 3000) + } + + test("scoping nested operations") { + val inputStream = new DummyInputDStream(ssc) + val countStream = inputStream.countByWindow(Seconds(10), Seconds(1)) + countStream.initialize(Time(0)) + + val countScopeBase = countStream.baseScope.map(RDDOperationScope.fromJson) + val countScope1 = countStream.getOrCompute(Time(1000)).get.scope + val countScope2 = countStream.getOrCompute(Time(2000)).get.scope + val countScope3 = countStream.getOrCompute(Time(3000)).get.scope + + // Assert that all children RDDs inherit the DStream operation name correctly + assertDefined(countScopeBase, countScope1, countScope2, countScope3) + assert(countScopeBase.get.name === "countByWindow") + assertScopeCorrect(countScopeBase.get, countScope1.get, 1000) + assertScopeCorrect(countScopeBase.get, countScope2.get, 2000) + assertScopeCorrect(countScopeBase.get, countScope3.get, 3000) + + // All streams except the input stream should share the same scopes as `countStream` + def testStream(stream: DStream[_]): Unit = { + if (stream != inputStream) { + val myScopeBase = stream.baseScope.map(RDDOperationScope.fromJson) + val myScope1 = stream.getOrCompute(Time(1000)).get.scope + val myScope2 = stream.getOrCompute(Time(2000)).get.scope + val myScope3 = stream.getOrCompute(Time(3000)).get.scope + assertDefined(myScopeBase, myScope1, myScope2, myScope3) + assert(myScopeBase === countScopeBase) + assert(myScope1 === countScope1) + assert(myScope2 === countScope2) + assert(myScope3 === countScope3) + // Climb upwards to test the parent streams + stream.dependencies.foreach(testStream) + } + } + testStream(countStream) + } + + /** Assert that the RDD operation scope properties are not set in our SparkContext. */ + private def assertPropertiesNotSet(): Unit = { + assert(ssc != null) + assert(ssc.sc.getLocalProperty(SparkContext.RDD_SCOPE_KEY) == null) + assert(ssc.sc.getLocalProperty(SparkContext.RDD_SCOPE_NO_OVERRIDE_KEY) == null) + } + + /** Assert that the given RDD scope inherits the name and ID of the base scope correctly. */ + private def assertScopeCorrect( + baseScope: RDDOperationScope, + rddScope: RDDOperationScope, + batchTime: Long): Unit = { + assertScopeCorrect(baseScope.id, baseScope.name, rddScope, batchTime) + } + + /** Assert that the given RDD scope inherits the base name and ID correctly. */ + private def assertScopeCorrect( + baseScopeId: String, + baseScopeName: String, + rddScope: RDDOperationScope, + batchTime: Long): Unit = { + val formattedBatchTime = UIUtils.formatBatchTime( + batchTime, ssc.graph.batchDuration.milliseconds, showYYYYMMSS = false) + assert(rddScope.id === s"${baseScopeId}_$batchTime") + assert(rddScope.name.replaceAll("\\n", " ") === s"$baseScopeName @ $formattedBatchTime") + } + + /** Assert that all the specified options are defined. */ + private def assertDefined[T](options: Option[T]*): Unit = { + options.zipWithIndex.foreach { case (o, i) => assert(o.isDefined, s"Option $i was empty!") } + } + +} diff --git a/streaming/src/test/scala/org/apache/spark/streaming/InputStreamsSuite.scala b/streaming/src/test/scala/org/apache/spark/streaming/InputStreamsSuite.scala index 93e6b0cd7c66..b74d67c63a78 100644 --- a/streaming/src/test/scala/org/apache/spark/streaming/InputStreamsSuite.scala +++ b/streaming/src/test/scala/org/apache/spark/streaming/InputStreamsSuite.scala @@ -39,6 +39,7 @@ import org.apache.spark.storage.StorageLevel import org.apache.spark.streaming.scheduler.{StreamingListenerBatchCompleted, StreamingListener} import org.apache.spark.util.{ManualClock, Utils} import org.apache.spark.streaming.dstream.{InputDStream, ReceiverInputDStream} +import org.apache.spark.streaming.rdd.WriteAheadLogBackedBlockRDD import org.apache.spark.streaming.receiver.Receiver class InputStreamsSuite extends TestSuiteBase with BeforeAndAfter { @@ -105,6 +106,36 @@ class InputStreamsSuite extends TestSuiteBase with BeforeAndAfter { } } + test("socket input stream - no block in a batch") { + withTestServer(new TestServer()) { testServer => + testServer.start() + + withStreamingContext(new StreamingContext(conf, batchDuration)) { ssc => + ssc.addStreamingListener(ssc.progressListener) + + val batchCounter = new BatchCounter(ssc) + val networkStream = ssc.socketTextStream( + "localhost", testServer.port, StorageLevel.MEMORY_AND_DISK) + val outputBuffer = new ArrayBuffer[Seq[String]] with SynchronizedBuffer[Seq[String]] + val outputStream = new TestOutputStream(networkStream, outputBuffer) + outputStream.register() + ssc.start() + + val clock = ssc.scheduler.clock.asInstanceOf[ManualClock] + clock.advance(batchDuration.milliseconds) + + // Make sure the first batch is finished + if (!batchCounter.waitUntilBatchesCompleted(1, 30000)) { + fail("Timeout: cannot finish all batches in 30 seconds") + } + + networkStream.generatedRDDs.foreach { case (_, rdd) => + assert(!rdd.isInstanceOf[WriteAheadLogBackedBlockRDD[_]]) + } + } + } + } + test("binary records stream") { val testDir: File = null try { @@ -387,7 +418,7 @@ class TestServer(portToBind: Int = 0) extends Logging { val servingThread = new Thread() { override def run() { try { - while(true) { + while (true) { logInfo("Accepting connections on port " + port) val clientSocket = serverSocket.accept() if (startLatch.getCount == 1) { diff --git a/streaming/src/test/scala/org/apache/spark/streaming/ReceivedBlockHandlerSuite.scala b/streaming/src/test/scala/org/apache/spark/streaming/ReceivedBlockHandlerSuite.scala index 23804237bda8..6c0c926755c2 100644 --- a/streaming/src/test/scala/org/apache/spark/streaming/ReceivedBlockHandlerSuite.scala +++ b/streaming/src/test/scala/org/apache/spark/streaming/ReceivedBlockHandlerSuite.scala @@ -25,7 +25,7 @@ import scala.concurrent.duration._ import scala.language.postfixOps import org.apache.hadoop.conf.Configuration -import org.scalatest.{BeforeAndAfter, FunSuite, Matchers} +import org.scalatest.{BeforeAndAfter, Matchers} import org.scalatest.concurrent.Eventually._ import org.apache.spark._ @@ -41,11 +41,14 @@ import org.apache.spark.util.{ManualClock, Utils} import WriteAheadLogBasedBlockHandler._ import WriteAheadLogSuite._ -class ReceivedBlockHandlerSuite extends FunSuite with BeforeAndAfter with Matchers with Logging { +class ReceivedBlockHandlerSuite + extends SparkFunSuite + with BeforeAndAfter + with Matchers + with Logging { val conf = new SparkConf().set("spark.streaming.receiver.writeAheadLog.rollingIntervalSecs", "1") val hadoopConf = new Configuration() - val storageLevel = StorageLevel.MEMORY_ONLY_SER val streamId = 1 val securityMgr = new SecurityManager(conf) val mapOutputTracker = new MapOutputTrackerMaster(conf) @@ -53,10 +56,12 @@ class ReceivedBlockHandlerSuite extends FunSuite with BeforeAndAfter with Matche val serializer = new KryoSerializer(conf) val manualClock = new ManualClock val blockManagerSize = 10000000 + val blockManagerBuffer = new ArrayBuffer[BlockManager]() var rpcEnv: RpcEnv = null var blockManagerMaster: BlockManagerMaster = null var blockManager: BlockManager = null + var storageLevel: StorageLevel = null var tempDirectory: File = null before { @@ -66,20 +71,21 @@ class ReceivedBlockHandlerSuite extends FunSuite with BeforeAndAfter with Matche blockManagerMaster = new BlockManagerMaster(rpcEnv.setupEndpoint("blockmanager", new BlockManagerMasterEndpoint(rpcEnv, true, conf, new LiveListenerBus)), conf, true) - blockManager = new BlockManager("bm", rpcEnv, blockManagerMaster, serializer, - blockManagerSize, conf, mapOutputTracker, shuffleManager, - new NioBlockTransferService(conf, securityMgr), securityMgr, 0) - blockManager.initialize("app-id") + storageLevel = StorageLevel.MEMORY_ONLY_SER + blockManager = createBlockManager(blockManagerSize, conf) tempDirectory = Utils.createTempDir() manualClock.setTime(0) } after { - if (blockManager != null) { - blockManager.stop() - blockManager = null + for ( blockManager <- blockManagerBuffer ) { + if (blockManager != null) { + blockManager.stop() + } } + blockManager = null + blockManagerBuffer.clear() if (blockManagerMaster != null) { blockManagerMaster.stop() blockManagerMaster = null @@ -170,6 +176,130 @@ class ReceivedBlockHandlerSuite extends FunSuite with BeforeAndAfter with Matche } } + test("Test Block - count messages") { + // Test count with BlockManagedBasedBlockHandler + testCountWithBlockManagerBasedBlockHandler(true) + // Test count with WriteAheadLogBasedBlockHandler + testCountWithBlockManagerBasedBlockHandler(false) + } + + test("Test Block - isFullyConsumed") { + val sparkConf = new SparkConf() + sparkConf.set("spark.storage.unrollMemoryThreshold", "512") + // spark.storage.unrollFraction set to 0.4 for BlockManager + sparkConf.set("spark.storage.unrollFraction", "0.4") + // Block Manager with 12000 * 0.4 = 4800 bytes of free space for unroll + blockManager = createBlockManager(12000, sparkConf) + + // there is not enough space to store this block in MEMORY, + // But BlockManager will be able to sereliaze this block to WAL + // and hence count returns correct value. + testRecordcount(false, StorageLevel.MEMORY_ONLY, + IteratorBlock((List.fill(70)(new Array[Byte](100))).iterator), blockManager, Some(70)) + + // there is not enough space to store this block in MEMORY, + // But BlockManager will be able to sereliaze this block to DISK + // and hence count returns correct value. + testRecordcount(true, StorageLevel.MEMORY_AND_DISK, + IteratorBlock((List.fill(70)(new Array[Byte](100))).iterator), blockManager, Some(70)) + + // there is not enough space to store this block With MEMORY_ONLY StorageLevel. + // BlockManager will not be able to unroll this block + // and hence it will not tryToPut this block, resulting the SparkException + storageLevel = StorageLevel.MEMORY_ONLY + withBlockManagerBasedBlockHandler { handler => + val thrown = intercept[SparkException] { + storeSingleBlock(handler, IteratorBlock((List.fill(70)(new Array[Byte](100))).iterator)) + } + } + } + + private def testCountWithBlockManagerBasedBlockHandler(isBlockManagerBasedBlockHandler: Boolean) { + // ByteBufferBlock-MEMORY_ONLY + testRecordcount(isBlockManagerBasedBlockHandler, StorageLevel.MEMORY_ONLY, + ByteBufferBlock(ByteBuffer.wrap(Array.tabulate(100)(i => i.toByte))), blockManager, None) + // ByteBufferBlock-MEMORY_ONLY_SER + testRecordcount(isBlockManagerBasedBlockHandler, StorageLevel.MEMORY_ONLY_SER, + ByteBufferBlock(ByteBuffer.wrap(Array.tabulate(100)(i => i.toByte))), blockManager, None) + // ArrayBufferBlock-MEMORY_ONLY + testRecordcount(isBlockManagerBasedBlockHandler, StorageLevel.MEMORY_ONLY, + ArrayBufferBlock(ArrayBuffer.fill(25)(0)), blockManager, Some(25)) + // ArrayBufferBlock-MEMORY_ONLY_SER + testRecordcount(isBlockManagerBasedBlockHandler, StorageLevel.MEMORY_ONLY_SER, + ArrayBufferBlock(ArrayBuffer.fill(25)(0)), blockManager, Some(25)) + // ArrayBufferBlock-DISK_ONLY + testRecordcount(isBlockManagerBasedBlockHandler, StorageLevel.DISK_ONLY, + ArrayBufferBlock(ArrayBuffer.fill(50)(0)), blockManager, Some(50)) + // ArrayBufferBlock-MEMORY_AND_DISK + testRecordcount(isBlockManagerBasedBlockHandler, StorageLevel.MEMORY_AND_DISK, + ArrayBufferBlock(ArrayBuffer.fill(75)(0)), blockManager, Some(75)) + // IteratorBlock-MEMORY_ONLY + testRecordcount(isBlockManagerBasedBlockHandler, StorageLevel.MEMORY_ONLY, + IteratorBlock((ArrayBuffer.fill(100)(0)).iterator), blockManager, Some(100)) + // IteratorBlock-MEMORY_ONLY_SER + testRecordcount(isBlockManagerBasedBlockHandler, StorageLevel.MEMORY_ONLY_SER, + IteratorBlock((ArrayBuffer.fill(100)(0)).iterator), blockManager, Some(100)) + // IteratorBlock-DISK_ONLY + testRecordcount(isBlockManagerBasedBlockHandler, StorageLevel.DISK_ONLY, + IteratorBlock((ArrayBuffer.fill(125)(0)).iterator), blockManager, Some(125)) + // IteratorBlock-MEMORY_AND_DISK + testRecordcount(isBlockManagerBasedBlockHandler, StorageLevel.MEMORY_AND_DISK, + IteratorBlock((ArrayBuffer.fill(150)(0)).iterator), blockManager, Some(150)) + } + + private def createBlockManager( + maxMem: Long, + conf: SparkConf, + name: String = SparkContext.DRIVER_IDENTIFIER): BlockManager = { + val transfer = new NioBlockTransferService(conf, securityMgr) + val manager = new BlockManager(name, rpcEnv, blockManagerMaster, serializer, maxMem, conf, + mapOutputTracker, shuffleManager, transfer, securityMgr, 0) + manager.initialize("app-id") + blockManagerBuffer += manager + manager + } + + /** + * Test storing of data using different types of Handler, StorageLevle and ReceivedBlocks + * and verify the correct record count + */ + private def testRecordcount(isBlockManagedBasedBlockHandler: Boolean, + sLevel: StorageLevel, + receivedBlock: ReceivedBlock, + bManager: BlockManager, + expectedNumRecords: Option[Long] + ) { + blockManager = bManager + storageLevel = sLevel + var bId: StreamBlockId = null + try { + if (isBlockManagedBasedBlockHandler) { + // test received block with BlockManager based handler + withBlockManagerBasedBlockHandler { handler => + val (blockId, blockStoreResult) = storeSingleBlock(handler, receivedBlock) + bId = blockId + assert(blockStoreResult.numRecords === expectedNumRecords, + "Message count not matches for a " + + receivedBlock.getClass.getName + + " being inserted using BlockManagerBasedBlockHandler with " + sLevel) + } + } else { + // test received block with WAL based handler + withWriteAheadLogBasedBlockHandler { handler => + val (blockId, blockStoreResult) = storeSingleBlock(handler, receivedBlock) + bId = blockId + assert(blockStoreResult.numRecords === expectedNumRecords, + "Message count not matches for a " + + receivedBlock.getClass.getName + + " being inserted using WriteAheadLogBasedBlockHandler with " + sLevel) + } + } + } finally { + // Removing the Block Id to use same blockManager for next test + blockManager.removeBlock(bId, true) + } + } + /** * Test storing of data using different forms of ReceivedBlocks and verify that they succeeded * using the given verification function @@ -247,9 +377,21 @@ class ReceivedBlockHandlerSuite extends FunSuite with BeforeAndAfter with Matche (blockIds, storeResults) } + /** Store single block using a handler */ + private def storeSingleBlock( + handler: ReceivedBlockHandler, + block: ReceivedBlock + ): (StreamBlockId, ReceivedBlockStoreResult) = { + val blockId = generateBlockId + val blockStoreResult = handler.storeBlock(blockId, block) + logDebug("Done inserting") + (blockId, blockStoreResult) + } + private def getWriteAheadLogFiles(): Seq[String] = { getLogFilesInDirectory(checkpointDirToLogDir(tempDirectory.toString, streamId)) } private def generateBlockId(): StreamBlockId = StreamBlockId(streamId, scala.util.Random.nextLong) } + diff --git a/streaming/src/test/scala/org/apache/spark/streaming/ReceivedBlockTrackerSuite.scala b/streaming/src/test/scala/org/apache/spark/streaming/ReceivedBlockTrackerSuite.scala index b1af8d5eaacf..f793a12843b2 100644 --- a/streaming/src/test/scala/org/apache/spark/streaming/ReceivedBlockTrackerSuite.scala +++ b/streaming/src/test/scala/org/apache/spark/streaming/ReceivedBlockTrackerSuite.scala @@ -25,10 +25,10 @@ import scala.language.{implicitConversions, postfixOps} import scala.util.Random import org.apache.hadoop.conf.Configuration -import org.scalatest.{BeforeAndAfter, FunSuite, Matchers} +import org.scalatest.{BeforeAndAfter, Matchers} import org.scalatest.concurrent.Eventually._ -import org.apache.spark.{Logging, SparkConf, SparkException} +import org.apache.spark.{Logging, SparkConf, SparkException, SparkFunSuite} import org.apache.spark.storage.StreamBlockId import org.apache.spark.streaming.receiver.BlockManagerBasedStoreResult import org.apache.spark.streaming.scheduler._ @@ -37,7 +37,7 @@ import org.apache.spark.streaming.util.WriteAheadLogSuite._ import org.apache.spark.util.{Clock, ManualClock, SystemClock, Utils} class ReceivedBlockTrackerSuite - extends FunSuite with BeforeAndAfter with Matchers with Logging { + extends SparkFunSuite with BeforeAndAfter with Matchers with Logging { val hadoopConf = new Configuration() val akkaTimeout = 10 seconds @@ -224,8 +224,8 @@ class ReceivedBlockTrackerSuite /** Generate blocks infos using random ids */ def generateBlockInfos(): Seq[ReceivedBlockInfo] = { - List.fill(5)(ReceivedBlockInfo(streamId, 0, None, - BlockManagerBasedStoreResult(StreamBlockId(streamId, math.abs(Random.nextInt))))) + List.fill(5)(ReceivedBlockInfo(streamId, Some(0L), None, + BlockManagerBasedStoreResult(StreamBlockId(streamId, math.abs(Random.nextInt)), Some(0L)))) } /** Get all the data written in the given write ahead log file. */ diff --git a/streaming/src/test/scala/org/apache/spark/streaming/StreamingContextSuite.scala b/streaming/src/test/scala/org/apache/spark/streaming/StreamingContextSuite.scala index 4b12affbb0dd..56b4ce5638a5 100644 --- a/streaming/src/test/scala/org/apache/spark/streaming/StreamingContextSuite.scala +++ b/streaming/src/test/scala/org/apache/spark/streaming/StreamingContextSuite.scala @@ -17,24 +17,26 @@ package org.apache.spark.streaming -import java.io.File +import java.io.{File, NotSerializableException} import java.util.concurrent.atomic.AtomicInteger +import scala.collection.mutable.Queue + import org.apache.commons.io.FileUtils -import org.scalatest.{Assertions, BeforeAndAfter, FunSuite} -import org.scalatest.concurrent.Timeouts import org.scalatest.concurrent.Eventually._ +import org.scalatest.concurrent.Timeouts import org.scalatest.exceptions.TestFailedDueToTimeoutException import org.scalatest.time.SpanSugar._ +import org.scalatest.{Assertions, BeforeAndAfter} -import org.apache.spark.{Logging, SparkConf, SparkContext, SparkException} import org.apache.spark.storage.StorageLevel import org.apache.spark.streaming.dstream.DStream import org.apache.spark.streaming.receiver.Receiver import org.apache.spark.util.Utils +import org.apache.spark.{Logging, SparkConf, SparkContext, SparkException, SparkFunSuite} -class StreamingContextSuite extends FunSuite with BeforeAndAfter with Timeouts with Logging { +class StreamingContextSuite extends SparkFunSuite with BeforeAndAfter with Timeouts with Logging { val master = "local[2]" val appName = this.getClass.getSimpleName @@ -132,6 +134,41 @@ class StreamingContextSuite extends FunSuite with BeforeAndAfter with Timeouts w } } + test("start with non-seriazable DStream checkpoints") { + val checkpointDir = Utils.createTempDir() + ssc = new StreamingContext(conf, batchDuration) + ssc.checkpoint(checkpointDir.getAbsolutePath) + addInputStream(ssc).foreachRDD { rdd => + // Refer to this.appName from inside closure so that this closure refers to + // the instance of StreamingContextSuite, and is therefore not serializable + rdd.count() + appName + } + + // Test whether start() fails early when checkpointing is enabled + val exception = intercept[NotSerializableException] { + ssc.start() + } + assert(exception.getMessage().contains("DStreams with their functions are not serializable")) + assert(ssc.getState() !== StreamingContextState.ACTIVE) + assert(StreamingContext.getActive().isEmpty) + } + + test("start failure should stop internal components") { + ssc = new StreamingContext(conf, batchDuration) + val inputStream = addInputStream(ssc) + val updateFunc = (values: Seq[Int], state: Option[Int]) => { + Some(values.sum + state.getOrElse(0)) + } + inputStream.map(x => (x, 1)).updateStateByKey[Int](updateFunc) + // Require that the start fails because checkpoint directory was not set + intercept[Exception] { + ssc.start() + } + assert(ssc.getState() === StreamingContextState.STOPPED) + assert(ssc.scheduler.isStarted === false) + } + + test("start multiple times") { ssc = new StreamingContext(master, appName, batchDuration) addInputStream(ssc).register() @@ -163,7 +200,7 @@ class StreamingContextSuite extends FunSuite with BeforeAndAfter with Timeouts w ssc = new StreamingContext(master, appName, batchDuration) addInputStream(ssc).register() ssc.stop() - intercept[SparkException] { + intercept[IllegalStateException] { ssc.start() // start after stop should throw exception } assert(ssc.getState() === StreamingContextState.STOPPED) @@ -581,7 +618,7 @@ class StreamingContextSuite extends FunSuite with BeforeAndAfter with Timeouts w val anotherInput = addInputStream(anotherSsc) anotherInput.foreachRDD { rdd => rdd.count } - val exception = intercept[SparkException] { + val exception = intercept[IllegalStateException] { anotherSsc.start() } assert(exception.getMessage.contains("StreamingContext"), "Did not get the right exception") @@ -604,7 +641,7 @@ class StreamingContextSuite extends FunSuite with BeforeAndAfter with Timeouts w def testForException(clue: String, expectedErrorMsg: String)(body: => Unit): Unit = { withClue(clue) { - val ex = intercept[SparkException] { + val ex = intercept[IllegalStateException] { body } assert(ex.getMessage.toLowerCase().contains(expectedErrorMsg)) @@ -630,6 +667,19 @@ class StreamingContextSuite extends FunSuite with BeforeAndAfter with Timeouts w transformed.foreachRDD { rdd => rdd.collect() } } } + test("queueStream doesn't support checkpointing") { + val checkpointDir = Utils.createTempDir() + ssc = new StreamingContext(master, appName, batchDuration) + val rdd = ssc.sparkContext.parallelize(1 to 10) + ssc.queueStream[Int](Queue(rdd)).print() + ssc.checkpoint(checkpointDir.getAbsolutePath) + val e = intercept[NotSerializableException] { + ssc.start() + } + // StreamingContext.validate changes the message, so use "contains" here + assert(e.getMessage.contains("queueStream doesn't support checkpointing")) + } + def addInputStream(s: StreamingContext): DStream[Int] = { val input = (1 to 100).map(i => 1 to i) val inputStream = new TestInputStream(s, input, 1) @@ -713,7 +763,9 @@ class SlowTestReceiver(totalRecords: Int, recordsPerSecond: Int) def onStop() { // Simulate slow receiver by waiting for all records to be produced - while(!SlowTestReceiver.receivedAllRecords) Thread.sleep(100) + while (!SlowTestReceiver.receivedAllRecords) { + Thread.sleep(100) + } // no clean to be done, the receiving thread should stop on it own } } diff --git a/streaming/src/test/scala/org/apache/spark/streaming/StreamingListenerSuite.scala b/streaming/src/test/scala/org/apache/spark/streaming/StreamingListenerSuite.scala index 312cce408cfe..7bc7727a9fbe 100644 --- a/streaming/src/test/scala/org/apache/spark/streaming/StreamingListenerSuite.scala +++ b/streaming/src/test/scala/org/apache/spark/streaming/StreamingListenerSuite.scala @@ -116,7 +116,7 @@ class StreamingListenerSuite extends TestSuiteBase with Matchers { ssc.start() try { - eventually(timeout(2000 millis), interval(20 millis)) { + eventually(timeout(30 seconds), interval(20 millis)) { collector.startedReceiverStreamIds.size should equal (1) collector.startedReceiverStreamIds(0) should equal (0) collector.stoppedReceiverStreamIds should have size 1 @@ -133,8 +133,10 @@ class StreamingListenerSuite extends TestSuiteBase with Matchers { /** Check if a sequence of numbers is in increasing order */ def isInIncreasingOrder(seq: Seq[Long]): Boolean = { - for(i <- 1 until seq.size) { - if (seq(i - 1) > seq(i)) return false + for (i <- 1 until seq.size) { + if (seq(i - 1) > seq(i)) { + return false + } } true } diff --git a/streaming/src/test/scala/org/apache/spark/streaming/TestSuiteBase.scala b/streaming/src/test/scala/org/apache/spark/streaming/TestSuiteBase.scala index 4f70ae7f1f18..31b1aebf6a8e 100644 --- a/streaming/src/test/scala/org/apache/spark/streaming/TestSuiteBase.scala +++ b/streaming/src/test/scala/org/apache/spark/streaming/TestSuiteBase.scala @@ -24,17 +24,35 @@ import scala.collection.mutable.SynchronizedBuffer import scala.language.implicitConversions import scala.reflect.ClassTag -import org.scalatest.{BeforeAndAfter, FunSuite} +import org.scalatest.BeforeAndAfter import org.scalatest.time.{Span, Seconds => ScalaTestSeconds} import org.scalatest.concurrent.Eventually.timeout import org.scalatest.concurrent.PatienceConfiguration -import org.apache.spark.{SparkConf, Logging} +import org.apache.spark.{Logging, SparkConf, SparkFunSuite} import org.apache.spark.rdd.RDD import org.apache.spark.streaming.dstream.{DStream, InputDStream, ForEachDStream} import org.apache.spark.streaming.scheduler._ import org.apache.spark.util.{ManualClock, Utils} +/** + * A dummy stream that does absolutely nothing. + */ +private[streaming] class DummyDStream(ssc: StreamingContext) extends DStream[Int](ssc) { + override def dependencies: List[DStream[Int]] = List.empty + override def slideDuration: Duration = Seconds(1) + override def compute(time: Time): Option[RDD[Int]] = Some(ssc.sc.emptyRDD[Int]) +} + +/** + * A dummy input stream that does absolutely nothing. + */ +private[streaming] class DummyInputDStream(ssc: StreamingContext) extends InputDStream[Int](ssc) { + override def start(): Unit = { } + override def stop(): Unit = { } + override def compute(time: Time): Option[RDD[Int]] = Some(ssc.sc.emptyRDD[Int]) +} + /** * This is a input stream just for the testsuites. This is equivalent to a checkpointable, * replayable, reliable message queue like Kafka. It requires a sequence as input, and @@ -186,7 +204,7 @@ class BatchCounter(ssc: StreamingContext) { * This is the base trait for Spark Streaming testsuites. This provides basic functionality * to run user-defined set of input on user-defined stream operations, and verify the output. */ -trait TestSuiteBase extends FunSuite with BeforeAndAfter with Logging { +trait TestSuiteBase extends SparkFunSuite with BeforeAndAfter with Logging { // Name of the framework for Spark context def framework: String = this.getClass.getSimpleName diff --git a/streaming/src/test/scala/org/apache/spark/streaming/UISeleniumSuite.scala b/streaming/src/test/scala/org/apache/spark/streaming/UISeleniumSuite.scala index 441bbf95d015..a08578680cff 100644 --- a/streaming/src/test/scala/org/apache/spark/streaming/UISeleniumSuite.scala +++ b/streaming/src/test/scala/org/apache/spark/streaming/UISeleniumSuite.scala @@ -27,20 +27,20 @@ import org.scalatest.selenium.WebBrowser import org.scalatest.time.SpanSugar._ import org.apache.spark._ - - - +import org.apache.spark.ui.SparkUICssErrorHandler /** - * Selenium tests for the Spark Web UI. + * Selenium tests for the Spark Streaming Web UI. */ class UISeleniumSuite - extends FunSuite with WebBrowser with Matchers with BeforeAndAfterAll with TestSuiteBase { + extends SparkFunSuite with WebBrowser with Matchers with BeforeAndAfterAll with TestSuiteBase { implicit var webDriver: WebDriver = _ override def beforeAll(): Unit = { - webDriver = new HtmlUnitDriver + webDriver = new HtmlUnitDriver { + getWebClient.setCssErrorHandler(new SparkUICssErrorHandler) + } } override def afterAll(): Unit = { @@ -197,4 +197,3 @@ class UISeleniumSuite } } } - diff --git a/streaming/src/test/scala/org/apache/spark/streaming/rdd/WriteAheadLogBackedBlockRDDSuite.scala b/streaming/src/test/scala/org/apache/spark/streaming/rdd/WriteAheadLogBackedBlockRDDSuite.scala index 6859b65c7165..cb017b798b2a 100644 --- a/streaming/src/test/scala/org/apache/spark/streaming/rdd/WriteAheadLogBackedBlockRDDSuite.scala +++ b/streaming/src/test/scala/org/apache/spark/streaming/rdd/WriteAheadLogBackedBlockRDDSuite.scala @@ -21,15 +21,15 @@ import java.io.File import scala.util.Random import org.apache.hadoop.conf.Configuration -import org.scalatest.{BeforeAndAfterAll, BeforeAndAfterEach, FunSuite} +import org.scalatest.{BeforeAndAfterAll, BeforeAndAfterEach} import org.apache.spark.storage.{BlockId, BlockManager, StorageLevel, StreamBlockId} import org.apache.spark.streaming.util.{FileBasedWriteAheadLogSegment, FileBasedWriteAheadLogWriter} import org.apache.spark.util.Utils -import org.apache.spark.{SparkConf, SparkContext, SparkException} +import org.apache.spark.{SparkConf, SparkContext, SparkException, SparkFunSuite} class WriteAheadLogBackedBlockRDDSuite - extends FunSuite with BeforeAndAfterAll with BeforeAndAfterEach { + extends SparkFunSuite with BeforeAndAfterAll with BeforeAndAfterEach { val conf = new SparkConf() .setMaster("local[2]") diff --git a/streaming/src/test/scala/org/apache/spark/streaming/scheduler/InputInfoTrackerSuite.scala b/streaming/src/test/scala/org/apache/spark/streaming/scheduler/InputInfoTrackerSuite.scala index 5478b4184594..2e210397fe7c 100644 --- a/streaming/src/test/scala/org/apache/spark/streaming/scheduler/InputInfoTrackerSuite.scala +++ b/streaming/src/test/scala/org/apache/spark/streaming/scheduler/InputInfoTrackerSuite.scala @@ -17,12 +17,12 @@ package org.apache.spark.streaming.scheduler -import org.scalatest.{BeforeAndAfter, FunSuite} +import org.scalatest.BeforeAndAfter -import org.apache.spark.SparkConf +import org.apache.spark.{SparkConf, SparkFunSuite} import org.apache.spark.streaming.{Time, Duration, StreamingContext} -class InputInfoTrackerSuite extends FunSuite with BeforeAndAfter { +class InputInfoTrackerSuite extends SparkFunSuite with BeforeAndAfter { private var ssc: StreamingContext = _ diff --git a/streaming/src/test/scala/org/apache/spark/streaming/scheduler/ReceiverTrackerSuite.scala b/streaming/src/test/scala/org/apache/spark/streaming/scheduler/ReceiverTrackerSuite.scala new file mode 100644 index 000000000000..a6e783861dbe --- /dev/null +++ b/streaming/src/test/scala/org/apache/spark/streaming/scheduler/ReceiverTrackerSuite.scala @@ -0,0 +1,90 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.streaming.scheduler + +import org.apache.spark.streaming._ +import org.apache.spark.SparkConf +import org.apache.spark.storage.StorageLevel +import org.apache.spark.streaming.receiver._ +import org.apache.spark.util.Utils + +/** Testsuite for receiver scheduling */ +class ReceiverTrackerSuite extends TestSuiteBase { + val sparkConf = new SparkConf().setMaster("local[8]").setAppName("test") + val ssc = new StreamingContext(sparkConf, Milliseconds(100)) + val tracker = new ReceiverTracker(ssc) + val launcher = new tracker.ReceiverLauncher() + val executors: List[String] = List("0", "1", "2", "3") + + test("receiver scheduling - all or none have preferred location") { + + def parse(s: String): Array[Array[String]] = { + val outerSplit = s.split("\\|") + val loc = new Array[Array[String]](outerSplit.length) + var i = 0 + for (i <- 0 until outerSplit.length) { + loc(i) = outerSplit(i).split("\\,") + } + loc + } + + def testScheduler(numReceivers: Int, preferredLocation: Boolean, allocation: String) { + val receivers = + if (preferredLocation) { + Array.tabulate(numReceivers)(i => new DummyReceiver(host = + Some(((i + 1) % executors.length).toString))) + } else { + Array.tabulate(numReceivers)(_ => new DummyReceiver) + } + val locations = launcher.scheduleReceivers(receivers, executors) + val expectedLocations = parse(allocation) + assert(locations.deep === expectedLocations.deep) + } + + testScheduler(numReceivers = 5, preferredLocation = false, allocation = "0|1|2|3|0") + testScheduler(numReceivers = 3, preferredLocation = false, allocation = "0,3|1|2") + testScheduler(numReceivers = 4, preferredLocation = true, allocation = "1|2|3|0") + } + + test("receiver scheduling - some have preferred location") { + val numReceivers = 4; + val receivers: Seq[Receiver[_]] = Seq(new DummyReceiver(host = Some("1")), + new DummyReceiver, new DummyReceiver, new DummyReceiver) + val locations = launcher.scheduleReceivers(receivers, executors) + assert(locations(0)(0) === "1") + assert(locations(1)(0) === "0") + assert(locations(2)(0) === "1") + assert(locations(0).length === 1) + assert(locations(3).length === 1) + } +} + +/** + * Dummy receiver implementation + */ +private class DummyReceiver(host: Option[String] = None) + extends Receiver[Int](StorageLevel.MEMORY_ONLY) { + + def onStart() { + } + + def onStop() { + } + + override def preferredLocation: Option[String] = host +} diff --git a/streaming/src/test/scala/org/apache/spark/streaming/ui/StreamingJobProgressListenerSuite.scala b/streaming/src/test/scala/org/apache/spark/streaming/ui/StreamingJobProgressListenerSuite.scala index 2a0f45830e03..c9175d61b1f4 100644 --- a/streaming/src/test/scala/org/apache/spark/streaming/ui/StreamingJobProgressListenerSuite.scala +++ b/streaming/src/test/scala/org/apache/spark/streaming/ui/StreamingJobProgressListenerSuite.scala @@ -64,7 +64,7 @@ class StreamingJobProgressListenerSuite extends TestSuiteBase with Matchers { listener.numTotalReceivedRecords should be (0) // onBatchStarted - val batchInfoStarted = BatchInfo(Time(1000), streamIdToNumRecords, 1000, Some(2000), None) + val batchInfoStarted = BatchInfo(Time(1000), streamIdToNumRecords, 1000, Some(2000), None) listener.onBatchStarted(StreamingListenerBatchStarted(batchInfoStarted)) listener.waitingBatches should be (Nil) listener.runningBatches should be (List(BatchUIData(batchInfoStarted))) diff --git a/streaming/src/test/scala/org/apache/spark/streaming/ui/UIUtilsSuite.scala b/streaming/src/test/scala/org/apache/spark/streaming/ui/UIUtilsSuite.scala index e9ab917ab845..d3ca2b58f36c 100644 --- a/streaming/src/test/scala/org/apache/spark/streaming/ui/UIUtilsSuite.scala +++ b/streaming/src/test/scala/org/apache/spark/streaming/ui/UIUtilsSuite.scala @@ -20,10 +20,11 @@ package org.apache.spark.streaming.ui import java.util.TimeZone import java.util.concurrent.TimeUnit -import org.scalatest.FunSuite import org.scalatest.Matchers -class UIUtilsSuite extends FunSuite with Matchers{ +import org.apache.spark.SparkFunSuite + +class UIUtilsSuite extends SparkFunSuite with Matchers{ test("shortTimeUnitString") { assert("ns" === UIUtils.shortTimeUnitString(TimeUnit.NANOSECONDS)) diff --git a/streaming/src/test/scala/org/apache/spark/streaming/util/RateLimitedOutputStreamSuite.scala b/streaming/src/test/scala/org/apache/spark/streaming/util/RateLimitedOutputStreamSuite.scala index 9ebf7b484f42..78fc344b0017 100644 --- a/streaming/src/test/scala/org/apache/spark/streaming/util/RateLimitedOutputStreamSuite.scala +++ b/streaming/src/test/scala/org/apache/spark/streaming/util/RateLimitedOutputStreamSuite.scala @@ -20,9 +20,9 @@ package org.apache.spark.streaming.util import java.io.ByteArrayOutputStream import java.util.concurrent.TimeUnit._ -import org.scalatest.FunSuite +import org.apache.spark.SparkFunSuite -class RateLimitedOutputStreamSuite extends FunSuite { +class RateLimitedOutputStreamSuite extends SparkFunSuite { private def benchmark[U](f: => U): Long = { val start = System.nanoTime diff --git a/streaming/src/test/scala/org/apache/spark/streaming/util/WriteAheadLogSuite.scala b/streaming/src/test/scala/org/apache/spark/streaming/util/WriteAheadLogSuite.scala index 79098bcf4861..325ff7c74c39 100644 --- a/streaming/src/test/scala/org/apache/spark/streaming/util/WriteAheadLogSuite.scala +++ b/streaming/src/test/scala/org/apache/spark/streaming/util/WriteAheadLogSuite.scala @@ -28,15 +28,15 @@ import scala.reflect.ClassTag import org.apache.hadoop.conf.Configuration import org.apache.hadoop.fs.Path import org.scalatest.concurrent.Eventually._ -import org.scalatest.{BeforeAndAfter, FunSuite} +import org.scalatest.BeforeAndAfter import org.apache.spark.util.{ManualClock, Utils} -import org.apache.spark.{SparkConf, SparkException} +import org.apache.spark.{SparkConf, SparkException, SparkFunSuite} -class WriteAheadLogSuite extends FunSuite with BeforeAndAfter { +class WriteAheadLogSuite extends SparkFunSuite with BeforeAndAfter { import WriteAheadLogSuite._ - + val hadoopConf = new Configuration() var tempDir: File = null var testDir: String = null @@ -359,7 +359,7 @@ object WriteAheadLogSuite { ): FileBasedWriteAheadLog = { if (manualClock.getTimeMillis() < 100000) manualClock.setTime(10000) val wal = new FileBasedWriteAheadLog(new SparkConf(), logDirectory, hadoopConf, 1, 1) - + // Ensure that 500 does not get sorted after 2000, so put a high base value. data.foreach { item => manualClock.advance(500) diff --git a/tools/pom.xml b/tools/pom.xml index 1c6f3e83a181..feffde4c857e 100644 --- a/tools/pom.xml +++ b/tools/pom.xml @@ -20,7 +20,7 @@ org.apache.spark spark-parent_2.10 - 1.4.0-SNAPSHOT + 1.5.0-SNAPSHOT ../pom.xml diff --git a/unsafe/pom.xml b/unsafe/pom.xml index 9e151fc7a914..33782c6c66f9 100644 --- a/unsafe/pom.xml +++ b/unsafe/pom.xml @@ -22,7 +22,7 @@ org.apache.spark spark-parent_2.10 - 1.4.0-SNAPSHOT + 1.5.0-SNAPSHOT ../pom.xml @@ -65,6 +65,11 @@ junit-interface test
    + + org.mockito + mockito-core + test + target/scala-${scala.binary.version}/classes @@ -75,7 +80,7 @@ net.alchim31.maven scala-maven-plugin - + -XDignore.symbol.file diff --git a/unsafe/src/main/java/org/apache/spark/unsafe/PlatformDependent.java b/unsafe/src/main/java/org/apache/spark/unsafe/PlatformDependent.java index 24b289209805..192c6714b240 100644 --- a/unsafe/src/main/java/org/apache/spark/unsafe/PlatformDependent.java +++ b/unsafe/src/main/java/org/apache/spark/unsafe/PlatformDependent.java @@ -25,8 +25,7 @@ public final class PlatformDependent { /** * Facade in front of {@link sun.misc.Unsafe}, used to avoid directly exposing Unsafe outside of - * this package. This also lets us aovid accidental use of deprecated methods or methods that - * aren't present in Java 6. + * this package. This also lets us avoid accidental use of deprecated methods. */ public static final class UNSAFE { diff --git a/unsafe/src/main/java/org/apache/spark/unsafe/map/BytesToBytesMap.java b/unsafe/src/main/java/org/apache/spark/unsafe/map/BytesToBytesMap.java index 19d6a169fd2a..0b4d8d286f5f 100644 --- a/unsafe/src/main/java/org/apache/spark/unsafe/map/BytesToBytesMap.java +++ b/unsafe/src/main/java/org/apache/spark/unsafe/map/BytesToBytesMap.java @@ -23,6 +23,8 @@ import java.util.LinkedList; import java.util.List; +import com.google.common.annotations.VisibleForTesting; + import org.apache.spark.unsafe.*; import org.apache.spark.unsafe.array.ByteArrayMethods; import org.apache.spark.unsafe.array.LongArray; @@ -36,9 +38,8 @@ * This is backed by a power-of-2-sized hash table, using quadratic probing with triangular numbers, * which is guaranteed to exhaust the space. *

    - * The map can support up to 2^31 keys because we use 32 bit MurmurHash. If the key cardinality is - * higher than this, you should probably be using sorting instead of hashing for better cache - * locality. + * The map can support up to 2^29 keys. If the key cardinality is higher than this, you should + * probably be using sorting instead of hashing for better cache locality. *

    * This class is not thread safe. */ @@ -48,6 +49,11 @@ public final class BytesToBytesMap { private static final HashMapGrowthStrategy growthStrategy = HashMapGrowthStrategy.DOUBLING; + /** + * Special record length that is placed after the last record in a data page. + */ + private static final int END_OF_PAGE_MARKER = -1; + private final TaskMemoryManager memoryManager; /** @@ -64,7 +70,7 @@ public final class BytesToBytesMap { /** * Offset into `currentDataPage` that points to the location where new data can be inserted into - * the page. + * the page. This does not incorporate the page's base offset. */ private long pageCursor = 0; @@ -74,6 +80,15 @@ public final class BytesToBytesMap { */ private static final long PAGE_SIZE_BYTES = 1L << 26; // 64 megabytes + /** + * The maximum number of keys that BytesToBytesMap supports. The hash table has to be + * power-of-2-sized and its backing Java array can contain at most (1 << 30) elements, since + * that's the largest power-of-2 that's less than Integer.MAX_VALUE. We need two long array + * entries per key, giving us a maximum capacity of (1 << 29). + */ + @VisibleForTesting + static final int MAX_CAPACITY = (1 << 29); + // This choice of page table size and page size means that we can address up to 500 gigabytes // of memory. @@ -143,6 +158,13 @@ public BytesToBytesMap( this.loadFactor = loadFactor; this.loc = new Location(); this.enablePerfMetrics = enablePerfMetrics; + if (initialCapacity <= 0) { + throw new IllegalArgumentException("Initial capacity must be greater than 0"); + } + if (initialCapacity > MAX_CAPACITY) { + throw new IllegalArgumentException( + "Initial capacity " + initialCapacity + " exceeds maximum capacity of " + MAX_CAPACITY); + } allocate(initialCapacity); } @@ -162,6 +184,55 @@ public BytesToBytesMap( */ public int size() { return size; } + private static final class BytesToBytesMapIterator implements Iterator { + + private final int numRecords; + private final Iterator dataPagesIterator; + private final Location loc; + + private int currentRecordNumber = 0; + private Object pageBaseObject; + private long offsetInPage; + + BytesToBytesMapIterator(int numRecords, Iterator dataPagesIterator, Location loc) { + this.numRecords = numRecords; + this.dataPagesIterator = dataPagesIterator; + this.loc = loc; + if (dataPagesIterator.hasNext()) { + advanceToNextPage(); + } + } + + private void advanceToNextPage() { + final MemoryBlock currentPage = dataPagesIterator.next(); + pageBaseObject = currentPage.getBaseObject(); + offsetInPage = currentPage.getBaseOffset(); + } + + @Override + public boolean hasNext() { + return currentRecordNumber != numRecords; + } + + @Override + public Location next() { + int keyLength = (int) PlatformDependent.UNSAFE.getLong(pageBaseObject, offsetInPage); + if (keyLength == END_OF_PAGE_MARKER) { + advanceToNextPage(); + keyLength = (int) PlatformDependent.UNSAFE.getLong(pageBaseObject, offsetInPage); + } + loc.with(pageBaseObject, offsetInPage); + offsetInPage += 8 + 8 + keyLength + loc.getValueLength(); + currentRecordNumber++; + return loc; + } + + @Override + public void remove() { + throw new UnsupportedOperationException(); + } + } + /** * Returns an iterator for iterating over the entries of this map. * @@ -171,27 +242,7 @@ public BytesToBytesMap( * `lookup()`, the behavior of the returned iterator is undefined. */ public Iterator iterator() { - return new Iterator() { - - private int nextPos = bitset.nextSetBit(0); - - @Override - public boolean hasNext() { - return nextPos != -1; - } - - @Override - public Location next() { - final int pos = nextPos; - nextPos = bitset.nextSetBit(nextPos + 1); - return loc.with(pos, 0, true); - } - - @Override - public void remove() { - throw new UnsupportedOperationException(); - } - }; + return new BytesToBytesMapIterator(size, dataPages.iterator(), loc); } /** @@ -268,8 +319,11 @@ public final class Location { private int valueLength; private void updateAddressesAndSizes(long fullKeyAddress) { - final Object page = memoryManager.getPage(fullKeyAddress); - final long keyOffsetInPage = memoryManager.getOffsetInPage(fullKeyAddress); + updateAddressesAndSizes( + memoryManager.getPage(fullKeyAddress), memoryManager.getOffsetInPage(fullKeyAddress)); + } + + private void updateAddressesAndSizes(Object page, long keyOffsetInPage) { long position = keyOffsetInPage; keyLength = (int) PlatformDependent.UNSAFE.getLong(page, position); position += 8; // word used to store the key size @@ -291,6 +345,12 @@ Location with(int pos, int keyHashcode, boolean isDefined) { return this; } + Location with(Object page, long keyOffsetInPage) { + this.isDefined = true; + updateAddressesAndSizes(page, keyOffsetInPage); + return this; + } + /** * Returns true if the key is defined at this position, and false otherwise. */ @@ -345,6 +405,8 @@ public int getValueLength() { *

    * It is only valid to call this method immediately after calling `lookup()` using the same key. *

    + * The key and value must be word-aligned (that is, their sizes must multiples of 8). + *

    * After calling this method, calls to `get[Key|Value]Address()` and `get[Key|Value]Length` * will return information on the data stored by this `putNewKey` call. *

    @@ -367,20 +429,29 @@ public void putNewKey( long valueBaseOffset, int valueLengthBytes) { assert (!isDefined) : "Can only set value once for a key"; - isDefined = true; assert (keyLengthBytes % 8 == 0); assert (valueLengthBytes % 8 == 0); + if (size == MAX_CAPACITY) { + throw new IllegalStateException("BytesToBytesMap has reached maximum capacity"); + } // Here, we'll copy the data into our data pages. Because we only store a relative offset from // the key address instead of storing the absolute address of the value, the key and value // must be stored in the same memory page. // (8 byte key length) (key) (8 byte value length) (value) final long requiredSize = 8 + keyLengthBytes + 8 + valueLengthBytes; - assert(requiredSize <= PAGE_SIZE_BYTES); + assert (requiredSize <= PAGE_SIZE_BYTES - 8); // Reserve 8 bytes for the end-of-page marker. size++; bitset.set(pos); - // If there's not enough space in the current page, allocate a new page: - if (currentDataPage == null || PAGE_SIZE_BYTES - pageCursor < requiredSize) { + // If there's not enough space in the current page, allocate a new page (8 bytes are reserved + // for the end-of-page marker). + if (currentDataPage == null || PAGE_SIZE_BYTES - 8 - pageCursor < requiredSize) { + if (currentDataPage != null) { + // There wasn't enough space in the current page, so write an end-of-page marker: + final Object pageBaseObject = currentDataPage.getBaseObject(); + final long lengthOffsetInPage = currentDataPage.getBaseOffset() + pageCursor; + PlatformDependent.UNSAFE.putLong(pageBaseObject, lengthOffsetInPage, END_OF_PAGE_MARKER); + } MemoryBlock newPage = memoryManager.allocatePage(PAGE_SIZE_BYTES); dataPages.add(newPage); pageCursor = 0; @@ -414,7 +485,7 @@ public void putNewKey( longArray.set(pos * 2 + 1, keyHashcode); updateAddressesAndSizes(storedKeyAddress); isDefined = true; - if (size > growthThreshold) { + if (size > growthThreshold && longArray.size() < MAX_CAPACITY) { growAndRehash(); } } @@ -427,8 +498,11 @@ public void putNewKey( * @param capacity the new map capacity */ private void allocate(int capacity) { - capacity = Math.max((int) Math.min(Integer.MAX_VALUE, nextPowerOf2(capacity)), 64); - longArray = new LongArray(memoryManager.allocate(capacity * 8 * 2)); + assert (capacity >= 0); + // The capacity needs to be divisible by 64 so that our bit set can be sized properly + capacity = Math.max((int) Math.min(MAX_CAPACITY, nextPowerOf2(capacity)), 64); + assert (capacity <= MAX_CAPACITY); + longArray = new LongArray(memoryManager.allocate(capacity * 8L * 2)); bitset = new BitSet(MemoryBlock.fromLongArray(new long[capacity / 64])); this.growthThreshold = (int) (capacity * loadFactor); @@ -494,10 +568,16 @@ public long getNumHashCollisions() { return numHashCollisions; } + @VisibleForTesting + int getNumDataPages() { + return dataPages.size(); + } + /** * Grows the size of the hash table and re-hash everything. */ - private void growAndRehash() { + @VisibleForTesting + void growAndRehash() { long resizeStartTime = -1; if (enablePerfMetrics) { resizeStartTime = System.nanoTime(); @@ -508,7 +588,7 @@ private void growAndRehash() { final int oldCapacity = (int) oldBitSet.capacity(); // Allocate the new data structures - allocate(Math.min(Integer.MAX_VALUE, growthStrategy.nextCapacity(oldCapacity))); + allocate(Math.min(growthStrategy.nextCapacity(oldCapacity), MAX_CAPACITY)); // Re-mask (we don't recompute the hashcode because we stored all 32 bits of it) for (int pos = oldBitSet.nextSetBit(0); pos >= 0; pos = oldBitSet.nextSetBit(pos + 1)) { diff --git a/unsafe/src/main/java/org/apache/spark/unsafe/map/HashMapGrowthStrategy.java b/unsafe/src/main/java/org/apache/spark/unsafe/map/HashMapGrowthStrategy.java index 7c321baffe82..20654e4eeaa0 100644 --- a/unsafe/src/main/java/org/apache/spark/unsafe/map/HashMapGrowthStrategy.java +++ b/unsafe/src/main/java/org/apache/spark/unsafe/map/HashMapGrowthStrategy.java @@ -32,7 +32,9 @@ public interface HashMapGrowthStrategy { class Doubling implements HashMapGrowthStrategy { @Override public int nextCapacity(int currentCapacity) { - return currentCapacity * 2; + assert (currentCapacity > 0); + // Guard against overflow + return (currentCapacity * 2 > 0) ? (currentCapacity * 2) : Integer.MAX_VALUE; } } diff --git a/unsafe/src/main/java/org/apache/spark/unsafe/memory/ExecutorMemoryManager.java b/unsafe/src/main/java/org/apache/spark/unsafe/memory/ExecutorMemoryManager.java index 62c29c8cc1e4..cbbe8594627a 100644 --- a/unsafe/src/main/java/org/apache/spark/unsafe/memory/ExecutorMemoryManager.java +++ b/unsafe/src/main/java/org/apache/spark/unsafe/memory/ExecutorMemoryManager.java @@ -17,6 +17,12 @@ package org.apache.spark.unsafe.memory; +import java.lang.ref.WeakReference; +import java.util.HashMap; +import java.util.LinkedList; +import java.util.Map; +import javax.annotation.concurrent.GuardedBy; + /** * Manages memory for an executor. Individual operators / tasks allocate memory through * {@link TaskMemoryManager} objects, which obtain their memory from ExecutorMemoryManager. @@ -33,6 +39,12 @@ public class ExecutorMemoryManager { */ final boolean inHeap; + @GuardedBy("this") + private final Map>> bufferPoolsBySize = + new HashMap>>(); + + private static final int POOLING_THRESHOLD_BYTES = 1024 * 1024; + /** * Construct a new ExecutorMemoryManager. * @@ -43,16 +55,57 @@ public ExecutorMemoryManager(MemoryAllocator allocator) { this.allocator = allocator; } + /** + * Returns true if allocations of the given size should go through the pooling mechanism and + * false otherwise. + */ + private boolean shouldPool(long size) { + // Very small allocations are less likely to benefit from pooling. + // At some point, we should explore supporting pooling for off-heap memory, but for now we'll + // ignore that case in the interest of simplicity. + return size >= POOLING_THRESHOLD_BYTES && allocator instanceof HeapMemoryAllocator; + } + /** * Allocates a contiguous block of memory. Note that the allocated memory is not guaranteed * to be zeroed out (call `zero()` on the result if this is necessary). */ MemoryBlock allocate(long size) throws OutOfMemoryError { - return allocator.allocate(size); + if (shouldPool(size)) { + synchronized (this) { + final LinkedList> pool = bufferPoolsBySize.get(size); + if (pool != null) { + while (!pool.isEmpty()) { + final WeakReference blockReference = pool.pop(); + final MemoryBlock memory = blockReference.get(); + if (memory != null) { + assert (memory.size() == size); + return memory; + } + } + bufferPoolsBySize.remove(size); + } + } + return allocator.allocate(size); + } else { + return allocator.allocate(size); + } } void free(MemoryBlock memory) { - allocator.free(memory); + final long size = memory.size(); + if (shouldPool(size)) { + synchronized (this) { + LinkedList> pool = bufferPoolsBySize.get(size); + if (pool == null) { + pool = new LinkedList>(); + bufferPoolsBySize.put(size, pool); + } + pool.add(new WeakReference(memory)); + } + } else { + allocator.free(memory); + } } } diff --git a/unsafe/src/main/java/org/apache/spark/unsafe/memory/TaskMemoryManager.java b/unsafe/src/main/java/org/apache/spark/unsafe/memory/TaskMemoryManager.java index 2906ac8abad1..10881969dbc7 100644 --- a/unsafe/src/main/java/org/apache/spark/unsafe/memory/TaskMemoryManager.java +++ b/unsafe/src/main/java/org/apache/spark/unsafe/memory/TaskMemoryManager.java @@ -44,7 +44,7 @@ * maximum size of a long[] array, allowing us to address 8192 * 2^32 * 8 bytes, which is * approximately 35 terabytes of memory. */ -public final class TaskMemoryManager { +public class TaskMemoryManager { private final Logger logger = LoggerFactory.getLogger(TaskMemoryManager.class); diff --git a/unsafe/src/main/java/org/apache/spark/unsafe/types/UTF8String.java b/unsafe/src/main/java/org/apache/spark/unsafe/types/UTF8String.java new file mode 100644 index 000000000000..9302b472925e --- /dev/null +++ b/unsafe/src/main/java/org/apache/spark/unsafe/types/UTF8String.java @@ -0,0 +1,214 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.unsafe.types; + +import javax.annotation.Nonnull; +import java.io.Serializable; +import java.io.UnsupportedEncodingException; +import java.util.Arrays; + +import org.apache.spark.unsafe.PlatformDependent; + +/** + * A UTF-8 String for internal Spark use. + *

    + * A String encoded in UTF-8 as an Array[Byte], which can be used for comparison, + * search, see http://en.wikipedia.org/wiki/UTF-8 for details. + *

    + * Note: This is not designed for general use cases, should not be used outside SQL. + */ +public final class UTF8String implements Comparable, Serializable { + + @Nonnull + private byte[] bytes; + + private static int[] bytesOfCodePointInUTF8 = {2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, + 4, 4, 4, 4, 4, 4, 4, 4, + 5, 5, 5, 5, + 6, 6, 6, 6}; + + public static UTF8String fromBytes(byte[] bytes) { + return (bytes != null) ? new UTF8String().set(bytes) : null; + } + + public static UTF8String fromString(String str) { + return (str != null) ? new UTF8String().set(str) : null; + } + + /** + * Updates the UTF8String with String. + */ + protected UTF8String set(final String str) { + try { + bytes = str.getBytes("utf-8"); + } catch (UnsupportedEncodingException e) { + // Turn the exception into unchecked so we can find out about it at runtime, but + // don't need to add lots of boilerplate code everywhere. + PlatformDependent.throwException(e); + } + return this; + } + + /** + * Updates the UTF8String with byte[], which should be encoded in UTF-8. + */ + protected UTF8String set(final byte[] bytes) { + this.bytes = bytes; + return this; + } + + /** + * Returns the number of bytes for a code point with the first byte as `b` + * @param b The first byte of a code point + */ + public int numBytes(final byte b) { + final int offset = (b & 0xFF) - 192; + return (offset >= 0) ? bytesOfCodePointInUTF8[offset] : 1; + } + + /** + * Returns the number of code points in it. + * + * This is only used by Substring() when `start` is negative. + */ + public int length() { + int len = 0; + for (int i = 0; i < bytes.length; i+= numBytes(bytes[i])) { + len += 1; + } + return len; + } + + public byte[] getBytes() { + return bytes; + } + + /** + * Returns a substring of this. + * @param start the position of first code point + * @param until the position after last code point, exclusive. + */ + public UTF8String substring(final int start, final int until) { + if (until <= start || start >= bytes.length) { + return UTF8String.fromBytes(new byte[0]); + } + + int i = 0; + int c = 0; + for (; i < bytes.length && c < start; i += numBytes(bytes[i])) { + c += 1; + } + + int j = i; + for (; j < bytes.length && c < until; j += numBytes(bytes[i])) { + c += 1; + } + + return UTF8String.fromBytes(Arrays.copyOfRange(bytes, i, j)); + } + + public boolean contains(final UTF8String substring) { + final byte[] b = substring.getBytes(); + if (b.length == 0) { + return true; + } + + for (int i = 0; i <= bytes.length - b.length; i++) { + if (bytes[i] == b[0] && startsWith(b, i)) { + return true; + } + } + return false; + } + + private boolean startsWith(final byte[] prefix, int offsetInBytes) { + if (prefix.length + offsetInBytes > bytes.length || offsetInBytes < 0) { + return false; + } + int i = 0; + while (i < prefix.length && prefix[i] == bytes[i + offsetInBytes]) { + i++; + } + return i == prefix.length; + } + + public boolean startsWith(final UTF8String prefix) { + return startsWith(prefix.getBytes(), 0); + } + + public boolean endsWith(final UTF8String suffix) { + return startsWith(suffix.getBytes(), bytes.length - suffix.getBytes().length); + } + + public UTF8String toUpperCase() { + return UTF8String.fromString(toString().toUpperCase()); + } + + public UTF8String toLowerCase() { + return UTF8String.fromString(toString().toLowerCase()); + } + + @Override + public String toString() { + try { + return new String(bytes, "utf-8"); + } catch (UnsupportedEncodingException e) { + // Turn the exception into unchecked so we can find out about it at runtime, but + // don't need to add lots of boilerplate code everywhere. + PlatformDependent.throwException(e); + return "unknown"; // we will never reach here. + } + } + + @Override + public UTF8String clone() { + return new UTF8String().set(bytes); + } + + @Override + public int compareTo(final UTF8String other) { + final byte[] b = other.getBytes(); + for (int i = 0; i < bytes.length && i < b.length; i++) { + int res = bytes[i] - b[i]; + if (res != 0) { + return res; + } + } + return bytes.length - b.length; + } + + public int compare(final UTF8String other) { + return compareTo(other); + } + + @Override + public boolean equals(final Object other) { + if (other instanceof UTF8String) { + return Arrays.equals(bytes, ((UTF8String) other).getBytes()); + } else { + return false; + } + } + + @Override + public int hashCode() { + return Arrays.hashCode(bytes); + } +} diff --git a/unsafe/src/test/java/org/apache/spark/unsafe/bitset/BitSetSuite.java b/unsafe/src/test/java/org/apache/spark/unsafe/bitset/BitSetSuite.java index 18393db9f382..a93fc0ee297c 100644 --- a/unsafe/src/test/java/org/apache/spark/unsafe/bitset/BitSetSuite.java +++ b/unsafe/src/test/java/org/apache/spark/unsafe/bitset/BitSetSuite.java @@ -18,7 +18,6 @@ package org.apache.spark.unsafe.bitset; import junit.framework.Assert; -import org.apache.spark.unsafe.bitset.BitSet; import org.junit.Test; import org.apache.spark.unsafe.memory.MemoryBlock; diff --git a/unsafe/src/test/java/org/apache/spark/unsafe/map/AbstractBytesToBytesMapSuite.java b/unsafe/src/test/java/org/apache/spark/unsafe/map/AbstractBytesToBytesMapSuite.java index 7a5c0622d1ff..81315f7c9464 100644 --- a/unsafe/src/test/java/org/apache/spark/unsafe/map/AbstractBytesToBytesMapSuite.java +++ b/unsafe/src/test/java/org/apache/spark/unsafe/map/AbstractBytesToBytesMapSuite.java @@ -25,24 +25,40 @@ import org.junit.Assert; import org.junit.Before; import org.junit.Test; +import org.mockito.invocation.InvocationOnMock; +import org.mockito.stubbing.Answer; +import static org.mockito.AdditionalMatchers.geq; +import static org.mockito.Mockito.*; import org.apache.spark.unsafe.array.ByteArrayMethods; +import org.apache.spark.unsafe.memory.*; import org.apache.spark.unsafe.PlatformDependent; import static org.apache.spark.unsafe.PlatformDependent.BYTE_ARRAY_OFFSET; -import org.apache.spark.unsafe.memory.ExecutorMemoryManager; -import org.apache.spark.unsafe.memory.MemoryAllocator; -import org.apache.spark.unsafe.memory.MemoryLocation; -import org.apache.spark.unsafe.memory.TaskMemoryManager; +import static org.apache.spark.unsafe.PlatformDependent.LONG_ARRAY_OFFSET; + public abstract class AbstractBytesToBytesMapSuite { private final Random rand = new Random(42); private TaskMemoryManager memoryManager; + private TaskMemoryManager sizeLimitedMemoryManager; @Before public void setup() { memoryManager = new TaskMemoryManager(new ExecutorMemoryManager(getMemoryAllocator())); + // Mocked memory manager for tests that check the maximum array size, since actually allocating + // such large arrays will cause us to run out of memory in our tests. + sizeLimitedMemoryManager = spy(memoryManager); + when(sizeLimitedMemoryManager.allocate(geq(1L << 20))).thenAnswer(new Answer() { + @Override + public MemoryBlock answer(InvocationOnMock invocation) throws Throwable { + if (((Long) invocation.getArguments()[0] / 8) > Integer.MAX_VALUE) { + throw new OutOfMemoryError("Requested array size exceeds VM limit"); + } + return memoryManager.allocate(1L << 20); + } + }); } @After @@ -101,6 +117,7 @@ public void emptyMap() { final int keyLengthInBytes = keyLengthInWords * 8; final byte[] key = getRandomByteArray(keyLengthInWords); Assert.assertFalse(map.lookup(key, BYTE_ARRAY_OFFSET, keyLengthInBytes).isDefined()); + Assert.assertFalse(map.iterator().hasNext()); } finally { map.free(); } @@ -159,7 +176,7 @@ public void setAndRetrieveAKey() { @Test public void iteratorTest() throws Exception { - final int size = 128; + final int size = 4096; BytesToBytesMap map = new BytesToBytesMap(memoryManager, size / 2); try { for (long i = 0; i < size; i++) { @@ -167,14 +184,26 @@ public void iteratorTest() throws Exception { final BytesToBytesMap.Location loc = map.lookup(value, PlatformDependent.LONG_ARRAY_OFFSET, 8); Assert.assertFalse(loc.isDefined()); - loc.putNewKey( - value, - PlatformDependent.LONG_ARRAY_OFFSET, - 8, - value, - PlatformDependent.LONG_ARRAY_OFFSET, - 8 - ); + // Ensure that we store some zero-length keys + if (i % 5 == 0) { + loc.putNewKey( + null, + PlatformDependent.LONG_ARRAY_OFFSET, + 0, + value, + PlatformDependent.LONG_ARRAY_OFFSET, + 8 + ); + } else { + loc.putNewKey( + value, + PlatformDependent.LONG_ARRAY_OFFSET, + 8, + value, + PlatformDependent.LONG_ARRAY_OFFSET, + 8 + ); + } } final java.util.BitSet valuesSeen = new java.util.BitSet(size); final Iterator iter = map.iterator(); @@ -183,11 +212,16 @@ public void iteratorTest() throws Exception { Assert.assertTrue(loc.isDefined()); final MemoryLocation keyAddress = loc.getKeyAddress(); final MemoryLocation valueAddress = loc.getValueAddress(); - final long key = PlatformDependent.UNSAFE.getLong( - keyAddress.getBaseObject(), keyAddress.getBaseOffset()); final long value = PlatformDependent.UNSAFE.getLong( valueAddress.getBaseObject(), valueAddress.getBaseOffset()); - Assert.assertEquals(key, value); + final long keyLength = loc.getKeyLength(); + if (keyLength == 0) { + Assert.assertTrue("value " + value + " was not divisible by 5", value % 5 == 0); + } else { + final long key = PlatformDependent.UNSAFE.getLong( + keyAddress.getBaseObject(), keyAddress.getBaseOffset()); + Assert.assertEquals(value, key); + } valuesSeen.set((int) value); } Assert.assertEquals(size, valuesSeen.cardinality()); @@ -196,6 +230,74 @@ public void iteratorTest() throws Exception { } } + @Test + public void iteratingOverDataPagesWithWastedSpace() throws Exception { + final int NUM_ENTRIES = 1000 * 1000; + final int KEY_LENGTH = 16; + final int VALUE_LENGTH = 40; + final BytesToBytesMap map = new BytesToBytesMap(memoryManager, NUM_ENTRIES); + // Each record will take 8 + 8 + 16 + 40 = 72 bytes of space in the data page. Our 64-megabyte + // pages won't be evenly-divisible by records of this size, which will cause us to waste some + // space at the end of the page. This is necessary in order for us to take the end-of-record + // handling branch in iterator(). + try { + for (int i = 0; i < NUM_ENTRIES; i++) { + final long[] key = new long[] { i, i }; // 2 * 8 = 16 bytes + final long[] value = new long[] { i, i, i, i, i }; // 5 * 8 = 40 bytes + final BytesToBytesMap.Location loc = map.lookup( + key, + LONG_ARRAY_OFFSET, + KEY_LENGTH + ); + Assert.assertFalse(loc.isDefined()); + loc.putNewKey( + key, + LONG_ARRAY_OFFSET, + KEY_LENGTH, + value, + LONG_ARRAY_OFFSET, + VALUE_LENGTH + ); + } + Assert.assertEquals(2, map.getNumDataPages()); + + final java.util.BitSet valuesSeen = new java.util.BitSet(NUM_ENTRIES); + final Iterator iter = map.iterator(); + final long key[] = new long[KEY_LENGTH / 8]; + final long value[] = new long[VALUE_LENGTH / 8]; + while (iter.hasNext()) { + final BytesToBytesMap.Location loc = iter.next(); + Assert.assertTrue(loc.isDefined()); + Assert.assertEquals(KEY_LENGTH, loc.getKeyLength()); + Assert.assertEquals(VALUE_LENGTH, loc.getValueLength()); + PlatformDependent.copyMemory( + loc.getKeyAddress().getBaseObject(), + loc.getKeyAddress().getBaseOffset(), + key, + LONG_ARRAY_OFFSET, + KEY_LENGTH + ); + PlatformDependent.copyMemory( + loc.getValueAddress().getBaseObject(), + loc.getValueAddress().getBaseOffset(), + value, + LONG_ARRAY_OFFSET, + VALUE_LENGTH + ); + for (long j : key) { + Assert.assertEquals(key[0], j); + } + for (long j : value) { + Assert.assertEquals(key[0], j); + } + valuesSeen.set((int) key[0]); + } + Assert.assertEquals(NUM_ENTRIES, valuesSeen.cardinality()); + } finally { + map.free(); + } + } + @Test public void randomizedStressTest() { final int size = 65536; @@ -247,4 +349,35 @@ public void randomizedStressTest() { map.free(); } } + + @Test + public void initialCapacityBoundsChecking() { + try { + new BytesToBytesMap(sizeLimitedMemoryManager, 0); + Assert.fail("Expected IllegalArgumentException to be thrown"); + } catch (IllegalArgumentException e) { + // expected exception + } + + try { + new BytesToBytesMap(sizeLimitedMemoryManager, BytesToBytesMap.MAX_CAPACITY + 1); + Assert.fail("Expected IllegalArgumentException to be thrown"); + } catch (IllegalArgumentException e) { + // expected exception + } + + // Can allocate _at_ the max capacity + BytesToBytesMap map = + new BytesToBytesMap(sizeLimitedMemoryManager, BytesToBytesMap.MAX_CAPACITY); + map.free(); + } + + @Test + public void resizingLargeMap() { + // As long as a map's capacity is below the max, we should be able to resize up to the max + BytesToBytesMap map = + new BytesToBytesMap(sizeLimitedMemoryManager, BytesToBytesMap.MAX_CAPACITY - 64); + map.growAndRehash(); + map.free(); + } } diff --git a/unsafe/src/test/java/org/apache/spark/unsafe/types/UTF8StringSuite.java b/unsafe/src/test/java/org/apache/spark/unsafe/types/UTF8StringSuite.java new file mode 100644 index 000000000000..796cdc9dbebd --- /dev/null +++ b/unsafe/src/test/java/org/apache/spark/unsafe/types/UTF8StringSuite.java @@ -0,0 +1,91 @@ +/* +* Licensed to the Apache Software Foundation (ASF) under one or more +* contributor license agreements. See the NOTICE file distributed with +* this work for additional information regarding copyright ownership. +* The ASF licenses this file to You under the Apache License, Version 2.0 +* (the "License"); you may not use this file except in compliance with +* the License. You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*/ + +package org.apache.spark.unsafe.types; + +import java.io.UnsupportedEncodingException; + +import junit.framework.Assert; +import org.junit.Test; + +public class UTF8StringSuite { + + private void checkBasic(String str, int len) throws UnsupportedEncodingException { + Assert.assertEquals(UTF8String.fromString(str).length(), len); + Assert.assertEquals(UTF8String.fromBytes(str.getBytes("utf8")).length(), len); + + Assert.assertEquals(UTF8String.fromString(str).toString(), str); + Assert.assertEquals(UTF8String.fromBytes(str.getBytes("utf8")).toString(), str); + Assert.assertEquals(UTF8String.fromBytes(str.getBytes("utf8")), UTF8String.fromString(str)); + + Assert.assertEquals(UTF8String.fromString(str).hashCode(), + UTF8String.fromBytes(str.getBytes("utf8")).hashCode()); + } + + @Test + public void basicTest() throws UnsupportedEncodingException { + checkBasic("hello", 5); + checkBasic("世 界", 3); + } + + @Test + public void contains() { + Assert.assertTrue(UTF8String.fromString("hello").contains(UTF8String.fromString("ello"))); + Assert.assertFalse(UTF8String.fromString("hello").contains(UTF8String.fromString("vello"))); + Assert.assertFalse(UTF8String.fromString("hello").contains(UTF8String.fromString("hellooo"))); + Assert.assertTrue(UTF8String.fromString("大千世界").contains(UTF8String.fromString("千世"))); + Assert.assertFalse(UTF8String.fromString("大千世界").contains(UTF8String.fromString("世千"))); + Assert.assertFalse( + UTF8String.fromString("大千世界").contains(UTF8String.fromString("大千世界好"))); + } + + @Test + public void startsWith() { + Assert.assertTrue(UTF8String.fromString("hello").startsWith(UTF8String.fromString("hell"))); + Assert.assertFalse(UTF8String.fromString("hello").startsWith(UTF8String.fromString("ell"))); + Assert.assertFalse(UTF8String.fromString("hello").startsWith(UTF8String.fromString("hellooo"))); + Assert.assertTrue(UTF8String.fromString("数据砖头").startsWith(UTF8String.fromString("数据"))); + Assert.assertFalse(UTF8String.fromString("大千世界").startsWith(UTF8String.fromString("千"))); + Assert.assertFalse( + UTF8String.fromString("大千世界").startsWith(UTF8String.fromString("大千世界好"))); + } + + @Test + public void endsWith() { + Assert.assertTrue(UTF8String.fromString("hello").endsWith(UTF8String.fromString("ello"))); + Assert.assertFalse(UTF8String.fromString("hello").endsWith(UTF8String.fromString("ellov"))); + Assert.assertFalse(UTF8String.fromString("hello").endsWith(UTF8String.fromString("hhhello"))); + Assert.assertTrue(UTF8String.fromString("大千世界").endsWith(UTF8String.fromString("世界"))); + Assert.assertFalse(UTF8String.fromString("大千世界").endsWith(UTF8String.fromString("世"))); + Assert.assertFalse( + UTF8String.fromString("数据砖头").endsWith(UTF8String.fromString("我的数据砖头"))); + } + + @Test + public void substring() { + Assert.assertEquals( + UTF8String.fromString("hello").substring(0, 0), UTF8String.fromString("")); + Assert.assertEquals( + UTF8String.fromString("hello").substring(1, 3), UTF8String.fromString("el")); + Assert.assertEquals( + UTF8String.fromString("数据砖头").substring(0, 1), UTF8String.fromString("数")); + Assert.assertEquals( + UTF8String.fromString("数据砖头").substring(1, 3), UTF8String.fromString("据砖")); + Assert.assertEquals( + UTF8String.fromString("数据砖头").substring(3, 5), UTF8String.fromString("头")); + } +} diff --git a/yarn/pom.xml b/yarn/pom.xml index 00d219f83670..2aeed98285aa 100644 --- a/yarn/pom.xml +++ b/yarn/pom.xml @@ -20,7 +20,7 @@ org.apache.spark spark-parent_2.10 - 1.4.0-SNAPSHOT + 1.5.0-SNAPSHOT ../pom.xml @@ -39,6 +39,13 @@ spark-core_${scala.binary.version} ${project.version} + + org.apache.spark + spark-core_${scala.binary.version} + ${project.version} + test-jar + test + org.apache.hadoop hadoop-yarn-api @@ -100,7 +107,7 @@ org.mockito - mockito-all + mockito-core test diff --git a/yarn/src/main/scala/org/apache/spark/deploy/yarn/AMDelegationTokenRenewer.scala b/yarn/src/main/scala/org/apache/spark/deploy/yarn/AMDelegationTokenRenewer.scala index aaae6f9734a8..56e4741b9387 100644 --- a/yarn/src/main/scala/org/apache/spark/deploy/yarn/AMDelegationTokenRenewer.scala +++ b/yarn/src/main/scala/org/apache/spark/deploy/yarn/AMDelegationTokenRenewer.scala @@ -60,8 +60,13 @@ private[yarn] class AMDelegationTokenRenewer( private val hadoopUtil = YarnSparkHadoopUtil.get - private val daysToKeepFiles = sparkConf.getInt("spark.yarn.credentials.file.retention.days", 5) - private val numFilesToKeep = sparkConf.getInt("spark.yarn.credentials.file.retention.count", 5) + private val credentialsFile = sparkConf.get("spark.yarn.credentials.file") + private val daysToKeepFiles = + sparkConf.getInt("spark.yarn.credentials.file.retention.days", 5) + private val numFilesToKeep = + sparkConf.getInt("spark.yarn.credentials.file.retention.count", 5) + private val freshHadoopConf = + hadoopUtil.getConfBypassingFSCache(hadoopConf, new Path(credentialsFile).toUri.getScheme) /** * Schedule a login from the keytab and principal set using the --principal and --keytab @@ -120,8 +125,8 @@ private[yarn] class AMDelegationTokenRenewer( private def cleanupOldFiles(): Unit = { import scala.concurrent.duration._ try { - val remoteFs = FileSystem.get(hadoopConf) - val credentialsPath = new Path(sparkConf.get("spark.yarn.credentials.file")) + val remoteFs = FileSystem.get(freshHadoopConf) + val credentialsPath = new Path(credentialsFile) val thresholdTime = System.currentTimeMillis() - (daysToKeepFiles days).toMillis hadoopUtil.listFilesSorted( remoteFs, credentialsPath.getParent, @@ -160,19 +165,19 @@ private[yarn] class AMDelegationTokenRenewer( val keytabLoggedInUGI = UserGroupInformation.loginUserFromKeytabAndReturnUGI(principal, keytab) logInfo("Successfully logged into KDC.") val tempCreds = keytabLoggedInUGI.getCredentials - val credentialsPath = new Path(sparkConf.get("spark.yarn.credentials.file")) + val credentialsPath = new Path(credentialsFile) val dst = credentialsPath.getParent keytabLoggedInUGI.doAs(new PrivilegedExceptionAction[Void] { // Get a copy of the credentials override def run(): Void = { val nns = YarnSparkHadoopUtil.get.getNameNodesToAccess(sparkConf) + dst - hadoopUtil.obtainTokensForNamenodes(nns, hadoopConf, tempCreds) + hadoopUtil.obtainTokensForNamenodes(nns, freshHadoopConf, tempCreds) null } }) // Add the temp credentials back to the original ones. UserGroupInformation.getCurrentUser.addCredentials(tempCreds) - val remoteFs = FileSystem.get(hadoopConf) + val remoteFs = FileSystem.get(freshHadoopConf) // If lastCredentialsFileSuffix is 0, then the AM is either started or restarted. If the AM // was restarted, then the lastCredentialsFileSuffix might be > 0, so find the newest file // and update the lastCredentialsFileSuffix. @@ -186,13 +191,12 @@ private[yarn] class AMDelegationTokenRenewer( } val nextSuffix = lastCredentialsFileSuffix + 1 val tokenPathStr = - sparkConf.get("spark.yarn.credentials.file") + - SparkHadoopUtil.SPARK_YARN_CREDS_COUNTER_DELIM + nextSuffix + credentialsFile + SparkHadoopUtil.SPARK_YARN_CREDS_COUNTER_DELIM + nextSuffix val tokenPath = new Path(tokenPathStr) val tempTokenPath = new Path(tokenPathStr + SparkHadoopUtil.SPARK_YARN_CREDS_TEMP_EXTENSION) logInfo("Writing out delegation tokens to " + tempTokenPath.toString) val credentials = UserGroupInformation.getCurrentUser.getCredentials - credentials.writeTokenStorageFile(tempTokenPath, hadoopConf) + credentials.writeTokenStorageFile(tempTokenPath, freshHadoopConf) logInfo(s"Delegation Tokens written out successfully. Renaming file to $tokenPathStr") remoteFs.rename(tempTokenPath, tokenPath) logInfo("Delegation token file rename complete.") diff --git a/yarn/src/main/scala/org/apache/spark/deploy/yarn/ApplicationMaster.scala b/yarn/src/main/scala/org/apache/spark/deploy/yarn/ApplicationMaster.scala index 29752969e615..83dafa4a125d 100644 --- a/yarn/src/main/scala/org/apache/spark/deploy/yarn/ApplicationMaster.scala +++ b/yarn/src/main/scala/org/apache/spark/deploy/yarn/ApplicationMaster.scala @@ -32,9 +32,9 @@ import org.apache.hadoop.yarn.conf.YarnConfiguration import org.apache.spark.rpc._ import org.apache.spark.{Logging, SecurityManager, SparkConf, SparkContext, SparkEnv} import org.apache.spark.SparkException -import org.apache.spark.deploy.{PythonRunner, SparkHadoopUtil} +import org.apache.spark.deploy.SparkHadoopUtil import org.apache.spark.deploy.history.HistoryServer -import org.apache.spark.scheduler.cluster.YarnSchedulerBackend +import org.apache.spark.scheduler.cluster.{CoarseGrainedSchedulerBackend, YarnSchedulerBackend} import org.apache.spark.scheduler.cluster.CoarseGrainedClusterMessages._ import org.apache.spark.util._ @@ -46,6 +46,14 @@ private[spark] class ApplicationMaster( client: YarnRMClient) extends Logging { + // Load the properties file with the Spark configuration and set entries as system properties, + // so that user code run inside the AM also has access to them. + if (args.propertiesFile != null) { + Utils.getPropertiesFromFile(args.propertiesFile).foreach { case (k, v) => + sys.props(k) = v + } + } + // TODO: Currently, task to container is computed once (TaskSetManager) - which need not be // optimal as more containers are available. Might need to handle this better. @@ -67,6 +75,7 @@ private[spark] class ApplicationMaster( @volatile private var reporterThread: Thread = _ @volatile private var allocator: YarnAllocator = _ + private val allocatorLock = new Object() // Fields used in client mode. private var rpcEnv: RpcEnv = null @@ -220,7 +229,7 @@ private[spark] class ApplicationMaster( sparkContextRef.compareAndSet(sc, null) } - private def registerAM(uiAddress: String, securityMgr: SecurityManager) = { + private def registerAM(_rpcEnv: RpcEnv, uiAddress: String, securityMgr: SecurityManager) = { val sc = sparkContextRef.get() val appId = client.getAttemptId().getApplicationId().toString() @@ -231,8 +240,14 @@ private[spark] class ApplicationMaster( .map { address => s"${address}${HistoryServer.UI_PATH_PREFIX}/${appId}/${attemptId}" } .getOrElse("") - allocator = client.register(yarnConf, - if (sc != null) sc.getConf else sparkConf, + val _sparkConf = if (sc != null) sc.getConf else sparkConf + val driverUrl = _rpcEnv.uriOf( + SparkEnv.driverActorSystemName, + RpcAddress(_sparkConf.get("spark.driver.host"), _sparkConf.get("spark.driver.port").toInt), + CoarseGrainedSchedulerBackend.ENDPOINT_NAME) + allocator = client.register(driverUrl, + yarnConf, + _sparkConf, if (sc != null) sc.preferredNodeLocationData else Map(), uiAddress, historyAddress, @@ -279,7 +294,7 @@ private[spark] class ApplicationMaster( sc.getConf.get("spark.driver.host"), sc.getConf.get("spark.driver.port"), isClusterMode = true) - registerAM(sc.ui.map(_.appUIAddress).getOrElse(""), securityMgr) + registerAM(rpcEnv, sc.ui.map(_.appUIAddress).getOrElse(""), securityMgr) userClassThread.join() } } @@ -289,7 +304,7 @@ private[spark] class ApplicationMaster( rpcEnv = RpcEnv.create("sparkYarnAM", Utils.localHostName, port, sparkConf, securityMgr) waitForSparkDriver() addAmIpFilter() - registerAM(sparkConf.get("spark.driver.appUIAddress", ""), securityMgr) + registerAM(rpcEnv, sparkConf.get("spark.driver.appUIAddress", ""), securityMgr) // In client mode the actor will stop the reporter thread. reporterThread.join() @@ -300,11 +315,14 @@ private[spark] class ApplicationMaster( val expiryInterval = yarnConf.getInt(YarnConfiguration.RM_AM_EXPIRY_INTERVAL_MS, 120000) // we want to be reasonably responsive without causing too many requests to RM. - val schedulerInterval = - sparkConf.getTimeAsMs("spark.yarn.scheduler.heartbeat.interval-ms", "5s") + val heartbeatInterval = math.max(0, math.min(expiryInterval / 2, + sparkConf.getTimeAsMs("spark.yarn.scheduler.heartbeat.interval-ms", "3s"))) - // must be <= expiryInterval / 2. - val interval = math.max(0, math.min(expiryInterval / 2, schedulerInterval)) + // we want to check more frequently for pending containers + val initialAllocationInterval = math.min(heartbeatInterval, + sparkConf.getTimeAsMs("spark.yarn.scheduler.initial-allocation.interval", "200ms")) + + var nextAllocationInterval = initialAllocationInterval // The number of failures in a row until Reporter thread give up val reporterMaxFailures = sparkConf.getInt("spark.yarn.scheduler.reporterThread.maxFailures", 5) @@ -330,15 +348,29 @@ private[spark] class ApplicationMaster( if (!NonFatal(e) || failureCount >= reporterMaxFailures) { finish(FinalApplicationStatus.FAILED, ApplicationMaster.EXIT_REPORTER_FAILURE, "Exception was thrown " + - s"${failureCount} time(s) from Reporter thread.") - + s"$failureCount time(s) from Reporter thread.") } else { - logWarning(s"Reporter thread fails ${failureCount} time(s) in a row.", e) + logWarning(s"Reporter thread fails $failureCount time(s) in a row.", e) } } } try { - Thread.sleep(interval) + val numPendingAllocate = allocator.getNumPendingAllocate + val sleepInterval = + if (numPendingAllocate > 0) { + val currentAllocationInterval = + math.min(heartbeatInterval, nextAllocationInterval) + nextAllocationInterval = currentAllocationInterval * 2 // avoid overflow + currentAllocationInterval + } else { + nextAllocationInterval = initialAllocationInterval + heartbeatInterval + } + logDebug(s"Number of pending allocations is $numPendingAllocate. " + + s"Sleeping for $sleepInterval.") + allocatorLock.synchronized { + allocatorLock.wait(sleepInterval) + } } catch { case e: InterruptedException => } @@ -349,7 +381,8 @@ private[spark] class ApplicationMaster( t.setDaemon(true) t.setName("Reporter") t.start() - logInfo("Started progress reporter thread - sleep time : " + interval) + logInfo(s"Started progress reporter thread with (heartbeat : $heartbeatInterval, " + + s"initial allocation : $initialAllocationInterval) intervals") t } @@ -465,9 +498,11 @@ private[spark] class ApplicationMaster( new MutableURLClassLoader(urls, Utils.getContextOrSparkClassLoader) } + var userArgs = args.userArgs if (args.primaryPyFile != null && args.primaryPyFile.endsWith(".py")) { - System.setProperty("spark.submit.pyFiles", - PythonRunner.formatPaths(args.pyFiles).mkString(",")) + // When running pyspark, the app is run using PythonRunner. The second argument is the list + // of files to add to PYTHONPATH, which Client.scala already handles, so it's empty. + userArgs = Seq(args.primaryPyFile, "") ++ userArgs } if (args.primaryRFile != null && args.primaryRFile.endsWith(".R")) { // TODO(davies): add R dependencies here @@ -478,9 +513,7 @@ private[spark] class ApplicationMaster( val userThread = new Thread { override def run() { try { - val mainArgs = new Array[String](args.userArgs.size) - args.userArgs.copyToArray(mainArgs, 0, args.userArgs.size) - mainMethod.invoke(null, mainArgs) + mainMethod.invoke(null, userArgs.toArray) finish(FinalApplicationStatus.SUCCEEDED, ApplicationMaster.EXIT_SUCCESS) logDebug("Done running users class") } catch { @@ -524,8 +557,15 @@ private[spark] class ApplicationMaster( override def receiveAndReply(context: RpcCallContext): PartialFunction[Any, Unit] = { case RequestExecutors(requestedTotal) => Option(allocator) match { - case Some(a) => a.requestTotalExecutors(requestedTotal) - case None => logWarning("Container allocator is not ready to request executors yet.") + case Some(a) => + allocatorLock.synchronized { + if (a.requestTotalExecutors(requestedTotal)) { + allocatorLock.notifyAll() + } + } + + case None => + logWarning("Container allocator is not ready to request executors yet.") } context.reply(true) diff --git a/yarn/src/main/scala/org/apache/spark/deploy/yarn/ApplicationMasterArguments.scala b/yarn/src/main/scala/org/apache/spark/deploy/yarn/ApplicationMasterArguments.scala index ae6dc1094d72..68e9f6b4db7f 100644 --- a/yarn/src/main/scala/org/apache/spark/deploy/yarn/ApplicationMasterArguments.scala +++ b/yarn/src/main/scala/org/apache/spark/deploy/yarn/ApplicationMasterArguments.scala @@ -26,11 +26,11 @@ class ApplicationMasterArguments(val args: Array[String]) { var userClass: String = null var primaryPyFile: String = null var primaryRFile: String = null - var pyFiles: String = null - var userArgs: Seq[String] = Seq[String]() + var userArgs: Seq[String] = Nil var executorMemory = 1024 var executorCores = 1 var numExecutors = DEFAULT_NUMBER_EXECUTORS + var propertiesFile: String = null parseArgs(args.toList) @@ -59,10 +59,6 @@ class ApplicationMasterArguments(val args: Array[String]) { primaryRFile = value args = tail - case ("--py-files") :: value :: tail => - pyFiles = value - args = tail - case ("--args" | "--arg") :: value :: tail => userArgsBuffer += value args = tail @@ -79,6 +75,10 @@ class ApplicationMasterArguments(val args: Array[String]) { executorCores = value args = tail + case ("--properties-file") :: value :: tail => + propertiesFile = value + args = tail + case _ => printUsageAndExit(1, args) } diff --git a/yarn/src/main/scala/org/apache/spark/deploy/yarn/Client.scala b/yarn/src/main/scala/org/apache/spark/deploy/yarn/Client.scala index 7e023f2d9257..67a5c95400e5 100644 --- a/yarn/src/main/scala/org/apache/spark/deploy/yarn/Client.scala +++ b/yarn/src/main/scala/org/apache/spark/deploy/yarn/Client.scala @@ -17,18 +17,21 @@ package org.apache.spark.deploy.yarn -import java.io.{ByteArrayInputStream, DataInputStream, File, FileOutputStream, IOException} +import java.io.{ByteArrayInputStream, DataInputStream, File, FileOutputStream, IOException, + OutputStreamWriter} import java.net.{InetAddress, UnknownHostException, URI, URISyntaxException} import java.nio.ByteBuffer import java.security.PrivilegedExceptionAction -import java.util.UUID +import java.util.{Properties, UUID} import java.util.zip.{ZipEntry, ZipOutputStream} import scala.collection.JavaConversions._ import scala.collection.mutable.{ArrayBuffer, HashMap, HashSet, ListBuffer, Map} import scala.reflect.runtime.universe import scala.util.{Try, Success, Failure} +import scala.util.control.NonFatal +import com.google.common.base.Charsets.UTF_8 import com.google.common.base.Objects import com.google.common.io.Files @@ -121,24 +124,31 @@ private[spark] class Client( } catch { case e: Throwable => if (appId != null) { - val appStagingDir = getAppStagingDir(appId) - try { - val preserveFiles = sparkConf.getBoolean("spark.yarn.preserve.staging.files", false) - val stagingDirPath = new Path(appStagingDir) - val fs = FileSystem.get(hadoopConf) - if (!preserveFiles && fs.exists(stagingDirPath)) { - logInfo("Deleting staging directory " + stagingDirPath) - fs.delete(stagingDirPath, true) - } - } catch { - case ioe: IOException => - logWarning("Failed to cleanup staging dir " + appStagingDir, ioe) - } + cleanupStagingDir(appId) } throw e } } + /** + * Cleanup application staging directory. + */ + private def cleanupStagingDir(appId: ApplicationId): Unit = { + val appStagingDir = getAppStagingDir(appId) + try { + val preserveFiles = sparkConf.getBoolean("spark.yarn.preserve.staging.files", false) + val stagingDirPath = new Path(appStagingDir) + val fs = FileSystem.get(hadoopConf) + if (!preserveFiles && fs.exists(stagingDirPath)) { + logInfo("Deleting staging directory " + stagingDirPath) + fs.delete(stagingDirPath, true) + } + } catch { + case ioe: IOException => + logWarning("Failed to cleanup staging dir " + appStagingDir, ioe) + } + } + /** * Set up the context for submitting our ApplicationMaster. * This uses the YarnClientApplication not available in the Yarn alpha API. @@ -240,7 +250,9 @@ private[spark] class Client( * This is used for setting up a container launch context for our ApplicationMaster. * Exposed for testing. */ - def prepareLocalResources(appStagingDir: String): HashMap[String, LocalResource] = { + def prepareLocalResources( + appStagingDir: String, + pySparkArchives: Seq[String]): HashMap[String, LocalResource] = { logInfo("Preparing resources for our AM container") // Upload Spark and the application JAR to the remote file system if necessary, // and add them as local resources to the application master. @@ -270,20 +282,6 @@ private[spark] class Client( "for alternatives.") } - // If we passed in a keytab, make sure we copy the keytab to the staging directory on - // HDFS, and setup the relevant environment vars, so the AM can login again. - if (loginFromKeytab) { - logInfo("To enable the AM to login from keytab, credentials are being copied over to the AM" + - " via the YARN Secure Distributed Cache.") - val localUri = new URI(args.keytab) - val localPath = getQualifiedLocalPath(localUri, hadoopConf) - val destinationPath = copyFileToRemote(dst, localPath, replication) - val destFs = FileSystem.get(destinationPath.toUri(), hadoopConf) - distCacheMgr.addResource( - destFs, hadoopConf, destinationPath, localResources, LocalResourceType.FILE, - sparkConf.get("spark.yarn.keytab"), statCache, appMasterOnly = true) - } - def addDistributedUri(uri: URI): Boolean = { val uriStr = uri.toString() if (distributedUris.contains(uriStr)) { @@ -295,6 +293,57 @@ private[spark] class Client( } } + /** + * Distribute a file to the cluster. + * + * If the file's path is a "local:" URI, it's actually not distributed. Other files are copied + * to HDFS (if not already there) and added to the application's distributed cache. + * + * @param path URI of the file to distribute. + * @param resType Type of resource being distributed. + * @param destName Name of the file in the distributed cache. + * @param targetDir Subdirectory where to place the file. + * @param appMasterOnly Whether to distribute only to the AM. + * @return A 2-tuple. First item is whether the file is a "local:" URI. Second item is the + * localized path for non-local paths, or the input `path` for local paths. + * The localized path will be null if the URI has already been added to the cache. + */ + def distribute( + path: String, + resType: LocalResourceType = LocalResourceType.FILE, + destName: Option[String] = None, + targetDir: Option[String] = None, + appMasterOnly: Boolean = false): (Boolean, String) = { + val localURI = new URI(path.trim()) + if (localURI.getScheme != LOCAL_SCHEME) { + if (addDistributedUri(localURI)) { + val localPath = getQualifiedLocalPath(localURI, hadoopConf) + val linkname = targetDir.map(_ + "/").getOrElse("") + + destName.orElse(Option(localURI.getFragment())).getOrElse(localPath.getName()) + val destPath = copyFileToRemote(dst, localPath, replication) + distCacheMgr.addResource( + fs, hadoopConf, destPath, localResources, resType, linkname, statCache, + appMasterOnly = appMasterOnly) + (false, linkname) + } else { + (false, null) + } + } else { + (true, path.trim()) + } + } + + // If we passed in a keytab, make sure we copy the keytab to the staging directory on + // HDFS, and setup the relevant environment vars, so the AM can login again. + if (loginFromKeytab) { + logInfo("To enable the AM to login from keytab, credentials are being copied over to the AM" + + " via the YARN Secure Distributed Cache.") + val (_, localizedPath) = distribute(args.keytab, + destName = Some(sparkConf.get("spark.yarn.keytab")), + appMasterOnly = true) + require(localizedPath != null, "Keytab file already distributed.") + } + /** * Copy the given main resource to the distributed cache if the scheme is not "local". * Otherwise, set the corresponding key in our SparkConf to handle it downstream. @@ -307,33 +356,18 @@ private[spark] class Client( (SPARK_JAR, sparkJar(sparkConf), CONF_SPARK_JAR), (APP_JAR, args.userJar, CONF_SPARK_USER_JAR), ("log4j.properties", oldLog4jConf.orNull, null) - ).foreach { case (destName, _localPath, confKey) => - val localPath: String = if (_localPath != null) _localPath.trim() else "" - if (!localPath.isEmpty()) { - val localURI = new URI(localPath) - if (localURI.getScheme != LOCAL_SCHEME) { - if (addDistributedUri(localURI)) { - val src = getQualifiedLocalPath(localURI, hadoopConf) - val destPath = copyFileToRemote(dst, src, replication) - val destFs = FileSystem.get(destPath.toUri(), hadoopConf) - distCacheMgr.addResource(destFs, hadoopConf, destPath, - localResources, LocalResourceType.FILE, destName, statCache) - } - } else if (confKey != null) { + ).foreach { case (destName, path, confKey) => + if (path != null && !path.trim().isEmpty()) { + val (isLocal, localizedPath) = distribute(path, destName = Some(destName)) + if (isLocal && confKey != null) { + require(localizedPath != null, s"Path $path already distributed.") // If the resource is intended for local use only, handle this downstream // by setting the appropriate property - sparkConf.set(confKey, localPath) + sparkConf.set(confKey, localizedPath) } } } - createConfArchive().foreach { file => - require(addDistributedUri(file.toURI())) - val destPath = copyFileToRemote(dst, new Path(file.toURI()), replication) - distCacheMgr.addResource(fs, hadoopConf, destPath, localResources, LocalResourceType.ARCHIVE, - LOCALIZED_HADOOP_CONF_DIR, statCache, appMasterOnly = true) - } - /** * Do the same for any additional resources passed in through ClientArguments. * Each resource category is represented by a 3-tuple of: @@ -349,21 +383,10 @@ private[spark] class Client( ).foreach { case (flist, resType, addToClasspath) => if (flist != null && !flist.isEmpty()) { flist.split(',').foreach { file => - val localURI = new URI(file.trim()) - if (localURI.getScheme != LOCAL_SCHEME) { - if (addDistributedUri(localURI)) { - val localPath = new Path(localURI) - val linkname = Option(localURI.getFragment()).getOrElse(localPath.getName()) - val destPath = copyFileToRemote(dst, localPath, replication) - distCacheMgr.addResource( - fs, hadoopConf, destPath, localResources, resType, linkname, statCache) - if (addToClasspath) { - cachedSecondaryJarLinks += linkname - } - } - } else if (addToClasspath) { - // Resource is intended for local use only and should be added to the class path - cachedSecondaryJarLinks += file.trim() + val (_, localizedPath) = distribute(file, resType = resType) + require(localizedPath != null) + if (addToClasspath) { + cachedSecondaryJarLinks += localizedPath } } } @@ -372,11 +395,31 @@ private[spark] class Client( sparkConf.set(CONF_SPARK_YARN_SECONDARY_JARS, cachedSecondaryJarLinks.mkString(",")) } + if (isClusterMode && args.primaryPyFile != null) { + distribute(args.primaryPyFile, appMasterOnly = true) + } + + pySparkArchives.foreach { f => distribute(f) } + + // The python files list needs to be treated especially. All files that are not an + // archive need to be placed in a subdirectory that will be added to PYTHONPATH. + args.pyFiles.foreach { f => + val targetDir = if (f.endsWith(".py")) Some(LOCALIZED_PYTHON_DIR) else None + distribute(f, targetDir = targetDir) + } + + // Distribute an archive with Hadoop and Spark configuration for the AM. + val (_, confLocalizedPath) = distribute(createConfArchive().getAbsolutePath(), + resType = LocalResourceType.ARCHIVE, + destName = Some(LOCALIZED_CONF_DIR), + appMasterOnly = true) + require(confLocalizedPath != null) + localResources } /** - * Create an archive with the Hadoop config files for distribution. + * Create an archive with the config files for distribution. * * These are only used by the AM, since executors will use the configuration object broadcast by * the driver. The files are zipped and added to the job as an archive, so that YARN will explode @@ -388,8 +431,11 @@ private[spark] class Client( * * Currently this makes a shallow copy of the conf directory. If there are cases where a * Hadoop config directory contains subdirectories, this code will have to be fixed. + * + * The archive also contains some Spark configuration. Namely, it saves the contents of + * SparkConf in a file to be loaded by the AM process. */ - private def createConfArchive(): Option[File] = { + private def createConfArchive(): File = { val hadoopConfFiles = new HashMap[String, File]() Seq("HADOOP_CONF_DIR", "YARN_CONF_DIR").foreach { envKey => sys.env.get(envKey).foreach { path => @@ -404,28 +450,32 @@ private[spark] class Client( } } - if (!hadoopConfFiles.isEmpty) { - val hadoopConfArchive = File.createTempFile(LOCALIZED_HADOOP_CONF_DIR, ".zip", - new File(Utils.getLocalDir(sparkConf))) + val confArchive = File.createTempFile(LOCALIZED_CONF_DIR, ".zip", + new File(Utils.getLocalDir(sparkConf))) + val confStream = new ZipOutputStream(new FileOutputStream(confArchive)) - val hadoopConfStream = new ZipOutputStream(new FileOutputStream(hadoopConfArchive)) - try { - hadoopConfStream.setLevel(0) - hadoopConfFiles.foreach { case (name, file) => - if (file.canRead()) { - hadoopConfStream.putNextEntry(new ZipEntry(name)) - Files.copy(file, hadoopConfStream) - hadoopConfStream.closeEntry() - } + try { + confStream.setLevel(0) + hadoopConfFiles.foreach { case (name, file) => + if (file.canRead()) { + confStream.putNextEntry(new ZipEntry(name)) + Files.copy(file, confStream) + confStream.closeEntry() } - } finally { - hadoopConfStream.close() } - Some(hadoopConfArchive) - } else { - None + // Save Spark configuration to a file in the archive. + val props = new Properties() + sparkConf.getAll.foreach { case (k, v) => props.setProperty(k, v) } + confStream.putNextEntry(new ZipEntry(SPARK_CONF_FILE)) + val writer = new OutputStreamWriter(confStream, UTF_8) + props.store(writer, "Spark configuration.") + writer.flush() + confStream.closeEntry() + } finally { + confStream.close() } + confArchive } /** @@ -453,7 +503,9 @@ private[spark] class Client( /** * Set up the environment for launching our ApplicationMaster container. */ - private def setupLaunchEnv(stagingDir: String): HashMap[String, String] = { + private def setupLaunchEnv( + stagingDir: String, + pySparkArchives: Seq[String]): HashMap[String, String] = { logInfo("Setting up the launch environment for our AM container") val env = new HashMap[String, String]() val extraCp = sparkConf.getOption("spark.driver.extraClassPath") @@ -471,9 +523,6 @@ private[spark] class Client( val renewalInterval = getTokenRenewalInterval(stagingDirPath) sparkConf.set("spark.yarn.token.renewal.interval", renewalInterval.toString) } - // Set the environment variables to be passed on to the executors. - distCacheMgr.setDistFilesEnv(env) - distCacheMgr.setDistArchivesEnv(env) // Pick up any environment variables for the AM provided through spark.yarn.appMasterEnv.* val amEnvPrefix = "spark.yarn.appMasterEnv." @@ -490,15 +539,32 @@ private[spark] class Client( env("SPARK_YARN_USER_ENV") = userEnvs } - // if spark.submit.pyArchives is in sparkConf, append pyArchives to PYTHONPATH - // that can be passed on to the ApplicationMaster and the executors. - if (sparkConf.contains("spark.submit.pyArchives")) { - var pythonPath = sparkConf.get("spark.submit.pyArchives") - if (env.contains("PYTHONPATH")) { - pythonPath = Seq(env.get("PYTHONPATH"), pythonPath).mkString(File.pathSeparator) + // If pyFiles contains any .py files, we need to add LOCALIZED_PYTHON_DIR to the PYTHONPATH + // of the container processes too. Add all non-.py files directly to PYTHONPATH. + // + // NOTE: the code currently does not handle .py files defined with a "local:" scheme. + val pythonPath = new ListBuffer[String]() + val (pyFiles, pyArchives) = args.pyFiles.partition(_.endsWith(".py")) + if (pyFiles.nonEmpty) { + pythonPath += buildPath(YarnSparkHadoopUtil.expandEnvironment(Environment.PWD), + LOCALIZED_PYTHON_DIR) + } + (pySparkArchives ++ pyArchives).foreach { path => + val uri = new URI(path) + if (uri.getScheme != LOCAL_SCHEME) { + pythonPath += buildPath(YarnSparkHadoopUtil.expandEnvironment(Environment.PWD), + new Path(path).getName()) + } else { + pythonPath += uri.getPath() } - env("PYTHONPATH") = pythonPath - sparkConf.setExecutorEnv("PYTHONPATH", pythonPath) + } + + // Finally, update the Spark config to propagate PYTHONPATH to the AM and executors. + if (pythonPath.nonEmpty) { + val pythonPathStr = (sys.env.get("PYTHONPATH") ++ pythonPath) + .mkString(YarnSparkHadoopUtil.getClassPathSeparator) + env("PYTHONPATH") = pythonPathStr + sparkConf.setExecutorEnv("PYTHONPATH", pythonPathStr) } // In cluster mode, if the deprecated SPARK_JAVA_OPTS is set, we need to propagate it to @@ -548,8 +614,19 @@ private[spark] class Client( logInfo("Setting up container launch context for our AM") val appId = newAppResponse.getApplicationId val appStagingDir = getAppStagingDir(appId) - val localResources = prepareLocalResources(appStagingDir) - val launchEnv = setupLaunchEnv(appStagingDir) + val pySparkArchives = + if (sys.props.getOrElse("spark.yarn.isPython", "false").toBoolean) { + findPySparkArchives() + } else { + Nil + } + val launchEnv = setupLaunchEnv(appStagingDir, pySparkArchives) + val localResources = prepareLocalResources(appStagingDir, pySparkArchives) + + // Set the environment variables to be passed on to the executors. + distCacheMgr.setDistFilesEnv(launchEnv) + distCacheMgr.setDistArchivesEnv(launchEnv) + val amContainer = Records.newRecord(classOf[ContainerLaunchContext]) amContainer.setLocalResources(localResources) amContainer.setEnvironment(launchEnv) @@ -589,13 +666,6 @@ private[spark] class Client( javaOpts += "-XX:CMSIncrementalDutyCycle=10" } - // Forward the Spark configuration to the application master / executors. - // TODO: it might be nicer to pass these as an internal environment variable rather than - // as Java options, due to complications with string parsing of nested quotes. - for ((k, v) <- sparkConf.getAll) { - javaOpts += YarnSparkHadoopUtil.escapeForShell(s"-D$k=$v") - } - // Include driver-specific java options if we are launching a driver if (isClusterMode) { val driverOpts = sparkConf.getOption("spark.driver.extraJavaOptions") @@ -606,7 +676,7 @@ private[spark] class Client( val libraryPaths = Seq(sys.props.get("spark.driver.extraLibraryPath"), sys.props.get("spark.driver.libraryPath")).flatten if (libraryPaths.nonEmpty) { - prefixEnv = Some(Utils.libraryPathEnvPrefix(libraryPaths)) + prefixEnv = Some(getClusterPath(sparkConf, Utils.libraryPathEnvPrefix(libraryPaths))) } if (sparkConf.getOption("spark.yarn.am.extraJavaOptions").isDefined) { logWarning("spark.yarn.am.extraJavaOptions will not take effect in cluster mode") @@ -628,7 +698,7 @@ private[spark] class Client( } sparkConf.getOption("spark.yarn.am.extraLibraryPath").foreach { paths => - prefixEnv = Some(Utils.libraryPathEnvPrefix(Seq(paths))) + prefixEnv = Some(getClusterPath(sparkConf, Utils.libraryPathEnvPrefix(Seq(paths)))) } } @@ -648,14 +718,8 @@ private[spark] class Client( Nil } val primaryPyFile = - if (args.primaryPyFile != null) { - Seq("--primary-py-file", args.primaryPyFile) - } else { - Nil - } - val pyFiles = - if (args.pyFiles != null) { - Seq("--py-files", args.pyFiles) + if (isClusterMode && args.primaryPyFile != null) { + Seq("--primary-py-file", new Path(args.primaryPyFile).getName()) } else { Nil } @@ -671,9 +735,6 @@ private[spark] class Client( } else { Class.forName("org.apache.spark.deploy.yarn.ExecutorLauncher").getName } - if (args.primaryPyFile != null && args.primaryPyFile.endsWith(".py")) { - args.userArgs = ArrayBuffer(args.primaryPyFile, args.pyFiles) ++ args.userArgs - } if (args.primaryRFile != null && args.primaryRFile.endsWith(".R")) { args.userArgs = ArrayBuffer(args.primaryRFile) ++ args.userArgs } @@ -681,11 +742,13 @@ private[spark] class Client( Seq("--arg", YarnSparkHadoopUtil.escapeForShell(arg)) } val amArgs = - Seq(amClass) ++ userClass ++ userJar ++ primaryPyFile ++ pyFiles ++ primaryRFile ++ + Seq(amClass) ++ userClass ++ userJar ++ primaryPyFile ++ primaryRFile ++ userArgs ++ Seq( "--executor-memory", args.executorMemory.toString + "m", "--executor-cores", args.executorCores.toString, - "--num-executors ", args.numExecutors.toString) + "--num-executors ", args.numExecutors.toString, + "--properties-file", buildPath(YarnSparkHadoopUtil.expandEnvironment(Environment.PWD), + LOCALIZED_CONF_DIR, SPARK_CONF_FILE)) // Command for the ApplicationMaster val commands = prefixEnv ++ Seq( @@ -764,6 +827,9 @@ private[spark] class Client( case e: ApplicationNotFoundException => logError(s"Application $appId not found.") return (YarnApplicationState.KILLED, FinalApplicationStatus.KILLED) + case NonFatal(e) => + logError(s"Failed to contact YARN for application $appId.", e) + return (YarnApplicationState.FAILED, FinalApplicationStatus.FAILED) } val state = report.getYarnApplicationState @@ -782,6 +848,7 @@ private[spark] class Client( if (state == YarnApplicationState.FINISHED || state == YarnApplicationState.FAILED || state == YarnApplicationState.KILLED) { + cleanupStagingDir(appId) return (state, report.getFinalApplicationStatus) } @@ -849,6 +916,22 @@ private[spark] class Client( } } } + + private def findPySparkArchives(): Seq[String] = { + sys.env.get("PYSPARK_ARCHIVES_PATH") + .map(_.split(",").toSeq) + .getOrElse { + val pyLibPath = Seq(sys.env("SPARK_HOME"), "python", "lib").mkString(File.separator) + val pyArchivesFile = new File(pyLibPath, "pyspark.zip") + require(pyArchivesFile.exists(), + "pyspark.zip not found; cannot run pyspark application in YARN mode.") + val py4jFile = new File(pyLibPath, "py4j-0.8.2.1-src.zip") + require(py4jFile.exists(), + "py4j-0.8.2.1-src.zip not found; cannot run pyspark application in YARN mode.") + Seq(pyArchivesFile.getAbsolutePath(), py4jFile.getAbsolutePath()) + } + } + } object Client extends Logging { @@ -899,8 +982,14 @@ object Client extends Logging { // Distribution-defined classpath to add to processes val ENV_DIST_CLASSPATH = "SPARK_DIST_CLASSPATH" - // Subdirectory where the user's hadoop config files will be placed. - val LOCALIZED_HADOOP_CONF_DIR = "__hadoop_conf__" + // Subdirectory where the user's Spark and Hadoop config files will be placed. + val LOCALIZED_CONF_DIR = "__spark_conf__" + + // Name of the file in the conf archive containing Spark configuration. + val SPARK_CONF_FILE = "__spark_conf__.properties" + + // Subdirectory where the user's python files (not archives) will be placed. + val LOCALIZED_PYTHON_DIR = "__pyfiles__" /** * Find the user-defined Spark jar if configured, or return the jar containing this @@ -1017,15 +1106,15 @@ object Client extends Logging { env: HashMap[String, String], isAM: Boolean, extraClassPath: Option[String] = None): Unit = { - extraClassPath.foreach(addClasspathEntry(_, env)) - addClasspathEntry( - YarnSparkHadoopUtil.expandEnvironment(Environment.PWD), env - ) + extraClassPath.foreach { cp => + addClasspathEntry(getClusterPath(sparkConf, cp), env) + } + addClasspathEntry(YarnSparkHadoopUtil.expandEnvironment(Environment.PWD), env) if (isAM) { addClasspathEntry( YarnSparkHadoopUtil.expandEnvironment(Environment.PWD) + Path.SEPARATOR + - LOCALIZED_HADOOP_CONF_DIR, env) + LOCALIZED_CONF_DIR, env) } if (sparkConf.getBoolean("spark.yarn.user.classpath.first", false)) { @@ -1036,12 +1125,14 @@ object Client extends Logging { getUserClasspath(sparkConf) } userClassPath.foreach { x => - addFileToClasspath(x, null, env) + addFileToClasspath(sparkConf, x, null, env) } } - addFileToClasspath(new URI(sparkJar(sparkConf)), SPARK_JAR, env) + addFileToClasspath(sparkConf, new URI(sparkJar(sparkConf)), SPARK_JAR, env) populateHadoopClasspath(conf, env) - sys.env.get(ENV_DIST_CLASSPATH).foreach(addClasspathEntry(_, env)) + sys.env.get(ENV_DIST_CLASSPATH).foreach { cp => + addClasspathEntry(getClusterPath(sparkConf, cp), env) + } } /** @@ -1070,16 +1161,18 @@ object Client extends Logging { * * If not a "local:" file and no alternate name, the environment is not modified. * + * @parma conf Spark configuration. * @param uri URI to add to classpath (optional). * @param fileName Alternate name for the file (optional). * @param env Map holding the environment variables. */ private def addFileToClasspath( + conf: SparkConf, uri: URI, fileName: String, env: HashMap[String, String]): Unit = { if (uri != null && uri.getScheme == LOCAL_SCHEME) { - addClasspathEntry(uri.getPath, env) + addClasspathEntry(getClusterPath(conf, uri.getPath), env) } else if (fileName != null) { addClasspathEntry(buildPath( YarnSparkHadoopUtil.expandEnvironment(Environment.PWD), fileName), env) @@ -1093,6 +1186,29 @@ object Client extends Logging { private def addClasspathEntry(path: String, env: HashMap[String, String]): Unit = YarnSparkHadoopUtil.addPathToEnvironment(env, Environment.CLASSPATH.name, path) + /** + * Returns the path to be sent to the NM for a path that is valid on the gateway. + * + * This method uses two configuration values: + * + * - spark.yarn.config.gatewayPath: a string that identifies a portion of the input path that may + * only be valid in the gateway node. + * - spark.yarn.config.replacementPath: a string with which to replace the gateway path. This may + * contain, for example, env variable references, which will be expanded by the NMs when + * starting containers. + * + * If either config is not available, the input path is returned. + */ + def getClusterPath(conf: SparkConf, path: String): String = { + val localPath = conf.get("spark.yarn.config.gatewayPath", null) + val clusterPath = conf.get("spark.yarn.config.replacementPath", null) + if (localPath != null && clusterPath != null) { + path.replace(localPath, clusterPath) + } else { + path + } + } + /** * Obtains token for the Hive metastore and adds them to the credentials. */ @@ -1142,9 +1258,9 @@ object Client extends Logging { logDebug("HiveMetaStore configured in localmode") } } catch { - case e:java.lang.NoSuchMethodException => { logInfo("Hive Method not found " + e); return } - case e:java.lang.ClassNotFoundException => { logInfo("Hive Class not found " + e); return } - case e:Exception => { logError("Unexpected Exception " + e) + case e: java.lang.NoSuchMethodException => { logInfo("Hive Method not found " + e); return } + case e: java.lang.ClassNotFoundException => { logInfo("Hive Class not found " + e); return } + case e: Exception => { logError("Unexpected Exception " + e) throw new RuntimeException("Unexpected exception", e) } } diff --git a/yarn/src/main/scala/org/apache/spark/deploy/yarn/ClientArguments.scala b/yarn/src/main/scala/org/apache/spark/deploy/yarn/ClientArguments.scala index 5653c9f14dc6..19d1bbff9993 100644 --- a/yarn/src/main/scala/org/apache/spark/deploy/yarn/ClientArguments.scala +++ b/yarn/src/main/scala/org/apache/spark/deploy/yarn/ClientArguments.scala @@ -30,7 +30,7 @@ private[spark] class ClientArguments(args: Array[String], sparkConf: SparkConf) var archives: String = null var userJar: String = null var userClass: String = null - var pyFiles: String = null + var pyFiles: Seq[String] = Nil var primaryPyFile: String = null var primaryRFile: String = null var userArgs: ArrayBuffer[String] = new ArrayBuffer[String]() @@ -46,7 +46,7 @@ private[spark] class ClientArguments(args: Array[String], sparkConf: SparkConf) var keytab: String = null def isClusterMode: Boolean = userClass != null - private var driverMemory: Int = 512 // MB + private var driverMemory: Int = Utils.DEFAULT_DRIVER_MEM_MB // MB private var driverCores: Int = 1 private val driverMemOverheadKey = "spark.yarn.driver.memoryOverhead" private val amMemKey = "spark.yarn.am.memory" @@ -98,6 +98,12 @@ private[spark] class ClientArguments(args: Array[String], sparkConf: SparkConf) numExecutors = initialNumExecutors } + principal = Option(principal) + .orElse(sparkConf.getOption("spark.yarn.principal")) + .orNull + keytab = Option(keytab) + .orElse(sparkConf.getOption("spark.yarn.keytab")) + .orNull } /** @@ -222,7 +228,7 @@ private[spark] class ClientArguments(args: Array[String], sparkConf: SparkConf) args = tail case ("--py-files") :: value :: tail => - pyFiles = value + pyFiles = value.split(",") args = tail case ("--files") :: value :: tail => @@ -256,8 +262,9 @@ private[spark] class ClientArguments(args: Array[String], sparkConf: SparkConf) private def getUsageMessage(unknownParam: List[String] = null): String = { val message = if (unknownParam != null) s"Unknown/unsupported param $unknownParam\n" else "" + val mem_mb = Utils.DEFAULT_DRIVER_MEM_MB message + - """ + s""" |Usage: org.apache.spark.deploy.yarn.Client [options] |Options: | --jar JAR_PATH Path to your application's JAR file (required in yarn-cluster @@ -269,7 +276,7 @@ private[spark] class ClientArguments(args: Array[String], sparkConf: SparkConf) | Multiple invocations are possible, each will be passed in order. | --num-executors NUM Number of executors to start (Default: 2) | --executor-cores NUM Number of cores per executor (Default: 1). - | --driver-memory MEM Memory for driver (e.g. 1000M, 2G) (Default: 512 Mb) + | --driver-memory MEM Memory for driver (e.g. 1000M, 2G) (Default: $mem_mb Mb) | --driver-cores NUM Number of cores used by the driver (Default: 1). | --executor-memory MEM Memory per executor (e.g. 1000M, 2G) (Default: 1G) | --name NAME The name of your application (Default: Spark) diff --git a/yarn/src/main/scala/org/apache/spark/deploy/yarn/ClientDistributedCacheManager.scala b/yarn/src/main/scala/org/apache/spark/deploy/yarn/ClientDistributedCacheManager.scala index c592ecfdfce0..3d3a966960e9 100644 --- a/yarn/src/main/scala/org/apache/spark/deploy/yarn/ClientDistributedCacheManager.scala +++ b/yarn/src/main/scala/org/apache/spark/deploy/yarn/ClientDistributedCacheManager.scala @@ -43,22 +43,22 @@ private[spark] class ClientDistributedCacheManager() extends Logging { * Add a resource to the list of distributed cache resources. This list can * be sent to the ApplicationMaster and possibly the executors so that it can * be downloaded into the Hadoop distributed cache for use by this application. - * Adds the LocalResource to the localResources HashMap passed in and saves + * Adds the LocalResource to the localResources HashMap passed in and saves * the stats of the resources to they can be sent to the executors and verified. * * @param fs FileSystem * @param conf Configuration * @param destPath path to the resource * @param localResources localResource hashMap to insert the resource into - * @param resourceType LocalResourceType + * @param resourceType LocalResourceType * @param link link presented in the distributed cache to the destination - * @param statCache cache to store the file/directory stats + * @param statCache cache to store the file/directory stats * @param appMasterOnly Whether to only add the resource to the app master */ def addResource( fs: FileSystem, conf: Configuration, - destPath: Path, + destPath: Path, localResources: HashMap[String, LocalResource], resourceType: LocalResourceType, link: String, @@ -74,15 +74,15 @@ private[spark] class ClientDistributedCacheManager() extends Logging { amJarRsrc.setSize(destStatus.getLen()) if (link == null || link.isEmpty()) throw new Exception("You must specify a valid link name") localResources(link) = amJarRsrc - + if (!appMasterOnly) { val uri = destPath.toUri() val pathURI = new URI(uri.getScheme(), uri.getAuthority(), uri.getPath(), null, link) if (resourceType == LocalResourceType.FILE) { - distCacheFiles(pathURI.toString()) = (destStatus.getLen().toString(), + distCacheFiles(pathURI.toString()) = (destStatus.getLen().toString(), destStatus.getModificationTime().toString(), visibility.name()) } else { - distCacheArchives(pathURI.toString()) = (destStatus.getLen().toString(), + distCacheArchives(pathURI.toString()) = (destStatus.getLen().toString(), destStatus.getModificationTime().toString(), visibility.name()) } } @@ -95,13 +95,13 @@ private[spark] class ClientDistributedCacheManager() extends Logging { val (keys, tupleValues) = distCacheFiles.unzip val (sizes, timeStamps, visibilities) = tupleValues.unzip3 if (keys.size > 0) { - env("SPARK_YARN_CACHE_FILES") = keys.reduceLeft[String] { (acc,n) => acc + "," + n } - env("SPARK_YARN_CACHE_FILES_TIME_STAMPS") = - timeStamps.reduceLeft[String] { (acc,n) => acc + "," + n } - env("SPARK_YARN_CACHE_FILES_FILE_SIZES") = - sizes.reduceLeft[String] { (acc,n) => acc + "," + n } - env("SPARK_YARN_CACHE_FILES_VISIBILITIES") = - visibilities.reduceLeft[String] { (acc,n) => acc + "," + n } + env("SPARK_YARN_CACHE_FILES") = keys.reduceLeft[String] { (acc, n) => acc + "," + n } + env("SPARK_YARN_CACHE_FILES_TIME_STAMPS") = + timeStamps.reduceLeft[String] { (acc, n) => acc + "," + n } + env("SPARK_YARN_CACHE_FILES_FILE_SIZES") = + sizes.reduceLeft[String] { (acc, n) => acc + "," + n } + env("SPARK_YARN_CACHE_FILES_VISIBILITIES") = + visibilities.reduceLeft[String] { (acc, n) => acc + "," + n } } } @@ -112,13 +112,13 @@ private[spark] class ClientDistributedCacheManager() extends Logging { val (keys, tupleValues) = distCacheArchives.unzip val (sizes, timeStamps, visibilities) = tupleValues.unzip3 if (keys.size > 0) { - env("SPARK_YARN_CACHE_ARCHIVES") = keys.reduceLeft[String] { (acc,n) => acc + "," + n } - env("SPARK_YARN_CACHE_ARCHIVES_TIME_STAMPS") = - timeStamps.reduceLeft[String] { (acc,n) => acc + "," + n } + env("SPARK_YARN_CACHE_ARCHIVES") = keys.reduceLeft[String] { (acc, n) => acc + "," + n } + env("SPARK_YARN_CACHE_ARCHIVES_TIME_STAMPS") = + timeStamps.reduceLeft[String] { (acc, n) => acc + "," + n } env("SPARK_YARN_CACHE_ARCHIVES_FILE_SIZES") = - sizes.reduceLeft[String] { (acc,n) => acc + "," + n } - env("SPARK_YARN_CACHE_ARCHIVES_VISIBILITIES") = - visibilities.reduceLeft[String] { (acc,n) => acc + "," + n } + sizes.reduceLeft[String] { (acc, n) => acc + "," + n } + env("SPARK_YARN_CACHE_ARCHIVES_VISIBILITIES") = + visibilities.reduceLeft[String] { (acc, n) => acc + "," + n } } } @@ -160,7 +160,7 @@ private[spark] class ClientDistributedCacheManager() extends Logging { def ancestorsHaveExecutePermissions( fs: FileSystem, path: Path, - statCache: Map[URI, FileStatus]): Boolean = { + statCache: Map[URI, FileStatus]): Boolean = { var current = path while (current != null) { // the subdirs in the path should have execute permissions for others @@ -197,7 +197,7 @@ private[spark] class ClientDistributedCacheManager() extends Logging { def getFileStatus(fs: FileSystem, uri: URI, statCache: Map[URI, FileStatus]): FileStatus = { val stat = statCache.get(uri) match { case Some(existstat) => existstat - case None => + case None => val newStat = fs.getFileStatus(new Path(uri)) statCache.put(uri, newStat) newStat diff --git a/yarn/src/main/scala/org/apache/spark/deploy/yarn/ExecutorDelegationTokenUpdater.scala b/yarn/src/main/scala/org/apache/spark/deploy/yarn/ExecutorDelegationTokenUpdater.scala index 229c2c4d5eb3..94feb6393fd6 100644 --- a/yarn/src/main/scala/org/apache/spark/deploy/yarn/ExecutorDelegationTokenUpdater.scala +++ b/yarn/src/main/scala/org/apache/spark/deploy/yarn/ExecutorDelegationTokenUpdater.scala @@ -35,6 +35,9 @@ private[spark] class ExecutorDelegationTokenUpdater( @volatile private var lastCredentialsFileSuffix = 0 private val credentialsFile = sparkConf.get("spark.yarn.credentials.file") + private val freshHadoopConf = + SparkHadoopUtil.get.getConfBypassingFSCache( + hadoopConf, new Path(credentialsFile).toUri.getScheme) private val delegationTokenRenewer = Executors.newSingleThreadScheduledExecutor( @@ -49,7 +52,7 @@ private[spark] class ExecutorDelegationTokenUpdater( def updateCredentialsIfRequired(): Unit = { try { val credentialsFilePath = new Path(credentialsFile) - val remoteFs = FileSystem.get(hadoopConf) + val remoteFs = FileSystem.get(freshHadoopConf) SparkHadoopUtil.get.listFilesSorted( remoteFs, credentialsFilePath.getParent, credentialsFilePath.getName, SparkHadoopUtil.SPARK_YARN_CREDS_TEMP_EXTENSION) diff --git a/yarn/src/main/scala/org/apache/spark/deploy/yarn/ExecutorRunnable.scala b/yarn/src/main/scala/org/apache/spark/deploy/yarn/ExecutorRunnable.scala index 9d04d241dae9..78e27fb7f333 100644 --- a/yarn/src/main/scala/org/apache/spark/deploy/yarn/ExecutorRunnable.scala +++ b/yarn/src/main/scala/org/apache/spark/deploy/yarn/ExecutorRunnable.scala @@ -146,7 +146,7 @@ class ExecutorRunnable( javaOpts ++= Utils.splitCommandString(opts).map(YarnSparkHadoopUtil.escapeForShell) } sys.props.get("spark.executor.extraLibraryPath").foreach { p => - prefixEnv = Some(Utils.libraryPathEnvPrefix(Seq(p))) + prefixEnv = Some(Client.getClusterPath(sparkConf, Utils.libraryPathEnvPrefix(Seq(p)))) } javaOpts += "-Djava.io.tmpdir=" + @@ -195,7 +195,7 @@ class ExecutorRunnable( val userClassPath = Client.getUserClasspath(sparkConf).flatMap { uri => val absPath = if (new File(uri.getPath()).isAbsolute()) { - uri.getPath() + Client.getClusterPath(sparkConf, uri.getPath()) } else { Client.buildPath(Environment.PWD.$(), uri.getPath()) } @@ -303,8 +303,8 @@ class ExecutorRunnable( val address = container.getNodeHttpAddress val baseUrl = s"$httpScheme$address/node/containerlogs/$containerId/$user" - env("SPARK_LOG_URL_STDERR") = s"$baseUrl/stderr?start=0" - env("SPARK_LOG_URL_STDOUT") = s"$baseUrl/stdout?start=0" + env("SPARK_LOG_URL_STDERR") = s"$baseUrl/stderr?start=-4096" + env("SPARK_LOG_URL_STDOUT") = s"$baseUrl/stdout?start=-4096" } System.getenv().filterKeys(_.startsWith("SPARK")).foreach { case (k, v) => env(k) = v } diff --git a/yarn/src/main/scala/org/apache/spark/deploy/yarn/YarnAllocator.scala b/yarn/src/main/scala/org/apache/spark/deploy/yarn/YarnAllocator.scala index 8a08f561a2df..940873fbd046 100644 --- a/yarn/src/main/scala/org/apache/spark/deploy/yarn/YarnAllocator.scala +++ b/yarn/src/main/scala/org/apache/spark/deploy/yarn/YarnAllocator.scala @@ -34,10 +34,8 @@ import org.apache.hadoop.yarn.util.RackResolver import org.apache.log4j.{Level, Logger} -import org.apache.spark.{SparkEnv, Logging, SecurityManager, SparkConf} +import org.apache.spark.{Logging, SecurityManager, SparkConf} import org.apache.spark.deploy.yarn.YarnSparkHadoopUtil._ -import org.apache.spark.scheduler.cluster.CoarseGrainedSchedulerBackend -import org.apache.spark.util.AkkaUtils /** * YarnAllocator is charged with requesting containers from the YARN ResourceManager and deciding @@ -53,6 +51,7 @@ import org.apache.spark.util.AkkaUtils * synchronized. */ private[yarn] class YarnAllocator( + driverUrl: String, conf: Configuration, sparkConf: SparkConf, amClient: AMRMClient[ContainerRequest], @@ -107,13 +106,6 @@ private[yarn] class YarnAllocator( new ThreadFactoryBuilder().setNameFormat("ContainerLauncher #%d").setDaemon(true).build()) launcherPool.allowCoreThreadTimeOut(true) - private val driverUrl = AkkaUtils.address( - AkkaUtils.protocol(securityMgr.akkaSSLOptions.enabled), - SparkEnv.driverActorSystemName, - sparkConf.get("spark.driver.host"), - sparkConf.get("spark.driver.port"), - CoarseGrainedSchedulerBackend.ENDPOINT_NAME) - // For testing private val launchContainers = sparkConf.getBoolean("spark.yarn.launchContainers", true) @@ -154,11 +146,16 @@ private[yarn] class YarnAllocator( * Request as many executors from the ResourceManager as needed to reach the desired total. If * the requested total is smaller than the current number of running executors, no executors will * be killed. + * + * @return Whether the new requested total is different than the old value. */ - def requestTotalExecutors(requestedTotal: Int): Unit = synchronized { + def requestTotalExecutors(requestedTotal: Int): Boolean = synchronized { if (requestedTotal != targetNumExecutors) { logInfo(s"Driver requested a total number of $requestedTotal executor(s).") targetNumExecutors = requestedTotal + true + } else { + false } } diff --git a/yarn/src/main/scala/org/apache/spark/deploy/yarn/YarnRMClient.scala b/yarn/src/main/scala/org/apache/spark/deploy/yarn/YarnRMClient.scala index b13475136652..7f533ee55e8b 100644 --- a/yarn/src/main/scala/org/apache/spark/deploy/yarn/YarnRMClient.scala +++ b/yarn/src/main/scala/org/apache/spark/deploy/yarn/YarnRMClient.scala @@ -55,6 +55,7 @@ private[spark] class YarnRMClient(args: ApplicationMasterArguments) extends Logg * @param uiHistoryAddress Address of the application on the History Server. */ def register( + driverUrl: String, conf: YarnConfiguration, sparkConf: SparkConf, preferredNodeLocations: Map[String, Set[SplitInfo]], @@ -72,7 +73,7 @@ private[spark] class YarnRMClient(args: ApplicationMasterArguments) extends Logg amClient.registerApplicationMaster(Utils.localHostName(), 0, uiAddress) registered = true } - new YarnAllocator(conf, sparkConf, amClient, getAttemptId(), args, securityMgr) + new YarnAllocator(driverUrl, conf, sparkConf, amClient, getAttemptId(), args, securityMgr) } /** @@ -89,9 +90,7 @@ private[spark] class YarnRMClient(args: ApplicationMasterArguments) extends Logg /** Returns the attempt ID. */ def getAttemptId(): ApplicationAttemptId = { - val containerIdString = System.getenv(ApplicationConstants.Environment.CONTAINER_ID.name()) - val containerId = ConverterUtils.toContainerId(containerIdString) - containerId.getApplicationAttemptId() + YarnSparkHadoopUtil.get.getContainerId.getApplicationAttemptId() } /** Returns the configuration for the AmIpFilter to add to the Spark UI. */ diff --git a/yarn/src/main/scala/org/apache/spark/deploy/yarn/YarnSparkHadoopUtil.scala b/yarn/src/main/scala/org/apache/spark/deploy/yarn/YarnSparkHadoopUtil.scala index ba91872107d0..68d01c17ef72 100644 --- a/yarn/src/main/scala/org/apache/spark/deploy/yarn/YarnSparkHadoopUtil.scala +++ b/yarn/src/main/scala/org/apache/spark/deploy/yarn/YarnSparkHadoopUtil.scala @@ -33,7 +33,8 @@ import org.apache.hadoop.security.UserGroupInformation import org.apache.hadoop.yarn.conf.YarnConfiguration import org.apache.hadoop.yarn.api.ApplicationConstants import org.apache.hadoop.yarn.api.ApplicationConstants.Environment -import org.apache.hadoop.yarn.api.records.{Priority, ApplicationAccessType} +import org.apache.hadoop.yarn.api.records.{ApplicationAccessType, ContainerId, Priority} +import org.apache.hadoop.yarn.util.ConverterUtils import org.apache.spark.deploy.SparkHadoopUtil import org.apache.spark.{SecurityManager, SparkConf, SparkException} @@ -136,12 +137,16 @@ class YarnSparkHadoopUtil extends SparkHadoopUtil { tokenRenewer.foreach(_.stop()) } + private[spark] def getContainerId: ContainerId = { + val containerIdString = System.getenv(ApplicationConstants.Environment.CONTAINER_ID.name()) + ConverterUtils.toContainerId(containerIdString) + } } object YarnSparkHadoopUtil { - // Additional memory overhead + // Additional memory overhead // 10% was arrived at experimentally. In the interest of minimizing memory waste while covering - // the common cases. Memory overhead tends to grow with container size. + // the common cases. Memory overhead tends to grow with container size. val MEMORY_OVERHEAD_FACTOR = 0.10 val MEMORY_OVERHEAD_MIN = 384 diff --git a/yarn/src/main/scala/org/apache/spark/scheduler/cluster/YarnClientSchedulerBackend.scala b/yarn/src/main/scala/org/apache/spark/scheduler/cluster/YarnClientSchedulerBackend.scala index 99c05329b4d7..3a0b9443d2d7 100644 --- a/yarn/src/main/scala/org/apache/spark/scheduler/cluster/YarnClientSchedulerBackend.scala +++ b/yarn/src/main/scala/org/apache/spark/scheduler/cluster/YarnClientSchedulerBackend.scala @@ -41,7 +41,6 @@ private[spark] class YarnClientSchedulerBackend( * This waits until the application is running. */ override def start() { - super.start() val driverHost = conf.get("spark.driver.host") val driverPort = conf.get("spark.driver.port") val hostport = driverHost + ":" + driverPort @@ -56,6 +55,12 @@ private[spark] class YarnClientSchedulerBackend( totalExpectedExecutors = args.numExecutors client = new Client(args, conf) appId = client.submitApplication() + + // SPARK-8687: Ensure all necessary properties have already been set before + // we initialize our driver scheduler backend, which serves these properties + // to the executors + super.start() + waitForApplication() monitorThread = asyncMonitorApplication() monitorThread.start() @@ -76,7 +81,8 @@ private[spark] class YarnClientSchedulerBackend( ("--executor-memory", "SPARK_EXECUTOR_MEMORY", "spark.executor.memory"), ("--executor-cores", "SPARK_WORKER_CORES", "spark.executor.cores"), ("--executor-cores", "SPARK_EXECUTOR_CORES", "spark.executor.cores"), - ("--queue", "SPARK_YARN_QUEUE", "spark.yarn.queue") + ("--queue", "SPARK_YARN_QUEUE", "spark.yarn.queue"), + ("--py-files", null, "spark.submit.pyFiles") ) // Warn against the following deprecated environment variables: env var -> suggestion val deprecatedEnvVars = Map( @@ -86,7 +92,7 @@ private[spark] class YarnClientSchedulerBackend( optionTuples.foreach { case (optionName, envVar, sparkProp) => if (sc.getConf.contains(sparkProp)) { extraArgs += (optionName, sc.getConf.get(sparkProp)) - } else if (System.getenv(envVar) != null) { + } else if (envVar != null && System.getenv(envVar) != null) { extraArgs += (optionName, System.getenv(envVar)) if (deprecatedEnvVars.contains(envVar)) { logWarning(s"NOTE: $envVar is deprecated. Use ${deprecatedEnvVars(envVar)} instead.") @@ -147,7 +153,9 @@ private[spark] class YarnClientSchedulerBackend( */ override def stop() { assert(client != null, "Attempted to stop this scheduler before starting it!") - monitorThread.interrupt() + if (monitorThread != null) { + monitorThread.interrupt() + } super.stop() client.stop() logInfo("Stopped") diff --git a/yarn/src/main/scala/org/apache/spark/scheduler/cluster/YarnClusterSchedulerBackend.scala b/yarn/src/main/scala/org/apache/spark/scheduler/cluster/YarnClusterSchedulerBackend.scala index aeb218a57545..33f580aaebdc 100644 --- a/yarn/src/main/scala/org/apache/spark/scheduler/cluster/YarnClusterSchedulerBackend.scala +++ b/yarn/src/main/scala/org/apache/spark/scheduler/cluster/YarnClusterSchedulerBackend.scala @@ -17,10 +17,19 @@ package org.apache.spark.scheduler.cluster +import java.net.NetworkInterface + +import scala.collection.JavaConverters._ + +import org.apache.hadoop.yarn.api.records.NodeState +import org.apache.hadoop.yarn.client.api.YarnClient +import org.apache.hadoop.yarn.conf.YarnConfiguration + import org.apache.spark.SparkContext +import org.apache.spark.deploy.yarn.YarnSparkHadoopUtil import org.apache.spark.deploy.yarn.YarnSparkHadoopUtil._ import org.apache.spark.scheduler.TaskSchedulerImpl -import org.apache.spark.util.IntParam +import org.apache.spark.util.{IntParam, Utils} private[spark] class YarnClusterSchedulerBackend( scheduler: TaskSchedulerImpl, @@ -53,4 +62,71 @@ private[spark] class YarnClusterSchedulerBackend( logError("Application attempt ID is not set.") super.applicationAttemptId } + + override def getDriverLogUrls: Option[Map[String, String]] = { + var yarnClientOpt: Option[YarnClient] = None + var driverLogs: Option[Map[String, String]] = None + try { + val yarnConf = new YarnConfiguration(sc.hadoopConfiguration) + val containerId = YarnSparkHadoopUtil.get.getContainerId + yarnClientOpt = Some(YarnClient.createYarnClient()) + yarnClientOpt.foreach { yarnClient => + yarnClient.init(yarnConf) + yarnClient.start() + + // For newer versions of YARN, we can find the HTTP address for a given node by getting a + // container report for a given container. But container reports came only in Hadoop 2.4, + // so we basically have to get the node reports for all nodes and find the one which runs + // this container. For that we have to compare the node's host against the current host. + // Since the host can have multiple addresses, we need to compare against all of them to + // find out if one matches. + + // Get all the addresses of this node. + val addresses = + NetworkInterface.getNetworkInterfaces.asScala + .flatMap(_.getInetAddresses.asScala) + .toSeq + + // Find a node report that matches one of the addresses + val nodeReport = + yarnClient.getNodeReports(NodeState.RUNNING).asScala.find { x => + val host = x.getNodeId.getHost + addresses.exists { address => + address.getHostAddress == host || + address.getHostName == host || + address.getCanonicalHostName == host + } + } + + // Now that we have found the report for the Node Manager that the AM is running on, we + // can get the base HTTP address for the Node manager from the report. + // The format used for the logs for each container is well-known and can be constructed + // using the NM's HTTP address and the container ID. + // The NM may be running several containers, but we can build the URL for the AM using + // the AM's container ID, which we already know. + nodeReport.foreach { report => + val httpAddress = report.getHttpAddress + // lookup appropriate http scheme for container log urls + val yarnHttpPolicy = yarnConf.get( + YarnConfiguration.YARN_HTTP_POLICY_KEY, + YarnConfiguration.YARN_HTTP_POLICY_DEFAULT + ) + val user = Utils.getCurrentUserName() + val httpScheme = if (yarnHttpPolicy == "HTTPS_ONLY") "https://" else "http://" + val baseUrl = s"$httpScheme$httpAddress/node/containerlogs/$containerId/$user" + logDebug(s"Base URL for logs: $baseUrl") + driverLogs = Some(Map( + "stderr" -> s"$baseUrl/stderr?start=-4096", + "stdout" -> s"$baseUrl/stdout?start=-4096")) + } + } + } catch { + case e: Exception => + logInfo("Node Report API is not available in the version of YARN being used, so AM" + + " logs link will not appear in application UI", e) + } finally { + yarnClientOpt.foreach(_.close()) + } + driverLogs + } } diff --git a/yarn/src/test/scala/org/apache/spark/deploy/yarn/ClientDistributedCacheManagerSuite.scala b/yarn/src/test/scala/org/apache/spark/deploy/yarn/ClientDistributedCacheManagerSuite.scala index 80b57d1355a3..804dfecde786 100644 --- a/yarn/src/test/scala/org/apache/spark/deploy/yarn/ClientDistributedCacheManagerSuite.scala +++ b/yarn/src/test/scala/org/apache/spark/deploy/yarn/ClientDistributedCacheManagerSuite.scala @@ -19,7 +19,6 @@ package org.apache.spark.deploy.yarn import java.net.URI -import org.scalatest.FunSuite import org.scalatest.mock.MockitoSugar import org.mockito.Mockito.when @@ -36,16 +35,18 @@ import org.apache.hadoop.yarn.util.{Records, ConverterUtils} import scala.collection.mutable.HashMap import scala.collection.mutable.Map +import org.apache.spark.SparkFunSuite -class ClientDistributedCacheManagerSuite extends FunSuite with MockitoSugar { + +class ClientDistributedCacheManagerSuite extends SparkFunSuite with MockitoSugar { class MockClientDistributedCacheManager extends ClientDistributedCacheManager { - override def getVisibility(conf: Configuration, uri: URI, statCache: Map[URI, FileStatus]): + override def getVisibility(conf: Configuration, uri: URI, statCache: Map[URI, FileStatus]): LocalResourceVisibility = { LocalResourceVisibility.PRIVATE } } - + test("test getFileStatus empty") { val distMgr = new ClientDistributedCacheManager() val fs = mock[FileSystem] @@ -60,7 +61,7 @@ class ClientDistributedCacheManagerSuite extends FunSuite with MockitoSugar { val distMgr = new ClientDistributedCacheManager() val fs = mock[FileSystem] val uri = new URI("/tmp/testing") - val realFileStatus = new FileStatus(10, false, 1, 1024, 10, 10, null, "testOwner", + val realFileStatus = new FileStatus(10, false, 1, 1024, 10, 10, null, "testOwner", null, new Path("/tmp/testing")) when(fs.getFileStatus(new Path(uri))).thenReturn(new FileStatus()) val statCache: Map[URI, FileStatus] = HashMap[URI, FileStatus](uri -> realFileStatus) @@ -77,7 +78,7 @@ class ClientDistributedCacheManagerSuite extends FunSuite with MockitoSugar { val statCache: Map[URI, FileStatus] = HashMap[URI, FileStatus]() when(fs.getFileStatus(destPath)).thenReturn(new FileStatus()) - distMgr.addResource(fs, conf, destPath, localResources, LocalResourceType.FILE, "link", + distMgr.addResource(fs, conf, destPath, localResources, LocalResourceType.FILE, "link", statCache, false) val resource = localResources("link") assert(resource.getVisibility() === LocalResourceVisibility.PRIVATE) @@ -100,11 +101,11 @@ class ClientDistributedCacheManagerSuite extends FunSuite with MockitoSugar { assert(env.get("SPARK_YARN_CACHE_ARCHIVES_VISIBILITIES") === None) // add another one and verify both there and order correct - val realFileStatus = new FileStatus(20, false, 1, 1024, 10, 30, null, "testOwner", + val realFileStatus = new FileStatus(20, false, 1, 1024, 10, 30, null, "testOwner", null, new Path("/tmp/testing2")) val destPath2 = new Path("file:///foo.invalid.com:8080/tmp/testing2") when(fs.getFileStatus(destPath2)).thenReturn(realFileStatus) - distMgr.addResource(fs, conf, destPath2, localResources, LocalResourceType.FILE, "link2", + distMgr.addResource(fs, conf, destPath2, localResources, LocalResourceType.FILE, "link2", statCache, false) val resource2 = localResources("link2") assert(resource2.getVisibility() === LocalResourceVisibility.PRIVATE) @@ -116,7 +117,7 @@ class ClientDistributedCacheManagerSuite extends FunSuite with MockitoSugar { val env2 = new HashMap[String, String]() distMgr.setDistFilesEnv(env2) val timestamps = env2("SPARK_YARN_CACHE_FILES_TIME_STAMPS").split(',') - val files = env2("SPARK_YARN_CACHE_FILES").split(',') + val files = env2("SPARK_YARN_CACHE_FILES").split(',') val sizes = env2("SPARK_YARN_CACHE_FILES_FILE_SIZES").split(',') val visibilities = env2("SPARK_YARN_CACHE_FILES_VISIBILITIES") .split(',') assert(files(0) === "file:/foo.invalid.com:8080/tmp/testing#link") @@ -140,7 +141,7 @@ class ClientDistributedCacheManagerSuite extends FunSuite with MockitoSugar { when(fs.getFileStatus(destPath)).thenReturn(new FileStatus()) intercept[Exception] { - distMgr.addResource(fs, conf, destPath, localResources, LocalResourceType.FILE, null, + distMgr.addResource(fs, conf, destPath, localResources, LocalResourceType.FILE, null, statCache, false) } assert(localResources.get("link") === None) @@ -154,11 +155,11 @@ class ClientDistributedCacheManagerSuite extends FunSuite with MockitoSugar { val destPath = new Path("file:///foo.invalid.com:8080/tmp/testing") val localResources = HashMap[String, LocalResource]() val statCache: Map[URI, FileStatus] = HashMap[URI, FileStatus]() - val realFileStatus = new FileStatus(20, false, 1, 1024, 10, 30, null, "testOwner", + val realFileStatus = new FileStatus(20, false, 1, 1024, 10, 30, null, "testOwner", null, new Path("/tmp/testing")) when(fs.getFileStatus(destPath)).thenReturn(realFileStatus) - distMgr.addResource(fs, conf, destPath, localResources, LocalResourceType.ARCHIVE, "link", + distMgr.addResource(fs, conf, destPath, localResources, LocalResourceType.ARCHIVE, "link", statCache, true) val resource = localResources("link") assert(resource.getVisibility() === LocalResourceVisibility.PRIVATE) @@ -188,11 +189,11 @@ class ClientDistributedCacheManagerSuite extends FunSuite with MockitoSugar { val destPath = new Path("file:///foo.invalid.com:8080/tmp/testing") val localResources = HashMap[String, LocalResource]() val statCache: Map[URI, FileStatus] = HashMap[URI, FileStatus]() - val realFileStatus = new FileStatus(20, false, 1, 1024, 10, 30, null, "testOwner", + val realFileStatus = new FileStatus(20, false, 1, 1024, 10, 30, null, "testOwner", null, new Path("/tmp/testing")) when(fs.getFileStatus(destPath)).thenReturn(realFileStatus) - distMgr.addResource(fs, conf, destPath, localResources, LocalResourceType.ARCHIVE, "link", + distMgr.addResource(fs, conf, destPath, localResources, LocalResourceType.ARCHIVE, "link", statCache, false) val resource = localResources("link") assert(resource.getVisibility() === LocalResourceVisibility.PRIVATE) diff --git a/yarn/src/test/scala/org/apache/spark/deploy/yarn/ClientSuite.scala b/yarn/src/test/scala/org/apache/spark/deploy/yarn/ClientSuite.scala index 508819e242a2..837f8d3fa55a 100644 --- a/yarn/src/test/scala/org/apache/spark/deploy/yarn/ClientSuite.scala +++ b/yarn/src/test/scala/org/apache/spark/deploy/yarn/ClientSuite.scala @@ -33,12 +33,12 @@ import org.apache.hadoop.yarn.api.records._ import org.apache.hadoop.yarn.conf.YarnConfiguration import org.mockito.Matchers._ import org.mockito.Mockito._ -import org.scalatest.{BeforeAndAfterAll, FunSuite, Matchers} +import org.scalatest.{BeforeAndAfterAll, Matchers} -import org.apache.spark.{SparkException, SparkConf} +import org.apache.spark.{SparkConf, SparkException, SparkFunSuite} import org.apache.spark.util.Utils -class ClientSuite extends FunSuite with Matchers with BeforeAndAfterAll { +class ClientSuite extends SparkFunSuite with Matchers with BeforeAndAfterAll { override def beforeAll(): Unit = { System.setProperty("SPARK_YARN_MODE", "true") @@ -113,7 +113,7 @@ class ClientSuite extends FunSuite with Matchers with BeforeAndAfterAll { Environment.PWD.$() } cp should contain(pwdVar) - cp should contain (s"$pwdVar${Path.SEPARATOR}${Client.LOCALIZED_HADOOP_CONF_DIR}") + cp should contain (s"$pwdVar${Path.SEPARATOR}${Client.LOCALIZED_CONF_DIR}") cp should not contain (Client.SPARK_JAR) cp should not contain (Client.APP_JAR) } @@ -129,7 +129,7 @@ class ClientSuite extends FunSuite with Matchers with BeforeAndAfterAll { val tempDir = Utils.createTempDir() try { - client.prepareLocalResources(tempDir.getAbsolutePath()) + client.prepareLocalResources(tempDir.getAbsolutePath(), Nil) sparkConf.getOption(Client.CONF_SPARK_USER_JAR) should be (Some(USER)) // The non-local path should be propagated by name only, since it will end up in the app's @@ -151,6 +151,25 @@ class ClientSuite extends FunSuite with Matchers with BeforeAndAfterAll { } } + test("Cluster path translation") { + val conf = new Configuration() + val sparkConf = new SparkConf() + .set(Client.CONF_SPARK_JAR, "local:/localPath/spark.jar") + .set("spark.yarn.config.gatewayPath", "/localPath") + .set("spark.yarn.config.replacementPath", "/remotePath") + + Client.getClusterPath(sparkConf, "/localPath") should be ("/remotePath") + Client.getClusterPath(sparkConf, "/localPath/1:/localPath/2") should be ( + "/remotePath/1:/remotePath/2") + + val env = new MutableHashMap[String, String]() + Client.populateClasspath(null, conf, sparkConf, env, false, + extraClassPath = Some("/localPath/my1.jar")) + val cp = classpath(env) + cp should contain ("/remotePath/spark.jar") + cp should contain ("/remotePath/my1.jar") + } + object Fixtures { val knownDefYarnAppCP: Seq[String] = @@ -203,7 +222,7 @@ class ClientSuite extends FunSuite with Matchers with BeforeAndAfterAll { def getFieldValue2[A: ClassTag, A1: ClassTag, B]( clazz: Class[_], field: String, - defaults: => B)(mapTo: A => B)(mapTo1: A1 => B): B = { + defaults: => B)(mapTo: A => B)(mapTo1: A1 => B): B = { Try(clazz.getField(field)).map(_.get(null)).map { case v: A => mapTo(v) case v1: A1 => mapTo1(v1) diff --git a/yarn/src/test/scala/org/apache/spark/deploy/yarn/YarnAllocatorSuite.scala b/yarn/src/test/scala/org/apache/spark/deploy/yarn/YarnAllocatorSuite.scala index 455f1019d86d..7509000771d9 100644 --- a/yarn/src/test/scala/org/apache/spark/deploy/yarn/YarnAllocatorSuite.scala +++ b/yarn/src/test/scala/org/apache/spark/deploy/yarn/YarnAllocatorSuite.scala @@ -26,13 +26,13 @@ import org.apache.hadoop.yarn.api.records._ import org.apache.hadoop.yarn.client.api.AMRMClient import org.apache.hadoop.yarn.client.api.AMRMClient.ContainerRequest -import org.apache.spark.SecurityManager +import org.apache.spark.{SecurityManager, SparkFunSuite} import org.apache.spark.SparkConf import org.apache.spark.deploy.yarn.YarnSparkHadoopUtil._ import org.apache.spark.deploy.yarn.YarnAllocator._ import org.apache.spark.scheduler.SplitInfo -import org.scalatest.{BeforeAndAfterEach, FunSuite, Matchers} +import org.scalatest.{BeforeAndAfterEach, Matchers} class MockResolver extends DNSToSwitchMapping { @@ -46,7 +46,7 @@ class MockResolver extends DNSToSwitchMapping { def reloadCachedMappings(names: JList[String]) {} } -class YarnAllocatorSuite extends FunSuite with Matchers with BeforeAndAfterEach { +class YarnAllocatorSuite extends SparkFunSuite with Matchers with BeforeAndAfterEach { val conf = new Configuration() conf.setClass( CommonConfigurationKeysPublic.NET_TOPOLOGY_NODE_SWITCH_MAPPING_IMPL_KEY, @@ -90,6 +90,7 @@ class YarnAllocatorSuite extends FunSuite with Matchers with BeforeAndAfterEach "--jar", "somejar.jar", "--class", "SomeClass") new YarnAllocator( + "not used", conf, sparkConf, rmClient, diff --git a/yarn/src/test/scala/org/apache/spark/deploy/yarn/YarnClusterSuite.scala b/yarn/src/test/scala/org/apache/spark/deploy/yarn/YarnClusterSuite.scala index d3c606e0ed99..335e966519c7 100644 --- a/yarn/src/test/scala/org/apache/spark/deploy/yarn/YarnClusterSuite.scala +++ b/yarn/src/test/scala/org/apache/spark/deploy/yarn/YarnClusterSuite.scala @@ -18,6 +18,7 @@ package org.apache.spark.deploy.yarn import java.io.{File, FileOutputStream, OutputStreamWriter} +import java.net.URL import java.util.Properties import java.util.concurrent.TimeUnit @@ -29,11 +30,12 @@ import com.google.common.io.ByteStreams import com.google.common.io.Files import org.apache.hadoop.yarn.conf.YarnConfiguration import org.apache.hadoop.yarn.server.MiniYARNCluster -import org.scalatest.{BeforeAndAfterAll, FunSuite, Matchers} +import org.scalatest.{BeforeAndAfterAll, Matchers} -import org.apache.spark.{Logging, SparkConf, SparkContext, SparkException, TestUtils} +import org.apache.spark._ import org.apache.spark.scheduler.cluster.ExecutorInfo -import org.apache.spark.scheduler.{SparkListenerJobStart, SparkListener, SparkListenerExecutorAdded} +import org.apache.spark.scheduler.{SparkListener, SparkListenerApplicationStart, + SparkListenerExecutorAdded} import org.apache.spark.util.Utils /** @@ -41,7 +43,7 @@ import org.apache.spark.util.Utils * applications, and require the Spark assembly to be built before they can be successfully * run. */ -class YarnClusterSuite extends FunSuite with BeforeAndAfterAll with Matchers with Logging { +class YarnClusterSuite extends SparkFunSuite with BeforeAndAfterAll with Matchers with Logging { // log4j configuration for the YARN containers, so that their output is collected // by YARN instead of trying to overwrite unit-tests.log. @@ -54,6 +56,7 @@ class YarnClusterSuite extends FunSuite with BeforeAndAfterAll with Matchers wit """.stripMargin private val TEST_PYFILE = """ + |import mod1, mod2 |import sys |from operator import add | @@ -65,7 +68,7 @@ class YarnClusterSuite extends FunSuite with BeforeAndAfterAll with Matchers wit | sc = SparkContext(conf=SparkConf()) | status = open(sys.argv[1],'w') | result = "failure" - | rdd = sc.parallelize(range(10)) + | rdd = sc.parallelize(range(10)).map(lambda x: x * mod1.func() * mod2.func()) | cnt = rdd.count() | if cnt == 10: | result = "success" @@ -74,6 +77,11 @@ class YarnClusterSuite extends FunSuite with BeforeAndAfterAll with Matchers wit | sc.stop() """.stripMargin + private val TEST_PYMODULE = """ + |def func(): + | return 42 + """.stripMargin + private var yarnCluster: MiniYARNCluster = _ private var tempDir: File = _ private var fakeSparkJar: File = _ @@ -122,7 +130,7 @@ class YarnClusterSuite extends FunSuite with BeforeAndAfterAll with Matchers wit logInfo(s"RM address in configuration is ${config.get(YarnConfiguration.RM_ADDRESS)}") fakeSparkJar = File.createTempFile("sparkJar", null, tempDir) - hadoopConfDir = new File(tempDir, Client.LOCALIZED_HADOOP_CONF_DIR) + hadoopConfDir = new File(tempDir, Client.LOCALIZED_CONF_DIR) assert(hadoopConfDir.mkdir()) File.createTempFile("token", ".txt", hadoopConfDir) } @@ -149,26 +157,12 @@ class YarnClusterSuite extends FunSuite with BeforeAndAfterAll with Matchers wit } } - // Enable this once fix SPARK-6700 - test("run Python application in yarn-cluster mode") { - val primaryPyFile = new File(tempDir, "test.py") - Files.write(TEST_PYFILE, primaryPyFile, UTF_8) - val pyFile = new File(tempDir, "test2.py") - Files.write(TEST_PYFILE, pyFile, UTF_8) - var result = File.createTempFile("result", null, tempDir) + test("run Python application in yarn-client mode") { + testPySpark(true) + } - // The sbt assembly does not include pyspark / py4j python dependencies, so we need to - // propagate SPARK_HOME so that those are added to PYTHONPATH. See PythonUtils.scala. - val sparkHome = sys.props("spark.test.home") - val extraConf = Map( - "spark.executorEnv.SPARK_HOME" -> sparkHome, - "spark.yarn.appMasterEnv.SPARK_HOME" -> sparkHome) - - runSpark(false, primaryPyFile.getAbsolutePath(), - sparkArgs = Seq("--py-files", pyFile.getAbsolutePath()), - appArgs = Seq(result.getAbsolutePath()), - extraConf = extraConf) - checkResult(result) + test("run Python application in yarn-cluster mode") { + testPySpark(false) } test("user class path first in client mode") { @@ -186,6 +180,33 @@ class YarnClusterSuite extends FunSuite with BeforeAndAfterAll with Matchers wit checkResult(result) } + private def testPySpark(clientMode: Boolean): Unit = { + val primaryPyFile = new File(tempDir, "test.py") + Files.write(TEST_PYFILE, primaryPyFile, UTF_8) + + val moduleDir = + if (clientMode) { + // In client-mode, .py files added with --py-files are not visible in the driver. + // This is something that the launcher library would have to handle. + tempDir + } else { + val subdir = new File(tempDir, "pyModules") + subdir.mkdir() + subdir + } + val pyModule = new File(moduleDir, "mod1.py") + Files.write(TEST_PYMODULE, pyModule, UTF_8) + + val mod2Archive = TestUtils.createJarWithFiles(Map("mod2.py" -> TEST_PYMODULE), moduleDir) + val pyFiles = Seq(pyModule.getAbsolutePath(), mod2Archive.getPath()).mkString(",") + val result = File.createTempFile("result", null, tempDir) + + runSpark(clientMode, primaryPyFile.getAbsolutePath(), + sparkArgs = Seq("--py-files", pyFiles), + appArgs = Seq(result.getAbsolutePath())) + checkResult(result) + } + private def testUseClassPathFirst(clientMode: Boolean): Unit = { // Create a jar file that contains a different version of "test.resource". val originalJar = TestUtils.createJarWithFiles(Map("test.resource" -> "ORIGINAL"), tempDir) @@ -290,10 +311,15 @@ class YarnClusterSuite extends FunSuite with BeforeAndAfterAll with Matchers wit private[spark] class SaveExecutorInfo extends SparkListener { val addedExecutorInfos = mutable.Map[String, ExecutorInfo]() + var driverLogs: Option[collection.Map[String, String]] = None override def onExecutorAdded(executor: SparkListenerExecutorAdded) { addedExecutorInfos(executor.executorId) = executor.executorInfo } + + override def onApplicationStart(appStart: SparkListenerApplicationStart): Unit = { + driverLogs = appStart.driverLogs + } } private object YarnClusterDriver extends Logging with Matchers { @@ -314,11 +340,12 @@ private object YarnClusterDriver extends Logging with Matchers { val sc = new SparkContext(new SparkConf() .set("spark.extraListeners", classOf[SaveExecutorInfo].getName) .setAppName("yarn \"test app\" 'with quotes' and \\back\\slashes and $dollarSigns")) + val conf = sc.getConf val status = new File(args(0)) var result = "failure" try { val data = sc.parallelize(1 to 4, 4).collect().toSet - assert(sc.listenerBus.waitUntilEmpty(WAIT_TIMEOUT_MILLIS)) + sc.listenerBus.waitUntilEmpty(WAIT_TIMEOUT_MILLIS) data should be (Set(1, 2, 3, 4)) result = "success" } finally { @@ -335,6 +362,22 @@ private object YarnClusterDriver extends Logging with Matchers { executorInfos.foreach { info => assert(info.logUrlMap.nonEmpty) } + + // If we are running in yarn-cluster mode, verify that driver logs links and present and are + // in the expected format. + if (conf.get("spark.master") == "yarn-cluster") { + assert(listener.driverLogs.nonEmpty) + val driverLogs = listener.driverLogs.get + assert(driverLogs.size === 2) + assert(driverLogs.containsKey("stderr")) + assert(driverLogs.containsKey("stdout")) + val urlStr = driverLogs("stderr") + // Ensure that this is a valid URL, else this will throw an exception + new URL(urlStr) + val containerId = YarnSparkHadoopUtil.get.getContainerId + val user = Utils.getCurrentUserName() + assert(urlStr.endsWith(s"/node/containerlogs/$containerId/$user/stderr?start=-4096")) + } } } diff --git a/yarn/src/test/scala/org/apache/spark/deploy/yarn/YarnSparkHadoopUtilSuite.scala b/yarn/src/test/scala/org/apache/spark/deploy/yarn/YarnSparkHadoopUtilSuite.scala index e10b985c3c23..49bee0866dd4 100644 --- a/yarn/src/test/scala/org/apache/spark/deploy/yarn/YarnSparkHadoopUtilSuite.scala +++ b/yarn/src/test/scala/org/apache/spark/deploy/yarn/YarnSparkHadoopUtilSuite.scala @@ -25,15 +25,15 @@ import org.apache.hadoop.fs.Path import org.apache.hadoop.yarn.api.ApplicationConstants import org.apache.hadoop.yarn.api.ApplicationConstants.Environment import org.apache.hadoop.yarn.conf.YarnConfiguration -import org.scalatest.{FunSuite, Matchers} +import org.scalatest.Matchers import org.apache.hadoop.yarn.api.records.ApplicationAccessType -import org.apache.spark.{Logging, SecurityManager, SparkConf, SparkException} +import org.apache.spark.{Logging, SecurityManager, SparkConf, SparkException, SparkFunSuite} import org.apache.spark.util.Utils -class YarnSparkHadoopUtilSuite extends FunSuite with Matchers with Logging { +class YarnSparkHadoopUtilSuite extends SparkFunSuite with Matchers with Logging { val hasBash = try {