diff --git a/.gitignore b/.gitignore index 9f8cd0b4cb23..19293932bb17 100644 --- a/.gitignore +++ b/.gitignore @@ -22,6 +22,7 @@ /lib/ R-unit-tests.log R/unit-tests.out +R/cran-check.out build/*.jar build/apache-maven* build/scala* @@ -62,7 +63,7 @@ spark-*-bin-*.tgz spark-tests.log src_managed/ streaming-tests.log -target/ +build-artifacts/ unit-tests.log work/ @@ -77,3 +78,11 @@ spark-warehouse/ # For R session data .RData .RHistory +.Rhistory +*.Rproj +*.Rproj.* + +.Rproj.user + +# gradle specific +.gradle/ diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md index f10d7e277eea..1a8206abe383 100644 --- a/CONTRIBUTING.md +++ b/CONTRIBUTING.md @@ -6,7 +6,7 @@ It lists steps that are required before creating a PR. In particular, consider: - Is the change important and ready enough to ask the community to spend time reviewing? - Have you searched for existing, related JIRAs and pull requests? -- Is this a new feature that can stand alone as a package on http://spark-packages.org ? +- Is this a new feature that can stand alone as a [third party project](https://cwiki.apache.org/confluence/display/SPARK/Third+Party+Projects) ? - Is the change being proposed clearly explained and motivated? When you contribute code, you affirm that the contribution is your original work and that you diff --git a/LICENSE b/LICENSE index 94fd46f56847..d68609cc2873 100644 --- a/LICENSE +++ b/LICENSE @@ -263,7 +263,7 @@ The text of each license is also included at licenses/LICENSE-[project].txt. (New BSD license) Protocol Buffer Java API (org.spark-project.protobuf:protobuf-java:2.4.1-shaded - http://code.google.com/p/protobuf) (The BSD License) Fortran to Java ARPACK (net.sourceforge.f2j:arpack_combined_all:0.1 - http://f2j.sourceforge.net) (The BSD License) xmlenc Library (xmlenc:xmlenc:0.52 - http://xmlenc.sourceforge.net) - (The New BSD License) Py4J (net.sf.py4j:py4j:0.10.1 - http://py4j.sourceforge.net/) + (The New BSD License) Py4J (net.sf.py4j:py4j:0.10.3 - http://py4j.sourceforge.net/) (Two-clause BSD-style license) JUnit-Interface (com.novocode:junit-interface:0.10 - http://github.com/szeiger/junit-interface/) (BSD licence) sbt and sbt-launch-lib.bash (BSD 3 Clause) d3.min.js (https://github.com/mbostock/d3/blob/master/LICENSE) diff --git a/R/.gitignore b/R/.gitignore index 9a5889ba28b2..c98504ab0778 100644 --- a/R/.gitignore +++ b/R/.gitignore @@ -4,3 +4,5 @@ lib pkg/man pkg/html +SparkR.Rcheck/ +SparkR_*.tar.gz diff --git a/R/check-cran.sh b/R/check-cran.sh index b3a6860961c1..bb331466ae93 100755 --- a/R/check-cran.sh +++ b/R/check-cran.sh @@ -43,10 +43,22 @@ $FWDIR/create-docs.sh "$R_SCRIPT_PATH/"R CMD build $FWDIR/pkg # Run check as-cran. -# TODO(shivaram): Remove the skip tests once we figure out the install mechanism - VERSION=`grep Version $FWDIR/pkg/DESCRIPTION | awk '{print $NF}'` -"$R_SCRIPT_PATH/"R CMD check --as-cran --no-tests SparkR_"$VERSION".tar.gz +CRAN_CHECK_OPTIONS="--as-cran" + +if [ -n "$NO_TESTS" ] +then + CRAN_CHECK_OPTIONS=$CRAN_CHECK_OPTIONS" --no-tests" +fi + +if [ -n "$NO_MANUAL" ] +then + CRAN_CHECK_OPTIONS=$CRAN_CHECK_OPTIONS" --no-manual" +fi + +echo "Running CRAN check with $CRAN_CHECK_OPTIONS options" + +"$R_SCRIPT_PATH/"R CMD check $CRAN_CHECK_OPTIONS SparkR_"$VERSION".tar.gz popd > /dev/null diff --git a/R/create-docs.sh b/R/create-docs.sh index d2ae160b5002..69ffc5f678c3 100755 --- a/R/create-docs.sh +++ b/R/create-docs.sh @@ -17,17 +17,26 @@ # limitations under the License. # -# Script to create API docs for SparkR -# This requires `devtools` and `knitr` to be installed on the machine. +# Script to create API docs and vignettes for SparkR +# This requires `devtools`, `knitr` and `rmarkdown` to be installed on the machine. # After running this script the html docs can be found in # $SPARK_HOME/R/pkg/html +# The vignettes can be found in +# $SPARK_HOME/R/pkg/vignettes/sparkr_vignettes.html set -o pipefail set -e # Figure out where the script is export FWDIR="$(cd "`dirname "$0"`"; pwd)" +export SPARK_HOME="$(cd "`dirname "$0"`"/..; pwd)" + +# Required for setting SPARK_SCALA_VERSION +. "${SPARK_HOME}"/bin/load-spark-env.sh + +echo "Using Scala $SPARK_SCALA_VERSION" + pushd $FWDIR # Install the package (this will also generate the Rd files) @@ -43,4 +52,21 @@ Rscript -e 'libDir <- "../../lib"; library(SparkR, lib.loc=libDir); library(knit popd +# Find Spark jars. +if [ -f "${SPARK_HOME}/RELEASE" ]; then + SPARK_JARS_DIR="${SPARK_HOME}/jars" +else + SPARK_JARS_DIR="${SPARK_HOME}/assembly/target/scala-$SPARK_SCALA_VERSION/jars" +fi + +# Only create vignettes if Spark JARs exist +if [ -d "$SPARK_JARS_DIR" ]; then + # render creates SparkR vignettes + Rscript -e 'library(rmarkdown); paths <- .libPaths(); .libPaths(c("lib", paths)); Sys.setenv(SPARK_HOME=tools::file_path_as_absolute("..")); render("pkg/vignettes/sparkr-vignettes.Rmd"); .libPaths(paths)' + + find pkg/vignettes/. -not -name '.' -not -name '*.Rmd' -not -name '*.md' -not -name '*.pdf' -not -name '*.html' -delete +else + echo "Skipping R vignettes as Spark JARs not found in $SPARK_HOME" +fi + popd diff --git a/R/pkg/DESCRIPTION b/R/pkg/DESCRIPTION index ac73d6c79891..3e49eac99478 100644 --- a/R/pkg/DESCRIPTION +++ b/R/pkg/DESCRIPTION @@ -1,13 +1,20 @@ Package: SparkR Type: Package Title: R Frontend for Apache Spark -Version: 2.0.0 -Date: 2016-07-07 -Author: The Apache Software Foundation -Maintainer: Shivaram Venkataraman +Version: 2.0.1 +Date: 2016-08-27 +Authors@R: c(person("Shivaram", "Venkataraman", role = c("aut", "cre"), + email = "shivaram@cs.berkeley.edu"), + person("Xiangrui", "Meng", role = "aut", + email = "meng@databricks.com"), + person("Felix", "Cheung", role = "aut", + email = "felixcheung@apache.org"), + person(family = "The Apache Software Foundation", role = c("aut", "cph"))) +URL: http://www.apache.org/ http://spark.apache.org/ +BugReports: https://cwiki.apache.org/confluence/display/SPARK/Contributing+to+Spark#ContributingtoSpark-ContributingBugReports Depends: R (>= 3.0), - methods, + methods Suggests: testthat, e1071, @@ -31,6 +38,8 @@ Collate: 'context.R' 'deserialize.R' 'functions.R' + 'install.R' + 'jvm.R' 'mllib.R' 'serialize.R' 'sparkR.R' diff --git a/R/pkg/NAMESPACE b/R/pkg/NAMESPACE index 1d74c6d95578..4c77d951247f 100644 --- a/R/pkg/NAMESPACE +++ b/R/pkg/NAMESPACE @@ -1,5 +1,9 @@ # Imports from base R -importFrom(methods, setGeneric, setMethod, setOldClass) +# Do not include stats:: "rpois", "runif" - causes error at runtime +importFrom("methods", "setGeneric", "setMethod", "setOldClass") +importFrom("methods", "is", "new", "signature", "show") +importFrom("stats", "gaussian", "setNames") +importFrom("utils", "download.file", "packageVersion", "untar") # Disable native libraries till we figure out how to package it # See SPARKR-7839 @@ -11,8 +15,15 @@ export("sparkR.init") export("sparkR.stop") export("sparkR.session.stop") export("sparkR.conf") +export("sparkR.version") export("print.jobj") +export("sparkR.newJObject") +export("sparkR.callJMethod") +export("sparkR.callJStatic") + +export("install.spark") + export("sparkRSQL.init", "sparkRHive.init") diff --git a/R/pkg/R/DataFrame.R b/R/pkg/R/DataFrame.R index aa211b326a16..a5bd60337601 100644 --- a/R/pkg/R/DataFrame.R +++ b/R/pkg/R/DataFrame.R @@ -120,8 +120,9 @@ setMethod("schema", #' #' Print the logical and physical Catalyst plans to the console for debugging. #' -#' @param x A SparkDataFrame +#' @param x a SparkDataFrame. #' @param extended Logical. If extended is FALSE, explain() only prints the physical plan. +#' @param ... further arguments to be passed to or from other methods. #' @family SparkDataFrame functions #' @aliases explain,SparkDataFrame-method #' @rdname explain @@ -149,7 +150,7 @@ setMethod("explain", #' isLocal #' -#' Returns True if the `collect` and `take` methods can be run locally +#' Returns True if the \code{collect} and \code{take} methods can be run locally #' (without any Spark executors). #' #' @param x A SparkDataFrame @@ -177,11 +178,11 @@ setMethod("isLocal", #' #' Print the first numRows rows of a SparkDataFrame #' -#' @param x A SparkDataFrame -#' @param numRows The number of rows to print. Defaults to 20. -#' @param truncate Whether truncate long strings. If true, strings more than 20 characters will be -#' truncated and all cells will be aligned right -#' +#' @param x a SparkDataFrame. +#' @param numRows the number of rows to print. Defaults to 20. +#' @param truncate whether truncate long strings. If \code{TRUE}, strings more than +#' 20 characters will be truncated and all cells will be aligned right. +#' @param ... further arguments to be passed to or from other methods. #' @family SparkDataFrame functions #' @aliases showDF,SparkDataFrame-method #' @rdname showDF @@ -204,9 +205,9 @@ setMethod("showDF", #' show #' -#' Print the SparkDataFrame column names and types +#' Print class and type information of a Spark object. #' -#' @param x A SparkDataFrame +#' @param object a Spark object. Can be a SparkDataFrame, Column, GroupedData, WindowSpec. #' #' @family SparkDataFrame functions #' @rdname show @@ -218,7 +219,7 @@ setMethod("showDF", #' sparkR.session() #' path <- "path/to/file.json" #' df <- read.json(path) -#' df +#' show(df) #'} #' @note show(SparkDataFrame) since 1.4.0 setMethod("show", "SparkDataFrame", @@ -257,11 +258,11 @@ setMethod("dtypes", }) }) -#' Column names +#' Column Names of SparkDataFrame #' -#' Return all column names as a list +#' Return all column names as a list. #' -#' @param x A SparkDataFrame +#' @param x a SparkDataFrame. #' #' @family SparkDataFrame functions #' @rdname columns @@ -318,6 +319,8 @@ setMethod("colnames", columns(x) }) +#' @param value a character vector. Must have the same length as the number +#' of columns in the SparkDataFrame. #' @rdname columns #' @aliases colnames<-,SparkDataFrame-method #' @name colnames<- @@ -363,7 +366,7 @@ setMethod("colnames<-", #' @examples #'\dontrun{ #' irisDF <- createDataFrame(iris) -#' coltypes(irisDF) +#' coltypes(irisDF) # get column types #'} #' @note coltypes since 1.6.0 setMethod("coltypes", @@ -387,7 +390,11 @@ setMethod("coltypes", } if (is.null(type)) { - stop(paste("Unsupported data type: ", x)) + specialtype <- specialtypeshandle(x) + if (is.null(specialtype)) { + stop(paste("Unsupported data type: ", x)) + } + type <- PRIMITIVE_TYPES[[specialtype]] } } type @@ -406,7 +413,6 @@ setMethod("coltypes", #' #' Set the column types of a SparkDataFrame. #' -#' @param x A SparkDataFrame #' @param value A character vector with the target column types for the given #' SparkDataFrame. Column types can be one of integer, numeric/double, character, logical, or NA #' to keep that column as-is. @@ -419,8 +425,8 @@ setMethod("coltypes", #' sparkR.session() #' path <- "path/to/file.json" #' df <- read.json(path) -#' coltypes(df) <- c("character", "integer") -#' coltypes(df) <- c(NA, "numeric") +#' coltypes(df) <- c("character", "integer") # set column types +#' coltypes(df) <- c(NA, "numeric") # set column types #'} #' @note coltypes<- since 1.6.0 setMethod("coltypes<-", @@ -510,9 +516,10 @@ setMethod("registerTempTable", #' #' Insert the contents of a SparkDataFrame into a table registered in the current SparkSession. #' -#' @param x A SparkDataFrame -#' @param tableName A character vector containing the name of the table -#' @param overwrite A logical argument indicating whether or not to overwrite +#' @param x a SparkDataFrame. +#' @param tableName a character vector containing the name of the table. +#' @param overwrite a logical argument indicating whether or not to overwrite. +#' @param ... further arguments to be passed to or from other methods. #' the existing rows in the table. #' #' @family SparkDataFrame functions @@ -571,7 +578,9 @@ setMethod("cache", #' supported storage levels, refer to #' \url{http://spark.apache.org/docs/latest/programming-guide.html#rdd-persistence}. #' -#' @param x The SparkDataFrame to persist +#' @param x the SparkDataFrame to persist. +#' @param newLevel storage level chosen for the persistance. See available options in +#' the description. #' #' @family SparkDataFrame functions #' @rdname persist @@ -599,8 +608,9 @@ setMethod("persist", #' Mark this SparkDataFrame as non-persistent, and remove all blocks for it from memory and #' disk. #' -#' @param x The SparkDataFrame to unpersist -#' @param blocking Whether to block until all blocks are deleted +#' @param x the SparkDataFrame to unpersist. +#' @param blocking whether to block until all blocks are deleted. +#' @param ... further arguments to be passed to or from other methods. #' #' @family SparkDataFrame functions #' @rdname unpersist-methods @@ -629,14 +639,15 @@ setMethod("unpersist", #' The following options for repartition are possible: #' \itemize{ #' \item{1.} {Return a new SparkDataFrame partitioned by -#' the given columns into `numPartitions`.} -#' \item{2.} {Return a new SparkDataFrame that has exactly `numPartitions`.} +#' the given columns into \code{numPartitions}.} +#' \item{2.} {Return a new SparkDataFrame that has exactly \code{numPartitions}.} #' \item{3.} {Return a new SparkDataFrame partitioned by the given column(s), -#' using `spark.sql.shuffle.partitions` as number of partitions.} +#' using \code{spark.sql.shuffle.partitions} as number of partitions.} #'} -#' @param x A SparkDataFrame -#' @param numPartitions The number of partitions to use. -#' @param col The column by which the partitioning will be performed. +#' @param x a SparkDataFrame. +#' @param numPartitions the number of partitions to use. +#' @param col the column by which the partitioning will be performed. +#' @param ... additional column(s) to be used in the partitioning. #' #' @family SparkDataFrame functions #' @rdname repartition @@ -915,11 +926,10 @@ setMethod("sample_frac", #' Returns the number of rows in a SparkDataFrame #' -#' @param x A SparkDataFrame -#' +#' @param x a SparkDataFrame. #' @family SparkDataFrame functions #' @rdname nrow -#' @name count +#' @name nrow #' @aliases count,SparkDataFrame-method #' @export #' @examples @@ -995,9 +1005,10 @@ setMethod("dim", #' Collects all the elements of a SparkDataFrame and coerces them into an R data.frame. #' -#' @param x A SparkDataFrame -#' @param stringsAsFactors (Optional) A logical indicating whether or not string columns +#' @param x a SparkDataFrame. +#' @param stringsAsFactors (Optional) a logical indicating whether or not string columns #' should be converted to factors. FALSE by default. +#' @param ... further arguments to be passed to or from other methods. #' #' @family SparkDataFrame functions #' @rdname collect @@ -1049,6 +1060,13 @@ setMethod("collect", df[[colIndex]] <- col } else { colType <- dtypes[[colIndex]][[2]] + if (is.null(PRIMITIVE_TYPES[[colType]])) { + specialtype <- specialtypeshandle(colType) + if (!is.null(specialtype)) { + colType <- specialtype + } + } + # Note that "binary" columns behave like complex types. if (!is.null(PRIMITIVE_TYPES[[colType]]) && colType != "binary") { vec <- do.call(c, col) @@ -1092,8 +1110,10 @@ setMethod("limit", dataFrame(res) }) -#' Take the first NUM rows of a SparkDataFrame and return a the results as a R data.frame +#' Take the first NUM rows of a SparkDataFrame and return the results as a R data.frame #' +#' @param x a SparkDataFrame. +#' @param num number of rows to take. #' @family SparkDataFrame functions #' @rdname take #' @name take @@ -1116,13 +1136,12 @@ setMethod("take", #' Head #' -#' Return the first NUM rows of a SparkDataFrame as a R data.frame. If NUM is NULL, -#' then head() returns the first 6 rows in keeping with the current data.frame -#' convention in R. +#' Return the first \code{num} rows of a SparkDataFrame as a R data.frame. If \code{num} is not +#' specified, then head() returns the first 6 rows as with R data.frame. #' -#' @param x A SparkDataFrame -#' @param num The number of rows to return. Default is 6. -#' @return A data.frame +#' @param x a SparkDataFrame. +#' @param num the number of rows to return. Default is 6. +#' @return A data.frame. #' #' @family SparkDataFrame functions #' @aliases head,SparkDataFrame-method @@ -1146,7 +1165,8 @@ setMethod("head", #' Return the first row of a SparkDataFrame #' -#' @param x A SparkDataFrame +#' @param x a SparkDataFrame or a column used in aggregation function. +#' @param ... further arguments to be passed to or from other methods. #' #' @family SparkDataFrame functions #' @aliases first,SparkDataFrame-method @@ -1197,8 +1217,9 @@ setMethod("toRDD", #' #' Groups the SparkDataFrame using the specified columns, so we can run aggregation on them. #' -#' @param x a SparkDataFrame -#' @return a GroupedData +#' @param x a SparkDataFrame. +#' @param ... variable(s) (character names(s) or Column(s)) to group on. +#' @return A GroupedData. #' @family SparkDataFrame functions #' @aliases groupBy,SparkDataFrame-method #' @rdname groupBy @@ -1240,7 +1261,6 @@ setMethod("group_by", #' #' Compute aggregates by specifying a list of columns #' -#' @param x a SparkDataFrame #' @family SparkDataFrame functions #' @aliases agg,SparkDataFrame-method #' @rdname summarize @@ -1387,16 +1407,15 @@ setMethod("dapplyCollect", #' Groups the SparkDataFrame using the specified columns and applies the R function to each #' group. #' -#' @param x A SparkDataFrame -#' @param cols Grouping columns -#' @param func A function to be applied to each group partition specified by grouping -#' column of the SparkDataFrame. The function `func` takes as argument +#' @param cols grouping columns. +#' @param func a function to be applied to each group partition specified by grouping +#' column of the SparkDataFrame. The function \code{func} takes as argument #' a key - grouping columns and a data frame - a local R data.frame. -#' The output of `func` is a local R data.frame. -#' @param schema The schema of the resulting SparkDataFrame after the function is applied. -#' The schema must match to output of `func`. It has to be defined for each +#' The output of \code{func} is a local R data.frame. +#' @param schema the schema of the resulting SparkDataFrame after the function is applied. +#' The schema must match to output of \code{func}. It has to be defined for each #' output column with preferred output column name and corresponding data type. -#' @return a SparkDataFrame +#' @return A SparkDataFrame. #' @family SparkDataFrame functions #' @aliases gapply,SparkDataFrame-method #' @rdname gapply @@ -1479,13 +1498,12 @@ setMethod("gapply", #' Groups the SparkDataFrame using the specified columns, applies the R function to each #' group and collects the result back to R as data.frame. #' -#' @param x A SparkDataFrame -#' @param cols Grouping columns -#' @param func A function to be applied to each group partition specified by grouping -#' column of the SparkDataFrame. The function `func` takes as argument +#' @param cols grouping columns. +#' @param func a function to be applied to each group partition specified by grouping +#' column of the SparkDataFrame. The function \code{func} takes as argument #' a key - grouping columns and a data frame - a local R data.frame. -#' The output of `func` is a local R data.frame. -#' @return a data.frame +#' The output of \code{func} is a local R data.frame. +#' @return A data.frame. #' @family SparkDataFrame functions #' @aliases gapplyCollect,SparkDataFrame-method #' @rdname gapplyCollect @@ -1632,6 +1650,7 @@ getColumn <- function(x, c) { column(callJMethod(x@sdf, "col", c)) } +#' @param name name of a Column (without being wrapped by \code{""}). #' @rdname select #' @name $ #' @aliases $,SparkDataFrame-method @@ -1641,6 +1660,7 @@ setMethod("$", signature(x = "SparkDataFrame"), getColumn(x, name) }) +#' @param value a Column or \code{NULL}. If \code{NULL}, the specified Column is dropped. #' @rdname select #' @name $<- #' @aliases $<-,SparkDataFrame-method @@ -1715,12 +1735,13 @@ setMethod("[", signature(x = "SparkDataFrame"), #' Subset #' #' Return subsets of SparkDataFrame according to given conditions -#' @param x A SparkDataFrame -#' @param subset (Optional) A logical expression to filter on rows -#' @param select expression for the single Column or a list of columns to select from the SparkDataFrame +#' @param x a SparkDataFrame. +#' @param i,subset (Optional) a logical expression to filter on rows. +#' @param j,select expression for the single Column or a list of columns to select from the SparkDataFrame. #' @param drop if TRUE, a Column will be returned if the resulting dataset has only one column. -#' Otherwise, a SparkDataFrame will always be returned. -#' @return A new SparkDataFrame containing only the rows that meet the condition with selected columns +#' Otherwise, a SparkDataFrame will always be returned. +#' @param ... currently not used. +#' @return A new SparkDataFrame containing only the rows that meet the condition with selected columns. #' @export #' @family SparkDataFrame functions #' @aliases subset,SparkDataFrame-method @@ -1729,7 +1750,7 @@ setMethod("[", signature(x = "SparkDataFrame"), #' @family subsetting functions #' @examples #' \dontrun{ -#' # Columns can be selected using `[[` and `[` +#' # Columns can be selected using [[ and [ #' df[[2]] == df[["age"]] #' df[,2] == df[,"age"] #' df[,c("name", "age")] @@ -1755,9 +1776,12 @@ setMethod("subset", signature(x = "SparkDataFrame"), #' Select #' #' Selects a set of columns with names or Column expressions. -#' @param x A SparkDataFrame -#' @param col A list of columns or single Column or name -#' @return A new SparkDataFrame with selected columns +#' @param x a SparkDataFrame. +#' @param col a list of columns or single Column or name. +#' @param ... additional column(s) if only one column is specified in \code{col}. +#' If more than one column is assigned in \code{col}, \code{...} +#' should be left empty. +#' @return A new SparkDataFrame with selected columns. #' @export #' @family SparkDataFrame functions #' @rdname select @@ -1771,7 +1795,7 @@ setMethod("subset", signature(x = "SparkDataFrame"), #' select(df, df$name, df$age + 1) #' select(df, c("col1", "col2")) #' select(df, list(df$name, df$age + 1)) -#' # Similar to R data frames columns can also be selected using `$` +#' # Similar to R data frames columns can also be selected using $ #' df[,df$age] #' } #' @note select(SparkDataFrame, character) since 1.4.0 @@ -1854,9 +1878,9 @@ setMethod("selectExpr", #' Return a new SparkDataFrame by adding a column or replacing the existing column #' that has the same name. #' -#' @param x A SparkDataFrame -#' @param colName A column name. -#' @param col A Column expression. +#' @param x a SparkDataFrame. +#' @param colName a column name. +#' @param col a Column expression. #' @return A SparkDataFrame with the new column added or the existing column replaced. #' @family SparkDataFrame functions #' @aliases withColumn,SparkDataFrame,character,Column-method @@ -1885,8 +1909,8 @@ setMethod("withColumn", #' #' Return a new SparkDataFrame with the specified columns added or replaced. #' -#' @param .data A SparkDataFrame -#' @param col a named argument of the form name = col +#' @param .data a SparkDataFrame. +#' @param ... additional column argument(s) each in the form name = col. #' @return A new SparkDataFrame with the new columns added or replaced. #' @family SparkDataFrame functions #' @aliases mutate,SparkDataFrame-method @@ -1963,6 +1987,7 @@ setMethod("mutate", do.call(select, c(x, colList, deDupCols)) }) +#' @param _data a SparkDataFrame. #' @export #' @rdname mutate #' @aliases transform,SparkDataFrame-method @@ -2044,14 +2069,14 @@ setMethod("rename", setClassUnion("characterOrColumn", c("character", "Column")) -#' Arrange +#' Arrange Rows by Variables #' #' Sort a SparkDataFrame by the specified column(s). #' -#' @param x A SparkDataFrame to be sorted. -#' @param col A character or Column object vector indicating the fields to sort on -#' @param ... Additional sorting fields -#' @param decreasing A logical argument indicating sorting order for columns when +#' @param x a SparkDataFrame to be sorted. +#' @param col a character or Column object indicating the fields to sort on +#' @param ... additional sorting fields +#' @param decreasing a logical argument indicating sorting order for columns when #' a character vector is specified for col #' @return A SparkDataFrame where all elements are sorted. #' @family SparkDataFrame functions @@ -2116,7 +2141,6 @@ setMethod("arrange", }) #' @rdname arrange -#' @name orderBy #' @aliases orderBy,SparkDataFrame,characterOrColumn-method #' @export #' @note orderBy(SparkDataFrame, characterOrColumn) since 1.4.0 @@ -2275,11 +2299,18 @@ setMethod("join", #' specified, the common column names in \code{x} and \code{y} will be used. #' @param by.x a character vector specifying the joining columns for x. #' @param by.y a character vector specifying the joining columns for y. +#' @param all a boolean value setting \code{all.x} and \code{all.y} +#' if any of them are unset. #' @param all.x a boolean value indicating whether all the rows in x should #' be including in the join #' @param all.y a boolean value indicating whether all the rows in y should #' be including in the join #' @param sort a logical argument indicating whether the resulting columns should be sorted +#' @param suffixes a string vector of length 2 used to make colnames of +#' \code{x} and \code{y} unique. +#' The first element is appended to each colname of \code{x}. +#' The second element is appended to each colname of \code{y}. +#' @param ... additional argument(s) passed to the method. #' @details If all.x and all.y are set to FALSE, a natural join will be returned. If #' all.x is set to TRUE and all.y is set to FALSE, a left outer join will #' be returned. If all.x is set to FALSE and all.y is set to TRUE, a right @@ -2308,7 +2339,7 @@ setMethod("merge", signature(x = "SparkDataFrame", y = "SparkDataFrame"), function(x, y, by = intersect(names(x), names(y)), by.x = by, by.y = by, all = FALSE, all.x = all, all.y = all, - sort = TRUE, suffixes = c("_x", "_y"), ... ) { + sort = TRUE, suffixes = c("_x", "_y"), ...) { if (length(suffixes) != 2) { stop("suffixes must have length 2") @@ -2415,7 +2446,7 @@ generateAliasesForIntersectedCols <- function (x, intersectedColNames, suffix) { #' Return a new SparkDataFrame containing the union of rows #' #' Return a new SparkDataFrame containing the union of rows in this SparkDataFrame -#' and another SparkDataFrame. This is equivalent to `UNION ALL` in SQL. +#' and another SparkDataFrame. This is equivalent to \code{UNION ALL} in SQL. #' Note that this does not remove duplicate rows across the two SparkDataFrames. #' #' @param x A SparkDataFrame @@ -2458,11 +2489,13 @@ setMethod("unionAll", #' Union two or more SparkDataFrames #' -#' Union two or more SparkDataFrames. This is equivalent to `UNION ALL` in SQL. +#' Union two or more SparkDataFrames. This is equivalent to \code{UNION ALL} in SQL. #' Note that this does not remove duplicate rows across the two SparkDataFrames. #' -#' @param x A SparkDataFrame -#' @param ... Additional SparkDataFrame +#' @param x a SparkDataFrame. +#' @param ... additional SparkDataFrame(s). +#' @param deparse.level currently not used (put here to match the signature of +#' the base implementation). #' @return A SparkDataFrame containing the result of the union. #' @family SparkDataFrame functions #' @aliases rbind,SparkDataFrame-method @@ -2489,7 +2522,7 @@ setMethod("rbind", #' Intersect #' #' Return a new SparkDataFrame containing rows only in both this SparkDataFrame -#' and another SparkDataFrame. This is equivalent to `INTERSECT` in SQL. +#' and another SparkDataFrame. This is equivalent to \code{INTERSECT} in SQL. #' #' @param x A SparkDataFrame #' @param y A SparkDataFrame @@ -2517,10 +2550,10 @@ setMethod("intersect", #' except #' #' Return a new SparkDataFrame containing rows in this SparkDataFrame -#' but not in another SparkDataFrame. This is equivalent to `EXCEPT` in SQL. +#' but not in another SparkDataFrame. This is equivalent to \code{EXCEPT} in SQL. #' -#' @param x A SparkDataFrame -#' @param y A SparkDataFrame +#' @param x a SparkDataFrame. +#' @param y a SparkDataFrame. #' @return A SparkDataFrame containing the result of the except operation. #' @family SparkDataFrame functions #' @aliases except,SparkDataFrame,SparkDataFrame-method @@ -2546,8 +2579,8 @@ setMethod("except", #' Save the contents of SparkDataFrame to a data source. #' -#' The data source is specified by the `source` and a set of options (...). -#' If `source` is not specified, the default data source configured by +#' The data source is specified by the \code{source} and a set of options (...). +#' If \code{source} is not specified, the default data source configured by #' spark.sql.sources.default will be used. #' #' Additionally, mode is used to specify the behavior of the save operation when data already @@ -2561,10 +2594,11 @@ setMethod("except", #' and to not change the existing data. #' } #' -#' @param df A SparkDataFrame -#' @param path A name for the table -#' @param source A name for external data source -#' @param mode One of 'append', 'overwrite', 'error', 'ignore' save mode (it is 'error' by default) +#' @param df a SparkDataFrame. +#' @param path a name for the table. +#' @param source a name for external data source. +#' @param mode one of 'append', 'overwrite', 'error', 'ignore' save mode (it is 'error' by default) +#' @param ... additional argument(s) passed to the method. #' #' @family SparkDataFrame functions #' @aliases write.df,SparkDataFrame,character-method @@ -2582,7 +2616,7 @@ setMethod("except", #' @note write.df since 1.4.0 setMethod("write.df", signature(df = "SparkDataFrame", path = "character"), - function(df, path, source = NULL, mode = "error", ...){ + function(df, path, source = NULL, mode = "error", ...) { if (is.null(source)) { source <- getDefaultSqlSource() } @@ -2594,6 +2628,7 @@ setMethod("write.df", write <- callJMethod(df@sdf, "write") write <- callJMethod(write, "format", source) write <- callJMethod(write, "mode", jmode) + write <- callJMethod(write, "options", options) write <- callJMethod(write, "save", path) }) @@ -2604,14 +2639,14 @@ setMethod("write.df", #' @note saveDF since 1.4.0 setMethod("saveDF", signature(df = "SparkDataFrame", path = "character"), - function(df, path, source = NULL, mode = "error", ...){ + function(df, path, source = NULL, mode = "error", ...) { write.df(df, path, source, mode, ...) }) #' Save the contents of the SparkDataFrame to a data source as a table #' -#' The data source is specified by the `source` and a set of options (...). -#' If `source` is not specified, the default data source configured by +#' The data source is specified by the \code{source} and a set of options (...). +#' If \code{source} is not specified, the default data source configured by #' spark.sql.sources.default will be used. #' #' Additionally, mode is used to specify the behavior of the save operation when @@ -2623,10 +2658,11 @@ setMethod("saveDF", #' ignore: The save operation is expected to not save the contents of the SparkDataFrame #' and to not change the existing data. \cr #' -#' @param df A SparkDataFrame -#' @param tableName A name for the table -#' @param source A name for external data source -#' @param mode One of 'append', 'overwrite', 'error', 'ignore' save mode (it is 'error' by default) +#' @param df a SparkDataFrame. +#' @param tableName a name for the table. +#' @param source a name for external data source. +#' @param mode one of 'append', 'overwrite', 'error', 'ignore' save mode (it is 'error' by default). +#' @param ... additional option(s) passed to the method. #' #' @family SparkDataFrame functions #' @aliases saveAsTable,SparkDataFrame,character-method @@ -2643,7 +2679,7 @@ setMethod("saveDF", #' @note saveAsTable since 1.4.0 setMethod("saveAsTable", signature(df = "SparkDataFrame", tableName = "character"), - function(df, tableName, source = NULL, mode="error", ...){ + function(df, tableName, source = NULL, mode="error", ...) { if (is.null(source)) { source <- getDefaultSqlSource() } @@ -2662,10 +2698,10 @@ setMethod("saveAsTable", #' Computes statistics for numeric columns. #' If no columns are given, this function computes statistics for all numerical columns. #' -#' @param x A SparkDataFrame to be computed. -#' @param col A string of name -#' @param ... Additional expressions -#' @return A SparkDataFrame +#' @param x a SparkDataFrame to be computed. +#' @param col a string of name. +#' @param ... additional expressions. +#' @return A SparkDataFrame. #' @family SparkDataFrame functions #' @aliases describe,SparkDataFrame,character-method describe,SparkDataFrame,ANY-method #' @rdname summary @@ -2700,6 +2736,7 @@ setMethod("describe", dataFrame(sdf) }) +#' @param object a SparkDataFrame to be summarized. #' @rdname summary #' @name summary #' @aliases summary,SparkDataFrame-method @@ -2715,16 +2752,20 @@ setMethod("summary", #' #' dropna, na.omit - Returns a new SparkDataFrame omitting rows with null values. #' -#' @param x A SparkDataFrame. +#' @param x a SparkDataFrame. #' @param how "any" or "all". #' if "any", drop a row if it contains any nulls. #' if "all", drop a row only if all its values are null. -#' if minNonNulls is specified, how is ignored. -#' @param minNonNulls If specified, drop rows that have less than -#' minNonNulls non-null values. +#' if \code{minNonNulls} is specified, how is ignored. +#' @param minNonNulls if specified, drop rows that have less than +#' \code{minNonNulls} non-null values. #' This overwrites the how parameter. -#' @param cols Optional list of column names to consider. -#' @return A SparkDataFrame +#' @param cols optional list of column names to consider. In \code{fillna}, +#' columns specified in cols that do not have matching data +#' type are ignored. For example, if value is a character, and +#' subset contains a non-character column, then the non-character +#' column is simply ignored. +#' @return A SparkDataFrame. #' #' @family SparkDataFrame functions #' @rdname nafunctions @@ -2756,6 +2797,8 @@ setMethod("dropna", dataFrame(sdf) }) +#' @param object a SparkDataFrame. +#' @param ... further arguments to be passed to or from other methods. #' @rdname nafunctions #' @name na.omit #' @aliases na.omit,SparkDataFrame-method @@ -2769,18 +2812,12 @@ setMethod("na.omit", #' fillna - Replace null values. #' -#' @param x A SparkDataFrame. -#' @param value Value to replace null values with. +#' @param value value to replace null values with. #' Should be an integer, numeric, character or named list. #' If the value is a named list, then cols is ignored and #' value must be a mapping from column name (character) to #' replacement value. The replacement value must be an #' integer, numeric or character. -#' @param cols optional list of column names to consider. -#' Columns specified in cols that do not have matching data -#' type are ignored. For example, if value is a character, and -#' subset contains a non-character column, then the non-character -#' column is simply ignored. #' #' @rdname nafunctions #' @name fillna @@ -2845,8 +2882,11 @@ setMethod("fillna", #' Since data.frames are held in memory, ensure that you have enough memory #' in your system to accommodate the contents. #' -#' @param x a SparkDataFrame -#' @return a data.frame +#' @param x a SparkDataFrame. +#' @param row.names \code{NULL} or a character vector giving the row names for the data frame. +#' @param optional If \code{TRUE}, converting column names is optional. +#' @param ... additional arguments to pass to base::as.data.frame. +#' @return A data.frame. #' @family SparkDataFrame functions #' @aliases as.data.frame,SparkDataFrame-method #' @rdname as.data.frame @@ -3000,9 +3040,10 @@ setMethod("str", #' Returns a new SparkDataFrame with columns dropped. #' This is a no-op if schema doesn't contain column name(s). #' -#' @param x A SparkDataFrame. -#' @param cols A character vector of column names or a Column. -#' @return A SparkDataFrame +#' @param x a SparkDataFrame. +#' @param col a character vector of column names or a Column. +#' @param ... further arguments to be passed to or from other methods. +#' @return A SparkDataFrame. #' #' @family SparkDataFrame functions #' @rdname drop @@ -3049,8 +3090,8 @@ setMethod("drop", #' #' @name histogram #' @param nbins the number of bins (optional). Default value is 10. +#' @param col the column as Character string or a Column to build the histogram from. #' @param df the SparkDataFrame containing the Column to build the histogram from. -#' @param colname the name of the column to build the histogram from. #' @return a data.frame with the histogram statistics, i.e., counts and centroids. #' @rdname histogram #' @aliases histogram,SparkDataFrame,characterOrColumn-method @@ -3181,10 +3222,11 @@ setMethod("histogram", #' and to not change the existing data. #' } #' -#' @param x A SparkDataFrame -#' @param url JDBC database url of the form `jdbc:subprotocol:subname` -#' @param tableName The name of the table in the external database -#' @param mode One of 'append', 'overwrite', 'error', 'ignore' save mode (it is 'error' by default) +#' @param x a SparkDataFrame. +#' @param url JDBC database url of the form \code{jdbc:subprotocol:subname}. +#' @param tableName yhe name of the table in the external database. +#' @param mode one of 'append', 'overwrite', 'error', 'ignore' save mode (it is 'error' by default). +#' @param ... additional JDBC database connection properties. #' @family SparkDataFrame functions #' @rdname write.jdbc #' @name write.jdbc @@ -3199,7 +3241,7 @@ setMethod("histogram", #' @note write.jdbc since 2.0.0 setMethod("write.jdbc", signature(x = "SparkDataFrame", url = "character", tableName = "character"), - function(x, url, tableName, mode = "error", ...){ + function(x, url, tableName, mode = "error", ...) { jmode <- convertToJSaveMode(mode) jprops <- varargsToJProperties(...) write <- callJMethod(x@sdf, "write") diff --git a/R/pkg/R/RDD.R b/R/pkg/R/RDD.R index 72a805256523..6cd0704003f1 100644 --- a/R/pkg/R/RDD.R +++ b/R/pkg/R/RDD.R @@ -67,7 +67,7 @@ setMethod("initialize", "RDD", function(.Object, jrdd, serializedMode, .Object }) -setMethod("show", "RDD", +setMethod("showRDD", "RDD", function(object) { cat(paste(callJMethod(getJRDD(object), "toString"), "\n", sep = "")) }) @@ -215,7 +215,7 @@ setValidity("RDD", #' @rdname cache-methods #' @aliases cache,RDD-method #' @noRd -setMethod("cache", +setMethod("cacheRDD", signature(x = "RDD"), function(x) { callJMethod(getJRDD(x), "cache") @@ -235,12 +235,12 @@ setMethod("cache", #'\dontrun{ #' sc <- sparkR.init() #' rdd <- parallelize(sc, 1:10, 2L) -#' persist(rdd, "MEMORY_AND_DISK") +#' persistRDD(rdd, "MEMORY_AND_DISK") #'} #' @rdname persist #' @aliases persist,RDD-method #' @noRd -setMethod("persist", +setMethod("persistRDD", signature(x = "RDD", newLevel = "character"), function(x, newLevel = "MEMORY_ONLY") { callJMethod(getJRDD(x), "persist", getStorageLevel(newLevel)) @@ -259,12 +259,12 @@ setMethod("persist", #' sc <- sparkR.init() #' rdd <- parallelize(sc, 1:10, 2L) #' cache(rdd) # rdd@@env$isCached == TRUE -#' unpersist(rdd) # rdd@@env$isCached == FALSE +#' unpersistRDD(rdd) # rdd@@env$isCached == FALSE #'} #' @rdname unpersist-methods #' @aliases unpersist,RDD-method #' @noRd -setMethod("unpersist", +setMethod("unpersistRDD", signature(x = "RDD"), function(x) { callJMethod(getJRDD(x), "unpersist") @@ -345,13 +345,13 @@ setMethod("numPartitions", #'\dontrun{ #' sc <- sparkR.init() #' rdd <- parallelize(sc, 1:10, 2L) -#' collect(rdd) # list from 1 to 10 +#' collectRDD(rdd) # list from 1 to 10 #' collectPartition(rdd, 0L) # list from 1 to 5 #'} #' @rdname collect-methods #' @aliases collect,RDD-method #' @noRd -setMethod("collect", +setMethod("collectRDD", signature(x = "RDD"), function(x, flatten = TRUE) { # Assumes a pairwise RDD is backed by a JavaPairRDD. @@ -397,7 +397,7 @@ setMethod("collectPartition", setMethod("collectAsMap", signature(x = "RDD"), function(x) { - pairList <- collect(x) + pairList <- collectRDD(x) map <- new.env() lapply(pairList, function(i) { assign(as.character(i[[1]]), i[[2]], envir = map) }) as.list(map) @@ -411,30 +411,30 @@ setMethod("collectAsMap", #'\dontrun{ #' sc <- sparkR.init() #' rdd <- parallelize(sc, 1:10) -#' count(rdd) # 10 +#' countRDD(rdd) # 10 #' length(rdd) # Same as count #'} #' @rdname count #' @aliases count,RDD-method #' @noRd -setMethod("count", +setMethod("countRDD", signature(x = "RDD"), function(x) { countPartition <- function(part) { as.integer(length(part)) } valsRDD <- lapplyPartition(x, countPartition) - vals <- collect(valsRDD) + vals <- collectRDD(valsRDD) sum(as.integer(vals)) }) #' Return the number of elements in the RDD #' @rdname count #' @noRd -setMethod("length", +setMethod("lengthRDD", signature(x = "RDD"), function(x) { - count(x) + countRDD(x) }) #' Return the count of each unique value in this RDD as a list of @@ -460,7 +460,7 @@ setMethod("countByValue", signature(x = "RDD"), function(x) { ones <- lapply(x, function(item) { list(item, 1L) }) - collect(reduceByKey(ones, `+`, getNumPartitions(x))) + collectRDD(reduceByKey(ones, `+`, getNumPartitions(x))) }) #' Apply a function to all elements @@ -479,7 +479,7 @@ setMethod("countByValue", #' sc <- sparkR.init() #' rdd <- parallelize(sc, 1:10) #' multiplyByTwo <- lapply(rdd, function(x) { x * 2 }) -#' collect(multiplyByTwo) # 2,4,6... +#' collectRDD(multiplyByTwo) # 2,4,6... #'} setMethod("lapply", signature(X = "RDD", FUN = "function"), @@ -512,7 +512,7 @@ setMethod("map", #' sc <- sparkR.init() #' rdd <- parallelize(sc, 1:10) #' multiplyByTwo <- flatMap(rdd, function(x) { list(x*2, x*10) }) -#' collect(multiplyByTwo) # 2,20,4,40,6,60... +#' collectRDD(multiplyByTwo) # 2,20,4,40,6,60... #'} #' @rdname flatMap #' @aliases flatMap,RDD,function-method @@ -541,7 +541,7 @@ setMethod("flatMap", #' sc <- sparkR.init() #' rdd <- parallelize(sc, 1:10) #' partitionSum <- lapplyPartition(rdd, function(part) { Reduce("+", part) }) -#' collect(partitionSum) # 15, 40 +#' collectRDD(partitionSum) # 15, 40 #'} #' @rdname lapplyPartition #' @aliases lapplyPartition,RDD,function-method @@ -576,7 +576,7 @@ setMethod("mapPartitions", #' rdd <- parallelize(sc, 1:10, 5L) #' prod <- lapplyPartitionsWithIndex(rdd, function(partIndex, part) { #' partIndex * Reduce("+", part) }) -#' collect(prod, flatten = FALSE) # 0, 7, 22, 45, 76 +#' collectRDD(prod, flatten = FALSE) # 0, 7, 22, 45, 76 #'} #' @rdname lapplyPartitionsWithIndex #' @aliases lapplyPartitionsWithIndex,RDD,function-method @@ -607,7 +607,7 @@ setMethod("mapPartitionsWithIndex", #'\dontrun{ #' sc <- sparkR.init() #' rdd <- parallelize(sc, 1:10) -#' unlist(collect(filterRDD(rdd, function (x) { x < 3 }))) # c(1, 2) +#' unlist(collectRDD(filterRDD(rdd, function (x) { x < 3 }))) # c(1, 2) #'} # nolint end #' @rdname filterRDD @@ -656,7 +656,7 @@ setMethod("reduce", Reduce(func, part) } - partitionList <- collect(lapplyPartition(x, reducePartition), + partitionList <- collectRDD(lapplyPartition(x, reducePartition), flatten = FALSE) Reduce(func, partitionList) }) @@ -736,7 +736,7 @@ setMethod("foreach", lapply(x, func) NULL } - invisible(collect(mapPartitions(x, partition.func))) + invisible(collectRDD(mapPartitions(x, partition.func))) }) #' Applies a function to each partition in an RDD, and forces evaluation. @@ -753,7 +753,7 @@ setMethod("foreach", setMethod("foreachPartition", signature(x = "RDD", func = "function"), function(x, func) { - invisible(collect(mapPartitions(x, func))) + invisible(collectRDD(mapPartitions(x, func))) }) #' Take elements from an RDD. @@ -768,13 +768,13 @@ setMethod("foreachPartition", #'\dontrun{ #' sc <- sparkR.init() #' rdd <- parallelize(sc, 1:10) -#' take(rdd, 2L) # list(1, 2) +#' takeRDD(rdd, 2L) # list(1, 2) #'} # nolint end #' @rdname take #' @aliases take,RDD,numeric-method #' @noRd -setMethod("take", +setMethod("takeRDD", signature(x = "RDD", num = "numeric"), function(x, num) { resList <- list() @@ -817,13 +817,13 @@ setMethod("take", #'\dontrun{ #' sc <- sparkR.init() #' rdd <- parallelize(sc, 1:10) -#' first(rdd) +#' firstRDD(rdd) #' } #' @noRd -setMethod("first", +setMethod("firstRDD", signature(x = "RDD"), function(x) { - take(x, 1)[[1]] + takeRDD(x, 1)[[1]] }) #' Removes the duplicates from RDD. @@ -838,13 +838,13 @@ setMethod("first", #'\dontrun{ #' sc <- sparkR.init() #' rdd <- parallelize(sc, c(1,2,2,3,3,3)) -#' sort(unlist(collect(distinct(rdd)))) # c(1, 2, 3) +#' sort(unlist(collectRDD(distinctRDD(rdd)))) # c(1, 2, 3) #'} # nolint end #' @rdname distinct #' @aliases distinct,RDD-method #' @noRd -setMethod("distinct", +setMethod("distinctRDD", signature(x = "RDD"), function(x, numPartitions = SparkR:::getNumPartitions(x)) { identical.mapped <- lapply(x, function(x) { list(x, NULL) }) @@ -868,8 +868,8 @@ setMethod("distinct", #'\dontrun{ #' sc <- sparkR.init() #' rdd <- parallelize(sc, 1:10) -#' collect(sampleRDD(rdd, FALSE, 0.5, 1618L)) # ~5 distinct elements -#' collect(sampleRDD(rdd, TRUE, 0.5, 9L)) # ~5 elements possibly with duplicates +#' collectRDD(sampleRDD(rdd, FALSE, 0.5, 1618L)) # ~5 distinct elements +#' collectRDD(sampleRDD(rdd, TRUE, 0.5, 9L)) # ~5 elements possibly with duplicates #'} #' @rdname sampleRDD #' @aliases sampleRDD,RDD @@ -887,17 +887,17 @@ setMethod("sampleRDD", # Discards some random values to ensure each partition has a # different random seed. - runif(partIndex) + stats::runif(partIndex) for (elem in part) { if (withReplacement) { - count <- rpois(1, fraction) + count <- stats::rpois(1, fraction) if (count > 0) { res[ (len + 1) : (len + count) ] <- rep(list(elem), count) len <- len + count } } else { - if (runif(1) < fraction) { + if (stats::runif(1) < fraction) { len <- len + 1 res[[len]] <- elem } @@ -942,7 +942,7 @@ setMethod("takeSample", signature(x = "RDD", withReplacement = "logical", fraction <- 0.0 total <- 0 multiplier <- 3.0 - initialCount <- count(x) + initialCount <- countRDD(x) maxSelected <- 0 MAXINT <- .Machine$integer.max @@ -964,16 +964,16 @@ setMethod("takeSample", signature(x = "RDD", withReplacement = "logical", } set.seed(seed) - samples <- collect(sampleRDD(x, withReplacement, fraction, - as.integer(ceiling(runif(1, + samples <- collectRDD(sampleRDD(x, withReplacement, fraction, + as.integer(ceiling(stats::runif(1, -MAXINT, MAXINT))))) # If the first sample didn't turn out large enough, keep trying to # take samples; this shouldn't happen often because we use a big # multiplier for thei initial size while (length(samples) < total) - samples <- collect(sampleRDD(x, withReplacement, fraction, - as.integer(ceiling(runif(1, + samples <- collectRDD(sampleRDD(x, withReplacement, fraction, + as.integer(ceiling(stats::runif(1, -MAXINT, MAXINT))))) @@ -990,7 +990,7 @@ setMethod("takeSample", signature(x = "RDD", withReplacement = "logical", #'\dontrun{ #' sc <- sparkR.init() #' rdd <- parallelize(sc, list(1, 2, 3)) -#' collect(keyBy(rdd, function(x) { x*x })) # list(list(1, 1), list(4, 2), list(9, 3)) +#' collectRDD(keyBy(rdd, function(x) { x*x })) # list(list(1, 1), list(4, 2), list(9, 3)) #'} # nolint end #' @rdname keyBy @@ -1019,12 +1019,12 @@ setMethod("keyBy", #' sc <- sparkR.init() #' rdd <- parallelize(sc, list(1, 2, 3, 4, 5, 6, 7), 4L) #' getNumPartitions(rdd) # 4 -#' getNumPartitions(repartition(rdd, 2L)) # 2 +#' getNumPartitions(repartitionRDD(rdd, 2L)) # 2 #'} #' @rdname repartition #' @aliases repartition,RDD #' @noRd -setMethod("repartition", +setMethod("repartitionRDD", signature(x = "RDD"), function(x, numPartitions) { if (!is.null(numPartitions) && is.numeric(numPartitions)) { @@ -1064,7 +1064,7 @@ setMethod("coalesce", }) } shuffled <- lapplyPartitionsWithIndex(x, func) - repartitioned <- partitionBy(shuffled, numPartitions) + repartitioned <- partitionByRDD(shuffled, numPartitions) values(repartitioned) } else { jrdd <- callJMethod(getJRDD(x), "coalesce", numPartitions, shuffle) @@ -1135,7 +1135,7 @@ setMethod("saveAsTextFile", #'\dontrun{ #' sc <- sparkR.init() #' rdd <- parallelize(sc, list(3, 2, 1)) -#' collect(sortBy(rdd, function(x) { x })) # list (1, 2, 3) +#' collectRDD(sortBy(rdd, function(x) { x })) # list (1, 2, 3) #'} # nolint end #' @rdname sortBy @@ -1304,7 +1304,7 @@ setMethod("aggregateRDD", Reduce(seqOp, part, zeroValue) } - partitionList <- collect(lapplyPartition(x, partitionFunc), + partitionList <- collectRDD(lapplyPartition(x, partitionFunc), flatten = FALSE) Reduce(combOp, partitionList, zeroValue) }) @@ -1322,7 +1322,7 @@ setMethod("aggregateRDD", #'\dontrun{ #' sc <- sparkR.init() #' rdd <- parallelize(sc, 1:10) -#' collect(pipeRDD(rdd, "more") +#' pipeRDD(rdd, "more") #' Output: c("1", "2", ..., "10") #'} #' @aliases pipeRDD,RDD,character-method @@ -1397,7 +1397,7 @@ setMethod("setName", #'\dontrun{ #' sc <- sparkR.init() #' rdd <- parallelize(sc, list("a", "b", "c", "d", "e"), 3L) -#' collect(zipWithUniqueId(rdd)) +#' collectRDD(zipWithUniqueId(rdd)) #' # list(list("a", 0), list("b", 3), list("c", 1), list("d", 4), list("e", 2)) #'} # nolint end @@ -1440,7 +1440,7 @@ setMethod("zipWithUniqueId", #'\dontrun{ #' sc <- sparkR.init() #' rdd <- parallelize(sc, list("a", "b", "c", "d", "e"), 3L) -#' collect(zipWithIndex(rdd)) +#' collectRDD(zipWithIndex(rdd)) #' # list(list("a", 0), list("b", 1), list("c", 2), list("d", 3), list("e", 4)) #'} # nolint end @@ -1452,7 +1452,7 @@ setMethod("zipWithIndex", function(x) { n <- getNumPartitions(x) if (n > 1) { - nums <- collect(lapplyPartition(x, + nums <- collectRDD(lapplyPartition(x, function(part) { list(length(part)) })) @@ -1488,7 +1488,7 @@ setMethod("zipWithIndex", #'\dontrun{ #' sc <- sparkR.init() #' rdd <- parallelize(sc, as.list(1:4), 2L) -#' collect(glom(rdd)) +#' collectRDD(glom(rdd)) #' # list(list(1, 2), list(3, 4)) #'} # nolint end @@ -1556,7 +1556,7 @@ setMethod("unionRDD", #' sc <- sparkR.init() #' rdd1 <- parallelize(sc, 0:4) #' rdd2 <- parallelize(sc, 1000:1004) -#' collect(zipRDD(rdd1, rdd2)) +#' collectRDD(zipRDD(rdd1, rdd2)) #' # list(list(0, 1000), list(1, 1001), list(2, 1002), list(3, 1003), list(4, 1004)) #'} # nolint end @@ -1628,7 +1628,7 @@ setMethod("cartesian", #' sc <- sparkR.init() #' rdd1 <- parallelize(sc, list(1, 1, 2, 2, 3, 4)) #' rdd2 <- parallelize(sc, list(2, 4)) -#' collect(subtract(rdd1, rdd2)) +#' collectRDD(subtract(rdd1, rdd2)) #' # list(1, 1, 3) #'} # nolint end @@ -1662,7 +1662,7 @@ setMethod("subtract", #' sc <- sparkR.init() #' rdd1 <- parallelize(sc, list(1, 10, 2, 3, 4, 5)) #' rdd2 <- parallelize(sc, list(1, 6, 2, 3, 7, 8)) -#' collect(sortBy(intersection(rdd1, rdd2), function(x) { x })) +#' collectRDD(sortBy(intersection(rdd1, rdd2), function(x) { x })) #' # list(1, 2, 3) #'} # nolint end @@ -1699,7 +1699,7 @@ setMethod("intersection", #' rdd1 <- parallelize(sc, 1:2, 2L) # 1, 2 #' rdd2 <- parallelize(sc, 1:4, 2L) # 1:2, 3:4 #' rdd3 <- parallelize(sc, 1:6, 2L) # 1:3, 4:6 -#' collect(zipPartitions(rdd1, rdd2, rdd3, +#' collectRDD(zipPartitions(rdd1, rdd2, rdd3, #' func = function(x, y, z) { list(list(x, y, z))} )) #' # list(list(1, c(1,2), c(1,2,3)), list(2, c(3,4), c(4,5,6))) #'} diff --git a/R/pkg/R/SQLContext.R b/R/pkg/R/SQLContext.R index a14bcd91b3ea..ce531c3f8886 100644 --- a/R/pkg/R/SQLContext.R +++ b/R/pkg/R/SQLContext.R @@ -115,7 +115,7 @@ infer_type <- function(x) { #' Get Runtime Config from the current active SparkSession #' #' Get Runtime Config from the current active SparkSession. -#' To change SparkSession Runtime Config, please see `sparkR.session()`. +#' To change SparkSession Runtime Config, please see \code{sparkR.session()}. #' #' @param key (optional) The key of the config to get, if omitted, all config is returned #' @param defaultValue (optional) The default value of the config to return if they config is not @@ -156,6 +156,25 @@ sparkR.conf <- function(key, defaultValue) { } } +#' Get version of Spark on which this application is running +#' +#' Get version of Spark on which this application is running. +#' +#' @return a character string of the Spark version +#' @rdname sparkR.version +#' @name sparkR.version +#' @export +#' @examples +#'\dontrun{ +#' sparkR.session() +#' version <- sparkR.version() +#' } +#' @note sparkR.version since 2.0.1 +sparkR.version <- function() { + sparkSession <- getSparkSession() + callJMethod(sparkSession, "version") +} + getDefaultSqlSource <- function() { l <- sparkR.conf("spark.sql.sources.default", "org.apache.spark.sql.parquet") l[["spark.sql.sources.default"]] @@ -165,9 +184,9 @@ getDefaultSqlSource <- function() { #' #' Converts R data.frame or list into SparkDataFrame. #' -#' @param data An RDD or list or data.frame -#' @param schema a list of column names or named list (StructType), optional -#' @return a SparkDataFrame +#' @param data an RDD or list or data.frame. +#' @param schema a list of column names or named list (StructType), optional. +#' @return A SparkDataFrame. #' @rdname createDataFrame #' @export #' @examples @@ -183,7 +202,10 @@ getDefaultSqlSource <- function() { # TODO(davies): support sampling and infer type from NA createDataFrame.default <- function(data, schema = NULL, samplingRatio = 1.0) { sparkSession <- getSparkSession() + if (is.data.frame(data)) { + # Convert data into a list of rows. Each row is a list. + # get the names of columns, they will be put into RDD if (is.null(schema)) { schema <- names(data) @@ -208,6 +230,7 @@ createDataFrame.default <- function(data, schema = NULL, samplingRatio = 1.0) { args <- list(FUN = list, SIMPLIFY = FALSE, USE.NAMES = FALSE) data <- do.call(mapply, append(args, data)) } + if (is.list(data)) { sc <- callJStatic("org.apache.spark.sql.api.r.SQLUtils", "getJavaSparkContext", sparkSession) rdd <- parallelize(sc, data) @@ -218,7 +241,7 @@ createDataFrame.default <- function(data, schema = NULL, samplingRatio = 1.0) { } if (is.null(schema) || (!inherits(schema, "structType") && is.null(names(schema)))) { - row <- first(rdd) + row <- firstRDD(rdd) names <- if (is.null(schema)) { names(row) } else { @@ -257,23 +280,25 @@ createDataFrame.default <- function(data, schema = NULL, samplingRatio = 1.0) { } createDataFrame <- function(x, ...) { - dispatchFunc("createDataFrame(data, schema = NULL, samplingRatio = 1.0)", x, ...) + dispatchFunc("createDataFrame(data, schema = NULL)", x, ...) } +#' @param samplingRatio Currently not used. #' @rdname createDataFrame #' @aliases createDataFrame #' @export #' @method as.DataFrame default #' @note as.DataFrame since 1.6.0 as.DataFrame.default <- function(data, schema = NULL, samplingRatio = 1.0) { - createDataFrame(data, schema, samplingRatio) + createDataFrame(data, schema) } +#' @param ... additional argument(s). #' @rdname createDataFrame #' @aliases as.DataFrame #' @export -as.DataFrame <- function(x, ...) { - dispatchFunc("as.DataFrame(data, schema = NULL, samplingRatio = 1.0)", x, ...) +as.DataFrame <- function(data, ...) { + dispatchFunc("as.DataFrame(data, schema = NULL)", data, ...) } #' toDF @@ -398,7 +423,7 @@ read.orc <- function(path) { #' #' Loads a Parquet file, returning the result as a SparkDataFrame. #' -#' @param path Path of file to read. A vector of multiple paths is allowed. +#' @param path path of file to read. A vector of multiple paths is allowed. #' @return SparkDataFrame #' @rdname read.parquet #' @export @@ -418,6 +443,7 @@ read.parquet <- function(x, ...) { dispatchFunc("read.parquet(...)", x, ...) } +#' @param ... argument(s) passed to the method. #' @rdname read.parquet #' @name parquetFile #' @export @@ -717,16 +743,17 @@ dropTempView <- function(viewName) { #' #' Returns the dataset in a data source as a SparkDataFrame #' -#' The data source is specified by the `source` and a set of options(...). -#' If `source` is not specified, the default data source configured by +#' The data source is specified by the \code{source} and a set of options(...). +#' If \code{source} is not specified, the default data source configured by #' "spark.sql.sources.default" will be used. \cr -#' Similar to R read.csv, when `source` is "csv", by default, a value of "NA" will be interpreted -#' as NA. +#' Similar to R read.csv, when \code{source} is "csv", by default, a value of "NA" will be +#' interpreted as NA. #' #' @param path The path of files to load #' @param source The name of external data source #' @param schema The data schema defined in structType #' @param na.strings Default string value for NA when source is "csv" +#' @param ... additional external data source specific named properties. #' @return SparkDataFrame #' @rdname read.df #' @name read.df @@ -787,14 +814,15 @@ loadDF <- function(x, ...) { #' Creates an external table based on the dataset in a data source, #' Returns a SparkDataFrame associated with the external table. #' -#' The data source is specified by the `source` and a set of options(...). -#' If `source` is not specified, the default data source configured by +#' The data source is specified by the \code{source} and a set of options(...). +#' If \code{source} is not specified, the default data source configured by #' "spark.sql.sources.default" will be used. #' -#' @param tableName A name of the table -#' @param path The path of files to load -#' @param source the name of external data source -#' @return SparkDataFrame +#' @param tableName a name of the table. +#' @param path the path of files to load. +#' @param source the name of external data source. +#' @param ... additional argument(s) passed to the method. +#' @return A SparkDataFrame. #' @rdname createExternalTable #' @export #' @examples @@ -825,21 +853,22 @@ createExternalTable <- function(x, ...) { #' Additional JDBC database connection properties can be set (...) #' #' Only one of partitionColumn or predicates should be set. Partitions of the table will be -#' retrieved in parallel based on the `numPartitions` or by the predicates. +#' retrieved in parallel based on the \code{numPartitions} or by the predicates. #' #' Don't create too many partitions in parallel on a large cluster; otherwise Spark might crash #' your external database systems. #' -#' @param url JDBC database url of the form `jdbc:subprotocol:subname` +#' @param url JDBC database url of the form \code{jdbc:subprotocol:subname} #' @param tableName the name of the table in the external database #' @param partitionColumn the name of a column of integral type that will be used for partitioning -#' @param lowerBound the minimum value of `partitionColumn` used to decide partition stride -#' @param upperBound the maximum value of `partitionColumn` used to decide partition stride -#' @param numPartitions the number of partitions, This, along with `lowerBound` (inclusive), -#' `upperBound` (exclusive), form partition strides for generated WHERE -#' clause expressions used to split the column `partitionColumn` evenly. +#' @param lowerBound the minimum value of \code{partitionColumn} used to decide partition stride +#' @param upperBound the maximum value of \code{partitionColumn} used to decide partition stride +#' @param numPartitions the number of partitions, This, along with \code{lowerBound} (inclusive), +#' \code{upperBound} (exclusive), form partition strides for generated WHERE +#' clause expressions used to split the column \code{partitionColumn} evenly. #' This defaults to SparkContext.defaultParallelism when unset. #' @param predicates a list of conditions in the where clause; each one defines one partition +#' @param ... additional JDBC database connection named properties. #' @return SparkDataFrame #' @rdname read.jdbc #' @name read.jdbc diff --git a/R/pkg/R/WindowSpec.R b/R/pkg/R/WindowSpec.R index 474638009624..4ac83c29c6f7 100644 --- a/R/pkg/R/WindowSpec.R +++ b/R/pkg/R/WindowSpec.R @@ -44,6 +44,7 @@ windowSpec <- function(sws) { } #' @rdname show +#' @export #' @note show(WindowSpec) since 2.0.0 setMethod("show", "WindowSpec", function(object) { @@ -54,8 +55,10 @@ setMethod("show", "WindowSpec", #' #' Defines the partitioning columns in a WindowSpec. #' -#' @param x a WindowSpec -#' @return a WindowSpec +#' @param x a WindowSpec. +#' @param col a column to partition on (desribed by the name or Column). +#' @param ... additional column(s) to partition on. +#' @return A WindowSpec. #' @rdname partitionBy #' @name partitionBy #' @aliases partitionBy,WindowSpec-method @@ -82,16 +85,18 @@ setMethod("partitionBy", } }) -#' orderBy +#' Ordering Columns in a WindowSpec #' #' Defines the ordering columns in a WindowSpec. -#' #' @param x a WindowSpec -#' @return a WindowSpec -#' @rdname arrange +#' @param col a character or Column indicating an ordering column +#' @param ... additional sorting fields +#' @return A WindowSpec. #' @name orderBy +#' @rdname orderBy #' @aliases orderBy,WindowSpec,character-method #' @family windowspec_method +#' @seealso See \link{arrange} for use in sorting a SparkDataFrame #' @export #' @examples #' \dontrun{ @@ -105,7 +110,7 @@ setMethod("orderBy", windowSpec(callJMethod(x@sws, "orderBy", col, list(...))) }) -#' @rdname arrange +#' @rdname orderBy #' @name orderBy #' @aliases orderBy,WindowSpec,Column-method #' @export @@ -121,11 +126,11 @@ setMethod("orderBy", #' rowsBetween #' -#' Defines the frame boundaries, from `start` (inclusive) to `end` (inclusive). -#' -#' Both `start` and `end` are relative positions from the current row. For example, "0" means -#' "current row", while "-1" means the row before the current row, and "5" means the fifth row -#' after the current row. +#' Defines the frame boundaries, from \code{start} (inclusive) to \code{end} (inclusive). +#' +#' Both \code{start} and \code{end} are relative positions from the current row. For example, +#' "0" means "current row", while "-1" means the row before the current row, and "5" means the +#' fifth row after the current row. #' #' @param x a WindowSpec #' @param start boundary start, inclusive. @@ -153,12 +158,12 @@ setMethod("rowsBetween", #' rangeBetween #' -#' Defines the frame boundaries, from `start` (inclusive) to `end` (inclusive). -#' -#' Both `start` and `end` are relative from the current row. For example, "0" means "current row", -#' while "-1" means one off before the current row, and "5" means the five off after the -#' current row. - +#' Defines the frame boundaries, from \code{start} (inclusive) to \code{end} (inclusive). +#' +#' Both \code{start} and \code{end} are relative from the current row. For example, "0" means +#' "current row", while "-1" means one off before the current row, and "5" means the five off +#' after the current row. +#' #' @param x a WindowSpec #' @param start boundary start, inclusive. #' The frame is unbounded if this is the minimum long value. @@ -188,13 +193,28 @@ setMethod("rangeBetween", #' over #' -#' Define a windowing column. +#' Define a windowing column. #' +#' @param x a Column, usually one returned by window function(s). +#' @param window a WindowSpec object. Can be created by \code{windowPartitionBy} or +#' \code{windowOrderBy} and configured by other WindowSpec methods. #' @rdname over #' @name over #' @aliases over,Column,WindowSpec-method #' @family colum_func #' @export +#' @examples \dontrun{ +#' df <- createDataFrame(mtcars) +#' +#' # Partition by am (transmission) and order by hp (horsepower) +#' ws <- orderBy(windowPartitionBy("am"), "hp") +#' +#' # Rank on hp within each partition +#' out <- select(df, over(rank(), ws), df$hp, df$am) +#' +#' # Lag mpg values by 1 row on the partition-and-ordered table +#' out <- select(df, over(lead(df$mpg), ws), df$mpg, df$hp, df$am) +#' } #' @note over since 2.0.0 setMethod("over", signature(x = "Column", window = "WindowSpec"), diff --git a/R/pkg/R/column.R b/R/pkg/R/column.R index 0edb9d2ae5c4..539d91b0f879 100644 --- a/R/pkg/R/column.R +++ b/R/pkg/R/column.R @@ -163,8 +163,9 @@ setMethod("alias", #' @family colum_func #' @aliases substr,Column-method #' -#' @param start starting position -#' @param stop ending position +#' @param x a Column. +#' @param start starting position. +#' @param stop ending position. #' @note substr since 1.4.0 setMethod("substr", signature(x = "Column"), function(x, start, stop) { @@ -219,6 +220,7 @@ setMethod("endsWith", signature(x = "Column"), #' @family colum_func #' @aliases between,Column-method #' +#' @param x a Column #' @param bounds lower and upper bounds #' @note between since 1.5.0 setMethod("between", signature(x = "Column"), @@ -233,6 +235,11 @@ setMethod("between", signature(x = "Column"), #' Casts the column to a different data type. #' +#' @param x a Column. +#' @param dataType a character object describing the target data type. +#' See +#' \href{https://spark.apache.org/docs/latest/sparkr.html#data-type-mapping-between-r-and-spark}{ +#' Spark Data Types} for available data types. #' @rdname cast #' @name cast #' @family colum_func @@ -254,10 +261,12 @@ setMethod("cast", #' Match a column with given values. #' +#' @param x a Column. +#' @param table a collection of values (coercible to list) to compare with. #' @rdname match #' @name %in% #' @aliases %in%,Column-method -#' @return a matched values as a result of comparing with given values. +#' @return A matched values as a result of comparing with given values. #' @export #' @examples #' \dontrun{ @@ -275,8 +284,11 @@ setMethod("%in%", #' otherwise #' #' If values in the specified column are null, returns the value. -#' Can be used in conjunction with `when` to specify a default value for expressions. +#' Can be used in conjunction with \code{when} to specify a default value for expressions. #' +#' @param x a Column. +#' @param value value to replace when the corresponding entry in \code{x} is NA. +#' Can be a single value or a Column. #' @rdname otherwise #' @name otherwise #' @family colum_func diff --git a/R/pkg/R/context.R b/R/pkg/R/context.R index 2538bb25073e..13ade49eabfa 100644 --- a/R/pkg/R/context.R +++ b/R/pkg/R/context.R @@ -267,7 +267,7 @@ spark.lapply <- function(list, func) { sc <- getSparkContext() rdd <- parallelize(sc, list, length(list)) results <- map(rdd, func) - local <- collect(results) + local <- collectRDD(results) local } diff --git a/R/pkg/R/functions.R b/R/pkg/R/functions.R index 573c915a5c67..4d94b4cd05d4 100644 --- a/R/pkg/R/functions.R +++ b/R/pkg/R/functions.R @@ -23,6 +23,7 @@ NULL #' A new \linkS4class{Column} is created to represent the literal value. #' If the parameter is a \linkS4class{Column}, it is returned unchanged. #' +#' @param x a literal value or a Column. #' @family normal_funcs #' @rdname lit #' @name lit @@ -89,8 +90,6 @@ setMethod("acos", #' Returns the approximate number of distinct items in a group. This is a column #' aggregate function. #' -#' @param x Column to compute on. -#' #' @rdname approxCountDistinct #' @name approxCountDistinct #' @return the approximate number of distinct items in a group. @@ -171,8 +170,6 @@ setMethod("atan", #' #' Aggregate function: returns the average of the values in a group. #' -#' @param x Column to compute on. -#' #' @rdname avg #' @name avg #' @family agg_funcs @@ -319,7 +316,7 @@ setMethod("column", #' #' Computes the Pearson Correlation Coefficient for two Columns. #' -#' @param x Column to compute on. +#' @param col2 a (second) Column. #' #' @rdname corr #' @name corr @@ -339,8 +336,6 @@ setMethod("corr", signature(x = "Column"), #' #' Compute the sample covariance between two expressions. #' -#' @param x Column to compute on. -#' #' @rdname cov #' @name cov #' @family math_funcs @@ -362,8 +357,8 @@ setMethod("cov", signature(x = "characterOrColumn"), #' @rdname cov #' -#' @param col1 First column to compute cov_samp. -#' @param col2 Second column to compute cov_samp. +#' @param col1 the first Column. +#' @param col2 the second Column. #' @name covar_samp #' @aliases covar_samp,characterOrColumn,characterOrColumn-method #' @note covar_samp since 2.0.0 @@ -449,11 +444,10 @@ setMethod("cosh", #' Returns the number of items in a group #' -#' Returns the number of items in a group. This is a column aggregate function. -#' -#' @param x Column to compute on. +#' This can be used as a column aggregate function with \code{Column} as input, +#' and returns the number of items in a group. #' -#' @rdname nrow +#' @rdname count #' @name count #' @family agg_funcs #' @aliases count,Column-method @@ -493,6 +487,7 @@ setMethod("crc32", #' Calculates the hash code of given columns, and returns the result as a int column. #' #' @param x Column to compute on. +#' @param ... additional Column(s) to be included. #' #' @rdname hash #' @name hash @@ -663,7 +658,8 @@ setMethod("factorial", #' The function by default returns the first values it sees. It will return the first non-missing #' value it sees when na.rm is set to true. If all values are missing, then NA is returned. #' -#' @param x Column to compute on. +#' @param na.rm a logical value indicating whether NA values should be stripped +#' before the computation proceeds. #' #' @rdname first #' @name first @@ -832,7 +828,10 @@ setMethod("kurtosis", #' The function by default returns the last values it sees. It will return the last non-missing #' value it sees when na.rm is set to true. If all values are missing, then NA is returned. #' -#' @param x Column to compute on. +#' @param x column to compute on. +#' @param na.rm a logical value indicating whether NA values should be stripped +#' before the computation proceeds. +#' @param ... further arguments to be passed to or from other methods. #' #' @rdname last #' @name last @@ -1143,7 +1142,7 @@ setMethod("minute", #' @export #' @examples \dontrun{select(df, monotonically_increasing_id())} setMethod("monotonically_increasing_id", - signature(x = "missing"), + signature("missing"), function() { jc <- callJStatic("org.apache.spark.sql.functions", "monotonically_increasing_id") column(jc) @@ -1252,7 +1251,7 @@ setMethod("rint", #' round #' -#' Returns the value of the column `e` rounded to 0 decimal places using HALF_UP rounding mode. +#' Returns the value of the column \code{e} rounded to 0 decimal places using HALF_UP rounding mode. #' #' @param x Column to compute on. #' @@ -1272,13 +1271,16 @@ setMethod("round", #' bround #' -#' Returns the value of the column `e` rounded to `scale` decimal places using HALF_EVEN rounding -#' mode if `scale` >= 0 or at integral part when `scale` < 0. +#' Returns the value of the column \code{e} rounded to \code{scale} decimal places using HALF_EVEN rounding +#' mode if \code{scale} >= 0 or at integer part when \code{scale} < 0. #' Also known as Gaussian rounding or bankers' rounding that rounds to the nearest even number. #' bround(2.5, 0) = 2, bround(3.5, 0) = 4. #' #' @param x Column to compute on. -#' +#' @param scale round to \code{scale} digits to the right of the decimal point when \code{scale} > 0, +#' the nearest even number when \code{scale} = 0, and \code{scale} digits to the left +#' of the decimal point when \code{scale} < 0. +#' @param ... further arguments to be passed to or from other methods. #' @rdname bround #' @name bround #' @family math_funcs @@ -1319,7 +1321,7 @@ setMethod("rtrim", #' Aggregate function: alias for \link{stddev_samp} #' #' @param x Column to compute on. -#' +#' @param na.rm currently not used. #' @rdname sd #' @name sd #' @family agg_funcs @@ -1497,7 +1499,7 @@ setMethod("soundex", #' \dontrun{select(df, spark_partition_id())} #' @note spark_partition_id since 2.0.0 setMethod("spark_partition_id", - signature(x = "missing"), + signature("missing"), function() { jc <- callJStatic("org.apache.spark.sql.functions", "spark_partition_id") column(jc) @@ -1560,7 +1562,8 @@ setMethod("stddev_samp", #' #' Creates a new struct column that composes multiple input columns. #' -#' @param x Column to compute on. +#' @param x a column to compute on. +#' @param ... optional column(s) to be included. #' #' @rdname struct #' @name struct @@ -1831,8 +1834,8 @@ setMethod("upper", #' #' Aggregate function: alias for \link{var_samp}. #' -#' @param x Column to compute on. -#' +#' @param x a Column to compute on. +#' @param y,na.rm,use currently not used. #' @rdname var #' @name var #' @family agg_funcs @@ -1972,7 +1975,7 @@ setMethod("atan2", signature(y = "Column"), #' datediff #' -#' Returns the number of days from `start` to `end`. +#' Returns the number of days from \code{start} to \code{end}. #' #' @param x start Column to use. #' @param y end Column to use. @@ -2041,7 +2044,7 @@ setMethod("levenshtein", signature(y = "Column"), #' months_between #' -#' Returns number of months between dates `date1` and `date2`. +#' Returns number of months between dates \code{date1} and \code{date2}. #' #' @param x start Column to use. #' @param y end Column to use. @@ -2114,7 +2117,9 @@ setMethod("pmod", signature(y = "Column"), #' @rdname approxCountDistinct #' @name approxCountDistinct #' +#' @param x Column to compute on. #' @param rsd maximum estimation error allowed (default = 0.05) +#' @param ... further arguments to be passed to or from other methods. #' #' @aliases approxCountDistinct,Column-method #' @export @@ -2127,7 +2132,7 @@ setMethod("approxCountDistinct", column(jc) }) -#' Count Distinct +#' Count Distinct Values #' #' @param x Column to compute on #' @param ... other columns @@ -2156,7 +2161,7 @@ setMethod("countDistinct", #' concat #' #' Concatenates multiple input string columns together into a single string column. -#' +#' #' @param x Column to compute on #' @param ... other columns #' @@ -2246,7 +2251,6 @@ setMethod("ceiling", }) #' @rdname sign -#' @param x Column to compute on #' #' @name sign #' @aliases sign,Column-method @@ -2262,9 +2266,6 @@ setMethod("sign", signature(x = "Column"), #' #' Aggregate function: returns the number of distinct items in a group. #' -#' @param x Column to compute on -#' @param ... other columns -#' #' @rdname countDistinct #' @name n_distinct #' @aliases n_distinct,Column-method @@ -2276,9 +2277,7 @@ setMethod("n_distinct", signature(x = "Column"), countDistinct(x, ...) }) -#' @rdname nrow -#' @param x Column to compute on -#' +#' @rdname count #' @name n #' @aliases n,Column-method #' @export @@ -2300,8 +2299,8 @@ setMethod("n", signature(x = "Column"), #' NOTE: Use when ever possible specialized functions like \code{year}. These benefit from a #' specialized implementation. #' -#' @param y Column to compute on -#' @param x date format specification +#' @param y Column to compute on. +#' @param x date format specification. #' #' @family datetime_funcs #' @rdname date_format @@ -2320,8 +2319,8 @@ setMethod("date_format", signature(y = "Column", x = "character"), #' #' Assumes given timestamp is UTC and converts to given timezone. #' -#' @param y Column to compute on -#' @param x time zone to use +#' @param y Column to compute on. +#' @param x time zone to use. #' #' @family datetime_funcs #' @rdname from_utc_timestamp @@ -2370,8 +2369,8 @@ setMethod("instr", signature(y = "Column", x = "character"), #' Day of the week parameter is case insensitive, and accepts first three or two characters: #' "Mon", "Tue", "Wed", "Thu", "Fri", "Sat", "Sun". #' -#' @param y Column to compute on -#' @param x Day of the week string +#' @param y Column to compute on. +#' @param x Day of the week string. #' #' @family datetime_funcs #' @rdname next_day @@ -2432,7 +2431,7 @@ setMethod("add_months", signature(y = "Column", x = "numeric"), #' date_add #' -#' Returns the date that is `days` days after `start` +#' Returns the date that is \code{x} days after #' #' @param y Column to compute on #' @param x Number of days to add @@ -2452,7 +2451,7 @@ setMethod("date_add", signature(y = "Column", x = "numeric"), #' date_sub #' -#' Returns the date that is `days` days before `start` +#' Returns the date that is \code{x} days before #' #' @param y Column to compute on #' @param x Number of days to substract @@ -2637,6 +2636,7 @@ setMethod("conv", signature(x = "Column", fromBase = "numeric", toBase = "numeri #' Parses the expression string into the column that it represents, similar to #' SparkDataFrame.selectExpr #' +#' @param x an expression character object to be parsed. #' @family normal_funcs #' @rdname expr #' @aliases expr,character-method @@ -2654,6 +2654,9 @@ setMethod("expr", signature(x = "character"), #' #' Formats the arguments in printf-style and returns the result as a string column. #' +#' @param format a character object of format strings. +#' @param x a Column. +#' @param ... additional Column(s). #' @family string_funcs #' @rdname format_string #' @name format_string @@ -2676,6 +2679,11 @@ setMethod("format_string", signature(format = "character", x = "Column"), #' representing the timestamp of that moment in the current system time zone in the given #' format. #' +#' @param x a Column of unix timestamp. +#' @param format the target format. See +#' \href{http://docs.oracle.com/javase/tutorial/i18n/format/simpleDateFormat.html}{ +#' Customizing Formats} for available options. +#' @param ... further arguments to be passed to or from other methods. #' @family datetime_funcs #' @rdname from_unixtime #' @name from_unixtime @@ -2702,19 +2710,25 @@ setMethod("from_unixtime", signature(x = "Column"), #' [12:05,12:10) but not in [12:00,12:05). Windows can support microsecond precision. Windows in #' the order of months are not supported. #' -#' The time column must be of TimestampType. -#' -#' Durations are provided as strings, e.g. '1 second', '1 day 12 hours', '2 minutes'. Valid -#' interval strings are 'week', 'day', 'hour', 'minute', 'second', 'millisecond', 'microsecond'. -#' If the `slideDuration` is not provided, the windows will be tumbling windows. -#' -#' The startTime is the offset with respect to 1970-01-01 00:00:00 UTC with which to start -#' window intervals. For example, in order to have hourly tumbling windows that start 15 minutes -#' past the hour, e.g. 12:15-13:15, 13:15-14:15... provide `startTime` as `15 minutes`. -#' -#' The output column will be a struct called 'window' by default with the nested columns 'start' -#' and 'end'. -#' +#' @param x a time Column. Must be of TimestampType. +#' @param windowDuration a string specifying the width of the window, e.g. '1 second', +#' '1 day 12 hours', '2 minutes'. Valid interval strings are 'week', +#' 'day', 'hour', 'minute', 'second', 'millisecond', 'microsecond'. Note that +#' the duration is a fixed length of time, and does not vary over time +#' according to a calendar. For example, '1 day' always means 86,400,000 +#' milliseconds, not a calendar day. +#' @param slideDuration a string specifying the sliding interval of the window. Same format as +#' \code{windowDuration}. A new window will be generated every +#' \code{slideDuration}. Must be less than or equal to +#' the \code{windowDuration}. This duration is likewise absolute, and does not +#' vary according to a calendar. +#' @param startTime the offset with respect to 1970-01-01 00:00:00 UTC with which to start +#' window intervals. For example, in order to have hourly tumbling windows +#' that start 15 minutes past the hour, e.g. 12:15-13:15, 13:15-14:15... provide +#' \code{startTime} as \code{"15 minutes"}. +#' @param ... further arguments to be passed to or from other methods. +#' @return An output column of struct called 'window' by default with the nested columns 'start' +#' and 'end'. #' @family datetime_funcs #' @rdname window #' @name window @@ -2766,6 +2780,10 @@ setMethod("window", signature(x = "Column"), #' NOTE: The position is not zero based, but 1 based index, returns 0 if substr #' could not be found in str. #' +#' @param substr a character string to be matched. +#' @param str a Column where matches are sought for each entry. +#' @param pos start position of search. +#' @param ... further arguments to be passed to or from other methods. #' @family string_funcs #' @rdname locate #' @aliases locate,character,Column-method @@ -2785,6 +2803,9 @@ setMethod("locate", signature(substr = "character", str = "Column"), #' #' Left-pad the string column with #' +#' @param x the string Column to be left-padded. +#' @param len maximum length of each output result. +#' @param pad a character string to be padded with. #' @family string_funcs #' @rdname lpad #' @aliases lpad,Column,numeric,character-method @@ -2804,6 +2825,7 @@ setMethod("lpad", signature(x = "Column", len = "numeric", pad = "character"), #' #' Generate a random column with i.i.d. samples from U[0.0, 1.0]. #' +#' @param seed a random seed. Can be missing. #' @family normal_funcs #' @rdname rand #' @name rand @@ -2832,6 +2854,7 @@ setMethod("rand", signature(seed = "numeric"), #' #' Generate a column with i.i.d. samples from the standard normal distribution. #' +#' @param seed a random seed. Can be missing. #' @family normal_funcs #' @rdname randn #' @name randn @@ -2858,8 +2881,12 @@ setMethod("randn", signature(seed = "numeric"), #' regexp_extract #' -#' Extract a specific(idx) group identified by a java regex, from the specified string column. +#' Extract a specific \code{idx} group identified by a Java regex, from the specified string column. +#' If the regex did not match, or the specified group did not match, an empty string is returned. #' +#' @param x a string Column. +#' @param pattern a regular expression. +#' @param idx a group index. #' @family string_funcs #' @rdname regexp_extract #' @name regexp_extract @@ -2880,6 +2907,9 @@ setMethod("regexp_extract", #' #' Replace all substrings of the specified string value that match regexp with rep. #' +#' @param x a string Column. +#' @param pattern a regular expression. +#' @param replacement a character string that a matched \code{pattern} is replaced with. #' @family string_funcs #' @rdname regexp_replace #' @name regexp_replace @@ -2900,6 +2930,9 @@ setMethod("regexp_replace", #' #' Right-padded with pad to a length of len. #' +#' @param x the string Column to be right-padded. +#' @param len maximum length of each output result. +#' @param pad a character string to be padded with. #' @family string_funcs #' @rdname rpad #' @name rpad @@ -2922,6 +2955,11 @@ setMethod("rpad", signature(x = "Column", len = "numeric", pad = "character"), #' returned. If count is negative, every to the right of the final delimiter (counting from the #' right) is returned. substring_index performs a case-sensitive match when searching for delim. #' +#' @param x a Column. +#' @param delim a delimiter string. +#' @param count number of occurrences of \code{delim} before the substring is returned. +#' A positive number means counting from the left, while negative means +#' counting from the right. #' @family string_funcs #' @rdname substring_index #' @aliases substring_index,Column,character,numeric-method @@ -2949,6 +2987,11 @@ setMethod("substring_index", #' The translate will happen when any character in the string matching with the character #' in the matchingString. #' +#' @param x a string Column. +#' @param matchingString a source string where each character will be translated. +#' @param replaceString a target string where each \code{matchingString} character will +#' be replaced by the character in \code{replaceString} +#' at the same location, if any. #' @family string_funcs #' @rdname translate #' @name translate @@ -2997,6 +3040,10 @@ setMethod("unix_timestamp", signature(x = "Column", format = "missing"), column(jc) }) +#' @param x a Column of date, in string, date or timestamp type. +#' @param format the target format. See +#' \href{http://docs.oracle.com/javase/tutorial/i18n/format/simpleDateFormat.html}{ +#' Customizing Formats} for available options. #' @rdname unix_timestamp #' @name unix_timestamp #' @aliases unix_timestamp,Column,character-method @@ -3012,6 +3059,8 @@ setMethod("unix_timestamp", signature(x = "Column", format = "character"), #' Evaluates a list of conditions and returns one of multiple possible result expressions. #' For unmatched expressions null is returned. #' +#' @param condition the condition to test on. Must be a Column expression. +#' @param value result expression. #' @family normal_funcs #' @rdname when #' @name when @@ -3033,6 +3082,9 @@ setMethod("when", signature(condition = "Column", value = "ANY"), #' Evaluates a list of conditions and returns \code{yes} if the conditions are satisfied. #' Otherwise \code{no} is returned for unmatched conditions. #' +#' @param test a Column expression that describes the condition. +#' @param yes return values for \code{TRUE} elements of test. +#' @param no return values for \code{FALSE} elements of test. #' @family normal_funcs #' @rdname ifelse #' @name ifelse @@ -3067,17 +3119,21 @@ setMethod("ifelse", #' N = total number of rows in the partition #' cume_dist(x) = number of values before (and including) x / N #' -#' This is equivalent to the CUME_DIST function in SQL. +#' This is equivalent to the \code{CUME_DIST} function in SQL. #' #' @rdname cume_dist #' @name cume_dist #' @family window_funcs #' @aliases cume_dist,missing-method #' @export -#' @examples \dontrun{cume_dist()} +#' @examples \dontrun{ +#' df <- createDataFrame(mtcars) +#' ws <- orderBy(windowPartitionBy("am"), "hp") +#' out <- select(df, over(cume_dist(), ws), df$hp, df$am) +#' } #' @note cume_dist since 1.6.0 setMethod("cume_dist", - signature(x = "missing"), + signature("missing"), function() { jc <- callJStatic("org.apache.spark.sql.functions", "cume_dist") column(jc) @@ -3091,17 +3147,21 @@ setMethod("cume_dist", #' and had three people tie for second place, you would say that all three were in second #' place and that the next person came in third. #' -#' This is equivalent to the DENSE_RANK function in SQL. +#' This is equivalent to the \code{DENSE_RANK} function in SQL. #' #' @rdname dense_rank #' @name dense_rank #' @family window_funcs #' @aliases dense_rank,missing-method #' @export -#' @examples \dontrun{dense_rank()} +#' @examples \dontrun{ +#' df <- createDataFrame(mtcars) +#' ws <- orderBy(windowPartitionBy("am"), "hp") +#' out <- select(df, over(dense_rank(), ws), df$hp, df$am) +#' } #' @note dense_rank since 1.6.0 setMethod("dense_rank", - signature(x = "missing"), + signature("missing"), function() { jc <- callJStatic("org.apache.spark.sql.functions", "dense_rank") column(jc) @@ -3109,22 +3169,35 @@ setMethod("dense_rank", #' lag #' -#' Window function: returns the value that is `offset` rows before the current row, and -#' `defaultValue` if there is less than `offset` rows before the current row. For example, -#' an `offset` of one will return the previous row at any given point in the window partition. +#' Window function: returns the value that is \code{offset} rows before the current row, and +#' \code{defaultValue} if there is less than \code{offset} rows before the current row. For example, +#' an \code{offset} of one will return the previous row at any given point in the window partition. #' -#' This is equivalent to the LAG function in SQL. +#' This is equivalent to the \code{LAG} function in SQL. #' +#' @param x the column as a character string or a Column to compute on. +#' @param offset the number of rows back from the current row from which to obtain a value. +#' If not specified, the default is 1. +#' @param defaultValue (optional) default to use when the offset row does not exist. +#' @param ... further arguments to be passed to or from other methods. #' @rdname lag #' @name lag #' @aliases lag,characterOrColumn-method #' @family window_funcs #' @export -#' @examples \dontrun{lag(df$c)} +#' @examples \dontrun{ +#' df <- createDataFrame(mtcars) +#' +#' # Partition by am (transmission) and order by hp (horsepower) +#' ws <- orderBy(windowPartitionBy("am"), "hp") +#' +#' # Lag mpg values by 1 row on the partition-and-ordered table +#' out <- select(df, over(lag(df$mpg), ws), df$mpg, df$hp, df$am) +#' } #' @note lag since 1.6.0 setMethod("lag", signature(x = "characterOrColumn"), - function(x, offset, defaultValue = NULL) { + function(x, offset = 1, defaultValue = NULL) { col <- if (class(x) == "Column") { x@jc } else { @@ -3138,26 +3211,36 @@ setMethod("lag", #' lead #' -#' Window function: returns the value that is `offset` rows after the current row, and -#' `null` if there is less than `offset` rows after the current row. For example, -#' an `offset` of one will return the next row at any given point in the window partition. +#' Window function: returns the value that is \code{offset} rows after the current row, and +#' \code{defaultValue} if there is less than \code{offset} rows after the current row. +#' For example, an \code{offset} of one will return the next row at any given point +#' in the window partition. #' -#' This is equivalent to the LEAD function in SQL. -#' -#' @param x Column to compute on -#' @param offset Number of rows to offset -#' @param defaultValue (Optional) default value to use +#' This is equivalent to the \code{LEAD} function in SQL. +#' +#' @param x the column as a character string or a Column to compute on. +#' @param offset the number of rows after the current row from which to obtain a value. +#' If not specified, the default is 1. +#' @param defaultValue (optional) default to use when the offset row does not exist. #' #' @rdname lead #' @name lead #' @family window_funcs #' @aliases lead,characterOrColumn,numeric-method #' @export -#' @examples \dontrun{lead(df$c)} +#' @examples \dontrun{ +#' df <- createDataFrame(mtcars) +#' +#' # Partition by am (transmission) and order by hp (horsepower) +#' ws <- orderBy(windowPartitionBy("am"), "hp") +#' +#' # Lead mpg values by 1 row on the partition-and-ordered table +#' out <- select(df, over(lead(df$mpg), ws), df$mpg, df$hp, df$am) +#' } #' @note lead since 1.6.0 setMethod("lead", signature(x = "characterOrColumn", offset = "numeric", defaultValue = "ANY"), - function(x, offset, defaultValue = NULL) { + function(x, offset = 1, defaultValue = NULL) { col <- if (class(x) == "Column") { x@jc } else { @@ -3171,11 +3254,11 @@ setMethod("lead", #' ntile #' -#' Window function: returns the ntile group id (from 1 to `n` inclusive) in an ordered window -#' partition. For example, if `n` is 4, the first quarter of the rows will get value 1, the second +#' Window function: returns the ntile group id (from 1 to n inclusive) in an ordered window +#' partition. For example, if n is 4, the first quarter of the rows will get value 1, the second #' quarter will get 2, the third quarter will get 3, and the last quarter will get 4. #' -#' This is equivalent to the NTILE function in SQL. +#' This is equivalent to the \code{NTILE} function in SQL. #' #' @param x Number of ntile groups #' @@ -3184,7 +3267,15 @@ setMethod("lead", #' @aliases ntile,numeric-method #' @family window_funcs #' @export -#' @examples \dontrun{ntile(1)} +#' @examples \dontrun{ +#' df <- createDataFrame(mtcars) +#' +#' # Partition by am (transmission) and order by hp (horsepower) +#' ws <- orderBy(windowPartitionBy("am"), "hp") +#' +#' # Get ntile group id (1-4) for hp +#' out <- select(df, over(ntile(4), ws), df$hp, df$am) +#' } #' @note ntile since 1.6.0 setMethod("ntile", signature(x = "numeric"), @@ -3208,10 +3299,14 @@ setMethod("ntile", #' @family window_funcs #' @aliases percent_rank,missing-method #' @export -#' @examples \dontrun{percent_rank()} +#' @examples \dontrun{ +#' df <- createDataFrame(mtcars) +#' ws <- orderBy(windowPartitionBy("am"), "hp") +#' out <- select(df, over(percent_rank(), ws), df$hp, df$am) +#' } #' @note percent_rank since 1.6.0 setMethod("percent_rank", - signature(x = "missing"), + signature("missing"), function() { jc <- callJStatic("org.apache.spark.sql.functions", "percent_rank") column(jc) @@ -3233,7 +3328,11 @@ setMethod("percent_rank", #' @family window_funcs #' @aliases rank,missing-method #' @export -#' @examples \dontrun{rank()} +#' @examples \dontrun{ +#' df <- createDataFrame(mtcars) +#' ws <- orderBy(windowPartitionBy("am"), "hp") +#' out <- select(df, over(rank(), ws), df$hp, df$am) +#' } #' @note rank since 1.6.0 setMethod("rank", signature(x = "missing"), @@ -3243,6 +3342,8 @@ setMethod("rank", }) # Expose rank() in the R base package +#' @param x a numeric, complex, character or logical vector. +#' @param ... additional argument(s) passed to the method. #' @name rank #' @rdname rank #' @aliases rank,ANY-method @@ -3264,10 +3365,14 @@ setMethod("rank", #' @aliases row_number,missing-method #' @family window_funcs #' @export -#' @examples \dontrun{row_number()} +#' @examples \dontrun{ +#' df <- createDataFrame(mtcars) +#' ws <- orderBy(windowPartitionBy("am"), "hp") +#' out <- select(df, over(row_number(), ws), df$hp, df$am) +#' } #' @note row_number since 1.6.0 setMethod("row_number", - signature(x = "missing"), + signature("missing"), function() { jc <- callJStatic("org.apache.spark.sql.functions", "row_number") column(jc) @@ -3318,7 +3423,7 @@ setMethod("explode", #' size #' #' Returns length of array or map. -#' +#' #' @param x Column to compute on #' #' @rdname size diff --git a/R/pkg/R/generics.R b/R/pkg/R/generics.R index e7444ac2467d..b54a92a3c6dd 100644 --- a/R/pkg/R/generics.R +++ b/R/pkg/R/generics.R @@ -23,9 +23,7 @@ setGeneric("aggregateRDD", function(x, zeroValue, seqOp, combOp) { standardGeneric("aggregateRDD") }) -# @rdname cache-methods -# @export -setGeneric("cache", function(x) { standardGeneric("cache") }) +setGeneric("cacheRDD", function(x) { standardGeneric("cacheRDD") }) # @rdname coalesce # @seealso repartition @@ -36,9 +34,7 @@ setGeneric("coalesce", function(x, numPartitions, ...) { standardGeneric("coales # @export setGeneric("checkpoint", function(x) { standardGeneric("checkpoint") }) -# @rdname collect-methods -# @export -setGeneric("collect", function(x, ...) { standardGeneric("collect") }) +setGeneric("collectRDD", function(x, ...) { standardGeneric("collectRDD") }) # @rdname collect-methods # @export @@ -51,9 +47,9 @@ setGeneric("collectPartition", standardGeneric("collectPartition") }) -# @rdname nrow -# @export -setGeneric("count", function(x) { standardGeneric("count") }) +setGeneric("countRDD", function(x) { standardGeneric("countRDD") }) + +setGeneric("lengthRDD", function(x) { standardGeneric("lengthRDD") }) # @rdname countByValue # @export @@ -74,17 +70,13 @@ setGeneric("approxQuantile", standardGeneric("approxQuantile") }) -# @rdname distinct -# @export -setGeneric("distinct", function(x, numPartitions = 1) { standardGeneric("distinct") }) +setGeneric("distinctRDD", function(x, numPartitions = 1) { standardGeneric("distinctRDD") }) # @rdname filterRDD # @export setGeneric("filterRDD", function(x, f) { standardGeneric("filterRDD") }) -# @rdname first -# @export -setGeneric("first", function(x, ...) { standardGeneric("first") }) +setGeneric("firstRDD", function(x, ...) { standardGeneric("firstRDD") }) # @rdname flatMap # @export @@ -110,6 +102,8 @@ setGeneric("glom", function(x) { standardGeneric("glom") }) # @export setGeneric("histogram", function(df, col, nbins=10) { standardGeneric("histogram") }) +setGeneric("joinRDD", function(x, y, ...) { standardGeneric("joinRDD") }) + # @rdname keyBy # @export setGeneric("keyBy", function(x, func) { standardGeneric("keyBy") }) @@ -152,9 +146,7 @@ setGeneric("getNumPartitions", function(x) { standardGeneric("getNumPartitions") # @export setGeneric("numPartitions", function(x) { standardGeneric("numPartitions") }) -# @rdname persist -# @export -setGeneric("persist", function(x, newLevel) { standardGeneric("persist") }) +setGeneric("persistRDD", function(x, newLevel) { standardGeneric("persistRDD") }) # @rdname pipeRDD # @export @@ -168,10 +160,7 @@ setGeneric("pivot", function(x, colname, values = list()) { standardGeneric("piv # @export setGeneric("reduce", function(x, func) { standardGeneric("reduce") }) -# @rdname repartition -# @seealso coalesce -# @export -setGeneric("repartition", function(x, ...) { standardGeneric("repartition") }) +setGeneric("repartitionRDD", function(x, ...) { standardGeneric("repartitionRDD") }) # @rdname sampleRDD # @export @@ -193,6 +182,8 @@ setGeneric("saveAsTextFile", function(x, path) { standardGeneric("saveAsTextFile # @export setGeneric("setName", function(x, name) { standardGeneric("setName") }) +setGeneric("showRDD", function(object, ...) { standardGeneric("showRDD") }) + # @rdname sortBy # @export setGeneric("sortBy", @@ -200,9 +191,7 @@ setGeneric("sortBy", standardGeneric("sortBy") }) -# @rdname take -# @export -setGeneric("take", function(x, num) { standardGeneric("take") }) +setGeneric("takeRDD", function(x, num) { standardGeneric("takeRDD") }) # @rdname takeOrdered # @export @@ -223,9 +212,7 @@ setGeneric("top", function(x, num) { standardGeneric("top") }) # @export setGeneric("unionRDD", function(x, y) { standardGeneric("unionRDD") }) -# @rdname unpersist-methods -# @export -setGeneric("unpersist", function(x, ...) { standardGeneric("unpersist") }) +setGeneric("unpersistRDD", function(x, ...) { standardGeneric("unpersistRDD") }) # @rdname zipRDD # @export @@ -343,9 +330,7 @@ setGeneric("join", function(x, y, ...) { standardGeneric("join") }) # @export setGeneric("leftOuterJoin", function(x, y, numPartitions) { standardGeneric("leftOuterJoin") }) -#' @rdname partitionBy -#' @export -setGeneric("partitionBy", function(x, ...) { standardGeneric("partitionBy") }) +setGeneric("partitionByRDD", function(x, ...) { standardGeneric("partitionByRDD") }) # @rdname reduceByKey # @seealso groupByKey @@ -395,6 +380,9 @@ setGeneric("value", function(bcast) { standardGeneric("value") }) #################### SparkDataFrame Methods ######################## +#' @param x a SparkDataFrame or GroupedData. +#' @param ... further arguments to be passed to or from other methods. +#' @return A SparkDataFrame. #' @rdname summarize #' @export setGeneric("agg", function (x, ...) { standardGeneric("agg") }) @@ -414,6 +402,16 @@ setGeneric("as.data.frame", #' @export setGeneric("attach") +#' @rdname cache +#' @export +setGeneric("cache", function(x) { standardGeneric("cache") }) + +#' @rdname collect +#' @export +setGeneric("collect", function(x, ...) { standardGeneric("collect") }) + +#' @param do.NULL currently not used. +#' @param prefix currently not used. #' @rdname columns #' @export setGeneric("colnames", function(x, do.NULL = TRUE, prefix = "col") { standardGeneric("colnames") }) @@ -434,11 +432,24 @@ setGeneric("coltypes<-", function(x, value) { standardGeneric("coltypes<-") }) #' @export setGeneric("columns", function(x) {standardGeneric("columns") }) +#' @param x a GroupedData or Column. +#' @rdname count +#' @export +setGeneric("count", function(x) { standardGeneric("count") }) + #' @rdname cov +#' @param x a Column or a SparkDataFrame. +#' @param ... additional argument(s). If \code{x} is a Column, a Column +#' should be provided. If \code{x} is a SparkDataFrame, two column names should +#' be provided. #' @export setGeneric("cov", function(x, ...) {standardGeneric("cov") }) #' @rdname corr +#' @param x a Column or a SparkDataFrame. +#' @param ... additional argument(s). If \code{x} is a Column, a Column +#' should be provided. If \code{x} is a SparkDataFrame, two column names should +#' be provided. #' @export setGeneric("corr", function(x, ...) {standardGeneric("corr") }) @@ -465,10 +476,14 @@ setGeneric("dapply", function(x, func, schema) { standardGeneric("dapply") }) #' @export setGeneric("dapplyCollect", function(x, func) { standardGeneric("dapplyCollect") }) +#' @param x a SparkDataFrame or GroupedData. +#' @param ... additional argument(s) passed to the method. #' @rdname gapply #' @export setGeneric("gapply", function(x, ...) { standardGeneric("gapply") }) +#' @param x a SparkDataFrame or GroupedData. +#' @param ... additional argument(s) passed to the method. #' @rdname gapplyCollect #' @export setGeneric("gapplyCollect", function(x, ...) { standardGeneric("gapplyCollect") }) @@ -477,6 +492,10 @@ setGeneric("gapplyCollect", function(x, ...) { standardGeneric("gapplyCollect") #' @export setGeneric("describe", function(x, col, ...) { standardGeneric("describe") }) +#' @rdname distinct +#' @export +setGeneric("distinct", function(x) { standardGeneric("distinct") }) + #' @rdname drop #' @export setGeneric("drop", function(x, ...) { standardGeneric("drop") }) @@ -519,6 +538,10 @@ setGeneric("fillna", function(x, value, cols = NULL) { standardGeneric("fillna") #' @export setGeneric("filter", function(x, condition) { standardGeneric("filter") }) +#' @rdname first +#' @export +setGeneric("first", function(x, ...) { standardGeneric("first") }) + #' @rdname groupBy #' @export setGeneric("group_by", function(x, ...) { standardGeneric("group_by") }) @@ -551,21 +574,29 @@ setGeneric("merge") #' @export setGeneric("mutate", function(.data, ...) {standardGeneric("mutate") }) -#' @rdname arrange +#' @rdname orderBy #' @export setGeneric("orderBy", function(x, col, ...) { standardGeneric("orderBy") }) +#' @rdname persist +#' @export +setGeneric("persist", function(x, newLevel) { standardGeneric("persist") }) + #' @rdname printSchema #' @export setGeneric("printSchema", function(x) { standardGeneric("printSchema") }) +#' @rdname registerTempTable-deprecated +#' @export +setGeneric("registerTempTable", function(x, tableName) { standardGeneric("registerTempTable") }) + #' @rdname rename #' @export setGeneric("rename", function(x, ...) { standardGeneric("rename") }) -#' @rdname registerTempTable-deprecated +#' @rdname repartition #' @export -setGeneric("registerTempTable", function(x, tableName) { standardGeneric("registerTempTable") }) +setGeneric("repartition", function(x, ...) { standardGeneric("repartition") }) #' @rdname sample #' @export @@ -592,6 +623,10 @@ setGeneric("saveAsTable", function(df, tableName, source = NULL, mode = "error", #' @export setGeneric("str") +#' @rdname take +#' @export +setGeneric("take", function(x, num) { standardGeneric("take") }) + #' @rdname mutate #' @export setGeneric("transform", function(`_data`, ...) {standardGeneric("transform") }) @@ -650,8 +685,8 @@ setGeneric("selectExpr", function(x, expr, ...) { standardGeneric("selectExpr") #' @export setGeneric("showDF", function(x, ...) { standardGeneric("showDF") }) -# @rdname subset -# @export +#' @rdname subset +#' @export setGeneric("subset", function(x, ...) { standardGeneric("subset") }) #' @rdname summarize @@ -674,6 +709,10 @@ setGeneric("union", function(x, y) { standardGeneric("union") }) #' @export setGeneric("unionAll", function(x, y) { standardGeneric("unionAll") }) +#' @rdname unpersist-methods +#' @export +setGeneric("unpersist", function(x, ...) { standardGeneric("unpersist") }) + #' @rdname filter #' @export setGeneric("where", function(x, condition) { standardGeneric("where") }) @@ -714,6 +753,8 @@ setGeneric("between", function(x, bounds) { standardGeneric("between") }) setGeneric("cast", function(x, dataType) { standardGeneric("cast") }) #' @rdname columnfunctions +#' @param x a Column object. +#' @param ... additional argument(s). #' @export setGeneric("contains", function(x, ...) { standardGeneric("contains") }) @@ -771,6 +812,10 @@ setGeneric("over", function(x, window) { standardGeneric("over") }) ###################### WindowSpec Methods ########################## +#' @rdname partitionBy +#' @export +setGeneric("partitionBy", function(x, ...) { standardGeneric("partitionBy") }) + #' @rdname rowsBetween #' @export setGeneric("rowsBetween", function(x, start, end) { standardGeneric("rowsBetween") }) @@ -805,6 +850,8 @@ setGeneric("array_contains", function(x, value) { standardGeneric("array_contain #' @export setGeneric("ascii", function(x) { standardGeneric("ascii") }) +#' @param x Column to compute on or a GroupedData object. +#' @param ... additional argument(s) when \code{x} is a GroupedData object. #' @rdname avg #' @export setGeneric("avg", function(x, ...) { standardGeneric("avg") }) @@ -861,9 +908,10 @@ setGeneric("crc32", function(x) { standardGeneric("crc32") }) #' @export setGeneric("hash", function(x, ...) { standardGeneric("hash") }) +#' @param x empty. Should be used with no argument. #' @rdname cume_dist #' @export -setGeneric("cume_dist", function(x) { standardGeneric("cume_dist") }) +setGeneric("cume_dist", function(x = "missing") { standardGeneric("cume_dist") }) #' @rdname datediff #' @export @@ -893,9 +941,10 @@ setGeneric("dayofyear", function(x) { standardGeneric("dayofyear") }) #' @export setGeneric("decode", function(x, charset) { standardGeneric("decode") }) +#' @param x empty. Should be used with no argument. #' @rdname dense_rank #' @export -setGeneric("dense_rank", function(x) { standardGeneric("dense_rank") }) +setGeneric("dense_rank", function(x = "missing") { standardGeneric("dense_rank") }) #' @rdname encode #' @export @@ -1009,10 +1058,11 @@ setGeneric("md5", function(x) { standardGeneric("md5") }) #' @export setGeneric("minute", function(x) { standardGeneric("minute") }) +#' @param x empty. Should be used with no argument. #' @rdname monotonically_increasing_id #' @export setGeneric("monotonically_increasing_id", - function(x) { standardGeneric("monotonically_increasing_id") }) + function(x = "missing") { standardGeneric("monotonically_increasing_id") }) #' @rdname month #' @export @@ -1022,7 +1072,7 @@ setGeneric("month", function(x) { standardGeneric("month") }) #' @export setGeneric("months_between", function(y, x) { standardGeneric("months_between") }) -#' @rdname nrow +#' @rdname count #' @export setGeneric("n", function(x) { standardGeneric("n") }) @@ -1046,9 +1096,10 @@ setGeneric("ntile", function(x) { standardGeneric("ntile") }) #' @export setGeneric("n_distinct", function(x, ...) { standardGeneric("n_distinct") }) +#' @param x empty. Should be used with no argument. #' @rdname percent_rank #' @export -setGeneric("percent_rank", function(x) { standardGeneric("percent_rank") }) +setGeneric("percent_rank", function(x = "missing") { standardGeneric("percent_rank") }) #' @rdname pmod #' @export @@ -1089,11 +1140,12 @@ setGeneric("reverse", function(x) { standardGeneric("reverse") }) #' @rdname rint #' @export -setGeneric("rint", function(x, ...) { standardGeneric("rint") }) +setGeneric("rint", function(x) { standardGeneric("rint") }) +#' @param x empty. Should be used with no argument. #' @rdname row_number #' @export -setGeneric("row_number", function(x) { standardGeneric("row_number") }) +setGeneric("row_number", function(x = "missing") { standardGeneric("row_number") }) #' @rdname rpad #' @export @@ -1151,9 +1203,10 @@ setGeneric("sort_array", function(x, asc = TRUE) { standardGeneric("sort_array") #' @export setGeneric("soundex", function(x) { standardGeneric("soundex") }) +#' @param x empty. Should be used with no argument. #' @rdname spark_partition_id #' @export -setGeneric("spark_partition_id", function(x) { standardGeneric("spark_partition_id") }) +setGeneric("spark_partition_id", function(x = "missing") { standardGeneric("spark_partition_id") }) #' @rdname sd #' @export @@ -1251,10 +1304,16 @@ setGeneric("year", function(x) { standardGeneric("year") }) #' @export setGeneric("spark.glm", function(data, formula, ...) { standardGeneric("spark.glm") }) +#' @param x,y For \code{glm}: logical values indicating whether the response vector +#' and model matrix used in the fitting process should be returned as +#' components of the returned value. +#' @inheritParams stats::glm #' @rdname glm #' @export setGeneric("glm") +#' @param object a fitted ML model object. +#' @param ... additional argument(s) passed to the method. #' @rdname predict #' @export setGeneric("predict", function(object, ...) { standardGeneric("predict") }) @@ -1277,8 +1336,11 @@ setGeneric("spark.naiveBayes", function(data, formula, ...) { standardGeneric("s #' @rdname spark.survreg #' @export -setGeneric("spark.survreg", function(data, formula, ...) { standardGeneric("spark.survreg") }) +setGeneric("spark.survreg", function(data, formula) { standardGeneric("spark.survreg") }) +#' @param object a fitted ML model object. +#' @param path the directory where the model is saved. +#' @param ... additional argument(s) passed to the method. #' @rdname write.ml #' @export setGeneric("write.ml", function(object, path, ...) { standardGeneric("write.ml") }) diff --git a/R/pkg/R/group.R b/R/pkg/R/group.R index 85348ae76baa..17f5283abead 100644 --- a/R/pkg/R/group.R +++ b/R/pkg/R/group.R @@ -48,6 +48,7 @@ groupedData <- function(sgd) { #' @rdname show #' @aliases show,GroupedData-method +#' @export #' @note show(GroupedData) since 1.4.0 setMethod("show", "GroupedData", function(object) { @@ -56,11 +57,10 @@ setMethod("show", "GroupedData", #' Count #' -#' Count the number of rows for each group. +#' Count the number of rows for each group when we have \code{GroupedData} input. #' The resulting SparkDataFrame will also contain the grouping columns. #' -#' @param x a GroupedData -#' @return a SparkDataFrame +#' @return A SparkDataFrame. #' @rdname count #' @aliases count,GroupedData-method #' @export @@ -83,8 +83,6 @@ setMethod("count", #' df2 <- agg(df, = ) #' df2 <- agg(df, newColName = aggFunction(column)) #' -#' @param x a GroupedData -#' @return a SparkDataFrame #' @rdname summarize #' @aliases agg,GroupedData-method #' @name agg @@ -201,7 +199,6 @@ createMethods() #' gapply #' -#' @param x A GroupedData #' @rdname gapply #' @aliases gapply,GroupedData-method #' @name gapply @@ -216,7 +213,6 @@ setMethod("gapply", #' gapplyCollect #' -#' @param x A GroupedData #' @rdname gapplyCollect #' @aliases gapplyCollect,GroupedData-method #' @name gapplyCollect diff --git a/R/pkg/R/install.R b/R/pkg/R/install.R new file mode 100644 index 000000000000..69b0a523b84e --- /dev/null +++ b/R/pkg/R/install.R @@ -0,0 +1,257 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +# Functions to install Spark in case the user directly downloads SparkR +# from CRAN. + +#' Download and Install Apache Spark to a Local Directory +#' +#' \code{install.spark} downloads and installs Spark to a local directory if +#' it is not found. The Spark version we use is the same as the SparkR version. +#' Users can specify a desired Hadoop version, the remote mirror site, and +#' the directory where the package is installed locally. +#' +#' The full url of remote file is inferred from \code{mirrorUrl} and \code{hadoopVersion}. +#' \code{mirrorUrl} specifies the remote path to a Spark folder. It is followed by a subfolder +#' named after the Spark version (that corresponds to SparkR), and then the tar filename. +#' The filename is composed of four parts, i.e. [Spark version]-bin-[Hadoop version].tgz. +#' For example, the full path for a Spark 2.0.0 package for Hadoop 2.7 from +#' \code{http://apache.osuosl.org} has path: +#' \code{http://apache.osuosl.org/spark/spark-2.0.0/spark-2.0.0-bin-hadoop2.7.tgz}. +#' For \code{hadoopVersion = "without"}, [Hadoop version] in the filename is then +#' \code{without-hadoop}. +#' +#' @param hadoopVersion Version of Hadoop to install. Default is \code{"2.7"}. It can take other +#' version number in the format of "x.y" where x and y are integer. +#' If \code{hadoopVersion = "without"}, "Hadoop free" build is installed. +#' See +#' \href{http://spark.apache.org/docs/latest/hadoop-provided.html}{ +#' "Hadoop Free" Build} for more information. +#' Other patched version names can also be used, e.g. \code{"cdh4"} +#' @param mirrorUrl base URL of the repositories to use. The directory layout should follow +#' \href{http://www.apache.org/dyn/closer.lua/spark/}{Apache mirrors}. +#' @param localDir a local directory where Spark is installed. The directory contains +#' version-specific folders of Spark packages. Default is path to +#' the cache directory: +#' \itemize{ +#' \item Mac OS X: \file{~/Library/Caches/spark} +#' \item Unix: \env{$XDG_CACHE_HOME} if defined, otherwise \file{~/.cache/spark} +#' \item Windows: \file{\%LOCALAPPDATA\%\\spark\\spark\\Cache}. +#' } +#' @param overwrite If \code{TRUE}, download and overwrite the existing tar file in localDir +#' and force re-install Spark (in case the local directory or file is corrupted) +#' @return \code{install.spark} returns the local directory where Spark is found or installed +#' @rdname install.spark +#' @name install.spark +#' @aliases install.spark +#' @export +#' @examples +#'\dontrun{ +#' install.spark() +#'} +#' @note install.spark since 2.1.0 +#' @seealso See available Hadoop versions: +#' \href{http://spark.apache.org/downloads.html}{Apache Spark} +install.spark <- function(hadoopVersion = "2.7", mirrorUrl = NULL, + localDir = NULL, overwrite = FALSE) { + version <- paste0("spark-", packageVersion("SparkR")) + hadoopVersion <- tolower(hadoopVersion) + hadoopVersionName <- hadoopVersionName(hadoopVersion) + packageName <- paste(version, "bin", hadoopVersionName, sep = "-") + localDir <- ifelse(is.null(localDir), sparkCachePath(), + normalizePath(localDir, mustWork = FALSE)) + + if (is.na(file.info(localDir)$isdir)) { + dir.create(localDir, recursive = TRUE) + } + + packageLocalDir <- file.path(localDir, packageName) + + if (overwrite) { + message(paste0("Overwrite = TRUE: download and overwrite the tar file", + "and Spark package directory if they exist.")) + } + + # can use dir.exists(packageLocalDir) under R 3.2.0 or later + if (!is.na(file.info(packageLocalDir)$isdir) && !overwrite) { + fmt <- "%s for Hadoop %s found, with SPARK_HOME set to %s" + msg <- sprintf(fmt, version, ifelse(hadoopVersion == "without", "Free build", hadoopVersion), + packageLocalDir) + message(msg) + Sys.setenv(SPARK_HOME = packageLocalDir) + return(invisible(packageLocalDir)) + } else { + message("Spark not found in the cache directory. Installation will start.") + } + + packageLocalPath <- paste0(packageLocalDir, ".tgz") + tarExists <- file.exists(packageLocalPath) + + if (tarExists && !overwrite) { + message("tar file found.") + } else { + robustDownloadTar(mirrorUrl, version, hadoopVersion, packageName, packageLocalPath) + } + + message(sprintf("Installing to %s", localDir)) + untar(tarfile = packageLocalPath, exdir = localDir) + if (!tarExists || overwrite) { + unlink(packageLocalPath) + } + message("DONE.") + Sys.setenv(SPARK_HOME = packageLocalDir) + message(paste("SPARK_HOME set to", packageLocalDir)) + invisible(packageLocalDir) +} + +robustDownloadTar <- function(mirrorUrl, version, hadoopVersion, packageName, packageLocalPath) { + # step 1: use user-provided url + if (!is.null(mirrorUrl)) { + msg <- sprintf("Use user-provided mirror site: %s.", mirrorUrl) + message(msg) + success <- directDownloadTar(mirrorUrl, version, hadoopVersion, + packageName, packageLocalPath) + if (success) { + return() + } else { + message(paste0("Unable to download from mirrorUrl: ", mirrorUrl)) + } + } else { + message("MirrorUrl not provided.") + } + + # step 2: use url suggested from apache website + message("Looking for preferred site from apache website...") + mirrorUrl <- getPreferredMirror(version, packageName) + if (!is.null(mirrorUrl)) { + success <- directDownloadTar(mirrorUrl, version, hadoopVersion, + packageName, packageLocalPath) + if (success) return() + } else { + message("Unable to find preferred mirror site.") + } + + # step 3: use backup option + message("To use backup site...") + mirrorUrl <- defaultMirrorUrl() + success <- directDownloadTar(mirrorUrl, version, hadoopVersion, + packageName, packageLocalPath) + if (success) { + return(packageLocalPath) + } else { + msg <- sprintf(paste("Unable to download Spark %s for Hadoop %s.", + "Please check network connection, Hadoop version,", + "or provide other mirror sites."), + version, ifelse(hadoopVersion == "without", "Free build", hadoopVersion)) + stop(msg) + } +} + +getPreferredMirror <- function(version, packageName) { + jsonUrl <- paste0("http://www.apache.org/dyn/closer.cgi?path=", + file.path("spark", version, packageName), + ".tgz&as_json=1") + textLines <- readLines(jsonUrl, warn = FALSE) + rowNum <- grep("\"preferred\"", textLines) + linePreferred <- textLines[rowNum] + matchInfo <- regexpr("\"[A-Za-z][A-Za-z0-9+-.]*://.+\"", linePreferred) + if (matchInfo != -1) { + startPos <- matchInfo + 1 + endPos <- matchInfo + attr(matchInfo, "match.length") - 2 + mirrorPreferred <- base::substr(linePreferred, startPos, endPos) + mirrorPreferred <- paste0(mirrorPreferred, "spark") + message(sprintf("Preferred mirror site found: %s", mirrorPreferred)) + } else { + mirrorPreferred <- NULL + } + mirrorPreferred +} + +directDownloadTar <- function(mirrorUrl, version, hadoopVersion, packageName, packageLocalPath) { + packageRemotePath <- paste0( + file.path(mirrorUrl, version, packageName), ".tgz") + fmt <- "Downloading %s for Hadoop %s from:\n- %s" + msg <- sprintf(fmt, version, ifelse(hadoopVersion == "without", "Free build", hadoopVersion), + packageRemotePath) + message(msg) + + isFail <- tryCatch(download.file(packageRemotePath, packageLocalPath), + error = function(e) { + message(sprintf("Fetch failed from %s", mirrorUrl)) + print(e) + TRUE + }) + !isFail +} + +defaultMirrorUrl <- function() { + "http://www-us.apache.org/dist/spark" +} + +hadoopVersionName <- function(hadoopVersion) { + if (hadoopVersion == "without") { + "without-hadoop" + } else if (grepl("^[0-9]+\\.[0-9]+$", hadoopVersion, perl = TRUE)) { + paste0("hadoop", hadoopVersion) + } else { + hadoopVersion + } +} + +# The implementation refers to appdirs package: https://pypi.python.org/pypi/appdirs and +# adapt to Spark context +sparkCachePath <- function() { + if (.Platform$OS.type == "windows") { + winAppPath <- Sys.getenv("LOCALAPPDATA", unset = NA) + if (is.na(winAppPath)) { + msg <- paste("%LOCALAPPDATA% not found.", + "Please define the environment variable", + "or restart and enter an installation path in localDir.") + stop(msg) + } else { + path <- file.path(winAppPath, "spark", "spark", "Cache") + } + } else if (.Platform$OS.type == "unix") { + if (Sys.info()["sysname"] == "Darwin") { + path <- file.path(Sys.getenv("HOME"), "Library/Caches", "spark") + } else { + path <- file.path( + Sys.getenv("XDG_CACHE_HOME", file.path(Sys.getenv("HOME"), ".cache")), "spark") + } + } else { + stop(sprintf("Unknown OS: %s", .Platform$OS.type)) + } + normalizePath(path, mustWork = FALSE) +} + + +installInstruction <- function(mode) { + if (mode == "remote") { + paste0("Connecting to a remote Spark master. ", + "Please make sure Spark package is also installed in this machine.\n", + "- If there is one, set the path in sparkHome parameter or ", + "environment variable SPARK_HOME.\n", + "- If not, you may run install.spark function to do the job. ", + "Please make sure the Spark and the Hadoop versions ", + "match the versions on the cluster. ", + "SparkR package is compatible with Spark ", packageVersion("SparkR"), ".", + "If you need further help, ", + "contact the administrators of the cluster.") + } else { + stop(paste0("No instruction found for ", mode, " mode.")) + } +} diff --git a/R/pkg/R/jvm.R b/R/pkg/R/jvm.R new file mode 100644 index 000000000000..bb5c77544a3d --- /dev/null +++ b/R/pkg/R/jvm.R @@ -0,0 +1,117 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +# Methods to directly access the JVM running the SparkR backend. + +#' Call Java Methods +#' +#' Call a Java method in the JVM running the Spark driver. The return +#' values are automatically converted to R objects for simple objects. Other +#' values are returned as "jobj" which are references to objects on JVM. +#' +#' @details +#' This is a low level function to access the JVM directly and should only be used +#' for advanced use cases. The arguments and return values that are primitive R +#' types (like integer, numeric, character, lists) are automatically translated to/from +#' Java types (like Integer, Double, String, Array). A full list can be found in +#' serialize.R and deserialize.R in the Apache Spark code base. +#' +#' @param x object to invoke the method on. Should be a "jobj" created by newJObject. +#' @param methodName method name to call. +#' @param ... parameters to pass to the Java method. +#' @return the return value of the Java method. Either returned as a R object +#' if it can be deserialized or returned as a "jobj". See details section for more. +#' @export +#' @seealso \link{sparkR.callJStatic}, \link{sparkR.newJObject} +#' @rdname sparkR.callJMethod +#' @examples +#' \dontrun{ +#' sparkR.session() # Need to have a Spark JVM running before calling newJObject +#' # Create a Java ArrayList and populate it +#' jarray <- sparkR.newJObject("java.util.ArrayList") +#' sparkR.callJMethod(jarray, "add", 42L) +#' sparkR.callJMethod(jarray, "get", 0L) # Will print 42 +#' } +#' @note sparkR.callJMethod since 2.0.1 +sparkR.callJMethod <- function(x, methodName, ...) { + callJMethod(x, methodName, ...) +} + +#' Call Static Java Methods +#' +#' Call a static method in the JVM running the Spark driver. The return +#' value is automatically converted to R objects for simple objects. Other +#' values are returned as "jobj" which are references to objects on JVM. +#' +#' @details +#' This is a low level function to access the JVM directly and should only be used +#' for advanced use cases. The arguments and return values that are primitive R +#' types (like integer, numeric, character, lists) are automatically translated to/from +#' Java types (like Integer, Double, String, Array). A full list can be found in +#' serialize.R and deserialize.R in the Apache Spark code base. +#' +#' @param x fully qualified Java class name that contains the static method to invoke. +#' @param methodName name of static method to invoke. +#' @param ... parameters to pass to the Java method. +#' @return the return value of the Java method. Either returned as a R object +#' if it can be deserialized or returned as a "jobj". See details section for more. +#' @export +#' @seealso \link{sparkR.callJMethod}, \link{sparkR.newJObject} +#' @rdname sparkR.callJStatic +#' @examples +#' \dontrun{ +#' sparkR.session() # Need to have a Spark JVM running before calling callJStatic +#' sparkR.callJStatic("java.lang.System", "currentTimeMillis") +#' sparkR.callJStatic("java.lang.System", "getProperty", "java.home") +#' } +#' @note sparkR.callJStatic since 2.0.1 +sparkR.callJStatic <- function(x, methodName, ...) { + callJStatic(x, methodName, ...) +} + +#' Create Java Objects +#' +#' Create a new Java object in the JVM running the Spark driver. The return +#' value is automatically converted to an R object for simple objects. Other +#' values are returned as a "jobj" which is a reference to an object on JVM. +#' +#' @details +#' This is a low level function to access the JVM directly and should only be used +#' for advanced use cases. The arguments and return values that are primitive R +#' types (like integer, numeric, character, lists) are automatically translated to/from +#' Java types (like Integer, Double, String, Array). A full list can be found in +#' serialize.R and deserialize.R in the Apache Spark code base. +#' +#' @param x fully qualified Java class name. +#' @param ... arguments to be passed to the constructor. +#' @return the object created. Either returned as a R object +#' if it can be deserialized or returned as a "jobj". See details section for more. +#' @export +#' @seealso \link{sparkR.callJMethod}, \link{sparkR.callJStatic} +#' @rdname sparkR.newJObject +#' @examples +#' \dontrun{ +#' sparkR.session() # Need to have a Spark JVM running before calling newJObject +#' # Create a Java ArrayList and populate it +#' jarray <- sparkR.newJObject("java.util.ArrayList") +#' sparkR.callJMethod(jarray, "add", 42L) +#' sparkR.callJMethod(jarray, "get", 0L) # Will print 42 +#' } +#' @note sparkR.newJObject since 2.0.1 +sparkR.newJObject <- function(x, ...) { + newJObject(x, ...) +} diff --git a/R/pkg/R/mllib.R b/R/pkg/R/mllib.R index 50c601fcd9e1..b33a16a7cef9 100644 --- a/R/pkg/R/mllib.R +++ b/R/pkg/R/mllib.R @@ -82,15 +82,16 @@ NULL #' Users can call \code{summary} to print a summary of the fitted model, \code{predict} to make #' predictions on new data, and \code{write.ml}/\code{read.ml} to save/load fitted models. #' -#' @param data SparkDataFrame for training. -#' @param formula A symbolic description of the model to be fitted. Currently only a few formula +#' @param data a SparkDataFrame for training. +#' @param formula a symbolic description of the model to be fitted. Currently only a few formula #' operators are supported, including '~', '.', ':', '+', and '-'. -#' @param family A description of the error distribution and link function to be used in the model. +#' @param family a description of the error distribution and link function to be used in the model. #' This can be a character string naming a family function, a family function or #' the result of a call to a family function. Refer R family at #' \url{https://stat.ethz.ch/R-manual/R-devel/library/stats/html/family.html}. -#' @param tol Positive convergence tolerance of iterations. -#' @param maxIter Integer giving the maximal number of IRLS iterations. +#' @param tol positive convergence tolerance of iterations. +#' @param maxIter integer giving the maximal number of IRLS iterations. +#' @param ... additional arguments passed to the method. #' @aliases spark.glm,SparkDataFrame,formula-method #' @return \code{spark.glm} returns a fitted generalized linear model #' @rdname spark.glm @@ -142,15 +143,15 @@ setMethod("spark.glm", signature(data = "SparkDataFrame", formula = "formula"), #' Generalized Linear Models (R-compliant) #' #' Fits a generalized linear model, similarly to R's glm(). -#' @param formula A symbolic description of the model to be fitted. Currently only a few formula +#' @param formula a symbolic description of the model to be fitted. Currently only a few formula #' operators are supported, including '~', '.', ':', '+', and '-'. -#' @param data SparkDataFrame for training. -#' @param family A description of the error distribution and link function to be used in the model. +#' @param data a SparkDataFrame or R's glm data for training. +#' @param family a description of the error distribution and link function to be used in the model. #' This can be a character string naming a family function, a family function or #' the result of a call to a family function. Refer R family at #' \url{https://stat.ethz.ch/R-manual/R-devel/library/stats/html/family.html}. -#' @param epsilon Positive convergence tolerance of iterations. -#' @param maxit Integer giving the maximal number of IRLS iterations. +#' @param epsilon positive convergence tolerance of iterations. +#' @param maxit integer giving the maximal number of IRLS iterations. #' @return \code{glm} returns a fitted generalized linear model. #' @rdname glm #' @export @@ -171,7 +172,7 @@ setMethod("glm", signature(formula = "formula", family = "ANY", data = "SparkDat # Returns the summary of a model produced by glm() or spark.glm(), similarly to R's summary(). -#' @param object A fitted generalized linear model +#' @param object a fitted generalized linear model. #' @return \code{summary} returns a summary object of the fitted model, a list of components #' including at least the coefficients, null/residual deviance, null/residual degrees #' of freedom, AIC and number of iterations IRLS takes. @@ -212,7 +213,7 @@ setMethod("summary", signature(object = "GeneralizedLinearRegressionModel"), # Prints the summary of GeneralizedLinearRegressionModel #' @rdname spark.glm -#' @param x Summary object of fitted generalized linear model returned by \code{summary} function +#' @param x summary object of fitted generalized linear model returned by \code{summary} function #' @export #' @note print.summary.GeneralizedLinearRegressionModel since 2.0.0 print.summary.GeneralizedLinearRegressionModel <- function(x, ...) { @@ -244,7 +245,7 @@ print.summary.GeneralizedLinearRegressionModel <- function(x, ...) { # Makes predictions from a generalized linear model produced by glm() or spark.glm(), # similarly to R's predict(). -#' @param newData SparkDataFrame for testing +#' @param newData a SparkDataFrame for testing. #' @return \code{predict} returns a SparkDataFrame containing predicted labels in a column named #' "prediction" #' @rdname spark.glm @@ -258,7 +259,7 @@ setMethod("predict", signature(object = "GeneralizedLinearRegressionModel"), # Makes predictions from a naive Bayes model or a model produced by spark.naiveBayes(), # similarly to R package e1071's predict. -#' @param newData A SparkDataFrame for testing +#' @param newData a SparkDataFrame for testing. #' @return \code{predict} returns a SparkDataFrame containing predicted labeled in a column named #' "prediction" #' @rdname spark.naiveBayes @@ -271,9 +272,9 @@ setMethod("predict", signature(object = "NaiveBayesModel"), # Returns the summary of a naive Bayes model produced by \code{spark.naiveBayes} -#' @param object A naive Bayes model fitted by \code{spark.naiveBayes} +#' @param object a naive Bayes model fitted by \code{spark.naiveBayes}. #' @return \code{summary} returns a list containing \code{apriori}, the label distribution, and -#' \code{tables}, conditional probabilities given the target label +#' \code{tables}, conditional probabilities given the target label. #' @rdname spark.naiveBayes #' @export #' @note summary(NaiveBayesModel) since 2.0.0 @@ -298,14 +299,15 @@ setMethod("summary", signature(object = "NaiveBayesModel"), #' Users can call \code{summary} to print a summary of the fitted model, \code{predict} to make #' predictions on new data, and \code{write.ml}/\code{read.ml} to save/load fitted models. #' -#' @param data SparkDataFrame for training -#' @param formula A symbolic description of the model to be fitted. Currently only a few formula +#' @param data a SparkDataFrame for training. +#' @param formula a symbolic description of the model to be fitted. Currently only a few formula #' operators are supported, including '~', '.', ':', '+', and '-'. #' Note that the response variable of formula is empty in spark.kmeans. -#' @param k Number of centers -#' @param maxIter Maximum iteration number -#' @param initMode The initialization algorithm choosen to fit the model -#' @return \code{spark.kmeans} returns a fitted k-means model +#' @param k number of centers. +#' @param maxIter maximum iteration number. +#' @param initMode the initialization algorithm choosen to fit the model. +#' @param ... additional argument(s) passed to the method. +#' @return \code{spark.kmeans} returns a fitted k-means model. #' @rdname spark.kmeans #' @aliases spark.kmeans,SparkDataFrame,formula-method #' @name spark.kmeans @@ -346,8 +348,11 @@ setMethod("spark.kmeans", signature(data = "SparkDataFrame", formula = "formula" #' Get fitted result from a k-means model, similarly to R's fitted(). #' Note: A saved-loaded model does not support this method. #' -#' @param object A fitted k-means model -#' @return \code{fitted} returns a SparkDataFrame containing fitted values +#' @param object a fitted k-means model. +#' @param method type of fitted results, \code{"centers"} for cluster centers +#' or \code{"classes"} for assigned classes. +#' @param ... additional argument(s) passed to the method. +#' @return \code{fitted} returns a SparkDataFrame containing fitted values. #' @rdname fitted #' @export #' @examples @@ -371,8 +376,8 @@ setMethod("fitted", signature(object = "KMeansModel"), # Get the summary of a k-means model -#' @param object A fitted k-means model -#' @return \code{summary} returns the model's coefficients, size and cluster +#' @param object a fitted k-means model. +#' @return \code{summary} returns the model's coefficients, size and cluster. #' @rdname spark.kmeans #' @export #' @note summary(KMeansModel) since 2.0.0 @@ -398,7 +403,8 @@ setMethod("summary", signature(object = "KMeansModel"), # Predicted values based on a k-means model -#' @return \code{predict} returns the predicted values based on a k-means model +#' @param newData a SparkDataFrame for testing. +#' @return \code{predict} returns the predicted values based on a k-means model. #' @rdname spark.kmeans #' @export #' @note predict(KMeansModel) since 2.0.0 @@ -414,22 +420,24 @@ setMethod("predict", signature(object = "KMeansModel"), #' predictions on new data, and \code{write.ml}/\code{read.ml} to save/load fitted models. #' Only categorical data is supported. #' -#' @param data A \code{SparkDataFrame} of observations and labels for model fitting -#' @param formula A symbolic description of the model to be fitted. Currently only a few formula +#' @param data a \code{SparkDataFrame} of observations and labels for model fitting. +#' @param formula a symbolic description of the model to be fitted. Currently only a few formula #' operators are supported, including '~', '.', ':', '+', and '-'. -#' @param smoothing Smoothing parameter -#' @return \code{spark.naiveBayes} returns a fitted naive Bayes model +#' @param smoothing smoothing parameter. +#' @param ... additional argument(s) passed to the method. Currently only \code{smoothing}. +#' @return \code{spark.naiveBayes} returns a fitted naive Bayes model. #' @rdname spark.naiveBayes #' @aliases spark.naiveBayes,SparkDataFrame,formula-method #' @name spark.naiveBayes -#' @seealso e1071: \url{https://cran.r-project.org/web/packages/e1071/} +#' @seealso e1071: \url{https://cran.r-project.org/package=e1071} #' @export #' @examples #' \dontrun{ -#' df <- createDataFrame(infert) +#' data <- as.data.frame(UCBAdmissions) +#' df <- createDataFrame(data) #' #' # fit a Bernoulli naive Bayes model -#' model <- spark.naiveBayes(df, education ~ ., smoothing = 0) +#' model <- spark.naiveBayes(df, Admit ~ Gender + Dept, smoothing = 0) #' #' # get the summary of the model #' summary(model) @@ -454,8 +462,8 @@ setMethod("spark.naiveBayes", signature(data = "SparkDataFrame", formula = "form # Saves the Bernoulli naive Bayes model to the input path. -#' @param path The directory where the model is saved -#' @param overwrite Overwrites or not if the output path already exists. Default is FALSE +#' @param path the directory where the model is saved +#' @param overwrite overwrites or not if the output path already exists. Default is FALSE #' which means throw exception if the output path exists. #' #' @rdname spark.naiveBayes @@ -473,10 +481,9 @@ setMethod("write.ml", signature(object = "NaiveBayesModel", path = "character"), # Saves the AFT survival regression model to the input path. -#' @param path The directory where the model is saved -#' @param overwrite Overwrites or not if the output path already exists. Default is FALSE +#' @param path the directory where the model is saved. +#' @param overwrite overwrites or not if the output path already exists. Default is FALSE #' which means throw exception if the output path exists. -#' #' @rdname spark.survreg #' @export #' @note write.ml(AFTSurvivalRegressionModel, character) since 2.0.0 @@ -492,8 +499,8 @@ setMethod("write.ml", signature(object = "AFTSurvivalRegressionModel", path = "c # Saves the generalized linear model to the input path. -#' @param path The directory where the model is saved -#' @param overwrite Overwrites or not if the output path already exists. Default is FALSE +#' @param path the directory where the model is saved. +#' @param overwrite overwrites or not if the output path already exists. Default is FALSE #' which means throw exception if the output path exists. #' #' @rdname spark.glm @@ -510,8 +517,8 @@ setMethod("write.ml", signature(object = "GeneralizedLinearRegressionModel", pat # Save fitted MLlib model to the input path -#' @param path The directory where the model is saved -#' @param overwrite Overwrites or not if the output path already exists. Default is FALSE +#' @param path the directory where the model is saved. +#' @param overwrite overwrites or not if the output path already exists. Default is FALSE #' which means throw exception if the output path exists. #' #' @rdname spark.kmeans @@ -528,8 +535,8 @@ setMethod("write.ml", signature(object = "KMeansModel", path = "character"), #' Load a fitted MLlib model from the input path. #' -#' @param path Path of the model to read. -#' @return a fitted MLlib model +#' @param path path of the model to read. +#' @return A fitted MLlib model. #' @rdname read.ml #' @name read.ml #' @export @@ -563,13 +570,13 @@ read.ml <- function(path) { #' \code{predict} to make predictions on new data, and \code{write.ml}/\code{read.ml} to #' save/load fitted models. #' -#' @param data A SparkDataFrame for training -#' @param formula A symbolic description of the model to be fitted. Currently only a few formula +#' @param data a SparkDataFrame for training. +#' @param formula a symbolic description of the model to be fitted. Currently only a few formula #' operators are supported, including '~', ':', '+', and '-'. -#' Note that operator '.' is not supported currently -#' @return \code{spark.survreg} returns a fitted AFT survival regression model +#' Note that operator '.' is not supported currently. +#' @return \code{spark.survreg} returns a fitted AFT survival regression model. #' @rdname spark.survreg -#' @seealso survival: \url{https://cran.r-project.org/web/packages/survival/} +#' @seealso survival: \url{https://cran.r-project.org/package=survival} #' @export #' @examples #' \dontrun{ @@ -591,25 +598,24 @@ read.ml <- function(path) { #' } #' @note spark.survreg since 2.0.0 setMethod("spark.survreg", signature(data = "SparkDataFrame", formula = "formula"), - function(data, formula, ...) { + function(data, formula) { formula <- paste(deparse(formula), collapse = "") jobj <- callJStatic("org.apache.spark.ml.r.AFTSurvivalRegressionWrapper", "fit", formula, data@sdf) return(new("AFTSurvivalRegressionModel", jobj = jobj)) }) - # Returns a summary of the AFT survival regression model produced by spark.survreg, # similarly to R's summary(). -#' @param object A fitted AFT survival regression model +#' @param object a fitted AFT survival regression model. #' @return \code{summary} returns a list containing the model's coefficients, #' intercept and log(scale) #' @rdname spark.survreg #' @export #' @note summary(AFTSurvivalRegressionModel) since 2.0.0 setMethod("summary", signature(object = "AFTSurvivalRegressionModel"), - function(object, ...) { + function(object) { jobj <- object@jobj features <- callJMethod(jobj, "rFeatures") coefficients <- callJMethod(jobj, "rCoefficients") @@ -622,9 +628,9 @@ setMethod("summary", signature(object = "AFTSurvivalRegressionModel"), # Makes predictions from an AFT survival regression model or a model produced by # spark.survreg, similarly to R package survival's predict. -#' @param newData A SparkDataFrame for testing +#' @param newData a SparkDataFrame for testing. #' @return \code{predict} returns a SparkDataFrame containing predicted values -#' on the original scale of the data (mean predicted value at scale = 1.0) +#' on the original scale of the data (mean predicted value at scale = 1.0). #' @rdname spark.survreg #' @export #' @note predict(AFTSurvivalRegressionModel) since 2.0.0 diff --git a/R/pkg/R/pairRDD.R b/R/pkg/R/pairRDD.R index d39775cabef8..4dee3245f9b7 100644 --- a/R/pkg/R/pairRDD.R +++ b/R/pkg/R/pairRDD.R @@ -49,7 +49,7 @@ setMethod("lookup", lapply(filtered, function(i) { i[[2]] }) } valsRDD <- lapplyPartition(x, partitionFunc) - collect(valsRDD) + collectRDD(valsRDD) }) #' Count the number of elements for each key, and return the result to the @@ -85,7 +85,7 @@ setMethod("countByKey", #'\dontrun{ #' sc <- sparkR.init() #' rdd <- parallelize(sc, list(list(1, 2), list(3, 4))) -#' collect(keys(rdd)) # list(1, 3) +#' collectRDD(keys(rdd)) # list(1, 3) #'} # nolint end #' @rdname keys @@ -108,7 +108,7 @@ setMethod("keys", #'\dontrun{ #' sc <- sparkR.init() #' rdd <- parallelize(sc, list(list(1, 2), list(3, 4))) -#' collect(values(rdd)) # list(2, 4) +#' collectRDD(values(rdd)) # list(2, 4) #'} # nolint end #' @rdname values @@ -135,7 +135,7 @@ setMethod("values", #' sc <- sparkR.init() #' rdd <- parallelize(sc, 1:10) #' makePairs <- lapply(rdd, function(x) { list(x, x) }) -#' collect(mapValues(makePairs, function(x) { x * 2) }) +#' collectRDD(mapValues(makePairs, function(x) { x * 2) }) #' Output: list(list(1,2), list(2,4), list(3,6), ...) #'} #' @rdname mapValues @@ -162,7 +162,7 @@ setMethod("mapValues", #'\dontrun{ #' sc <- sparkR.init() #' rdd <- parallelize(sc, list(list(1, c(1,2)), list(2, c(3,4)))) -#' collect(flatMapValues(rdd, function(x) { x })) +#' collectRDD(flatMapValues(rdd, function(x) { x })) #' Output: list(list(1,1), list(1,2), list(2,3), list(2,4)) #'} #' @rdname flatMapValues @@ -198,13 +198,13 @@ setMethod("flatMapValues", #' sc <- sparkR.init() #' pairs <- list(list(1, 2), list(1.1, 3), list(1, 4)) #' rdd <- parallelize(sc, pairs) -#' parts <- partitionBy(rdd, 2L) +#' parts <- partitionByRDD(rdd, 2L) #' collectPartition(parts, 0L) # First partition should contain list(1, 2) and list(1, 4) #'} #' @rdname partitionBy #' @aliases partitionBy,RDD,integer-method #' @noRd -setMethod("partitionBy", +setMethod("partitionByRDD", signature(x = "RDD"), function(x, numPartitions, partitionFunc = hashCode) { stopifnot(is.numeric(numPartitions)) @@ -261,7 +261,7 @@ setMethod("partitionBy", #' pairs <- list(list(1, 2), list(1.1, 3), list(1, 4)) #' rdd <- parallelize(sc, pairs) #' parts <- groupByKey(rdd, 2L) -#' grouped <- collect(parts) +#' grouped <- collectRDD(parts) #' grouped[[1]] # Should be a list(1, list(2, 4)) #'} #' @rdname groupByKey @@ -270,7 +270,7 @@ setMethod("partitionBy", setMethod("groupByKey", signature(x = "RDD", numPartitions = "numeric"), function(x, numPartitions) { - shuffled <- partitionBy(x, numPartitions) + shuffled <- partitionByRDD(x, numPartitions) groupVals <- function(part) { vals <- new.env() keys <- new.env() @@ -321,7 +321,7 @@ setMethod("groupByKey", #' pairs <- list(list(1, 2), list(1.1, 3), list(1, 4)) #' rdd <- parallelize(sc, pairs) #' parts <- reduceByKey(rdd, "+", 2L) -#' reduced <- collect(parts) +#' reduced <- collectRDD(parts) #' reduced[[1]] # Should be a list(1, 6) #'} #' @rdname reduceByKey @@ -342,7 +342,7 @@ setMethod("reduceByKey", convertEnvsToList(keys, vals) } locallyReduced <- lapplyPartition(x, reduceVals) - shuffled <- partitionBy(locallyReduced, numToInt(numPartitions)) + shuffled <- partitionByRDD(locallyReduced, numToInt(numPartitions)) lapplyPartition(shuffled, reduceVals) }) @@ -430,7 +430,7 @@ setMethod("reduceByKeyLocally", #' pairs <- list(list(1, 2), list(1.1, 3), list(1, 4)) #' rdd <- parallelize(sc, pairs) #' parts <- combineByKey(rdd, function(x) { x }, "+", "+", 2L) -#' combined <- collect(parts) +#' combined <- collectRDD(parts) #' combined[[1]] # Should be a list(1, 6) #'} # nolint end @@ -453,7 +453,7 @@ setMethod("combineByKey", convertEnvsToList(keys, combiners) } locallyCombined <- lapplyPartition(x, combineLocally) - shuffled <- partitionBy(locallyCombined, numToInt(numPartitions)) + shuffled <- partitionByRDD(locallyCombined, numToInt(numPartitions)) mergeAfterShuffle <- function(part) { combiners <- new.env() keys <- new.env() @@ -563,13 +563,13 @@ setMethod("foldByKey", #' sc <- sparkR.init() #' rdd1 <- parallelize(sc, list(list(1, 1), list(2, 4))) #' rdd2 <- parallelize(sc, list(list(1, 2), list(1, 3))) -#' join(rdd1, rdd2, 2L) # list(list(1, list(1, 2)), list(1, list(1, 3)) +#' joinRDD(rdd1, rdd2, 2L) # list(list(1, list(1, 2)), list(1, list(1, 3)) #'} # nolint end #' @rdname join-methods #' @aliases join,RDD,RDD-method #' @noRd -setMethod("join", +setMethod("joinRDD", signature(x = "RDD", y = "RDD"), function(x, y, numPartitions) { xTagged <- lapply(x, function(i) { list(i[[1]], list(1L, i[[2]])) }) @@ -772,7 +772,7 @@ setMethod("cogroup", #'\dontrun{ #' sc <- sparkR.init() #' rdd <- parallelize(sc, list(list(3, 1), list(2, 2), list(1, 3))) -#' collect(sortByKey(rdd)) # list (list(1, 3), list(2, 2), list(3, 1)) +#' collectRDD(sortByKey(rdd)) # list (list(1, 3), list(2, 2), list(3, 1)) #'} # nolint end #' @rdname sortByKey @@ -784,12 +784,12 @@ setMethod("sortByKey", rangeBounds <- list() if (numPartitions > 1) { - rddSize <- count(x) + rddSize <- countRDD(x) # constant from Spark's RangePartitioner maxSampleSize <- numPartitions * 20 fraction <- min(maxSampleSize / max(rddSize, 1), 1.0) - samples <- collect(keys(sampleRDD(x, FALSE, fraction, 1L))) + samples <- collectRDD(keys(sampleRDD(x, FALSE, fraction, 1L))) # Note: the built-in R sort() function only works on atomic vectors samples <- sort(unlist(samples, recursive = FALSE), decreasing = !ascending) @@ -822,7 +822,7 @@ setMethod("sortByKey", sortKeyValueList(part, decreasing = !ascending) } - newRDD <- partitionBy(x, numPartitions, rangePartitionFunc) + newRDD <- partitionByRDD(x, numPartitions, rangePartitionFunc) lapplyPartition(newRDD, partitionFunc) }) @@ -841,7 +841,7 @@ setMethod("sortByKey", #' rdd1 <- parallelize(sc, list(list("a", 1), list("b", 4), #' list("b", 5), list("a", 2))) #' rdd2 <- parallelize(sc, list(list("a", 3), list("c", 1))) -#' collect(subtractByKey(rdd1, rdd2)) +#' collectRDD(subtractByKey(rdd1, rdd2)) #' # list(list("b", 4), list("b", 5)) #'} # nolint end @@ -917,19 +917,19 @@ setMethod("sampleByKey", len <- 0 # mixing because the initial seeds are close to each other - runif(10) + stats::runif(10) for (elem in part) { if (elem[[1]] %in% names(fractions)) { frac <- as.numeric(fractions[which(elem[[1]] == names(fractions))]) if (withReplacement) { - count <- rpois(1, frac) + count <- stats::rpois(1, frac) if (count > 0) { res[ (len + 1) : (len + count) ] <- rep(list(elem), count) len <- len + count } } else { - if (runif(1) < frac) { + if (stats::runif(1) < frac) { len <- len + 1 res[[len]] <- elem } diff --git a/R/pkg/R/schema.R b/R/pkg/R/schema.R index b429f5de13b8..cb5bdb90175b 100644 --- a/R/pkg/R/schema.R +++ b/R/pkg/R/schema.R @@ -92,8 +92,9 @@ print.structType <- function(x, ...) { #' #' Create a structField object that contains the metadata for a single field in a schema. #' -#' @param x The name of the field -#' @return a structField object +#' @param x the name of the field. +#' @param ... additional argument(s) passed to the method. +#' @return A structField object. #' @rdname structField #' @export #' @examples diff --git a/R/pkg/R/sparkR.R b/R/pkg/R/sparkR.R index 524f7c4a26b6..cc6d591bb2f4 100644 --- a/R/pkg/R/sparkR.R +++ b/R/pkg/R/sparkR.R @@ -100,7 +100,7 @@ sparkR.stop <- function() { #' @param sparkEnvir Named list of environment variables to set on worker nodes #' @param sparkExecutorEnv Named list of environment variables to be used when launching executors #' @param sparkJars Character vector of jar files to pass to the worker nodes -#' @param sparkPackages Character vector of packages from spark-packages.org +#' @param sparkPackages Character vector of package coordinates #' @seealso \link{sparkR.session} #' @rdname sparkR.init-deprecated #' @export @@ -314,20 +314,23 @@ sparkRHive.init <- function(jsc = NULL) { #' Get the existing SparkSession or initialize a new SparkSession. #' -#' Additional Spark properties can be set (...), and these named parameters take priority over -#' over values in master, appName, named lists of sparkConfig. +#' SparkSession is the entry point into SparkR. \code{sparkR.session} gets the existing +#' SparkSession or initializes a new SparkSession. +#' Additional Spark properties can be set in \code{...}, and these named parameters take priority +#' over values in \code{master}, \code{appName}, named lists of \code{sparkConfig}. #' #' For details on how to initialize and use SparkR, refer to SparkR programming guide at #' \url{http://spark.apache.org/docs/latest/sparkr.html#starting-up-sparksession}. #' -#' @param master The Spark master URL -#' @param appName Application name to register with cluster manager -#' @param sparkHome Spark Home directory -#' @param sparkConfig Named list of Spark configuration to set on worker nodes -#' @param sparkJars Character vector of jar files to pass to the worker nodes -#' @param sparkPackages Character vector of packages from spark-packages.org -#' @param enableHiveSupport Enable support for Hive, fallback if not built with Hive support; once +#' @param master the Spark master URL. +#' @param appName application name to register with cluster manager. +#' @param sparkHome Spark Home directory. +#' @param sparkConfig named list of Spark configuration to set on worker nodes. +#' @param sparkJars character vector of jar files to pass to the worker nodes. +#' @param sparkPackages character vector of package coordinates +#' @param enableHiveSupport enable support for Hive, fallback if not built with Hive support; once #' set, this cannot be turned off on an existing session +#' @param ... named Spark properties passed to the method. #' @export #' @examples #'\dontrun{ @@ -367,6 +370,8 @@ sparkR.session <- function( } if (!exists(".sparkRjsc", envir = .sparkREnv)) { + retHome <- sparkCheckInstall(sparkHome, master) + if (!is.null(retHome)) sparkHome <- retHome sparkExecutorEnvMap <- new.env() sparkR.sparkContext(master, appName, sparkHome, sparkConfigMap, sparkExecutorEnvMap, sparkJars, sparkPackages) @@ -396,9 +401,9 @@ sparkR.session <- function( #' Assigns a group ID to all the jobs started by this thread until the group ID is set to a #' different value or cleared. #' -#' @param groupid the ID to be assigned to job groups -#' @param description description for the job group ID -#' @param interruptOnCancel flag to indicate if the job is interrupted on job cancellation +#' @param groupId the ID to be assigned to job groups. +#' @param description description for the job group ID. +#' @param interruptOnCancel flag to indicate if the job is interrupted on job cancellation. #' @rdname setJobGroup #' @name setJobGroup #' @examples @@ -486,6 +491,10 @@ sparkConfToSubmitOps[["spark.driver.memory"]] <- "--driver-memory" sparkConfToSubmitOps[["spark.driver.extraClassPath"]] <- "--driver-class-path" sparkConfToSubmitOps[["spark.driver.extraJavaOptions"]] <- "--driver-java-options" sparkConfToSubmitOps[["spark.driver.extraLibraryPath"]] <- "--driver-library-path" +sparkConfToSubmitOps[["spark.master"]] <- "--master" +sparkConfToSubmitOps[["spark.yarn.keytab"]] <- "--keytab" +sparkConfToSubmitOps[["spark.yarn.principal"]] <- "--principal" + # Utility function that returns Spark Submit arguments as a string # @@ -529,3 +538,35 @@ processSparkPackages <- function(packages) { } splittedPackages } + +# Utility function that checks and install Spark to local folder if not found +# +# Installation will not be triggered if it's called from sparkR shell +# or if the master url is not local +# +# @param sparkHome directory to find Spark package. +# @param master the Spark master URL, used to check local or remote mode. +# @return NULL if no need to update sparkHome, and new sparkHome otherwise. +sparkCheckInstall <- function(sparkHome, master) { + if (!isSparkRShell()) { + if (!is.na(file.info(sparkHome)$isdir)) { + msg <- paste0("Spark package found in SPARK_HOME: ", sparkHome) + message(msg) + NULL + } else { + if (!nzchar(master) || isMasterLocal(master)) { + msg <- paste0("Spark not found in SPARK_HOME: ", + sparkHome) + message(msg) + packageLocalDir <- install.spark() + packageLocalDir + } else { + msg <- paste0("Spark not found in SPARK_HOME: ", + sparkHome, "\n", installInstruction("remote")) + stop(msg) + } + } + } else { + NULL + } +} diff --git a/R/pkg/R/stats.R b/R/pkg/R/stats.R index 2b4ce195cbdd..dcd7198f41ea 100644 --- a/R/pkg/R/stats.R +++ b/R/pkg/R/stats.R @@ -25,12 +25,13 @@ setOldClass("jobj") #' table. The number of distinct values for each column should be less than 1e4. At most 1e6 #' non-zero pair frequencies will be returned. #' +#' @param x a SparkDataFrame #' @param col1 name of the first column. Distinct items will make the first item of each row. #' @param col2 name of the second column. Distinct items will make the column names of the output. #' @return a local R data.frame representing the contingency table. The first column of each row -#' will be the distinct values of `col1` and the column names will be the distinct values -#' of `col2`. The name of the first column will be `$col1_$col2`. Pairs that have no -#' occurrences will have zero as their counts. +#' will be the distinct values of \code{col1} and the column names will be the distinct values +#' of \code{col2}. The name of the first column will be "\code{col1}_\code{col2}". Pairs +#' that have no occurrences will have zero as their counts. #' #' @rdname crosstab #' @name crosstab @@ -53,10 +54,9 @@ setMethod("crosstab", #' Calculate the sample covariance of two numerical columns of a SparkDataFrame. #' -#' @param x A SparkDataFrame -#' @param col1 the name of the first column -#' @param col2 the name of the second column -#' @return the covariance of the two columns. +#' @param colName1 the name of the first column +#' @param colName2 the name of the second column +#' @return The covariance of the two columns. #' #' @rdname cov #' @name cov @@ -71,19 +71,18 @@ setMethod("crosstab", #' @note cov since 1.6.0 setMethod("cov", signature(x = "SparkDataFrame"), - function(x, col1, col2) { - stopifnot(class(col1) == "character" && class(col2) == "character") + function(x, colName1, colName2) { + stopifnot(class(colName1) == "character" && class(colName2) == "character") statFunctions <- callJMethod(x@sdf, "stat") - callJMethod(statFunctions, "cov", col1, col2) + callJMethod(statFunctions, "cov", colName1, colName2) }) #' Calculates the correlation of two columns of a SparkDataFrame. #' Currently only supports the Pearson Correlation Coefficient. #' For Spearman Correlation, consider using RDD methods found in MLlib's Statistics. #' -#' @param x A SparkDataFrame -#' @param col1 the name of the first column -#' @param col2 the name of the second column +#' @param colName1 the name of the first column +#' @param colName2 the name of the second column #' @param method Optional. A character specifying the method for calculating the correlation. #' only "pearson" is allowed now. #' @return The Pearson Correlation Coefficient as a Double. @@ -102,10 +101,10 @@ setMethod("cov", #' @note corr since 1.6.0 setMethod("corr", signature(x = "SparkDataFrame"), - function(x, col1, col2, method = "pearson") { - stopifnot(class(col1) == "character" && class(col2) == "character") + function(x, colName1, colName2, method = "pearson") { + stopifnot(class(colName1) == "character" && class(colName2) == "character") statFunctions <- callJMethod(x@sdf, "stat") - callJMethod(statFunctions, "corr", col1, col2, method) + callJMethod(statFunctions, "corr", colName1, colName2, method) }) @@ -117,7 +116,7 @@ setMethod("corr", #' #' @param x A SparkDataFrame. #' @param cols A vector column names to search frequent items in. -#' @param support (Optional) The minimum frequency for an item to be considered `frequent`. +#' @param support (Optional) The minimum frequency for an item to be considered \code{frequent}. #' Should be greater than 1e-4. Default support = 0.01. #' @return a local R data.frame with the frequent items in each column #' @@ -143,9 +142,9 @@ setMethod("freqItems", signature(x = "SparkDataFrame", cols = "character"), #' #' Calculates the approximate quantiles of a numerical column of a SparkDataFrame. #' The result of this algorithm has the following deterministic bound: -#' If the SparkDataFrame has N elements and if we request the quantile at probability `p` up to -#' error `err`, then the algorithm will return a sample `x` from the SparkDataFrame so that the -#' *exact* rank of `x` is close to (p * N). More precisely, +#' If the SparkDataFrame has N elements and if we request the quantile at probability p up to +#' error err, then the algorithm will return a sample x from the SparkDataFrame so that the +#' *exact* rank of x is close to (p * N). More precisely, #' floor((p - err) * N) <= rank(x) <= ceil((p + err) * N). #' This method implements a variation of the Greenwald-Khanna algorithm (with some speed #' optimizations). The algorithm was first present in [[http://dx.doi.org/10.1145/375663.375670 diff --git a/R/pkg/R/types.R b/R/pkg/R/types.R index ad048b1cd179..abca703617c7 100644 --- a/R/pkg/R/types.R +++ b/R/pkg/R/types.R @@ -67,3 +67,19 @@ rToSQLTypes <- as.environment(list( "double" = "double", "character" = "string", "logical" = "boolean")) + +# Helper function of coverting decimal type. When backend returns column type in the +# format of decimal(,) (e.g., decimal(10, 0)), this function coverts the column type +# as double type. This function converts backend returned types that are not the key +# of PRIMITIVE_TYPES, but should be treated as PRIMITIVE_TYPES. +# @param A type returned from the JVM backend. +# @return A type is the key of the PRIMITIVE_TYPES. +specialtypeshandle <- function(type) { + returntype <- NULL + m <- regexec("^decimal(.+)$", type) + matchedStrings <- regmatches(type, m) + if (length(matchedStrings[[1]]) >= 2) { + returntype <- "double" + } + returntype +} diff --git a/R/pkg/R/utils.R b/R/pkg/R/utils.R index 240b9f669bdd..248c57532b6c 100644 --- a/R/pkg/R/utils.R +++ b/R/pkg/R/utils.R @@ -689,3 +689,26 @@ getSparkContext <- function() { sc <- get(".sparkRjsc", envir = .sparkREnv) sc } + +isMasterLocal <- function(master) { + grepl("^local(\\[([0-9]+|\\*)\\])?$", master, perl = TRUE) +} + +isSparkRShell <- function() { + grepl(".*shell\\.R$", Sys.getenv("R_PROFILE_USER"), perl = TRUE) +} + +# rbind a list of rows with raw (binary) columns +# +# @param inputData a list of rows, with each row a list +# @return data.frame with raw columns as lists +rbindRaws <- function(inputData){ + row1 <- inputData[[1]] + rawcolumns <- ("raw" == sapply(row1, class)) + + listmatrix <- do.call(rbind, inputData) + # A dataframe with all list columns + out <- as.data.frame(listmatrix) + out[!rawcolumns] <- lapply(out[!rawcolumns], unlist) + out +} diff --git a/R/pkg/R/window.R b/R/pkg/R/window.R index 215d0e7b5cfb..0799d841e5dc 100644 --- a/R/pkg/R/window.R +++ b/R/pkg/R/window.R @@ -21,9 +21,9 @@ #' #' Creates a WindowSpec with the partitioning defined. #' -#' @param col A column name or Column by which rows are partitioned to +#' @param col A column name or Column by which rows are partitioned to #' windows. -#' @param ... Optional column names or Columns in addition to col, by +#' @param ... Optional column names or Columns in addition to col, by #' which rows are partitioned to windows. #' #' @rdname windowPartitionBy @@ -32,10 +32,10 @@ #' @export #' @examples #' \dontrun{ -#' ws <- windowPartitionBy("key1", "key2") +#' ws <- orderBy(windowPartitionBy("key1", "key2"), "key3") #' df1 <- select(df, over(lead("value", 1), ws)) #' -#' ws <- windowPartitionBy(df$key1, df$key2) +#' ws <- orderBy(windowPartitionBy(df$key1, df$key2), df$key3) #' df1 <- select(df, over(lead("value", 1), ws)) #' } #' @note windowPartitionBy(character) since 2.0.0 @@ -70,9 +70,9 @@ setMethod("windowPartitionBy", #' #' Creates a WindowSpec with the ordering defined. #' -#' @param col A column name or Column by which rows are ordered within +#' @param col A column name or Column by which rows are ordered within #' windows. -#' @param ... Optional column names or Columns in addition to col, by +#' @param ... Optional column names or Columns in addition to col, by #' which rows are ordered within windows. #' #' @rdname windowOrderBy diff --git a/R/pkg/inst/tests/testthat/test_binaryFile.R b/R/pkg/inst/tests/testthat/test_binaryFile.R index b69f017de81d..f7a0510711da 100644 --- a/R/pkg/inst/tests/testthat/test_binaryFile.R +++ b/R/pkg/inst/tests/testthat/test_binaryFile.R @@ -31,7 +31,7 @@ test_that("saveAsObjectFile()/objectFile() following textFile() works", { rdd <- textFile(sc, fileName1, 1) saveAsObjectFile(rdd, fileName2) rdd <- objectFile(sc, fileName2) - expect_equal(collect(rdd), as.list(mockFile)) + expect_equal(collectRDD(rdd), as.list(mockFile)) unlink(fileName1) unlink(fileName2, recursive = TRUE) @@ -44,7 +44,7 @@ test_that("saveAsObjectFile()/objectFile() works on a parallelized list", { rdd <- parallelize(sc, l, 1) saveAsObjectFile(rdd, fileName) rdd <- objectFile(sc, fileName) - expect_equal(collect(rdd), l) + expect_equal(collectRDD(rdd), l) unlink(fileName, recursive = TRUE) }) @@ -64,7 +64,7 @@ test_that("saveAsObjectFile()/objectFile() following RDD transformations works", saveAsObjectFile(counts, fileName2) counts <- objectFile(sc, fileName2) - output <- collect(counts) + output <- collectRDD(counts) expected <- list(list("awesome.", 1), list("Spark", 2), list("pretty.", 1), list("is", 2)) expect_equal(sortKeyValueList(output), sortKeyValueList(expected)) @@ -83,7 +83,7 @@ test_that("saveAsObjectFile()/objectFile() works with multiple paths", { saveAsObjectFile(rdd2, fileName2) rdd <- objectFile(sc, c(fileName1, fileName2)) - expect_equal(count(rdd), 2) + expect_equal(countRDD(rdd), 2) unlink(fileName1, recursive = TRUE) unlink(fileName2, recursive = TRUE) diff --git a/R/pkg/inst/tests/testthat/test_binary_function.R b/R/pkg/inst/tests/testthat/test_binary_function.R index 6f51d2068727..b780b9458545 100644 --- a/R/pkg/inst/tests/testthat/test_binary_function.R +++ b/R/pkg/inst/tests/testthat/test_binary_function.R @@ -29,7 +29,7 @@ rdd <- parallelize(sc, nums, 2L) mockFile <- c("Spark is pretty.", "Spark is awesome.") test_that("union on two RDDs", { - actual <- collect(unionRDD(rdd, rdd)) + actual <- collectRDD(unionRDD(rdd, rdd)) expect_equal(actual, as.list(rep(nums, 2))) fileName <- tempfile(pattern = "spark-test", fileext = ".tmp") @@ -37,13 +37,13 @@ test_that("union on two RDDs", { text.rdd <- textFile(sc, fileName) union.rdd <- unionRDD(rdd, text.rdd) - actual <- collect(union.rdd) + actual <- collectRDD(union.rdd) expect_equal(actual, c(as.list(nums), mockFile)) expect_equal(getSerializedMode(union.rdd), "byte") rdd <- map(text.rdd, function(x) {x}) union.rdd <- unionRDD(rdd, text.rdd) - actual <- collect(union.rdd) + actual <- collectRDD(union.rdd) expect_equal(actual, as.list(c(mockFile, mockFile))) expect_equal(getSerializedMode(union.rdd), "byte") @@ -54,14 +54,14 @@ test_that("cogroup on two RDDs", { rdd1 <- parallelize(sc, list(list(1, 1), list(2, 4))) rdd2 <- parallelize(sc, list(list(1, 2), list(1, 3))) cogroup.rdd <- cogroup(rdd1, rdd2, numPartitions = 2L) - actual <- collect(cogroup.rdd) + actual <- collectRDD(cogroup.rdd) expect_equal(actual, list(list(1, list(list(1), list(2, 3))), list(2, list(list(4), list())))) rdd1 <- parallelize(sc, list(list("a", 1), list("a", 4))) rdd2 <- parallelize(sc, list(list("b", 2), list("a", 3))) cogroup.rdd <- cogroup(rdd1, rdd2, numPartitions = 2L) - actual <- collect(cogroup.rdd) + actual <- collectRDD(cogroup.rdd) expected <- list(list("b", list(list(), list(2))), list("a", list(list(1, 4), list(3)))) expect_equal(sortKeyValueList(actual), @@ -72,7 +72,7 @@ test_that("zipPartitions() on RDDs", { rdd1 <- parallelize(sc, 1:2, 2L) # 1, 2 rdd2 <- parallelize(sc, 1:4, 2L) # 1:2, 3:4 rdd3 <- parallelize(sc, 1:6, 2L) # 1:3, 4:6 - actual <- collect(zipPartitions(rdd1, rdd2, rdd3, + actual <- collectRDD(zipPartitions(rdd1, rdd2, rdd3, func = function(x, y, z) { list(list(x, y, z))} )) expect_equal(actual, list(list(1, c(1, 2), c(1, 2, 3)), list(2, c(3, 4), c(4, 5, 6)))) @@ -82,19 +82,19 @@ test_that("zipPartitions() on RDDs", { writeLines(mockFile, fileName) rdd <- textFile(sc, fileName, 1) - actual <- collect(zipPartitions(rdd, rdd, + actual <- collectRDD(zipPartitions(rdd, rdd, func = function(x, y) { list(paste(x, y, sep = "\n")) })) expected <- list(paste(mockFile, mockFile, sep = "\n")) expect_equal(actual, expected) rdd1 <- parallelize(sc, 0:1, 1) - actual <- collect(zipPartitions(rdd1, rdd, + actual <- collectRDD(zipPartitions(rdd1, rdd, func = function(x, y) { list(x + nchar(y)) })) expected <- list(0:1 + nchar(mockFile)) expect_equal(actual, expected) rdd <- map(rdd, function(x) { x }) - actual <- collect(zipPartitions(rdd, rdd1, + actual <- collectRDD(zipPartitions(rdd, rdd1, func = function(x, y) { list(y + nchar(x)) })) expect_equal(actual, expected) diff --git a/R/pkg/inst/tests/testthat/test_broadcast.R b/R/pkg/inst/tests/testthat/test_broadcast.R index cf1d43277105..064249a57aed 100644 --- a/R/pkg/inst/tests/testthat/test_broadcast.R +++ b/R/pkg/inst/tests/testthat/test_broadcast.R @@ -32,7 +32,7 @@ test_that("using broadcast variable", { useBroadcast <- function(x) { sum(SparkR:::value(randomMatBr) * x) } - actual <- collect(lapply(rrdd, useBroadcast)) + actual <- collectRDD(lapply(rrdd, useBroadcast)) expected <- list(sum(randomMat) * 1, sum(randomMat) * 2) expect_equal(actual, expected) }) @@ -43,7 +43,7 @@ test_that("without using broadcast variable", { useBroadcast <- function(x) { sum(randomMat * x) } - actual <- collect(lapply(rrdd, useBroadcast)) + actual <- collectRDD(lapply(rrdd, useBroadcast)) expected <- list(sum(randomMat) * 1, sum(randomMat) * 2) expect_equal(actual, expected) }) diff --git a/R/pkg/inst/tests/testthat/test_context.R b/R/pkg/inst/tests/testthat/test_context.R index 2a1bd61b1111..66640c4b0845 100644 --- a/R/pkg/inst/tests/testthat/test_context.R +++ b/R/pkg/inst/tests/testthat/test_context.R @@ -58,7 +58,7 @@ test_that("repeatedly starting and stopping SparkR", { for (i in 1:4) { sc <- suppressWarnings(sparkR.init()) rdd <- parallelize(sc, 1:20, 2L) - expect_equal(count(rdd), 20) + expect_equal(countRDD(rdd), 20) suppressWarnings(sparkR.stop()) } }) @@ -94,8 +94,9 @@ test_that("rdd GC across sparkR.stop", { rm(rdd2) gc() - count(rdd3) - count(rdd4) + countRDD(rdd3) + countRDD(rdd4) + sparkR.session.stop() }) test_that("job group functions can be called", { diff --git a/R/pkg/inst/tests/testthat/test_includePackage.R b/R/pkg/inst/tests/testthat/test_includePackage.R index d6a3766539c0..025eb9b9fc9d 100644 --- a/R/pkg/inst/tests/testthat/test_includePackage.R +++ b/R/pkg/inst/tests/testthat/test_includePackage.R @@ -37,7 +37,7 @@ test_that("include inside function", { } data <- lapplyPartition(rdd, generateData) - actual <- collect(data) + actual <- collectRDD(data) } }) @@ -53,6 +53,6 @@ test_that("use include package", { includePackage(sc, plyr) data <- lapplyPartition(rdd, generateData) - actual <- collect(data) + actual <- collectRDD(data) } }) diff --git a/R/pkg/inst/tests/testthat/test_jvm_api.R b/R/pkg/inst/tests/testthat/test_jvm_api.R new file mode 100644 index 000000000000..7348c893d0af --- /dev/null +++ b/R/pkg/inst/tests/testthat/test_jvm_api.R @@ -0,0 +1,36 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +context("JVM API") + +sparkSession <- sparkR.session(enableHiveSupport = FALSE) + +test_that("Create and call methods on object", { + jarr <- sparkR.newJObject("java.util.ArrayList") + # Add an element to the array + sparkR.callJMethod(jarr, "add", 1L) + # Check if get returns the same element + expect_equal(sparkR.callJMethod(jarr, "get", 0L), 1L) +}) + +test_that("Call static methods", { + # Convert a boolean to a string + strTrue <- sparkR.callJStatic("java.lang.String", "valueOf", TRUE) + expect_equal(strTrue, "true") +}) + +sparkR.session.stop() diff --git a/R/pkg/inst/tests/testthat/test_parallelize_collect.R b/R/pkg/inst/tests/testthat/test_parallelize_collect.R index f79a8a70aafb..1b230554f7a0 100644 --- a/R/pkg/inst/tests/testthat/test_parallelize_collect.R +++ b/R/pkg/inst/tests/testthat/test_parallelize_collect.R @@ -67,22 +67,22 @@ test_that("parallelize() on simple vectors and lists returns an RDD", { test_that("collect(), following a parallelize(), gives back the original collections", { numVectorRDD <- parallelize(jsc, numVector, 10) - expect_equal(collect(numVectorRDD), as.list(numVector)) + expect_equal(collectRDD(numVectorRDD), as.list(numVector)) numListRDD <- parallelize(jsc, numList, 1) numListRDD2 <- parallelize(jsc, numList, 4) - expect_equal(collect(numListRDD), as.list(numList)) - expect_equal(collect(numListRDD2), as.list(numList)) + expect_equal(collectRDD(numListRDD), as.list(numList)) + expect_equal(collectRDD(numListRDD2), as.list(numList)) strVectorRDD <- parallelize(jsc, strVector, 2) strVectorRDD2 <- parallelize(jsc, strVector, 3) - expect_equal(collect(strVectorRDD), as.list(strVector)) - expect_equal(collect(strVectorRDD2), as.list(strVector)) + expect_equal(collectRDD(strVectorRDD), as.list(strVector)) + expect_equal(collectRDD(strVectorRDD2), as.list(strVector)) strListRDD <- parallelize(jsc, strList, 4) strListRDD2 <- parallelize(jsc, strList, 1) - expect_equal(collect(strListRDD), as.list(strList)) - expect_equal(collect(strListRDD2), as.list(strList)) + expect_equal(collectRDD(strListRDD), as.list(strList)) + expect_equal(collectRDD(strListRDD2), as.list(strList)) }) test_that("regression: collect() following a parallelize() does not drop elements", { @@ -90,7 +90,7 @@ test_that("regression: collect() following a parallelize() does not drop element collLen <- 10 numPart <- 6 expected <- runif(collLen) - actual <- collect(parallelize(jsc, expected, numPart)) + actual <- collectRDD(parallelize(jsc, expected, numPart)) expect_equal(actual, as.list(expected)) }) @@ -99,12 +99,12 @@ test_that("parallelize() and collect() work for lists of pairs (pairwise data)", numPairsRDDD1 <- parallelize(jsc, numPairs, 1) numPairsRDDD2 <- parallelize(jsc, numPairs, 2) numPairsRDDD3 <- parallelize(jsc, numPairs, 3) - expect_equal(collect(numPairsRDDD1), numPairs) - expect_equal(collect(numPairsRDDD2), numPairs) - expect_equal(collect(numPairsRDDD3), numPairs) + expect_equal(collectRDD(numPairsRDDD1), numPairs) + expect_equal(collectRDD(numPairsRDDD2), numPairs) + expect_equal(collectRDD(numPairsRDDD3), numPairs) # can also leave out the parameter name, if the params are supplied in order strPairsRDDD1 <- parallelize(jsc, strPairs, 1) strPairsRDDD2 <- parallelize(jsc, strPairs, 2) - expect_equal(collect(strPairsRDDD1), strPairs) - expect_equal(collect(strPairsRDDD2), strPairs) + expect_equal(collectRDD(strPairsRDDD1), strPairs) + expect_equal(collectRDD(strPairsRDDD2), strPairs) }) diff --git a/R/pkg/inst/tests/testthat/test_rdd.R b/R/pkg/inst/tests/testthat/test_rdd.R index 429311d2924f..d38a763bab8c 100644 --- a/R/pkg/inst/tests/testthat/test_rdd.R +++ b/R/pkg/inst/tests/testthat/test_rdd.R @@ -34,14 +34,14 @@ test_that("get number of partitions in RDD", { }) test_that("first on RDD", { - expect_equal(first(rdd), 1) + expect_equal(firstRDD(rdd), 1) newrdd <- lapply(rdd, function(x) x + 1) - expect_equal(first(newrdd), 2) + expect_equal(firstRDD(newrdd), 2) }) test_that("count and length on RDD", { - expect_equal(count(rdd), 10) - expect_equal(length(rdd), 10) + expect_equal(countRDD(rdd), 10) + expect_equal(lengthRDD(rdd), 10) }) test_that("count by values and keys", { @@ -57,40 +57,40 @@ test_that("count by values and keys", { test_that("lapply on RDD", { multiples <- lapply(rdd, function(x) { 2 * x }) - actual <- collect(multiples) + actual <- collectRDD(multiples) expect_equal(actual, as.list(nums * 2)) }) test_that("lapplyPartition on RDD", { sums <- lapplyPartition(rdd, function(part) { sum(unlist(part)) }) - actual <- collect(sums) + actual <- collectRDD(sums) expect_equal(actual, list(15, 40)) }) test_that("mapPartitions on RDD", { sums <- mapPartitions(rdd, function(part) { sum(unlist(part)) }) - actual <- collect(sums) + actual <- collectRDD(sums) expect_equal(actual, list(15, 40)) }) test_that("flatMap() on RDDs", { flat <- flatMap(intRdd, function(x) { list(x, x) }) - actual <- collect(flat) + actual <- collectRDD(flat) expect_equal(actual, rep(intPairs, each = 2)) }) test_that("filterRDD on RDD", { filtered.rdd <- filterRDD(rdd, function(x) { x %% 2 == 0 }) - actual <- collect(filtered.rdd) + actual <- collectRDD(filtered.rdd) expect_equal(actual, list(2, 4, 6, 8, 10)) filtered.rdd <- Filter(function(x) { x[[2]] < 0 }, intRdd) - actual <- collect(filtered.rdd) + actual <- collectRDD(filtered.rdd) expect_equal(actual, list(list(1L, -1))) # Filter out all elements. filtered.rdd <- filterRDD(rdd, function(x) { x > 10 }) - actual <- collect(filtered.rdd) + actual <- collectRDD(filtered.rdd) expect_equal(actual, list()) }) @@ -110,7 +110,7 @@ test_that("several transformations on RDD (a benchmark on PipelinedRDD)", { part <- as.list(unlist(part) * partIndex + i) }) rdd2 <- lapply(rdd2, function(x) x + x) - actual <- collect(rdd2) + actual <- collectRDD(rdd2) expected <- list(24, 24, 24, 24, 24, 168, 170, 172, 174, 176) expect_equal(actual, expected) @@ -126,20 +126,20 @@ test_that("PipelinedRDD support actions: cache(), persist(), unpersist(), checkp part <- as.list(unlist(part) * partIndex) }) - cache(rdd2) + cacheRDD(rdd2) expect_true(rdd2@env$isCached) rdd2 <- lapply(rdd2, function(x) x) expect_false(rdd2@env$isCached) - unpersist(rdd2) + unpersistRDD(rdd2) expect_false(rdd2@env$isCached) - persist(rdd2, "MEMORY_AND_DISK") + persistRDD(rdd2, "MEMORY_AND_DISK") expect_true(rdd2@env$isCached) rdd2 <- lapply(rdd2, function(x) x) expect_false(rdd2@env$isCached) - unpersist(rdd2) + unpersistRDD(rdd2) expect_false(rdd2@env$isCached) tempDir <- tempfile(pattern = "checkpoint") @@ -152,7 +152,7 @@ test_that("PipelinedRDD support actions: cache(), persist(), unpersist(), checkp expect_false(rdd2@env$isCheckpointed) # make sure the data is collectable - collect(rdd2) + collectRDD(rdd2) unlink(tempDir) }) @@ -169,21 +169,21 @@ test_that("reduce on RDD", { test_that("lapply with dependency", { fa <- 5 multiples <- lapply(rdd, function(x) { fa * x }) - actual <- collect(multiples) + actual <- collectRDD(multiples) expect_equal(actual, as.list(nums * 5)) }) test_that("lapplyPartitionsWithIndex on RDDs", { func <- function(partIndex, part) { list(partIndex, Reduce("+", part)) } - actual <- collect(lapplyPartitionsWithIndex(rdd, func), flatten = FALSE) + actual <- collectRDD(lapplyPartitionsWithIndex(rdd, func), flatten = FALSE) expect_equal(actual, list(list(0, 15), list(1, 40))) pairsRDD <- parallelize(sc, list(list(1, 2), list(3, 4), list(4, 8)), 1L) partitionByParity <- function(key) { if (key %% 2 == 1) 0 else 1 } mkTup <- function(partIndex, part) { list(partIndex, part) } - actual <- collect(lapplyPartitionsWithIndex( - partitionBy(pairsRDD, 2L, partitionByParity), + actual <- collectRDD(lapplyPartitionsWithIndex( + partitionByRDD(pairsRDD, 2L, partitionByParity), mkTup), FALSE) expect_equal(actual, list(list(0, list(list(1, 2), list(3, 4))), @@ -191,7 +191,7 @@ test_that("lapplyPartitionsWithIndex on RDDs", { }) test_that("sampleRDD() on RDDs", { - expect_equal(unlist(collect(sampleRDD(rdd, FALSE, 1.0, 2014L))), nums) + expect_equal(unlist(collectRDD(sampleRDD(rdd, FALSE, 1.0, 2014L))), nums) }) test_that("takeSample() on RDDs", { @@ -238,7 +238,7 @@ test_that("takeSample() on RDDs", { test_that("mapValues() on pairwise RDDs", { multiples <- mapValues(intRdd, function(x) { x * 2 }) - actual <- collect(multiples) + actual <- collectRDD(multiples) expected <- lapply(intPairs, function(x) { list(x[[1]], x[[2]] * 2) }) @@ -247,11 +247,11 @@ test_that("mapValues() on pairwise RDDs", { test_that("flatMapValues() on pairwise RDDs", { l <- parallelize(sc, list(list(1, c(1, 2)), list(2, c(3, 4)))) - actual <- collect(flatMapValues(l, function(x) { x })) + actual <- collectRDD(flatMapValues(l, function(x) { x })) expect_equal(actual, list(list(1, 1), list(1, 2), list(2, 3), list(2, 4))) # Generate x to x+1 for every value - actual <- collect(flatMapValues(intRdd, function(x) { x: (x + 1) })) + actual <- collectRDD(flatMapValues(intRdd, function(x) { x: (x + 1) })) expect_equal(actual, list(list(1L, -1), list(1L, 0), list(2L, 100), list(2L, 101), list(2L, 1), list(2L, 2), list(1L, 200), list(1L, 201))) @@ -273,8 +273,8 @@ test_that("reduceByKeyLocally() on PairwiseRDDs", { test_that("distinct() on RDDs", { nums.rep2 <- rep(1:10, 2) rdd.rep2 <- parallelize(sc, nums.rep2, 2L) - uniques <- distinct(rdd.rep2) - actual <- sort(unlist(collect(uniques))) + uniques <- distinctRDD(rdd.rep2) + actual <- sort(unlist(collectRDD(uniques))) expect_equal(actual, nums) }) @@ -296,7 +296,7 @@ test_that("sumRDD() on RDDs", { test_that("keyBy on RDDs", { func <- function(x) { x * x } keys <- keyBy(rdd, func) - actual <- collect(keys) + actual <- collectRDD(keys) expect_equal(actual, lapply(nums, function(x) { list(func(x), x) })) }) @@ -304,12 +304,12 @@ test_that("repartition/coalesce on RDDs", { rdd <- parallelize(sc, 1:20, 4L) # each partition contains 5 elements # repartition - r1 <- repartition(rdd, 2) + r1 <- repartitionRDD(rdd, 2) expect_equal(getNumPartitions(r1), 2L) count <- length(collectPartition(r1, 0L)) expect_true(count >= 8 && count <= 12) - r2 <- repartition(rdd, 6) + r2 <- repartitionRDD(rdd, 6) expect_equal(getNumPartitions(r2), 6L) count <- length(collectPartition(r2, 0L)) expect_true(count >= 0 && count <= 4) @@ -323,12 +323,12 @@ test_that("repartition/coalesce on RDDs", { test_that("sortBy() on RDDs", { sortedRdd <- sortBy(rdd, function(x) { x * x }, ascending = FALSE) - actual <- collect(sortedRdd) + actual <- collectRDD(sortedRdd) expect_equal(actual, as.list(sort(nums, decreasing = TRUE))) rdd2 <- parallelize(sc, sort(nums, decreasing = TRUE), 2L) sortedRdd2 <- sortBy(rdd2, function(x) { x * x }) - actual <- collect(sortedRdd2) + actual <- collectRDD(sortedRdd2) expect_equal(actual, as.list(nums)) }) @@ -380,13 +380,13 @@ test_that("aggregateRDD() on RDDs", { test_that("zipWithUniqueId() on RDDs", { rdd <- parallelize(sc, list("a", "b", "c", "d", "e"), 3L) - actual <- collect(zipWithUniqueId(rdd)) + actual <- collectRDD(zipWithUniqueId(rdd)) expected <- list(list("a", 0), list("b", 3), list("c", 1), list("d", 4), list("e", 2)) expect_equal(actual, expected) rdd <- parallelize(sc, list("a", "b", "c", "d", "e"), 1L) - actual <- collect(zipWithUniqueId(rdd)) + actual <- collectRDD(zipWithUniqueId(rdd)) expected <- list(list("a", 0), list("b", 1), list("c", 2), list("d", 3), list("e", 4)) expect_equal(actual, expected) @@ -394,13 +394,13 @@ test_that("zipWithUniqueId() on RDDs", { test_that("zipWithIndex() on RDDs", { rdd <- parallelize(sc, list("a", "b", "c", "d", "e"), 3L) - actual <- collect(zipWithIndex(rdd)) + actual <- collectRDD(zipWithIndex(rdd)) expected <- list(list("a", 0), list("b", 1), list("c", 2), list("d", 3), list("e", 4)) expect_equal(actual, expected) rdd <- parallelize(sc, list("a", "b", "c", "d", "e"), 1L) - actual <- collect(zipWithIndex(rdd)) + actual <- collectRDD(zipWithIndex(rdd)) expected <- list(list("a", 0), list("b", 1), list("c", 2), list("d", 3), list("e", 4)) expect_equal(actual, expected) @@ -408,35 +408,35 @@ test_that("zipWithIndex() on RDDs", { test_that("glom() on RDD", { rdd <- parallelize(sc, as.list(1:4), 2L) - actual <- collect(glom(rdd)) + actual <- collectRDD(glom(rdd)) expect_equal(actual, list(list(1, 2), list(3, 4))) }) test_that("keys() on RDDs", { keys <- keys(intRdd) - actual <- collect(keys) + actual <- collectRDD(keys) expect_equal(actual, lapply(intPairs, function(x) { x[[1]] })) }) test_that("values() on RDDs", { values <- values(intRdd) - actual <- collect(values) + actual <- collectRDD(values) expect_equal(actual, lapply(intPairs, function(x) { x[[2]] })) }) test_that("pipeRDD() on RDDs", { - actual <- collect(pipeRDD(rdd, "more")) + actual <- collectRDD(pipeRDD(rdd, "more")) expected <- as.list(as.character(1:10)) expect_equal(actual, expected) trailed.rdd <- parallelize(sc, c("1", "", "2\n", "3\n\r\n")) - actual <- collect(pipeRDD(trailed.rdd, "sort")) + actual <- collectRDD(pipeRDD(trailed.rdd, "sort")) expected <- list("", "1", "2", "3") expect_equal(actual, expected) rev.nums <- 9:0 rev.rdd <- parallelize(sc, rev.nums, 2L) - actual <- collect(pipeRDD(rev.rdd, "sort")) + actual <- collectRDD(pipeRDD(rev.rdd, "sort")) expected <- as.list(as.character(c(5:9, 0:4))) expect_equal(actual, expected) }) @@ -444,7 +444,7 @@ test_that("pipeRDD() on RDDs", { test_that("zipRDD() on RDDs", { rdd1 <- parallelize(sc, 0:4, 2) rdd2 <- parallelize(sc, 1000:1004, 2) - actual <- collect(zipRDD(rdd1, rdd2)) + actual <- collectRDD(zipRDD(rdd1, rdd2)) expect_equal(actual, list(list(0, 1000), list(1, 1001), list(2, 1002), list(3, 1003), list(4, 1004))) @@ -453,17 +453,17 @@ test_that("zipRDD() on RDDs", { writeLines(mockFile, fileName) rdd <- textFile(sc, fileName, 1) - actual <- collect(zipRDD(rdd, rdd)) + actual <- collectRDD(zipRDD(rdd, rdd)) expected <- lapply(mockFile, function(x) { list(x, x) }) expect_equal(actual, expected) rdd1 <- parallelize(sc, 0:1, 1) - actual <- collect(zipRDD(rdd1, rdd)) + actual <- collectRDD(zipRDD(rdd1, rdd)) expected <- lapply(0:1, function(x) { list(x, mockFile[x + 1]) }) expect_equal(actual, expected) rdd1 <- map(rdd, function(x) { x }) - actual <- collect(zipRDD(rdd, rdd1)) + actual <- collectRDD(zipRDD(rdd, rdd1)) expected <- lapply(mockFile, function(x) { list(x, x) }) expect_equal(actual, expected) @@ -472,7 +472,7 @@ test_that("zipRDD() on RDDs", { test_that("cartesian() on RDDs", { rdd <- parallelize(sc, 1:3) - actual <- collect(cartesian(rdd, rdd)) + actual <- collectRDD(cartesian(rdd, rdd)) expect_equal(sortKeyValueList(actual), list( list(1, 1), list(1, 2), list(1, 3), @@ -481,7 +481,7 @@ test_that("cartesian() on RDDs", { # test case where one RDD is empty emptyRdd <- parallelize(sc, list()) - actual <- collect(cartesian(rdd, emptyRdd)) + actual <- collectRDD(cartesian(rdd, emptyRdd)) expect_equal(actual, list()) mockFile <- c("Spark is pretty.", "Spark is awesome.") @@ -489,7 +489,7 @@ test_that("cartesian() on RDDs", { writeLines(mockFile, fileName) rdd <- textFile(sc, fileName) - actual <- collect(cartesian(rdd, rdd)) + actual <- collectRDD(cartesian(rdd, rdd)) expected <- list( list("Spark is awesome.", "Spark is pretty."), list("Spark is awesome.", "Spark is awesome."), @@ -498,7 +498,7 @@ test_that("cartesian() on RDDs", { expect_equal(sortKeyValueList(actual), expected) rdd1 <- parallelize(sc, 0:1) - actual <- collect(cartesian(rdd1, rdd)) + actual <- collectRDD(cartesian(rdd1, rdd)) expect_equal(sortKeyValueList(actual), list( list(0, "Spark is pretty."), @@ -507,7 +507,7 @@ test_that("cartesian() on RDDs", { list(1, "Spark is awesome."))) rdd1 <- map(rdd, function(x) { x }) - actual <- collect(cartesian(rdd, rdd1)) + actual <- collectRDD(cartesian(rdd, rdd1)) expect_equal(sortKeyValueList(actual), expected) unlink(fileName) @@ -518,24 +518,24 @@ test_that("subtract() on RDDs", { rdd1 <- parallelize(sc, l) # subtract by itself - actual <- collect(subtract(rdd1, rdd1)) + actual <- collectRDD(subtract(rdd1, rdd1)) expect_equal(actual, list()) # subtract by an empty RDD rdd2 <- parallelize(sc, list()) - actual <- collect(subtract(rdd1, rdd2)) + actual <- collectRDD(subtract(rdd1, rdd2)) expect_equal(as.list(sort(as.vector(actual, mode = "integer"))), l) rdd2 <- parallelize(sc, list(2, 4)) - actual <- collect(subtract(rdd1, rdd2)) + actual <- collectRDD(subtract(rdd1, rdd2)) expect_equal(as.list(sort(as.vector(actual, mode = "integer"))), list(1, 1, 3)) l <- list("a", "a", "b", "b", "c", "d") rdd1 <- parallelize(sc, l) rdd2 <- parallelize(sc, list("b", "d")) - actual <- collect(subtract(rdd1, rdd2)) + actual <- collectRDD(subtract(rdd1, rdd2)) expect_equal(as.list(sort(as.vector(actual, mode = "character"))), list("a", "a", "c")) }) @@ -546,17 +546,17 @@ test_that("subtractByKey() on pairwise RDDs", { rdd1 <- parallelize(sc, l) # subtractByKey by itself - actual <- collect(subtractByKey(rdd1, rdd1)) + actual <- collectRDD(subtractByKey(rdd1, rdd1)) expect_equal(actual, list()) # subtractByKey by an empty RDD rdd2 <- parallelize(sc, list()) - actual <- collect(subtractByKey(rdd1, rdd2)) + actual <- collectRDD(subtractByKey(rdd1, rdd2)) expect_equal(sortKeyValueList(actual), sortKeyValueList(l)) rdd2 <- parallelize(sc, list(list("a", 3), list("c", 1))) - actual <- collect(subtractByKey(rdd1, rdd2)) + actual <- collectRDD(subtractByKey(rdd1, rdd2)) expect_equal(actual, list(list("b", 4), list("b", 5))) @@ -564,76 +564,76 @@ test_that("subtractByKey() on pairwise RDDs", { list(2, 5), list(1, 2)) rdd1 <- parallelize(sc, l) rdd2 <- parallelize(sc, list(list(1, 3), list(3, 1))) - actual <- collect(subtractByKey(rdd1, rdd2)) + actual <- collectRDD(subtractByKey(rdd1, rdd2)) expect_equal(actual, list(list(2, 4), list(2, 5))) }) test_that("intersection() on RDDs", { # intersection with self - actual <- collect(intersection(rdd, rdd)) + actual <- collectRDD(intersection(rdd, rdd)) expect_equal(sort(as.integer(actual)), nums) # intersection with an empty RDD emptyRdd <- parallelize(sc, list()) - actual <- collect(intersection(rdd, emptyRdd)) + actual <- collectRDD(intersection(rdd, emptyRdd)) expect_equal(actual, list()) rdd1 <- parallelize(sc, list(1, 10, 2, 3, 4, 5)) rdd2 <- parallelize(sc, list(1, 6, 2, 3, 7, 8)) - actual <- collect(intersection(rdd1, rdd2)) + actual <- collectRDD(intersection(rdd1, rdd2)) expect_equal(sort(as.integer(actual)), 1:3) }) test_that("join() on pairwise RDDs", { rdd1 <- parallelize(sc, list(list(1, 1), list(2, 4))) rdd2 <- parallelize(sc, list(list(1, 2), list(1, 3))) - actual <- collect(join(rdd1, rdd2, 2L)) + actual <- collectRDD(joinRDD(rdd1, rdd2, 2L)) expect_equal(sortKeyValueList(actual), sortKeyValueList(list(list(1, list(1, 2)), list(1, list(1, 3))))) rdd1 <- parallelize(sc, list(list("a", 1), list("b", 4))) rdd2 <- parallelize(sc, list(list("a", 2), list("a", 3))) - actual <- collect(join(rdd1, rdd2, 2L)) + actual <- collectRDD(joinRDD(rdd1, rdd2, 2L)) expect_equal(sortKeyValueList(actual), sortKeyValueList(list(list("a", list(1, 2)), list("a", list(1, 3))))) rdd1 <- parallelize(sc, list(list(1, 1), list(2, 2))) rdd2 <- parallelize(sc, list(list(3, 3), list(4, 4))) - actual <- collect(join(rdd1, rdd2, 2L)) + actual <- collectRDD(joinRDD(rdd1, rdd2, 2L)) expect_equal(actual, list()) rdd1 <- parallelize(sc, list(list("a", 1), list("b", 2))) rdd2 <- parallelize(sc, list(list("c", 3), list("d", 4))) - actual <- collect(join(rdd1, rdd2, 2L)) + actual <- collectRDD(joinRDD(rdd1, rdd2, 2L)) expect_equal(actual, list()) }) test_that("leftOuterJoin() on pairwise RDDs", { rdd1 <- parallelize(sc, list(list(1, 1), list(2, 4))) rdd2 <- parallelize(sc, list(list(1, 2), list(1, 3))) - actual <- collect(leftOuterJoin(rdd1, rdd2, 2L)) + actual <- collectRDD(leftOuterJoin(rdd1, rdd2, 2L)) expected <- list(list(1, list(1, 2)), list(1, list(1, 3)), list(2, list(4, NULL))) expect_equal(sortKeyValueList(actual), sortKeyValueList(expected)) rdd1 <- parallelize(sc, list(list("a", 1), list("b", 4))) rdd2 <- parallelize(sc, list(list("a", 2), list("a", 3))) - actual <- collect(leftOuterJoin(rdd1, rdd2, 2L)) + actual <- collectRDD(leftOuterJoin(rdd1, rdd2, 2L)) expected <- list(list("b", list(4, NULL)), list("a", list(1, 2)), list("a", list(1, 3))) expect_equal(sortKeyValueList(actual), sortKeyValueList(expected)) rdd1 <- parallelize(sc, list(list(1, 1), list(2, 2))) rdd2 <- parallelize(sc, list(list(3, 3), list(4, 4))) - actual <- collect(leftOuterJoin(rdd1, rdd2, 2L)) + actual <- collectRDD(leftOuterJoin(rdd1, rdd2, 2L)) expected <- list(list(1, list(1, NULL)), list(2, list(2, NULL))) expect_equal(sortKeyValueList(actual), sortKeyValueList(expected)) rdd1 <- parallelize(sc, list(list("a", 1), list("b", 2))) rdd2 <- parallelize(sc, list(list("c", 3), list("d", 4))) - actual <- collect(leftOuterJoin(rdd1, rdd2, 2L)) + actual <- collectRDD(leftOuterJoin(rdd1, rdd2, 2L)) expected <- list(list("b", list(2, NULL)), list("a", list(1, NULL))) expect_equal(sortKeyValueList(actual), sortKeyValueList(expected)) @@ -642,26 +642,26 @@ test_that("leftOuterJoin() on pairwise RDDs", { test_that("rightOuterJoin() on pairwise RDDs", { rdd1 <- parallelize(sc, list(list(1, 2), list(1, 3))) rdd2 <- parallelize(sc, list(list(1, 1), list(2, 4))) - actual <- collect(rightOuterJoin(rdd1, rdd2, 2L)) + actual <- collectRDD(rightOuterJoin(rdd1, rdd2, 2L)) expected <- list(list(1, list(2, 1)), list(1, list(3, 1)), list(2, list(NULL, 4))) expect_equal(sortKeyValueList(actual), sortKeyValueList(expected)) rdd1 <- parallelize(sc, list(list("a", 2), list("a", 3))) rdd2 <- parallelize(sc, list(list("a", 1), list("b", 4))) - actual <- collect(rightOuterJoin(rdd1, rdd2, 2L)) + actual <- collectRDD(rightOuterJoin(rdd1, rdd2, 2L)) expected <- list(list("b", list(NULL, 4)), list("a", list(2, 1)), list("a", list(3, 1))) expect_equal(sortKeyValueList(actual), sortKeyValueList(expected)) rdd1 <- parallelize(sc, list(list(1, 1), list(2, 2))) rdd2 <- parallelize(sc, list(list(3, 3), list(4, 4))) - actual <- collect(rightOuterJoin(rdd1, rdd2, 2L)) + actual <- collectRDD(rightOuterJoin(rdd1, rdd2, 2L)) expect_equal(sortKeyValueList(actual), sortKeyValueList(list(list(3, list(NULL, 3)), list(4, list(NULL, 4))))) rdd1 <- parallelize(sc, list(list("a", 1), list("b", 2))) rdd2 <- parallelize(sc, list(list("c", 3), list("d", 4))) - actual <- collect(rightOuterJoin(rdd1, rdd2, 2L)) + actual <- collectRDD(rightOuterJoin(rdd1, rdd2, 2L)) expect_equal(sortKeyValueList(actual), sortKeyValueList(list(list("d", list(NULL, 4)), list("c", list(NULL, 3))))) }) @@ -669,14 +669,14 @@ test_that("rightOuterJoin() on pairwise RDDs", { test_that("fullOuterJoin() on pairwise RDDs", { rdd1 <- parallelize(sc, list(list(1, 2), list(1, 3), list(3, 3))) rdd2 <- parallelize(sc, list(list(1, 1), list(2, 4))) - actual <- collect(fullOuterJoin(rdd1, rdd2, 2L)) + actual <- collectRDD(fullOuterJoin(rdd1, rdd2, 2L)) expected <- list(list(1, list(2, 1)), list(1, list(3, 1)), list(2, list(NULL, 4)), list(3, list(3, NULL))) expect_equal(sortKeyValueList(actual), sortKeyValueList(expected)) rdd1 <- parallelize(sc, list(list("a", 2), list("a", 3), list("c", 1))) rdd2 <- parallelize(sc, list(list("a", 1), list("b", 4))) - actual <- collect(fullOuterJoin(rdd1, rdd2, 2L)) + actual <- collectRDD(fullOuterJoin(rdd1, rdd2, 2L)) expected <- list(list("b", list(NULL, 4)), list("a", list(2, 1)), list("a", list(3, 1)), list("c", list(1, NULL))) expect_equal(sortKeyValueList(actual), @@ -684,14 +684,14 @@ test_that("fullOuterJoin() on pairwise RDDs", { rdd1 <- parallelize(sc, list(list(1, 1), list(2, 2))) rdd2 <- parallelize(sc, list(list(3, 3), list(4, 4))) - actual <- collect(fullOuterJoin(rdd1, rdd2, 2L)) + actual <- collectRDD(fullOuterJoin(rdd1, rdd2, 2L)) expect_equal(sortKeyValueList(actual), sortKeyValueList(list(list(1, list(1, NULL)), list(2, list(2, NULL)), list(3, list(NULL, 3)), list(4, list(NULL, 4))))) rdd1 <- parallelize(sc, list(list("a", 1), list("b", 2))) rdd2 <- parallelize(sc, list(list("c", 3), list("d", 4))) - actual <- collect(fullOuterJoin(rdd1, rdd2, 2L)) + actual <- collectRDD(fullOuterJoin(rdd1, rdd2, 2L)) expect_equal(sortKeyValueList(actual), sortKeyValueList(list(list("a", list(1, NULL)), list("b", list(2, NULL)), list("d", list(NULL, 4)), list("c", list(NULL, 3))))) @@ -700,21 +700,21 @@ test_that("fullOuterJoin() on pairwise RDDs", { test_that("sortByKey() on pairwise RDDs", { numPairsRdd <- map(rdd, function(x) { list (x, x) }) sortedRdd <- sortByKey(numPairsRdd, ascending = FALSE) - actual <- collect(sortedRdd) + actual <- collectRDD(sortedRdd) numPairs <- lapply(nums, function(x) { list (x, x) }) expect_equal(actual, sortKeyValueList(numPairs, decreasing = TRUE)) rdd2 <- parallelize(sc, sort(nums, decreasing = TRUE), 2L) numPairsRdd2 <- map(rdd2, function(x) { list (x, x) }) sortedRdd2 <- sortByKey(numPairsRdd2) - actual <- collect(sortedRdd2) + actual <- collectRDD(sortedRdd2) expect_equal(actual, numPairs) # sort by string keys l <- list(list("a", 1), list("b", 2), list("1", 3), list("d", 4), list("2", 5)) rdd3 <- parallelize(sc, l, 2L) sortedRdd3 <- sortByKey(rdd3) - actual <- collect(sortedRdd3) + actual <- collectRDD(sortedRdd3) expect_equal(actual, list(list("1", 3), list("2", 5), list("a", 1), list("b", 2), list("d", 4))) # test on the boundary cases @@ -722,27 +722,27 @@ test_that("sortByKey() on pairwise RDDs", { # boundary case 1: the RDD to be sorted has only 1 partition rdd4 <- parallelize(sc, l, 1L) sortedRdd4 <- sortByKey(rdd4) - actual <- collect(sortedRdd4) + actual <- collectRDD(sortedRdd4) expect_equal(actual, list(list("1", 3), list("2", 5), list("a", 1), list("b", 2), list("d", 4))) # boundary case 2: the sorted RDD has only 1 partition rdd5 <- parallelize(sc, l, 2L) sortedRdd5 <- sortByKey(rdd5, numPartitions = 1L) - actual <- collect(sortedRdd5) + actual <- collectRDD(sortedRdd5) expect_equal(actual, list(list("1", 3), list("2", 5), list("a", 1), list("b", 2), list("d", 4))) # boundary case 3: the RDD to be sorted has only 1 element l2 <- list(list("a", 1)) rdd6 <- parallelize(sc, l2, 2L) sortedRdd6 <- sortByKey(rdd6) - actual <- collect(sortedRdd6) + actual <- collectRDD(sortedRdd6) expect_equal(actual, l2) # boundary case 4: the RDD to be sorted has 0 element l3 <- list() rdd7 <- parallelize(sc, l3, 2L) sortedRdd7 <- sortByKey(rdd7) - actual <- collect(sortedRdd7) + actual <- collectRDD(sortedRdd7) expect_equal(actual, l3) }) @@ -766,7 +766,7 @@ test_that("collectAsMap() on a pairwise RDD", { test_that("show()", { rdd <- parallelize(sc, list(1:10)) - expect_output(show(rdd), "ParallelCollectionRDD\\[\\d+\\] at parallelize at RRDD\\.scala:\\d+") + expect_output(showRDD(rdd), "ParallelCollectionRDD\\[\\d+\\] at parallelize at RRDD\\.scala:\\d+") }) test_that("sampleByKey() on pairwise RDDs", { diff --git a/R/pkg/inst/tests/testthat/test_shuffle.R b/R/pkg/inst/tests/testthat/test_shuffle.R index 7d4f34201644..07f3b02df664 100644 --- a/R/pkg/inst/tests/testthat/test_shuffle.R +++ b/R/pkg/inst/tests/testthat/test_shuffle.R @@ -39,7 +39,7 @@ strListRDD <- parallelize(sc, strList, 4) test_that("groupByKey for integers", { grouped <- groupByKey(intRdd, 2L) - actual <- collect(grouped) + actual <- collectRDD(grouped) expected <- list(list(2L, list(100, 1)), list(1L, list(-1, 200))) expect_equal(sortKeyValueList(actual), sortKeyValueList(expected)) @@ -48,7 +48,7 @@ test_that("groupByKey for integers", { test_that("groupByKey for doubles", { grouped <- groupByKey(doubleRdd, 2L) - actual <- collect(grouped) + actual <- collectRDD(grouped) expected <- list(list(1.5, list(-1, 200)), list(2.5, list(100, 1))) expect_equal(sortKeyValueList(actual), sortKeyValueList(expected)) @@ -57,7 +57,7 @@ test_that("groupByKey for doubles", { test_that("reduceByKey for ints", { reduced <- reduceByKey(intRdd, "+", 2L) - actual <- collect(reduced) + actual <- collectRDD(reduced) expected <- list(list(2L, 101), list(1L, 199)) expect_equal(sortKeyValueList(actual), sortKeyValueList(expected)) @@ -65,7 +65,7 @@ test_that("reduceByKey for ints", { test_that("reduceByKey for doubles", { reduced <- reduceByKey(doubleRdd, "+", 2L) - actual <- collect(reduced) + actual <- collectRDD(reduced) expected <- list(list(1.5, 199), list(2.5, 101)) expect_equal(sortKeyValueList(actual), sortKeyValueList(expected)) @@ -74,7 +74,7 @@ test_that("reduceByKey for doubles", { test_that("combineByKey for ints", { reduced <- combineByKey(intRdd, function(x) { x }, "+", "+", 2L) - actual <- collect(reduced) + actual <- collectRDD(reduced) expected <- list(list(2L, 101), list(1L, 199)) expect_equal(sortKeyValueList(actual), sortKeyValueList(expected)) @@ -82,7 +82,7 @@ test_that("combineByKey for ints", { test_that("combineByKey for doubles", { reduced <- combineByKey(doubleRdd, function(x) { x }, "+", "+", 2L) - actual <- collect(reduced) + actual <- collectRDD(reduced) expected <- list(list(1.5, 199), list(2.5, 101)) expect_equal(sortKeyValueList(actual), sortKeyValueList(expected)) @@ -94,7 +94,7 @@ test_that("combineByKey for characters", { list("other", 3L), list("max", 4L)), 2L) reduced <- combineByKey(stringKeyRDD, function(x) { x }, "+", "+", 2L) - actual <- collect(reduced) + actual <- collectRDD(reduced) expected <- list(list("max", 5L), list("min", 2L), list("other", 3L)) expect_equal(sortKeyValueList(actual), sortKeyValueList(expected)) @@ -109,7 +109,7 @@ test_that("aggregateByKey", { combOp <- function(x, y) { list(x[[1]] + y[[1]], x[[2]] + y[[2]]) } aggregatedRDD <- aggregateByKey(rdd, zeroValue, seqOp, combOp, 2L) - actual <- collect(aggregatedRDD) + actual <- collectRDD(aggregatedRDD) expected <- list(list(1, list(3, 2)), list(2, list(7, 2))) expect_equal(sortKeyValueList(actual), sortKeyValueList(expected)) @@ -122,7 +122,7 @@ test_that("aggregateByKey", { combOp <- function(x, y) { list(x[[1]] + y[[1]], x[[2]] + y[[2]]) } aggregatedRDD <- aggregateByKey(rdd, zeroValue, seqOp, combOp, 2L) - actual <- collect(aggregatedRDD) + actual <- collectRDD(aggregatedRDD) expected <- list(list("a", list(3, 2)), list("b", list(7, 2))) expect_equal(sortKeyValueList(actual), sortKeyValueList(expected)) @@ -132,7 +132,7 @@ test_that("foldByKey", { # test foldByKey for int keys folded <- foldByKey(intRdd, 0, "+", 2L) - actual <- collect(folded) + actual <- collectRDD(folded) expected <- list(list(2L, 101), list(1L, 199)) expect_equal(sortKeyValueList(actual), sortKeyValueList(expected)) @@ -140,7 +140,7 @@ test_that("foldByKey", { # test foldByKey for double keys folded <- foldByKey(doubleRdd, 0, "+", 2L) - actual <- collect(folded) + actual <- collectRDD(folded) expected <- list(list(1.5, 199), list(2.5, 101)) expect_equal(sortKeyValueList(actual), sortKeyValueList(expected)) @@ -151,7 +151,7 @@ test_that("foldByKey", { stringKeyRDD <- parallelize(sc, stringKeyPairs) folded <- foldByKey(stringKeyRDD, 0, "+", 2L) - actual <- collect(folded) + actual <- collectRDD(folded) expected <- list(list("b", 101), list("a", 199)) expect_equal(sortKeyValueList(actual), sortKeyValueList(expected)) @@ -159,14 +159,14 @@ test_that("foldByKey", { # test foldByKey for empty pair RDD rdd <- parallelize(sc, list()) folded <- foldByKey(rdd, 0, "+", 2L) - actual <- collect(folded) + actual <- collectRDD(folded) expected <- list() expect_equal(actual, expected) # test foldByKey for RDD with only 1 pair rdd <- parallelize(sc, list(list(1, 1))) folded <- foldByKey(rdd, 0, "+", 2L) - actual <- collect(folded) + actual <- collectRDD(folded) expected <- list(list(1, 1)) expect_equal(actual, expected) }) @@ -175,7 +175,7 @@ test_that("partitionBy() partitions data correctly", { # Partition by magnitude partitionByMagnitude <- function(key) { if (key >= 3) 1 else 0 } - resultRDD <- partitionBy(numPairsRdd, 2L, partitionByMagnitude) + resultRDD <- partitionByRDD(numPairsRdd, 2L, partitionByMagnitude) expected_first <- list(list(1, 100), list(2, 200)) # key less than 3 expected_second <- list(list(4, -1), list(3, 1), list(3, 0)) # key greater than or equal 3 @@ -191,7 +191,7 @@ test_that("partitionBy works with dependencies", { partitionByParity <- function(key) { if (key %% 2 == kOne) 7 else 4 } # Partition by parity - resultRDD <- partitionBy(numPairsRdd, numPartitions = 2L, partitionByParity) + resultRDD <- partitionByRDD(numPairsRdd, numPartitions = 2L, partitionByParity) # keys even; 100 %% 2 == 0 expected_first <- list(list(2, 200), list(4, -1)) @@ -208,7 +208,7 @@ test_that("test partitionBy with string keys", { words <- flatMap(strListRDD, function(line) { strsplit(line, " ")[[1]] }) wordCount <- lapply(words, function(word) { list(word, 1L) }) - resultRDD <- partitionBy(wordCount, 2L) + resultRDD <- partitionByRDD(wordCount, 2L) expected_first <- list(list("Dexter", 1), list("Dexter", 1)) expected_second <- list(list("and", 1), list("and", 1)) diff --git a/R/pkg/inst/tests/testthat/test_sparkSQL.R b/R/pkg/inst/tests/testthat/test_sparkSQL.R index 7e59fdf4620e..cdb8ff6b6f8c 100644 --- a/R/pkg/inst/tests/testthat/test_sparkSQL.R +++ b/R/pkg/inst/tests/testthat/test_sparkSQL.R @@ -39,7 +39,7 @@ setHiveContext <- function(sc) { # initialize once and reuse ssc <- callJMethod(sc, "sc") hiveCtx <- tryCatch({ - newJObject("org.apache.spark.sql.hive.test.TestHiveContext", ssc) + newJObject("org.apache.spark.sql.hive.test.TestHiveContext", ssc, FALSE) }, error = function(err) { skip("Hive is not build with SparkSQL, skipped") @@ -208,7 +208,7 @@ test_that("create DataFrame from RDD", { unsetHiveContext() }) -test_that("read csv as DataFrame", { +test_that("read/write csv as DataFrame", { csvPath <- tempfile(pattern = "sparkr-test", fileext = ".csv") mockLinesCsv <- c("year,make,model,comment,blank", "\"2012\",\"Tesla\",\"S\",\"No comment\",", @@ -243,7 +243,17 @@ test_that("read csv as DataFrame", { expect_equal(count(withoutna2), 3) expect_equal(count(where(withoutna2, withoutna2$make == "Dummy")), 0) + # writing csv file + csvPath2 <- tempfile(pattern = "csvtest2", fileext = ".csv") + write.df(df2, path = csvPath2, "csv", header = "true") + df3 <- read.df(csvPath2, "csv", header = "true") + expect_equal(nrow(df3), nrow(df2)) + expect_equal(colnames(df3), colnames(df2)) + csv <- read.csv(file = list.files(csvPath2, pattern = "^part", full.names = T)[[1]]) + expect_equal(colnames(df3), colnames(csv)) + unlink(csvPath) + unlink(csvPath2) }) test_that("convert NAs to null type in DataFrames", { @@ -490,7 +500,7 @@ test_that("read/write json files", { test_that("jsonRDD() on a RDD with json string", { sqlContext <- suppressWarnings(sparkRSQL.init(sc)) rdd <- parallelize(sc, mockLines) - expect_equal(count(rdd), 3) + expect_equal(countRDD(rdd), 3) df <- suppressWarnings(jsonRDD(sqlContext, rdd)) expect_is(df, "SparkDataFrame") expect_equal(count(df), 3) @@ -526,6 +536,17 @@ test_that( expect_is(newdf, "SparkDataFrame") expect_equal(count(newdf), 1) dropTempView("table1") + + createOrReplaceTempView(df, "dfView") + sqlCast <- collect(sql("select cast('2' as decimal) as x from dfView limit 1")) + out <- capture.output(sqlCast) + expect_true(is.data.frame(sqlCast)) + expect_equal(names(sqlCast)[1], "x") + expect_equal(nrow(sqlCast), 1) + expect_equal(ncol(sqlCast), 1) + expect_equal(out[1], " x") + expect_equal(out[2], "1 2") + dropTempView("dfView") }) test_that("test cache, uncache and clearCache", { @@ -582,7 +603,7 @@ test_that("toRDD() returns an RRDD", { df <- read.json(jsonPath) testRDD <- toRDD(df) expect_is(testRDD, "RDD") - expect_equal(count(testRDD), 3) + expect_equal(countRDD(testRDD), 3) }) test_that("union on two RDDs created from DataFrames returns an RRDD", { @@ -592,7 +613,7 @@ test_that("union on two RDDs created from DataFrames returns an RRDD", { unioned <- unionRDD(RDD1, RDD2) expect_is(unioned, "RDD") expect_equal(getSerializedMode(unioned), "byte") - expect_equal(collect(unioned)[[2]]$name, "Andy") + expect_equal(collectRDD(unioned)[[2]]$name, "Andy") }) test_that("union on mixed serialization types correctly returns a byte RRDD", { @@ -614,14 +635,14 @@ test_that("union on mixed serialization types correctly returns a byte RRDD", { unionByte <- unionRDD(rdd, dfRDD) expect_is(unionByte, "RDD") expect_equal(getSerializedMode(unionByte), "byte") - expect_equal(collect(unionByte)[[1]], 1) - expect_equal(collect(unionByte)[[12]]$name, "Andy") + expect_equal(collectRDD(unionByte)[[1]], 1) + expect_equal(collectRDD(unionByte)[[12]]$name, "Andy") unionString <- unionRDD(textRDD, dfRDD) expect_is(unionString, "RDD") expect_equal(getSerializedMode(unionString), "byte") - expect_equal(collect(unionString)[[1]], "Michael") - expect_equal(collect(unionString)[[5]]$name, "Andy") + expect_equal(collectRDD(unionString)[[1]], "Michael") + expect_equal(collectRDD(unionString)[[5]]$name, "Andy") }) test_that("objectFile() works with row serialization", { @@ -633,7 +654,7 @@ test_that("objectFile() works with row serialization", { expect_is(objectIn, "RDD") expect_equal(getSerializedMode(objectIn), "byte") - expect_equal(collect(objectIn)[[2]]$age, 30) + expect_equal(collectRDD(objectIn)[[2]]$age, 30) }) test_that("lapply() on a DataFrame returns an RDD with the correct columns", { @@ -643,7 +664,7 @@ test_that("lapply() on a DataFrame returns an RDD with the correct columns", { row }) expect_is(testRDD, "RDD") - collected <- collect(testRDD) + collected <- collectRDD(testRDD) expect_equal(collected[[1]]$name, "Michael") expect_equal(collected[[2]]$newCol, 35) }) @@ -715,10 +736,10 @@ test_that("multiple pipeline transformations result in an RDD with the correct v row }) expect_is(second, "RDD") - expect_equal(count(second), 3) - expect_equal(collect(second)[[2]]$age, 35) - expect_true(collect(second)[[2]]$testCol) - expect_false(collect(second)[[3]]$testCol) + expect_equal(countRDD(second), 3) + expect_equal(collectRDD(second)[[2]]$age, 35) + expect_true(collectRDD(second)[[2]]$testCol) + expect_false(collectRDD(second)[[3]]$testCol) }) test_that("cache(), persist(), and unpersist() on a DataFrame", { @@ -1608,7 +1629,7 @@ test_that("toJSON() returns an RDD of the correct values", { testRDD <- toJSON(df) expect_is(testRDD, "RDD") expect_equal(getSerializedMode(testRDD), "string") - expect_equal(collect(testRDD)[[1]], mockLines[1]) + expect_equal(collectRDD(testRDD)[[1]], mockLines[1]) }) test_that("showDF()", { @@ -2081,6 +2102,9 @@ test_that("Method coltypes() to get and set R's data types of a DataFrame", { # Test primitive types DF <- createDataFrame(data, schema) expect_equal(coltypes(DF), c("integer", "logical", "POSIXct")) + createOrReplaceTempView(DF, "DFView") + sqlCast <- sql("select cast('2' as decimal) as x from DFView limit 1") + expect_equal(coltypes(sqlCast), "numeric") # Test complex types x <- createDataFrame(list(list(as.environment( @@ -2124,6 +2148,14 @@ test_that("Method str()", { "setosa\" \"setosa\" \"setosa\" \"setosa\"")) expect_equal(out[7], " $ col : logi TRUE TRUE TRUE TRUE TRUE TRUE") + createOrReplaceTempView(irisDF2, "irisView") + + sqlCast <- sql("select cast('2' as decimal) as x from irisView limit 1") + castStr <- capture.output(str(sqlCast)) + expect_equal(length(castStr), 2) + expect_equal(castStr[1], "'SparkDataFrame': 1 variables:") + expect_equal(castStr[2], " $ x: num 2") + # A random dataset with many columns. This test is to check str limits # the number of columns. Therefore, it will suffice to check for the # number of returned rows @@ -2240,6 +2272,27 @@ test_that("dapply() and dapplyCollect() on a DataFrame", { expect_identical(expected, result) }) +test_that("dapplyCollect() on DataFrame with a binary column", { + + df <- data.frame(key = 1:3) + df$bytes <- lapply(df$key, serialize, connection = NULL) + + df_spark <- createDataFrame(df) + + result1 <- collect(df_spark) + expect_identical(df, result1) + + result2 <- dapplyCollect(df_spark, function(x) x) + expect_identical(df, result2) + + # A data.frame with a single column of bytes + scb <- subset(df, select = "bytes") + scb_spark <- createDataFrame(scb) + result <- dapplyCollect(scb_spark, function(x) x) + expect_identical(scb, result) + +}) + test_that("repartition by columns on DataFrame", { df <- createDataFrame( list(list(1L, 1, "1", 0.1), list(1L, 2, "2", 0.2), list(3L, 3, "3", 0.3)), @@ -2477,6 +2530,12 @@ test_that("enableHiveSupport on SparkSession", { expect_equal(value, "hive") }) +test_that("Spark version from SparkSession", { + ver <- callJMethod(sc, "version") + version <- sparkR.version() + expect_equal(ver, version) +}) + unlink(parquetPath) unlink(orcPath) unlink(jsonPath) diff --git a/R/pkg/inst/tests/testthat/test_take.R b/R/pkg/inst/tests/testthat/test_take.R index daf5e41abe13..dcf479363b9a 100644 --- a/R/pkg/inst/tests/testthat/test_take.R +++ b/R/pkg/inst/tests/testthat/test_take.R @@ -36,32 +36,32 @@ sc <- callJStatic("org.apache.spark.sql.api.r.SQLUtils", "getJavaSparkContext", test_that("take() gives back the original elements in correct count and order", { numVectorRDD <- parallelize(sc, numVector, 10) # case: number of elements to take is less than the size of the first partition - expect_equal(take(numVectorRDD, 1), as.list(head(numVector, n = 1))) + expect_equal(takeRDD(numVectorRDD, 1), as.list(head(numVector, n = 1))) # case: number of elements to take is the same as the size of the first partition - expect_equal(take(numVectorRDD, 11), as.list(head(numVector, n = 11))) + expect_equal(takeRDD(numVectorRDD, 11), as.list(head(numVector, n = 11))) # case: number of elements to take is greater than all elements - expect_equal(take(numVectorRDD, length(numVector)), as.list(numVector)) - expect_equal(take(numVectorRDD, length(numVector) + 1), as.list(numVector)) + expect_equal(takeRDD(numVectorRDD, length(numVector)), as.list(numVector)) + expect_equal(takeRDD(numVectorRDD, length(numVector) + 1), as.list(numVector)) numListRDD <- parallelize(sc, numList, 1) numListRDD2 <- parallelize(sc, numList, 4) - expect_equal(take(numListRDD, 3), take(numListRDD2, 3)) - expect_equal(take(numListRDD, 5), take(numListRDD2, 5)) - expect_equal(take(numListRDD, 1), as.list(head(numList, n = 1))) - expect_equal(take(numListRDD2, 999), numList) + expect_equal(takeRDD(numListRDD, 3), takeRDD(numListRDD2, 3)) + expect_equal(takeRDD(numListRDD, 5), takeRDD(numListRDD2, 5)) + expect_equal(takeRDD(numListRDD, 1), as.list(head(numList, n = 1))) + expect_equal(takeRDD(numListRDD2, 999), numList) strVectorRDD <- parallelize(sc, strVector, 2) strVectorRDD2 <- parallelize(sc, strVector, 3) - expect_equal(take(strVectorRDD, 4), as.list(strVector)) - expect_equal(take(strVectorRDD2, 2), as.list(head(strVector, n = 2))) + expect_equal(takeRDD(strVectorRDD, 4), as.list(strVector)) + expect_equal(takeRDD(strVectorRDD2, 2), as.list(head(strVector, n = 2))) strListRDD <- parallelize(sc, strList, 4) strListRDD2 <- parallelize(sc, strList, 1) - expect_equal(take(strListRDD, 3), as.list(head(strList, n = 3))) - expect_equal(take(strListRDD2, 1), as.list(head(strList, n = 1))) + expect_equal(takeRDD(strListRDD, 3), as.list(head(strList, n = 3))) + expect_equal(takeRDD(strListRDD2, 1), as.list(head(strList, n = 1))) - expect_equal(length(take(strListRDD, 0)), 0) - expect_equal(length(take(strVectorRDD, 0)), 0) - expect_equal(length(take(numListRDD, 0)), 0) - expect_equal(length(take(numVectorRDD, 0)), 0) + expect_equal(length(takeRDD(strListRDD, 0)), 0) + expect_equal(length(takeRDD(strVectorRDD, 0)), 0) + expect_equal(length(takeRDD(numListRDD, 0)), 0) + expect_equal(length(takeRDD(numVectorRDD, 0)), 0) }) diff --git a/R/pkg/inst/tests/testthat/test_textFile.R b/R/pkg/inst/tests/testthat/test_textFile.R index 7b2cc74753fe..ba434a5d4127 100644 --- a/R/pkg/inst/tests/testthat/test_textFile.R +++ b/R/pkg/inst/tests/testthat/test_textFile.R @@ -29,8 +29,8 @@ test_that("textFile() on a local file returns an RDD", { rdd <- textFile(sc, fileName) expect_is(rdd, "RDD") - expect_true(count(rdd) > 0) - expect_equal(count(rdd), 2) + expect_true(countRDD(rdd) > 0) + expect_equal(countRDD(rdd), 2) unlink(fileName) }) @@ -40,7 +40,7 @@ test_that("textFile() followed by a collect() returns the same content", { writeLines(mockFile, fileName) rdd <- textFile(sc, fileName) - expect_equal(collect(rdd), as.list(mockFile)) + expect_equal(collectRDD(rdd), as.list(mockFile)) unlink(fileName) }) @@ -55,7 +55,7 @@ test_that("textFile() word count works as expected", { wordCount <- lapply(words, function(word) { list(word, 1L) }) counts <- reduceByKey(wordCount, "+", 2L) - output <- collect(counts) + output <- collectRDD(counts) expected <- list(list("pretty.", 1), list("is", 2), list("awesome.", 1), list("Spark", 2)) expect_equal(sortKeyValueList(output), sortKeyValueList(expected)) @@ -72,7 +72,7 @@ test_that("several transformations on RDD created by textFile()", { # PipelinedRDD initially created from RDD rdd <- lapply(rdd, function(x) paste(x, x)) } - collect(rdd) + collectRDD(rdd) unlink(fileName) }) @@ -85,7 +85,7 @@ test_that("textFile() followed by a saveAsTextFile() returns the same content", rdd <- textFile(sc, fileName1, 1L) saveAsTextFile(rdd, fileName2) rdd <- textFile(sc, fileName2) - expect_equal(collect(rdd), as.list(mockFile)) + expect_equal(collectRDD(rdd), as.list(mockFile)) unlink(fileName1) unlink(fileName2) @@ -97,7 +97,7 @@ test_that("saveAsTextFile() on a parallelized list works as expected", { rdd <- parallelize(sc, l, 1L) saveAsTextFile(rdd, fileName) rdd <- textFile(sc, fileName) - expect_equal(collect(rdd), lapply(l, function(x) {toString(x)})) + expect_equal(collectRDD(rdd), lapply(l, function(x) {toString(x)})) unlink(fileName) }) @@ -117,7 +117,7 @@ test_that("textFile() and saveAsTextFile() word count works as expected", { saveAsTextFile(counts, fileName2) rdd <- textFile(sc, fileName2) - output <- collect(rdd) + output <- collectRDD(rdd) expected <- list(list("awesome.", 1), list("Spark", 2), list("pretty.", 1), list("is", 2)) expectedStr <- lapply(expected, function(x) { toString(x) }) @@ -134,7 +134,7 @@ test_that("textFile() on multiple paths", { writeLines("Spark is awesome.", fileName2) rdd <- textFile(sc, c(fileName1, fileName2)) - expect_equal(count(rdd), 2) + expect_equal(countRDD(rdd), 2) unlink(fileName1) unlink(fileName2) @@ -147,16 +147,16 @@ test_that("Pipelined operations on RDDs created using textFile", { rdd <- textFile(sc, fileName) lengths <- lapply(rdd, function(x) { length(x) }) - expect_equal(collect(lengths), list(1, 1)) + expect_equal(collectRDD(lengths), list(1, 1)) lengthsPipelined <- lapply(lengths, function(x) { x + 10 }) - expect_equal(collect(lengthsPipelined), list(11, 11)) + expect_equal(collectRDD(lengthsPipelined), list(11, 11)) lengths30 <- lapply(lengthsPipelined, function(x) { x + 20 }) - expect_equal(collect(lengths30), list(31, 31)) + expect_equal(collectRDD(lengths30), list(31, 31)) lengths20 <- lapply(lengths, function(x) { x + 20 }) - expect_equal(collect(lengths20), list(21, 21)) + expect_equal(collectRDD(lengths20), list(21, 21)) unlink(fileName) }) diff --git a/R/pkg/inst/tests/testthat/test_utils.R b/R/pkg/inst/tests/testthat/test_utils.R index 21a119a06b93..dd5001b1599a 100644 --- a/R/pkg/inst/tests/testthat/test_utils.R +++ b/R/pkg/inst/tests/testthat/test_utils.R @@ -24,7 +24,7 @@ sc <- callJStatic("org.apache.spark.sql.api.r.SQLUtils", "getJavaSparkContext", test_that("convertJListToRList() gives back (deserializes) the original JLists of strings and integers", { # It's hard to manually create a Java List using rJava, since it does not - # support generics well. Instead, we rely on collect() returning a + # support generics well. Instead, we rely on collectRDD() returning a # JList. nums <- as.list(1:10) rdd <- parallelize(sc, nums, 1L) @@ -48,7 +48,7 @@ test_that("serializeToBytes on RDD", { text.rdd <- textFile(sc, fileName) expect_equal(getSerializedMode(text.rdd), "string") ser.rdd <- serializeToBytes(text.rdd) - expect_equal(collect(ser.rdd), as.list(mockFile)) + expect_equal(collectRDD(ser.rdd), as.list(mockFile)) expect_equal(getSerializedMode(ser.rdd), "byte") unlink(fileName) @@ -128,7 +128,7 @@ test_that("cleanClosure on R functions", { env <- environment(newF) expect_equal(ls(env), "t") expect_equal(get("t", envir = env, inherits = FALSE), t) - actual <- collect(lapply(rdd, f)) + actual <- collectRDD(lapply(rdd, f)) expected <- as.list(c(rep(FALSE, 4), rep(TRUE, 6))) expect_equal(actual, expected) @@ -182,3 +182,27 @@ test_that("overrideEnvs", { expect_equal(config[["param_only"]], "blah") expect_equal(config[["config_only"]], "ok") }) + +test_that("rbindRaws", { + + # Mixed Column types + r <- serialize(1:5, connection = NULL) + r1 <- serialize(1, connection = NULL) + r2 <- serialize(letters, connection = NULL) + r3 <- serialize(1:10, connection = NULL) + inputData <- list(list(1L, r1, "a", r), list(2L, r2, "b", r), + list(3L, r3, "c", r)) + expected <- data.frame(V1 = 1:3) + expected$V2 <- list(r1, r2, r3) + expected$V3 <- c("a", "b", "c") + expected$V4 <- list(r, r, r) + result <- rbindRaws(inputData) + expect_equal(expected, result) + + # Single binary column + input <- list(list(r1), list(r2), list(r3)) + expected <- subset(expected, select = "V2") + result <- setNames(rbindRaws(input), "V2") + expect_equal(expected, result) + +}) diff --git a/R/pkg/inst/worker/worker.R b/R/pkg/inst/worker/worker.R index debf0180183a..cfe41ded200c 100644 --- a/R/pkg/inst/worker/worker.R +++ b/R/pkg/inst/worker/worker.R @@ -36,7 +36,14 @@ compute <- function(mode, partition, serializer, deserializer, key, # available since R 3.2.4. So we set the global option here. oldOpt <- getOption("stringsAsFactors") options(stringsAsFactors = FALSE) - inputData <- do.call(rbind.data.frame, inputData) + + # Handle binary data types + if ("raw" %in% sapply(inputData[[1]], class)) { + inputData <- SparkR:::rbindRaws(inputData) + } else { + inputData <- do.call(rbind.data.frame, inputData) + } + options(stringsAsFactors = oldOpt) names(inputData) <- colNames diff --git a/R/pkg/vignettes/sparkr-vignettes.Rmd b/R/pkg/vignettes/sparkr-vignettes.Rmd new file mode 100644 index 000000000000..5156c9e566c9 --- /dev/null +++ b/R/pkg/vignettes/sparkr-vignettes.Rmd @@ -0,0 +1,643 @@ +--- +title: "SparkR - Practical Guide" +output: + html_document: + theme: united + toc: true + toc_depth: 4 + toc_float: true + highlight: textmate +--- + +## Overview + +SparkR is an R package that provides a light-weight frontend to use Apache Spark from R. With Spark `r packageVersion("SparkR")`, SparkR provides a distributed data frame implementation that supports data processing operations like selection, filtering, aggregation etc. and distributed machine learning using [MLlib](http://spark.apache.org/mllib/). + +## Getting Started + +We begin with an example running on the local machine and provide an overview of the use of SparkR: data ingestion, data processing and machine learning. + +First, let's load and attach the package. +```{r, message=FALSE} +library(SparkR) +``` + +`SparkSession` is the entry point into SparkR which connects your R program to a Spark cluster. You can create a `SparkSession` using `sparkR.session` and pass in options such as the application name, any Spark packages depended on, etc. + +We use default settings in which it runs in local mode. It auto downloads Spark package in the background if no previous installation is found. For more details about setup, see [Spark Session](#SetupSparkSession). + +```{r, message=FALSE} +sparkR.session() +``` + +The operations in SparkR are centered around an R class called `SparkDataFrame`. It is a distributed collection of data organized into named columns, which is conceptually equivalent to a table in a relational database or a data frame in R, but with richer optimizations under the hood. + +`SparkDataFrame` can be constructed from a wide array of sources such as: structured data files, tables in Hive, external databases, or existing local R data frames. For example, we create a `SparkDataFrame` from a local R data frame, + +```{r} +cars <- cbind(model = rownames(mtcars), mtcars) +carsDF <- createDataFrame(cars) +``` + +We can view the first few rows of the `SparkDataFrame` by `head` or `showDF` function. +```{r} +head(carsDF) +``` + +Common data processing operations such as `filter`, `select` are supported on the `SparkDataFrame`. +```{r} +carsSubDF <- select(carsDF, "model", "mpg", "hp") +carsSubDF <- filter(carsSubDF, carsSubDF$hp >= 200) +head(carsSubDF) +``` + +SparkR can use many common aggregation functions after grouping. + +```{r} +carsGPDF <- summarize(groupBy(carsDF, carsDF$gear), count = n(carsDF$gear)) +head(carsGPDF) +``` + +The results `carsDF` and `carsSubDF` are `SparkDataFrame` objects. To convert back to R `data.frame`, we can use `collect`. **Caution**: This can cause your interactive environment to run out of memory, though, because `collect()` fetches the entire distributed `DataFrame` to your client, which is acting as a Spark driver. +```{r} +carsGP <- collect(carsGPDF) +class(carsGP) +``` + +SparkR supports a number of commonly used machine learning algorithms. Under the hood, SparkR uses MLlib to train the model. Users can call `summary` to print a summary of the fitted model, `predict` to make predictions on new data, and `write.ml`/`read.ml` to save/load fitted models. + +SparkR supports a subset of R formula operators for model fitting, including ‘~’, ‘.’, ‘:’, ‘+’, and ‘-‘. We use linear regression as an example. +```{r} +model <- spark.glm(carsDF, mpg ~ wt + cyl) +``` + +The result matches that returned by R `glm` function applied to the corresponding `data.frame` `mtcars` of `carsDF`. In fact, for Generalized Linear Model, we specifically expose `glm` for `SparkDataFrame` as well so that the above is equivalent to `model <- glm(mpg ~ wt + cyl, data = carsDF)`. + +```{r} +summary(model) +``` + +The model can be saved by `write.ml` and loaded back using `read.ml`. +```{r, eval=FALSE} +write.ml(model, path = "/HOME/tmp/mlModel/glmModel") +``` + +In the end, we can stop Spark Session by running +```{r, eval=FALSE} +sparkR.session.stop() +``` + +## Setup + +### Installation + +Different from many other R packages, to use SparkR, you need an additional installation of Apache Spark. The Spark installation will be used to run a backend process that will compile and execute SparkR programs. + +If you don't have Spark installed on the computer, you may download it from [Apache Spark Website](http://spark.apache.org/downloads.html). Alternatively, we provide an easy-to-use function `install.spark` to complete this process. You don't have to call it explicitly. We will check the installation when `sparkR.session` is called and `install.spark` function will be triggered automatically if no installation is found. + +```{r, eval=FALSE} +install.spark() +``` + +If you already have Spark installed, you don't have to install again and can pass the `sparkHome` argument to `sparkR.session` to let SparkR know where the Spark installation is. + +```{r, eval=FALSE} +sparkR.session(sparkHome = "/HOME/spark") +``` + +### Spark Session {#SetupSparkSession} + + +In addition to `sparkHome`, many other options can be specified in `sparkR.session`. For a complete list, see [Starting up: SparkSession](http://spark.apache.org/docs/latest/sparkr.html#starting-up-sparksession) and [SparkR API doc](http://spark.apache.org/docs/latest/api/R/sparkR.session.html). + +In particular, the following Spark driver properties can be set in `sparkConfig`. + +Property Name | Property group | spark-submit equivalent +---------------- | ------------------ | ---------------------- +spark.driver.memory | Application Properties | --driver-memory +spark.driver.extraClassPath | Runtime Environment | --driver-class-path +spark.driver.extraJavaOptions | Runtime Environment | --driver-java-options +spark.driver.extraLibraryPath | Runtime Environment | --driver-library-path + +**For Windows users**: Due to different file prefixes across operating systems, to avoid the issue of potential wrong prefix, a current workaround is to specify `spark.sql.warehouse.dir` when starting the `SparkSession`. + +```{r, eval=FALSE} +spark_warehouse_path <- file.path(path.expand('~'), "spark-warehouse") +sparkR.session(spark.sql.warehouse.dir = spark_warehouse_path) +``` + + +#### Cluster Mode +SparkR can connect to remote Spark clusters. [Cluster Mode Overview](http://spark.apache.org/docs/latest/cluster-overview.html) is a good introduction to different Spark cluster modes. + +When connecting SparkR to a remote Spark cluster, make sure that the Spark version and Hadoop version on the machine match the corresponding versions on the cluster. Current SparkR package is compatible with +```{r, echo=FALSE, tidy = TRUE} +paste("Spark", packageVersion("SparkR")) +``` +It should be used both on the local computer and on the remote cluster. + +To connect, pass the URL of the master node to `sparkR.session`. A complete list can be seen in [Spark Master URLs](http://spark.apache.org/docs/latest/submitting-applications.html#master-urls). +For example, to connect to a local standalone Spark master, we can call + +```{r, eval=FALSE} +sparkR.session(master = "spark://local:7077") +``` + +For YARN cluster, SparkR supports the client mode with the master set as "yarn". +```{r, eval=FALSE} +sparkR.session(master = "yarn") +``` +Yarn cluster mode is not supported in the current version. + +## Data Import + +### Local Data Frame +The simplest way is to convert a local R data frame into a `SparkDataFrame`. Specifically we can use `as.DataFrame` or `createDataFrame` and pass in the local R data frame to create a `SparkDataFrame`. As an example, the following creates a `SparkDataFrame` based using the `faithful` dataset from R. +```{r} +df <- as.DataFrame(faithful) +head(df) +``` + +### Data Sources +SparkR supports operating on a variety of data sources through the `SparkDataFrame` interface. You can check the Spark SQL programming guide for more [specific options](https://spark.apache.org/docs/latest/sql-programming-guide.html#manually-specifying-options) that are available for the built-in data sources. + +The general method for creating `SparkDataFrame` from data sources is `read.df`. This method takes in the path for the file to load and the type of data source, and the currently active Spark Session will be used automatically. SparkR supports reading CSV, JSON and Parquet files natively and through Spark Packages you can find data source connectors for popular file formats like Avro. These packages can be added with `sparkPackages` parameter when initializing SparkSession using `sparkR.session'.` + +```{r, eval=FALSE} +sparkR.session(sparkPackages = "com.databricks:spark-avro_2.11:3.0.0") +``` + +We can see how to use data sources using an example CSV input file. For more information please refer to SparkR [read.df](https://spark.apache.org/docs/latest/api/R/read.df.html) API documentation. +```{r, eval=FALSE} +df <- read.df(csvPath, "csv", header = "true", inferSchema = "true", na.strings = "NA") +``` + +The data sources API natively supports JSON formatted input files. Note that the file that is used here is not a typical JSON file. Each line in the file must contain a separate, self-contained valid JSON object. As a consequence, a regular multi-line JSON file will most often fail. + +Let's take a look at the first two lines of the raw JSON file used here. + +```{r} +filePath <- paste0(sparkR.conf("spark.home"), + "/examples/src/main/resources/people.json") +readLines(filePath, n = 2L) +``` + +We use `read.df` to read that into a `SparkDataFrame`. + +```{r} +people <- read.df(filePath, "json") +count(people) +head(people) +``` + +SparkR automatically infers the schema from the JSON file. +```{r} +printSchema(people) +``` + +If we want to read multiple JSON files, `read.json` can be used. +```{r} +people <- read.json(paste0(Sys.getenv("SPARK_HOME"), + c("/examples/src/main/resources/people.json", + "/examples/src/main/resources/people.json"))) +count(people) +``` + +The data sources API can also be used to save out `SparkDataFrames` into multiple file formats. For example we can save the `SparkDataFrame` from the previous example to a Parquet file using `write.df`. +```{r, eval=FALSE} +write.df(people, path = "people.parquet", source = "parquet", mode = "overwrite") +``` + +### Hive Tables +You can also create SparkDataFrames from Hive tables. To do this we will need to create a SparkSession with Hive support which can access tables in the Hive MetaStore. Note that Spark should have been built with Hive support and more details can be found in the [SQL programming guide](https://spark.apache.org/docs/latest/sql-programming-guide.html). In SparkR, by default it will attempt to create a SparkSession with Hive support enabled (`enableHiveSupport = TRUE`). + +```{r, eval=FALSE} +sql("CREATE TABLE IF NOT EXISTS src (key INT, value STRING)") + +txtPath <- paste0(sparkR.conf("spark.home"), "/examples/src/main/resources/kv1.txt") +sqlCMD <- sprintf("LOAD DATA LOCAL INPATH '%s' INTO TABLE src", txtPath) +sql(sqlCMD) + +results <- sql("FROM src SELECT key, value") + +# results is now a SparkDataFrame +head(results) +``` + + +## Data Processing + +**To dplyr users**: SparkR has similar interface as dplyr in data processing. However, some noticeable differences are worth mentioning in the first place. We use `df` to represent a `SparkDataFrame` and `col` to represent the name of column here. + +1. indicate columns. SparkR uses either a character string of the column name or a Column object constructed with `$` to indicate a column. For example, to select `col` in `df`, we can write `select(df, "col")` or `select(df, df$col)`. + +2. describe conditions. In SparkR, the Column object representation can be inserted into the condition directly, or we can use a character string to describe the condition, without referring to the `SparkDataFrame` used. For example, to select rows with value > 1, we can write `filter(df, df$col > 1)` or `filter(df, "col > 1")`. + +Here are more concrete examples. + +dplyr | SparkR +-------- | --------- +`select(mtcars, mpg, hp)` | `select(carsDF, "mpg", "hp")` +`filter(mtcars, mpg > 20, hp > 100)` | `filter(carsDF, carsDF$mpg > 20, carsDF$hp > 100)` + +Other differences will be mentioned in the specific methods. + +We use the `SparkDataFrame` `carsDF` created above. We can get basic information about the `SparkDataFrame`. +```{r} +carsDF +``` + +Print out the schema in tree format. +```{r} +printSchema(carsDF) +``` + +### SparkDataFrame Operations + +#### Selecting rows, columns + +SparkDataFrames support a number of functions to do structured data processing. Here we include some basic examples and a complete list can be found in the [API](https://spark.apache.org/docs/latest/api/R/index.html) docs: + +You can also pass in column name as strings. +```{r} +head(select(carsDF, "mpg")) +``` + +Filter the SparkDataFrame to only retain rows with mpg less than 20 miles/gallon. +```{r} +head(filter(carsDF, carsDF$mpg < 20)) +``` + +#### Grouping, Aggregation + +A common flow of grouping and aggregation is + +1. Use `groupBy` or `group_by` with respect to some grouping variables to create a `GroupedData` object + +2. Feed the `GroupedData` object to `agg` or `summarize` functions, with some provided aggregation functions to compute a number within each group. + +A number of widely used functions are supported to aggregate data after grouping, including `avg`, `countDistinct`, `count`, `first`, `kurtosis`, `last`, `max`, `mean`, `min`, `sd`, `skewness`, `stddev_pop`, `stddev_samp`, `sumDistinct`, `sum`, `var_pop`, `var_samp`, `var`. See the [API doc for `mean`](http://spark.apache.org/docs/latest/api/R/mean.html) and other `agg_funcs` linked there. + +For example we can compute a histogram of the number of cylinders in the `mtcars` dataset as shown below. + +```{r} +numCyl <- summarize(groupBy(carsDF, carsDF$cyl), count = n(carsDF$cyl)) +head(numCyl) +``` + +#### Operating on Columns + +SparkR also provides a number of functions that can directly applied to columns for data processing and during aggregation. The example below shows the use of basic arithmetic functions. + +```{r} +carsDF_km <- carsDF +carsDF_km$kmpg <- carsDF_km$mpg * 1.61 +head(select(carsDF_km, "model", "mpg", "kmpg")) +``` + + +### Window Functions +A window function is a variation of aggregation function. In simple words, + +* aggregation function: `n` to `1` mapping - returns a single value for a group of entries. Examples include `sum`, `count`, `max`. + +* window function: `n` to `n` mapping - returns one value for each entry in the group, but the value may depend on all the entries of the *group*. Examples include `rank`, `lead`, `lag`. + +Formally, the *group* mentioned above is called the *frame*. Every input row can have a unique frame associated with it and the output of the window function on that row is based on the rows confined in that frame. + +Window functions are often used in conjunction with the following functions: `windowPartitionBy`, `windowOrderBy`, `partitionBy`, `orderBy`, `over`. To illustrate this we next look at an example. + +We still use the `mtcars` dataset. The corresponding `SparkDataFrame` is `carsDF`. Suppose for each number of cylinders, we want to calculate the rank of each car in `mpg` within the group. +```{r} +carsSubDF <- select(carsDF, "model", "mpg", "cyl") +ws <- orderBy(windowPartitionBy("cyl"), "mpg") +carsRank <- withColumn(carsSubDF, "rank", over(rank(), ws)) +head(carsRank, n = 20L) +``` + +We explain in detail the above steps. + +* `windowPartitionBy` creates a window specification object `WindowSpec` that defines the partition. It controls which rows will be in the same partition as the given row. In this case, rows with the same value in `cyl` will be put in the same partition. `orderBy` further defines the ordering - the position a given row is in the partition. The resulting `WindowSpec` is returned as `ws`. + +More window specification methods include `rangeBetween`, which can define boundaries of the frame by value, and `rowsBetween`, which can define the boundaries by row indices. + +* `withColumn` appends a Column called `rank` to the `SparkDataFrame`. `over` returns a windowing column. The first argument is usually a Column returned by window function(s) such as `rank()`, `lead(carsDF$wt)`. That calculates the corresponding values according to the partitioned-and-ordered table. + +### User-Defined Function + +In SparkR, we support several kinds of user-defined functions (UDFs). + +#### Apply by Partition + +`dapply` can apply a function to each partition of a `SparkDataFrame`. The function to be applied to each partition of the `SparkDataFrame` should have only one parameter, a `data.frame` corresponding to a partition, and the output should be a `data.frame` as well. Schema specifies the row format of the resulting a `SparkDataFrame`. It must match to data types of returned value. See [here](#DataTypes) for mapping between R and Spark. + +We convert `mpg` to `kmpg` (kilometers per gallon). `carsSubDF` is a `SparkDataFrame` with a subset of `carsDF` columns. + +```{r} +carsSubDF <- select(carsDF, "model", "mpg") +schema <- structType(structField("model", "string"), structField("mpg", "double"), + structField("kmpg", "double")) +out <- dapply(carsSubDF, function(x) { x <- cbind(x, x$mpg * 1.61) }, schema) +head(collect(out)) +``` + +Like `dapply`, apply a function to each partition of a `SparkDataFrame` and collect the result back. The output of function should be a `data.frame`, but no schema is required in this case. Note that `dapplyCollect` can fail if the output of UDF run on all the partition cannot be pulled to the driver and fit in driver memory. + +```{r} +out <- dapplyCollect( + carsSubDF, + function(x) { + x <- cbind(x, "kmpg" = x$mpg * 1.61) + }) +head(out, 3) +``` + +#### Apply by Group +`gapply` can apply a function to each group of a `SparkDataFrame`. The function is to be applied to each group of the `SparkDataFrame` and should have only two parameters: grouping key and R `data.frame` corresponding to that key. The groups are chosen from `SparkDataFrames` column(s). The output of function should be a `data.frame`. Schema specifies the row format of the resulting `SparkDataFrame`. It must represent R function’s output schema on the basis of Spark data types. The column names of the returned `data.frame` are set by user. See [here](#DataTypes) for mapping between R and Spark. + +```{r} +schema <- structType(structField("cyl", "double"), structField("max_mpg", "double")) +result <- gapply( + carsDF, + "cyl", + function(key, x) { + y <- data.frame(key, max(x$mpg)) + }, + schema) +head(arrange(result, "max_mpg", decreasing = TRUE)) +``` + +Like gapply, `gapplyCollect` applies a function to each partition of a `SparkDataFrame` and collect the result back to R `data.frame`. The output of the function should be a `data.frame` but no schema is required in this case. Note that `gapplyCollect` can fail if the output of UDF run on all the partition cannot be pulled to the driver and fit in driver memory. + +```{r} +result <- gapplyCollect( + carsDF, + "cyl", + function(key, x) { + y <- data.frame(key, max(x$mpg)) + colnames(y) <- c("cyl", "max_mpg") + y + }) +head(result[order(result$max_mpg, decreasing = TRUE), ]) +``` + +#### Distribute Local Functions + +Similar to `lapply` in native R, `spark.lapply` runs a function over a list of elements and distributes the computations with Spark. `spark.lapply` works in a manner that is similar to `doParallel` or `lapply` to elements of a list. The results of all the computations should fit in a single machine. If that is not the case you can do something like `df <- createDataFrame(list)` and then use `dapply`. + +We use `svm` in package `e1071` as an example. We use all default settings except for varying costs of constraints violation. `spark.lapply` can train those different models in parallel. + +```{r} +costs <- exp(seq(from = log(1), to = log(1000), length.out = 5)) +train <- function(cost) { + stopifnot(requireNamespace("e1071", quietly = TRUE)) + model <- e1071::svm(Species ~ ., data = iris, cost = cost) + summary(model) +} +``` + +Return a list of model's summaries. +```{r} +model.summaries <- spark.lapply(costs, train) +``` + +```{r} +class(model.summaries) +``` + + +To avoid lengthy display, we only present the result of the second fitted model. You are free to inspect other models as well. +```{r} +print(model.summaries[[2]]) +``` + + +### SQL Queries +A `SparkDataFrame` can also be registered as a temporary view in Spark SQL and that allows you to run SQL queries over its data. The sql function enables applications to run SQL queries programmatically and returns the result as a `SparkDataFrame`. + +```{r} +people <- read.df(paste0(sparkR.conf("spark.home"), + "/examples/src/main/resources/people.json"), "json") +``` + +Register this SparkDataFrame as a temporary view. + +```{r} +createOrReplaceTempView(people, "people") +``` + +SQL statements can be run by using the sql method. +```{r} +teenagers <- sql("SELECT name FROM people WHERE age >= 13 AND age <= 19") +head(teenagers) +``` + + +## Machine Learning + +SparkR supports the following machine learning models and algorithms. + +* Generalized Linear Model (GLM) + +* Naive Bayes Model + +* $k$-means Clustering + +* Accelerated Failure Time (AFT) Survival Model + +More will be added in the future. + +### R Formula + +For most above, SparkR supports **R formula operators**, including `~`, `.`, `:`, `+` and `-` for model fitting. This makes it a similar experience as using R functions. + +### Training and Test Sets + +We can easily split `SparkDataFrame` into random training and test sets by the `randomSplit` function. It returns a list of split `SparkDataFrames` with provided `weights`. We use `carsDF` as an example and want to have about $70%$ training data and $30%$ test data. +```{r} +splitDF_list <- randomSplit(carsDF, c(0.7, 0.3), seed = 0) +carsDF_train <- splitDF_list[[1]] +carsDF_test <- splitDF_list[[2]] +``` + +```{r} +count(carsDF_train) +head(carsDF_train) +``` + +```{r} +count(carsDF_test) +head(carsDF_test) +``` + + +### Models and Algorithms + +#### Generalized Linear Model + +The main function is `spark.glm`. The following families and link functions are supported. The default is gaussian. + +Family | Link Function +------ | --------- +gaussian | identity, log, inverse +binomial | logit, probit, cloglog (complementary log-log) +poisson | log, identity, sqrt +gamma | inverse, identity, log + +There are three ways to specify the `family` argument. + +* Family name as a character string, e.g. `family = "gaussian"`. + +* Family function, e.g. `family = binomial`. + +* Result returned by a family function, e.g. `family = poisson(link = log)` + +For more information regarding the families and their link functions, see the Wikipedia page [Generalized Linear Model](https://en.wikipedia.org/wiki/Generalized_linear_model). + +We use the `mtcars` dataset as an illustration. The corresponding `SparkDataFrame` is `carsDF`. After fitting the model, we print out a summary and see the fitted values by making predictions on the original dataset. We can also pass into a new `SparkDataFrame` of same schema to predict on new data. + +```{r} +gaussianGLM <- spark.glm(carsDF, mpg ~ wt + hp) +summary(gaussianGLM) +``` +When doing prediction, a new column called `prediction` will be appended. Let's look at only a subset of columns here. +```{r} +gaussianFitted <- predict(gaussianGLM, carsDF) +head(select(gaussianFitted, "model", "prediction", "mpg", "wt", "hp")) +``` + +#### Naive Bayes Model + +Naive Bayes model assumes independence among the features. `spark.naiveBayes` fits a [Bernoulli naive Bayes model](https://en.wikipedia.org/wiki/Naive_Bayes_classifier#Bernoulli_naive_Bayes) against a SparkDataFrame. The data should be all categorical. These models are often used for document classification. + +```{r} +titanic <- as.data.frame(Titanic) +titanicDF <- createDataFrame(titanic[titanic$Freq > 0, -5]) +naiveBayesModel <- spark.naiveBayes(titanicDF, Survived ~ Class + Sex + Age) +summary(naiveBayesModel) +naiveBayesPrediction <- predict(naiveBayesModel, titanicDF) +head(select(naiveBayesPrediction, "Class", "Sex", "Age", "Survived", "prediction")) +``` + +#### k-Means Clustering + +`spark.kmeans` fits a $k$-means clustering model against a `SparkDataFrame`. As an unsupervised learning method, we don't need a response variable. Hence, the left hand side of the R formula should be left blank. The clustering is based only on the variables on the right hand side. + +```{r} +kmeansModel <- spark.kmeans(carsDF, ~ mpg + hp + wt, k = 3) +summary(kmeansModel) +kmeansPredictions <- predict(kmeansModel, carsDF) +head(select(kmeansPredictions, "model", "mpg", "hp", "wt", "prediction"), n = 20L) +``` + +#### AFT Survival Model +Survival analysis studies the expected duration of time until an event happens, and often the relationship with risk factors or treatment taken on the subject. In contrast to standard regression analysis, survival modeling has to deal with special characteristics in the data including non-negative survival time and censoring. + +Accelerated Failure Time (AFT) model is a parametric survival model for censored data that assumes the effect of a covariate is to accelerate or decelerate the life course of an event by some constant. For more information, refer to the Wikipedia page [AFT Model](https://en.wikipedia.org/wiki/Accelerated_failure_time_model) and the references there. Different from a [Proportional Hazards Model](https://en.wikipedia.org/wiki/Proportional_hazards_model) designed for the same purpose, the AFT model is easier to parallelize because each instance contributes to the objective function independently. +```{r} +library(survival) +ovarianDF <- createDataFrame(ovarian) +aftModel <- spark.survreg(ovarianDF, Surv(futime, fustat) ~ ecog_ps + rx) +summary(aftModel) +aftPredictions <- predict(aftModel, ovarianDF) +head(aftPredictions) +``` + +### Model Persistence +The following example shows how to save/load an ML model by SparkR. +```{r} +irisDF <- suppressWarnings(createDataFrame(iris)) +gaussianGLM <- spark.glm(irisDF, Sepal_Length ~ Sepal_Width + Species, family = "gaussian") + +# Save and then load a fitted MLlib model +modelPath <- tempfile(pattern = "ml", fileext = ".tmp") +write.ml(gaussianGLM, modelPath) +gaussianGLM2 <- read.ml(modelPath) + +# Check model summary +summary(gaussianGLM2) + +# Check model prediction +gaussianPredictions <- predict(gaussianGLM2, irisDF) +head(gaussianPredictions) + +unlink(modelPath) +``` + + +## Advanced Topics + +### SparkR Object Classes + +There are three main object classes in SparkR you may be working with. + +* `SparkDataFrame`: the central component of SparkR. It is an S4 class representing distributed collection of data organized into named columns, which is conceptually equivalent to a table in a relational database or a data frame in R. It has two slots `sdf` and `env`. + + `sdf` stores a reference to the corresponding Spark Dataset in the Spark JVM backend. + + `env` saves the meta-information of the object such as `isCached`. + +It can be created by data import methods or by transforming an existing `SparkDataFrame`. We can manipulate `SparkDataFrame` by numerous data processing functions and feed that into machine learning algorithms. + +* `Column`: an S4 class representing column of `SparkDataFrame`. The slot `jc` saves a reference to the corresponding Column object in the Spark JVM backend. + +It can be obtained from a `SparkDataFrame` by `$` operator, `df$col`. More often, it is used together with other functions, for example, with `select` to select particular columns, with `filter` and constructed conditions to select rows, with aggregation functions to compute aggregate statistics for each group. + +* `GroupedData`: an S4 class representing grouped data created by `groupBy` or by transforming other `GroupedData`. Its `sgd` slot saves a reference to a RelationalGroupedDataset object in the backend. + +This is often an intermediate object with group information and followed up by aggregation operations. + +### Architecture + +A complete description of architecture can be seen in reference, in particular the paper *SparkR: Scaling R Programs with Spark*. + +Under the hood of SparkR is Spark SQL engine. This avoids the overheads of running interpreted R code, and the optimized SQL execution engine in Spark uses structural information about data and computation flow to perform a bunch of optimizations to speed up the computation. + +The main method calls of actual computation happen in the Spark JVM of the driver. We have a socket-based SparkR API that allows us to invoke functions on the JVM from R. We use a SparkR JVM backend that listens on a Netty-based socket server. + +Two kinds of RPCs are supported in the SparkR JVM backend: method invocation and creating new objects. Method invocation can be done in two ways. + +* `sparkR.invokeJMethod` takes a reference to an existing Java object and a list of arguments to be passed on to the method. + +* `sparkR.invokeJStatic` takes a class name for static method and a list of arguments to be passed on to the method. + +The arguments are serialized using our custom wire format which is then deserialized on the JVM side. We then use Java reflection to invoke the appropriate method. + +To create objects, `sparkR.newJObject` is used and then similarly the appropriate constructor is invoked with provided arguments. + +Finally, we use a new R class `jobj` that refers to a Java object existing in the backend. These references are tracked on the Java side and are automatically garbage collected when they go out of scope on the R side. + +## Appendix + +### R and Spark Data Types {#DataTypes} + +R | Spark +----------- | ------------- +byte | byte +integer | integer +float | float +double | double +numeric | double +character | string +string | string +binary | binary +raw | binary +logical | boolean +POSIXct | timestamp +POSIXlt | timestamp +Date | date +array | array +list | array +env | map + +## References + +* [Spark Cluster Mode Overview](http://spark.apache.org/docs/latest/cluster-overview.html) + +* [Submitting Spark Applications](http://spark.apache.org/docs/latest/submitting-applications.html) + +* [Machine Learning Library Guide (MLlib)](http://spark.apache.org/docs/latest/ml-guide.html) + +* [SparkR: Scaling R Programs with Spark](https://people.csail.mit.edu/matei/papers/2016/sigmod_sparkr.pdf), Shivaram Venkataraman, Zongheng Yang, Davies Liu, Eric Liang, Hossein Falaki, Xiangrui Meng, Reynold Xin, Ali Ghodsi, Michael Franklin, Ion Stoica, and Matei Zaharia. SIGMOD 2016. June 2016. + +```{r, echo=FALSE} +sparkR.session.stop() +``` diff --git a/R/run-tests.sh b/R/run-tests.sh index 9dcf0ace7d97..1a1e8ab9ffe1 100755 --- a/R/run-tests.sh +++ b/R/run-tests.sh @@ -26,6 +26,17 @@ rm -f $LOGFILE SPARK_TESTING=1 $FWDIR/../bin/spark-submit --driver-java-options "-Dlog4j.configuration=file:$FWDIR/log4j.properties" --conf spark.hadoop.fs.default.name="file:///" $FWDIR/pkg/tests/run-all.R 2>&1 | tee -a $LOGFILE FAILED=$((PIPESTATUS[0]||$FAILED)) +# Also run the documentation tests for CRAN +CRAN_CHECK_LOG_FILE=$FWDIR/cran-check.out +rm -f $CRAN_CHECK_LOG_FILE + +NO_TESTS=1 NO_MANUAL=1 $FWDIR/check-cran.sh 2>&1 | tee -a $CRAN_CHECK_LOG_FILE +FAILED=$((PIPESTATUS[0]||$FAILED)) + +NUM_CRAN_WARNING="$(grep -c WARNING$ $CRAN_CHECK_LOG_FILE)" +NUM_CRAN_ERROR="$(grep -c ERROR$ $CRAN_CHECK_LOG_FILE)" +NUM_CRAN_NOTES="$(grep -c NOTE$ $CRAN_CHECK_LOG_FILE)" + if [[ $FAILED != 0 ]]; then cat $LOGFILE echo -en "\033[31m" # Red @@ -33,7 +44,17 @@ if [[ $FAILED != 0 ]]; then echo -en "\033[0m" # No color exit -1 else - echo -en "\033[32m" # Green - echo "Tests passed." - echo -en "\033[0m" # No color + # We have 2 existing NOTEs for new maintainer, attach() + # We have one more NOTE in Jenkins due to "No repository set" + if [[ $NUM_CRAN_WARNING != 0 || $NUM_CRAN_ERROR != 0 || $NUM_CRAN_NOTES -gt 3 ]]; then + cat $CRAN_CHECK_LOG_FILE + echo -en "\033[31m" # Red + echo "Had CRAN check errors; see logs." + echo -en "\033[0m" # No color + exit -1 + else + echo -en "\033[32m" # Green + echo "Tests passed." + echo -en "\033[0m" # No color + fi fi diff --git a/README.md b/README.md index c77c429e577c..96139e65918e 100644 --- a/README.md +++ b/README.md @@ -1,3 +1,19 @@ +## SnappyData's extensions to Spark + +- SnappyData collocates Spark executors with its in-memory data store in the same JVM. To achieve this, support for external cluster manager in Spark 2.0 is used to add a SnappyData cluster manager. +- SnappyData's MemoryManager was needed to generate and handle memory events. A property spark.memory.manager is now used to specify a memory manager other than Spark's own. +- To display the consumption of memory in an external embedded store, Spark's storage UI was updated. +- Support for getting length of type (for VARCHAR) was added in the JDBCDialect class. +- For SnappyData, dynamic continous queries on streams would be enabled in future. For that, support for registering DStreams after streaming context has started is added. +- For partitioning, sequence of expressions can be provided. SnappyData adds OrderlessHashPartitioning that does not take into account order of expressions while partitioning. +- Hive client thread-local configuration changed to be instance specific. +- Hive client added support for dropTable and listing tables for all databases. +- RDD partitions with executor specific preferred locations will be forced to be routed to one of those executors if alive. +- An "unsecure" version of random UUID added in DiskBlockManager for temporary file names. +- Added a fix for SPARK-13116. +- Increased visibility of some classes/methods. + + # Apache Spark Spark is a fast and general cluster computing system for Big Data. It provides diff --git a/assembly/build.gradle b/assembly/build.gradle new file mode 100644 index 000000000000..63db32e3e41f --- /dev/null +++ b/assembly/build.gradle @@ -0,0 +1,132 @@ +/* + * Copyright (c) 2016 SnappyData, Inc. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you + * may not use this file except in compliance with the License. You + * may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or + * implied. See the License for the specific language governing + * permissions and limitations under the License. See accompanying + * LICENSE file. + */ + +description = 'Spark Project Assembly' + +dependencies { + compile project(subprojectBase + 'snappy-spark-core_' + scalaBinaryVersion) + compile project(subprojectBase + 'snappy-spark-catalyst_' + scalaBinaryVersion) + compile project(subprojectBase + 'snappy-spark-sql_' + scalaBinaryVersion) + compile project(subprojectBase + 'snappy-spark-hive_' + scalaBinaryVersion) + compile project(subprojectBase + 'snappy-spark-hive-thriftserver_' + scalaBinaryVersion) + compile project(subprojectBase + 'snappy-spark-repl_' + scalaBinaryVersion) + compile project(subprojectBase + 'snappy-spark-streaming_' + scalaBinaryVersion) + compile project(subprojectBase + 'snappy-spark-streaming-kafka-0.8_' + scalaBinaryVersion) + compile project(subprojectBase + 'snappy-spark-streaming-kafka-0.10_' + scalaBinaryVersion) + compile project(subprojectBase + 'snappy-spark-yarn_' + scalaBinaryVersion) + compile project(subprojectBase + 'snappy-spark-mllib_' + scalaBinaryVersion) + compile project(subprojectBase + 'snappy-spark-graphx_' + scalaBinaryVersion) + if (rootProject.hasProperty('ganglia')) { + compile project(subprojectBase + 'snappy-spark-ganglia-lgpl_' + scalaBinaryVersion) + } +} + +def cleanProduct() { + delete "${sparkProjectRootDir}/python/lib/pyspark.zip" + delete snappyProductDir +} +clean.doLast { + cleanProduct() +} + +task product(type: Zip) { + def examplesProject = project(subprojectBase + 'snappy-spark-examples_' + scalaBinaryVersion) + String yarnShuffleProject = subprojectBase + 'snappy-spark-network-yarn_' + scalaBinaryVersion + dependsOn jar, examplesProject.jar, "${yarnShuffleProject}:shadowJar" + // create python zip + destinationDir = file("${snappyProductDir}/python/lib") + archiveName = 'pyspark.zip' + from("${sparkProjectRootDir}/python") { + include 'pyspark/**/*' + } + + doFirst { + cleanProduct() + } + doLast { + // copy all runtime dependencies (skip for top-level snappydata builds) + if (rootProject.name == 'snappy-spark') { + copy { + from(configurations.runtime) { + // exclude antlr4 explicitly (runtime is still included) + // that gets pulled by antlr gradle plugin + exclude '**antlr4-4*.jar' + // exclude scalatest included by spark-tags + exclude '**scalatest*.jar' + } + into "${snappyProductDir}/jars" + } + } + // copy scripts, data and other files that are part of distribution + copy { + from(sparkProjectRootDir) { + include 'bin/**' + include 'sbin/**' + include 'conf/**' + include 'data/**' + include 'licenses/**' + include 'python/**' + include 'examples/src/**' + } + into snappyProductDir + } + def sparkR = 'sparkProjectRootDir/R/lib/SparkR' + if (file(sparkR).exists()) { + copy { + from sparkR + into "${snappyProductDir}/R/lib" + } + } + + // copy yarn shuffle shadow jar + copy { + from "${project(yarnShuffleProject).buildDir}/jars" + into "${snappyProductDir}/yarn" + } + // copy examples jars + copy { + from "${examplesProject.buildDir}/jars" + into "${snappyProductDir}/examples/jars" + } + // create RELEASE file, copy README etc for top-level snappy-spark project + if (rootProject.name == 'snappy-spark') { + copy { + from(sparkProjectRootDir) { + include 'LICENSE' + include 'NOTICE' + include 'README.md' + } + into snappyProductDir + } + def releaseFile = file("${snappyProductDir}/RELEASE") + String buildFlags = '' + if (rootProject.hasProperty('docker')) { + buildFlags += ' -Pdocker' + } + if (rootProject.hasProperty('ganglia')) { + buildFlags += ' -Pganglia' + } + String gitRevision = "${gitCmd} rev-parse --short HEAD".execute().text.trim() + if (gitRevision.length() > 0) { + gitRevision = " (git revision ${gitRevision})" + } + + releaseFile.append("Spark ${version}${gitRevision} built for Hadoop ${hadoopVersion}\n") + releaseFile.append("Build flags:${buildFlags}\n") + } + } +} diff --git a/assembly/pom.xml b/assembly/pom.xml index 507ddc778383..6db3a599ff5e 100644 --- a/assembly/pom.xml +++ b/assembly/pom.xml @@ -21,7 +21,7 @@ org.apache.spark spark-parent_2.11 - 2.0.1-SNAPSHOT + 2.0.1 ../pom.xml diff --git a/bin/pyspark b/bin/pyspark index ac8aa04dba8a..037645dbd64d 100755 --- a/bin/pyspark +++ b/bin/pyspark @@ -65,7 +65,7 @@ export PYSPARK_PYTHON # Add the PySpark classes to the Python path: export PYTHONPATH="${SPARK_HOME}/python/:$PYTHONPATH" -export PYTHONPATH="${SPARK_HOME}/python/lib/py4j-0.10.1-src.zip:$PYTHONPATH" +export PYTHONPATH="${SPARK_HOME}/python/lib/py4j-0.10.3-src.zip:$PYTHONPATH" # Load the PySpark shell.py script when ./pyspark is used interactively: export OLD_PYTHONSTARTUP="$PYTHONSTARTUP" diff --git a/bin/pyspark2.cmd b/bin/pyspark2.cmd index 3e2ff100fb8a..1217a4f2f97a 100644 --- a/bin/pyspark2.cmd +++ b/bin/pyspark2.cmd @@ -30,7 +30,7 @@ if "x%PYSPARK_DRIVER_PYTHON%"=="x" ( ) set PYTHONPATH=%SPARK_HOME%\python;%PYTHONPATH% -set PYTHONPATH=%SPARK_HOME%\python\lib\py4j-0.10.1-src.zip;%PYTHONPATH% +set PYTHONPATH=%SPARK_HOME%\python\lib\py4j-0.10.3-src.zip;%PYTHONPATH% set OLD_PYTHONSTARTUP=%PYTHONSTARTUP% set PYTHONSTARTUP=%SPARK_HOME%\python\pyspark\shell.py diff --git a/bin/spark-class b/bin/spark-class index 658e076bc046..377c8d1add3f 100755 --- a/bin/spark-class +++ b/bin/spark-class @@ -80,6 +80,15 @@ done < <(build_command "$@") COUNT=${#CMD[@]} LAST=$((COUNT - 1)) LAUNCHER_EXIT_CODE=${CMD[$LAST]} + +# Certain JVM failures result in errors being printed to stdout (instead of stderr), which causes +# the code that parses the output of the launcher to get confused. In those cases, check if the +# exit code is an integer, and if it's not, handle it as a special error case. +if ! [[ $LAUNCHER_EXIT_CODE =~ ^[0-9]+$ ]]; then + echo "${CMD[@]}" | head -n-1 1>&2 + exit 1 +fi + if [ $LAUNCHER_EXIT_CODE != 0 ]; then exit $LAUNCHER_EXIT_CODE fi diff --git a/build.gradle b/build.gradle new file mode 100644 index 000000000000..292b3c0b6cd6 --- /dev/null +++ b/build.gradle @@ -0,0 +1,379 @@ +/* + * Copyright (c) 2016 SnappyData, Inc. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you + * may not use this file except in compliance with the License. You + * may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or + * implied. See the License for the specific language governing + * permissions and limitations under the License. See accompanying + * LICENSE file. + */ + +apply plugin: 'wrapper' + +// TODO: profiles and allow changing hadoopVersion + +buildscript { + repositories { + maven { url 'https://plugins.gradle.org/m2' } + mavenCentral() + } + dependencies { + classpath 'io.snappydata:gradle-scalatest:0.13-1' + classpath 'org.github.ngbinh.scalastyle:gradle-scalastyle-plugin_2.11:0.8.2' + } +} + +description = 'Spark Project' + +allprojects { + // We want to see all test results. This is equivalent to setting --continue + // on the command line. + gradle.startParameter.continueOnFailure = true + + repositories { + mavenCentral() + maven { url 'http://repository.apache.org/snapshots' } + } + + apply plugin: 'idea' + + group = 'io.snappydata' + version = '2.0.1-2' + + ext { + scalaBinaryVersion = '2.11' + scalaVersion = scalaBinaryVersion + '.8' + hadoopVersion = '2.7.3' + protobufVersion = '2.6.1' + jerseyVersion = '2.22.2' + sunJerseyVersion = '1.19.1' + jettyVersion = '9.2.16.v20160414' + log4jVersion = '1.2.17' + slf4jVersion = '1.7.21' + junitVersion = '4.12' + javaxServletVersion = '3.1.0' + guavaVersion = '14.0.1' + hiveVersion = '1.2.1.spark2' + chillVersion = '0.8.0' + nettyVersion = '3.8.0.Final' + nettyAllVersion = '4.0.29.Final' + derbyVersion = '10.12.1.1' + httpClientVersion = '4.5.2' + httpCoreVersion = '4.4.4' + fasterXmlVersion = '2.6.5' + snappyJavaVersion = '1.1.2.6' + parquetVersion = '1.7.0' + hiveParquetVersion = '1.6.0' + metricsVersion = '3.1.2' + thriftVersion = '0.9.3' + antlrVersion = '4.5.3' + jpamVersion = '1.1' + seleniumVersion = '2.52.0' + curatorVersion = '2.7.1' + commonsCodecVersion = '1.10' + avroVersion = '1.7.7' + jsr305Version = '3.0.1' + jlineVersion = '2.14.2' + scalatestVersion = '2.2.6' + pegdownVersion = '1.6.0' + + shadePackageName = 'org.spark_project' + } + + // default output directory like in sbt/maven + buildDir = 'build-artifacts/scala-' + scalaBinaryVersion + + ext { + if (rootProject.name == 'snappy-spark') { + subprojectBase = ':' + sparkProjectRoot = ':' + sparkProjectRootDir = project(':').projectDir + testResultsBase = "${rootProject.buildDir}/tests" + gitCmd = "git --git-dir=${rootDir}/.git --work-tree=${rootDir}" + } else { + subprojectBase = ':snappy-spark:' + sparkProjectRoot = ':snappy-spark' + sparkProjectRootDir = project(':snappy-spark').projectDir + testResultsBase = "${rootProject.buildDir}/tests/spark" + gitCmd = "git --git-dir=${project(sparkProjectRoot).projectDir}/.git --work-tree=${project(sparkProjectRoot).projectDir}" + } + snappyProductDir = "${rootProject.buildDir}/snappy" + } +} + +def getStackTrace(def t) { + java.io.StringWriter sw = new java.io.StringWriter() + java.io.PrintWriter pw = new java.io.PrintWriter(sw) + org.codehaus.groovy.runtime.StackTraceUtils.sanitize(t).printStackTrace(pw) + return sw.toString() +} + +task cleanSparkScalaTest << { + def workingDir = "${testResultsBase}/scalatest" + delete workingDir + file(workingDir).mkdirs() +} +task cleanSparkJUnit << { + def workingDir = "${testResultsBase}/junit" + delete workingDir + file(workingDir).mkdirs() +} + +subprojects { + apply plugin: 'scala' + apply plugin: 'maven' + apply plugin: 'scalaStyle' + + // apply compiler options + compileJava.options.encoding = 'UTF-8' + compileJava.options.compilerArgs << '-Xlint:all,-serial,-path,-deprecation' + // compileScala.scalaCompileOptions.optimize = true + compileScala.options.encoding = 'UTF-8' + + javadoc.options.charSet = 'UTF-8' + + scalaStyle { + configLocation = "${sparkProjectRootDir}/scalastyle-config.xml" + inputEncoding = 'UTF-8' + outputEncoding = 'UTF-8' + outputFile = "${buildDir}/scalastyle-output.xml" + includeTestSourceDirectory = false + source = 'src/main/scala' + testSource = 'src/test/scala' + failOnViolation = true + failOnWarning = false + } + + configurations { + runtimeJar { + description 'a dependency to include additional jars at runtime' + visible true + } + } + + // when invoking from snappydata, below are already defined at top-level + if (rootProject.name == 'snappy-spark') { + task packageSources(type: Jar, dependsOn: classes) { + classifier = 'sources' + from sourceSets.main.allSource + } + + configurations { + provided { + description 'a dependency that is provided externally at runtime' + visible true + } + testOutput { + extendsFrom testCompile + description 'a dependency that exposes test artifacts' + } + } + + task packageTests(type: Jar, dependsOn: testClasses) { + description 'Assembles a jar archive of test classes.' + from sourceSets.test.output.classesDir + classifier = 'tests' + } + artifacts { + testOutput packageTests + } + + idea { + module { + scopes.PROVIDED.plus += [ configurations.provided ] + } + } + + sourceSets { + main.compileClasspath += configurations.provided + main.runtimeClasspath -= configurations.provided + test.compileClasspath += configurations.provided + test.runtimeClasspath += configurations.provided + } + + javadoc.classpath += configurations.provided + } + task packageScalaDocs(type: Jar, dependsOn: scaladoc) { + classifier = 'javadoc' + from scaladoc + } + if (rootProject.hasProperty('enablePublish')) { + artifacts { + archives packageScalaDocs, packageSources + } + } + + // fix scala+java mix to all use compileScala which use correct dependency order + sourceSets.main.scala.srcDir 'src/main/java' + sourceSets.main.java.srcDirs = [] + + dependencies { + // This is a dummy dependency that is used along with the shading plug-in + // to create effective poms on publishing (see SPARK-3812). + //compile group: 'org.spark-project.spark', name: 'unused', version: '1.0.0' + compile 'org.scala-lang:scala-library:' + scalaVersion + compile 'org.scala-lang:scala-reflect:' + scalaVersion + + compile group: 'log4j', name:'log4j', version: log4jVersion + compile 'org.slf4j:slf4j-api:' + slf4jVersion + compile 'org.slf4j:slf4j-log4j12:' + slf4jVersion + + testCompile "junit:junit:${junitVersion}" + testCompile "org.scalatest:scalatest_${scalaBinaryVersion}:${scalatestVersion}" + testCompile 'org.mockito:mockito-core:1.10.19' + testCompile 'org.scalacheck:scalacheck_' + scalaBinaryVersion + ':1.12.5' + testCompile 'com.novocode:junit-interface:0.11' + + testRuntime "org.pegdown:pegdown:${pegdownVersion}" + } + + if (rootProject.name == 'snappy-spark') { + task scalaTest(type: Test) { + actions = [ new com.github.maiflai.ScalaTestAction() ] + + List suites = [] + extensions.add(com.github.maiflai.ScalaTestAction.SUITES, suites) + extensions.add('suite', { String name -> suites.add(name) } ) + extensions.add('suites', { String... name -> suites.addAll(name) } ) + + def result = new StringBuilder() + extensions.add(com.github.maiflai.ScalaTestAction.TESTRESULT, result) + extensions.add('testResult', { String name -> result.setLength(0); result.append(name) } ) + + def output = new StringBuilder() + extensions.add(com.github.maiflai.ScalaTestAction.TESTOUTPUT, output) + extensions.add('testOutput', { String name -> output.setLength(0); output.append(name) } ) + + def errorOutput = new StringBuilder() + extensions.add(com.github.maiflai.ScalaTestAction.TESTERROR, errorOutput) + extensions.add('testError', { String name -> errorOutput.setLength(0); errorOutput.append(name) } ) + + // running a single scala suite + if (rootProject.hasProperty('singleSuite')) { + suite singleSuite + } + } + } + scalaTest { + // top-level default is single process run since scalatest does not + // spawn separate JVMs + maxParallelForks = 1 + systemProperties 'test.src.tables': '__not_used__' + + workingDir = "${testResultsBase}/scalatest" + + testResult '/dev/tty' + testOutput "${workingDir}/output.txt" + testError "${workingDir}/error.txt" + binResultsDir = file("${workingDir}/binary/${project.name}") + reports.html.destination = file("${workingDir}/html/${project.name}") + reports.junitXml.destination = file(workingDir) + } + test { + jvmArgs '-Xss4096k' + maxParallelForks = (2 * Runtime.getRuntime().availableProcessors()) + systemProperties 'spark.master.rest.enabled': 'false', + 'test.src.tables': 'src' + + workingDir = "${testResultsBase}/junit" + + binResultsDir = file("${workingDir}/binary/${project.name}") + reports.html.destination = file("${workingDir}/html/${project.name}") + reports.junitXml.destination = file(workingDir) + } + // need to do below after graph is ready else it will give an error about + // runtimeClaspath being set after being finalized + gradle.taskGraph.whenReady({ graph -> + tasks.withType(Test).each { test -> + test.configure { + onlyIf { ! Boolean.getBoolean('skip.tests') } + + jvmArgs '-ea', '-XX:+HeapDumpOnOutOfMemoryError','-XX:+UseConcMarkSweepGC', + '-XX:+UseParNewGC', '-XX:+CMSClassUnloadingEnabled', '-XX:MaxPermSize=512m' + minHeapSize '4g' + maxHeapSize '4g' + // disable assertions for hive tests as in Spark's pom.xml because HiveCompatibilitySuite currently fails (SPARK-4814) + if (test.project.name.contains('snappy-spark-hive_')) { + jvmArgs '-da' + maxParallelForks = 1 + } else { + jvmArgs '-ea' + } + environment 'SPARK_DIST_CLASSPATH': "${sourceSets.test.runtimeClasspath.asPath}", + 'SPARK_PREPEND_CLASSES': '1', + 'SPARK_SCALA_VERSION': scalaBinaryVersion, + 'SPARK_TESTING': '1', + 'JAVA_HOME': System.getProperty('java.home') + systemProperties 'log4j.configuration': "file:${projectDir}/src/test/resources/log4j.properties", + 'derby.system.durability': 'test', + 'java.awt.headless': 'true', + 'java.io.tmpdir': "${rootProject.buildDir}/tmp", + 'spark.test.home': snappyProductDir, + 'spark.project.home': "${project(sparkProjectRoot).projectDir}", + 'spark.testing': '1', + 'spark.ui.enabled': 'false', + 'spark.ui.showConsoleProgress': 'false', + 'spark.driver.allowMultipleContexts': 'true', + 'spark.unsafe.exceptionOnMemoryLeak': 'true' + + testLogging.exceptionFormat = 'full' + + if (rootProject.name == 'snappy-spark') { + def eol = System.getProperty('line.separator') + beforeTest { desc -> + def now = new Date().format('yyyy-MM-dd HH:mm:ss.SSS Z') + def progress = new File(workingDir, 'progress.txt') + def output = new File(workingDir, 'output.txt') + progress << "$now Starting test $desc.className $desc.name$eol" + output << "${now} STARTING TEST ${desc.className} ${desc.name}${eol}${eol}" + } + onOutput { desc, event -> + def output = new File(workingDir, 'output.txt') + output << event.message + } + afterTest { desc, result -> + def now = new Date().format('yyyy-MM-dd HH:mm:ss.SSS Z') + def progress = new File(workingDir, 'progress.txt') + def output = new File(workingDir, 'output.txt') + progress << "${now} Completed test ${desc.className} ${desc.name} with result: ${result.resultType}${eol}" + output << "${eol}${now} COMPLETED TEST ${desc.className} ${desc.name} with result: ${result.resultType}${eol}${eol}" + result.exceptions.each { t -> + progress << " EXCEPTION: ${getStackTrace(t)}${eol}" + output << "${getStackTrace(t)}${eol}" + } + } + } + } + } + }) + test.dependsOn subprojectBase + 'cleanSparkJUnit' + scalaTest.dependsOn subprojectBase + 'cleanSparkScalaTest' + check.dependsOn scalaTest + if (rootProject.name == 'snappy-spark') { + check.dependsOn "${subprojectBase}snappy-spark-assembly_${scalaBinaryVersion}:product" + } +} + +task generateSources { + dependsOn subprojectBase + 'snappy-spark-catalyst_' + scalaBinaryVersion + ':generateGrammarSource' + dependsOn subprojectBase + 'snappy-spark-streaming-flume-sink_' + scalaBinaryVersion + ':generateAvroJava' +} + +if (rootProject.name == 'snappy-spark') { + task scalaStyle { + dependsOn subprojects.scalaStyle + } + task check { + dependsOn subprojects.check + } +} else { + scalaStyle.dependsOn subprojects.scalaStyle + check.dependsOn subprojects.check +} diff --git a/common/network-common/build.gradle b/common/network-common/build.gradle new file mode 100644 index 000000000000..a5f68978b439 --- /dev/null +++ b/common/network-common/build.gradle @@ -0,0 +1,29 @@ +/* + * Copyright (c) 2016 SnappyData, Inc. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you + * may not use this file except in compliance with the License. You + * may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or + * implied. See the License for the specific language governing + * permissions and limitations under the License. See accompanying + * LICENSE file. + */ + +description = 'Spark Project Networking' + +dependencies { + compile project(subprojectBase + 'snappy-spark-tags_' + scalaBinaryVersion) + + compile group: 'io.netty', name: 'netty-all', version: nettyAllVersion + compile group: 'com.google.code.findbugs', name: 'jsr305', version: jsr305Version + compile group: 'com.google.guava', name: 'guava', version: guavaVersion + compile group: 'com.fasterxml.jackson.core', name: 'jackson-annotations', version: fasterXmlVersion + compile group: 'com.fasterxml.jackson.core', name: 'jackson-databind', version: fasterXmlVersion + compile group: 'org.fusesource.leveldbjni', name: 'leveldbjni-all', version: '1.8' +} diff --git a/common/network-common/pom.xml b/common/network-common/pom.xml index bc3b0fe73f6e..269b845565f1 100644 --- a/common/network-common/pom.xml +++ b/common/network-common/pom.xml @@ -22,7 +22,7 @@ org.apache.spark spark-parent_2.11 - 2.0.1-SNAPSHOT + 2.0.1 ../../pom.xml @@ -42,6 +42,22 @@ netty-all + + org.fusesource.leveldbjni + leveldbjni-all + 1.8 + + + + com.fasterxml.jackson.core + jackson-databind + + + + com.fasterxml.jackson.core + jackson-annotations + + org.slf4j diff --git a/common/network-common/src/main/java/org/apache/spark/network/client/TransportClient.java b/common/network-common/src/main/java/org/apache/spark/network/client/TransportClient.java index 64a83171e9e9..17ac91dd4cdd 100644 --- a/common/network-common/src/main/java/org/apache/spark/network/client/TransportClient.java +++ b/common/network-common/src/main/java/org/apache/spark/network/client/TransportClient.java @@ -43,7 +43,7 @@ import org.apache.spark.network.protocol.RpcRequest; import org.apache.spark.network.protocol.StreamChunkId; import org.apache.spark.network.protocol.StreamRequest; -import org.apache.spark.network.util.NettyUtils; +import static org.apache.spark.network.util.NettyUtils.getRemoteAddress; /** * Client for fetching consecutive chunks of a pre-negotiated stream. This API is intended to allow @@ -135,9 +135,10 @@ public void fetchChunk( long streamId, final int chunkIndex, final ChunkReceivedCallback callback) { - final String serverAddr = NettyUtils.getRemoteAddress(channel); final long startTime = System.currentTimeMillis(); - logger.debug("Sending fetch chunk request {} to {}", chunkIndex, serverAddr); + if (logger.isDebugEnabled()) { + logger.debug("Sending fetch chunk request {} to {}", chunkIndex, getRemoteAddress(channel)); + } final StreamChunkId streamChunkId = new StreamChunkId(streamId, chunkIndex); handler.addFetchRequest(streamChunkId, callback); @@ -148,11 +149,13 @@ public void fetchChunk( public void operationComplete(ChannelFuture future) throws Exception { if (future.isSuccess()) { long timeTaken = System.currentTimeMillis() - startTime; - logger.trace("Sending request {} to {} took {} ms", streamChunkId, serverAddr, - timeTaken); + if (logger.isTraceEnabled()) { + logger.trace("Sending request {} to {} took {} ms", streamChunkId, + getRemoteAddress(channel), timeTaken); + } } else { String errorMsg = String.format("Failed to send request %s to %s: %s", streamChunkId, - serverAddr, future.cause()); + getRemoteAddress(channel), future.cause()); logger.error(errorMsg, future.cause()); handler.removeFetchRequest(streamChunkId); channel.close(); @@ -173,9 +176,10 @@ public void operationComplete(ChannelFuture future) throws Exception { * @param callback Object to call with the stream data. */ public void stream(final String streamId, final StreamCallback callback) { - final String serverAddr = NettyUtils.getRemoteAddress(channel); final long startTime = System.currentTimeMillis(); - logger.debug("Sending stream request for {} to {}", streamId, serverAddr); + if (logger.isDebugEnabled()) { + logger.debug("Sending stream request for {} to {}", streamId, getRemoteAddress(channel)); + } // Need to synchronize here so that the callback is added to the queue and the RPC is // written to the socket atomically, so that callbacks are called in the right order @@ -188,11 +192,13 @@ public void stream(final String streamId, final StreamCallback callback) { public void operationComplete(ChannelFuture future) throws Exception { if (future.isSuccess()) { long timeTaken = System.currentTimeMillis() - startTime; - logger.trace("Sending request for {} to {} took {} ms", streamId, serverAddr, - timeTaken); + if (logger.isTraceEnabled()) { + logger.trace("Sending request for {} to {} took {} ms", streamId, + getRemoteAddress(channel), timeTaken); + } } else { String errorMsg = String.format("Failed to send request for %s to %s: %s", streamId, - serverAddr, future.cause()); + getRemoteAddress(channel), future.cause()); logger.error(errorMsg, future.cause()); channel.close(); try { @@ -215,9 +221,10 @@ public void operationComplete(ChannelFuture future) throws Exception { * @return The RPC's id. */ public long sendRpc(ByteBuffer message, final RpcResponseCallback callback) { - final String serverAddr = NettyUtils.getRemoteAddress(channel); final long startTime = System.currentTimeMillis(); - logger.trace("Sending RPC to {}", serverAddr); + if (logger.isTraceEnabled()) { + logger.trace("Sending RPC to {}", getRemoteAddress(channel)); + } final long requestId = Math.abs(UUID.randomUUID().getLeastSignificantBits()); handler.addRpcRequest(requestId, callback); @@ -228,10 +235,13 @@ public long sendRpc(ByteBuffer message, final RpcResponseCallback callback) { public void operationComplete(ChannelFuture future) throws Exception { if (future.isSuccess()) { long timeTaken = System.currentTimeMillis() - startTime; - logger.trace("Sending request {} to {} took {} ms", requestId, serverAddr, timeTaken); + if (logger.isTraceEnabled()) { + logger.trace("Sending request {} to {} took {} ms", requestId, + getRemoteAddress(channel), timeTaken); + } } else { String errorMsg = String.format("Failed to send RPC %s to %s: %s", requestId, - serverAddr, future.cause()); + getRemoteAddress(channel), future.cause()); logger.error(errorMsg, future.cause()); handler.removeRpcRequest(requestId); channel.close(); diff --git a/common/network-common/src/main/java/org/apache/spark/network/client/TransportClientFactory.java b/common/network-common/src/main/java/org/apache/spark/network/client/TransportClientFactory.java index a27aaf2b277f..1c9916baee07 100644 --- a/common/network-common/src/main/java/org/apache/spark/network/client/TransportClientFactory.java +++ b/common/network-common/src/main/java/org/apache/spark/network/client/TransportClientFactory.java @@ -195,7 +195,7 @@ public TransportClient createUnmanagedClient(String remoteHost, int remotePort) /** Create a completely new {@link TransportClient} to the remote address. */ private TransportClient createClient(InetSocketAddress address) throws IOException { - logger.debug("Creating new connection to " + address); + logger.debug("Creating new connection to {}", address); Bootstrap bootstrap = new Bootstrap(); bootstrap.group(workerGroup) diff --git a/common/network-common/src/main/java/org/apache/spark/network/client/TransportResponseHandler.java b/common/network-common/src/main/java/org/apache/spark/network/client/TransportResponseHandler.java index 8a69223c88ee..179667296ec7 100644 --- a/common/network-common/src/main/java/org/apache/spark/network/client/TransportResponseHandler.java +++ b/common/network-common/src/main/java/org/apache/spark/network/client/TransportResponseHandler.java @@ -38,7 +38,7 @@ import org.apache.spark.network.protocol.StreamFailure; import org.apache.spark.network.protocol.StreamResponse; import org.apache.spark.network.server.MessageHandler; -import org.apache.spark.network.util.NettyUtils; +import static org.apache.spark.network.util.NettyUtils.getRemoteAddress; import org.apache.spark.network.util.TransportFrameDecoder; /** @@ -122,7 +122,7 @@ public void channelActive() { @Override public void channelInactive() { if (numOutstandingRequests() > 0) { - String remoteAddress = NettyUtils.getRemoteAddress(channel); + String remoteAddress = getRemoteAddress(channel); logger.error("Still have {} requests outstanding when connection from {} is closed", numOutstandingRequests(), remoteAddress); failOutstandingRequests(new IOException("Connection from " + remoteAddress + " closed")); @@ -132,7 +132,7 @@ public void channelInactive() { @Override public void exceptionCaught(Throwable cause) { if (numOutstandingRequests() > 0) { - String remoteAddress = NettyUtils.getRemoteAddress(channel); + String remoteAddress = getRemoteAddress(channel); logger.error("Still have {} requests outstanding when connection from {} is closed", numOutstandingRequests(), remoteAddress); failOutstandingRequests(cause); @@ -141,13 +141,12 @@ public void exceptionCaught(Throwable cause) { @Override public void handle(ResponseMessage message) throws Exception { - String remoteAddress = NettyUtils.getRemoteAddress(channel); if (message instanceof ChunkFetchSuccess) { ChunkFetchSuccess resp = (ChunkFetchSuccess) message; ChunkReceivedCallback listener = outstandingFetches.get(resp.streamChunkId); if (listener == null) { logger.warn("Ignoring response for block {} from {} since it is not outstanding", - resp.streamChunkId, remoteAddress); + resp.streamChunkId, getRemoteAddress(channel)); resp.body().release(); } else { outstandingFetches.remove(resp.streamChunkId); @@ -159,7 +158,7 @@ public void handle(ResponseMessage message) throws Exception { ChunkReceivedCallback listener = outstandingFetches.get(resp.streamChunkId); if (listener == null) { logger.warn("Ignoring response for block {} from {} ({}) since it is not outstanding", - resp.streamChunkId, remoteAddress, resp.errorString); + resp.streamChunkId, getRemoteAddress(channel), resp.errorString); } else { outstandingFetches.remove(resp.streamChunkId); listener.onFailure(resp.streamChunkId.chunkIndex, new ChunkFetchFailureException( @@ -170,7 +169,7 @@ public void handle(ResponseMessage message) throws Exception { RpcResponseCallback listener = outstandingRpcs.get(resp.requestId); if (listener == null) { logger.warn("Ignoring response for RPC {} from {} ({} bytes) since it is not outstanding", - resp.requestId, remoteAddress, resp.body().size()); + resp.requestId, getRemoteAddress(channel), resp.body().size()); } else { outstandingRpcs.remove(resp.requestId); try { @@ -184,7 +183,7 @@ public void handle(ResponseMessage message) throws Exception { RpcResponseCallback listener = outstandingRpcs.get(resp.requestId); if (listener == null) { logger.warn("Ignoring response for RPC {} from {} ({}) since it is not outstanding", - resp.requestId, remoteAddress, resp.errorString); + resp.requestId, getRemoteAddress(channel), resp.errorString); } else { outstandingRpcs.remove(resp.requestId); listener.onFailure(new RuntimeException(resp.errorString)); diff --git a/common/network-common/src/main/java/org/apache/spark/network/protocol/MessageDecoder.java b/common/network-common/src/main/java/org/apache/spark/network/protocol/MessageDecoder.java index 074780f2b95c..f0453186185e 100644 --- a/common/network-common/src/main/java/org/apache/spark/network/protocol/MessageDecoder.java +++ b/common/network-common/src/main/java/org/apache/spark/network/protocol/MessageDecoder.java @@ -39,7 +39,7 @@ public void decode(ChannelHandlerContext ctx, ByteBuf in, List out) { Message.Type msgType = Message.Type.decode(in); Message decoded = decode(msgType, in); assert decoded.type() == msgType; - logger.trace("Received message " + msgType + ": " + decoded); + logger.trace("Received message {}: {}", msgType, decoded); out.add(decoded); } diff --git a/common/network-common/src/main/java/org/apache/spark/network/server/TransportChannelHandler.java b/common/network-common/src/main/java/org/apache/spark/network/server/TransportChannelHandler.java index f2223379a9d2..884ea7d1152a 100644 --- a/common/network-common/src/main/java/org/apache/spark/network/server/TransportChannelHandler.java +++ b/common/network-common/src/main/java/org/apache/spark/network/server/TransportChannelHandler.java @@ -29,7 +29,7 @@ import org.apache.spark.network.protocol.Message; import org.apache.spark.network.protocol.RequestMessage; import org.apache.spark.network.protocol.ResponseMessage; -import org.apache.spark.network.util.NettyUtils; +import static org.apache.spark.network.util.NettyUtils.getRemoteAddress; /** * The single Transport-level Channel handler which is used for delegating requests to the @@ -76,7 +76,7 @@ public TransportClient getClient() { @Override public void exceptionCaught(ChannelHandlerContext ctx, Throwable cause) throws Exception { - logger.warn("Exception in connection from " + NettyUtils.getRemoteAddress(ctx.channel()), + logger.warn("Exception in connection from " + getRemoteAddress(ctx.channel()), cause); requestHandler.exceptionCaught(cause); responseHandler.exceptionCaught(cause); @@ -139,7 +139,7 @@ public void userEventTriggered(ChannelHandlerContext ctx, Object evt) throws Exc System.nanoTime() - responseHandler.getTimeOfLastRequestNs() > requestTimeoutNs; if (e.state() == IdleState.ALL_IDLE && isActuallyOverdue) { if (responseHandler.numOutstandingRequests() > 0) { - String address = NettyUtils.getRemoteAddress(ctx.channel()); + String address = getRemoteAddress(ctx.channel()); logger.error("Connection to {} has been quiet for {} ms while there are outstanding " + "requests. Assuming connection is dead; please adjust spark.network.timeout if " + "this is wrong.", address, requestTimeoutNs / 1000 / 1000); diff --git a/common/network-common/src/main/java/org/apache/spark/network/server/TransportRequestHandler.java b/common/network-common/src/main/java/org/apache/spark/network/server/TransportRequestHandler.java index bebe88ec5d50..143c29c6a57c 100644 --- a/common/network-common/src/main/java/org/apache/spark/network/server/TransportRequestHandler.java +++ b/common/network-common/src/main/java/org/apache/spark/network/server/TransportRequestHandler.java @@ -17,6 +17,7 @@ package org.apache.spark.network.server; +import java.net.SocketAddress; import java.nio.ByteBuffer; import com.google.common.base.Throwables; @@ -42,7 +43,7 @@ import org.apache.spark.network.protocol.StreamFailure; import org.apache.spark.network.protocol.StreamRequest; import org.apache.spark.network.protocol.StreamResponse; -import org.apache.spark.network.util.NettyUtils; +import static org.apache.spark.network.util.NettyUtils.getRemoteAddress; /** * A handler that processes requests from clients and writes chunk data back. Each handler is @@ -114,9 +115,10 @@ public void handle(RequestMessage request) { } private void processFetchRequest(final ChunkFetchRequest req) { - final String client = NettyUtils.getRemoteAddress(channel); - - logger.trace("Received req from {} to fetch block {}", client, req.streamChunkId); + if (logger.isTraceEnabled()) { + logger.trace("Received req from {} to fetch block {}", getRemoteAddress(channel), + req.streamChunkId); + } ManagedBuffer buf; try { @@ -124,8 +126,8 @@ private void processFetchRequest(final ChunkFetchRequest req) { streamManager.registerChannel(channel, req.streamChunkId.streamId); buf = streamManager.getChunk(req.streamChunkId.streamId, req.streamChunkId.chunkIndex); } catch (Exception e) { - logger.error(String.format( - "Error opening block %s for request from %s", req.streamChunkId, client), e); + logger.error(String.format("Error opening block %s for request from %s", req.streamChunkId, + getRemoteAddress(channel)), e); respond(new ChunkFetchFailure(req.streamChunkId, Throwables.getStackTraceAsString(e))); return; } @@ -134,13 +136,12 @@ private void processFetchRequest(final ChunkFetchRequest req) { } private void processStreamRequest(final StreamRequest req) { - final String client = NettyUtils.getRemoteAddress(channel); ManagedBuffer buf; try { buf = streamManager.openStream(req.streamId); } catch (Exception e) { logger.error(String.format( - "Error opening stream %s for request from %s", req.streamId, client), e); + "Error opening stream %s for request from %s", req.streamId, getRemoteAddress(channel)), e); respond(new StreamFailure(req.streamId, Throwables.getStackTraceAsString(e))); return; } @@ -189,13 +190,13 @@ private void processOneWayMessage(OneWayMessage req) { * it will be logged and the channel closed. */ private void respond(final Encodable result) { - final String remoteAddress = channel.remoteAddress().toString(); + final SocketAddress remoteAddress = channel.remoteAddress(); channel.writeAndFlush(result).addListener( new ChannelFutureListener() { @Override public void operationComplete(ChannelFuture future) throws Exception { if (future.isSuccess()) { - logger.trace(String.format("Sent result %s to client %s", result, remoteAddress)); + logger.trace("Sent result {} to client {}", result, remoteAddress); } else { logger.error(String.format("Error sending result %s to %s; closing connection", result, remoteAddress), future.cause()); diff --git a/common/network-common/src/main/java/org/apache/spark/network/server/TransportServer.java b/common/network-common/src/main/java/org/apache/spark/network/server/TransportServer.java index baae235e0220..a67db4f69f08 100644 --- a/common/network-common/src/main/java/org/apache/spark/network/server/TransportServer.java +++ b/common/network-common/src/main/java/org/apache/spark/network/server/TransportServer.java @@ -130,7 +130,7 @@ protected void initChannel(SocketChannel ch) throws Exception { channelFuture.syncUninterruptibly(); port = ((InetSocketAddress) channelFuture.channel().localAddress()).getPort(); - logger.debug("Shuffle server started on port :" + port); + logger.debug("Shuffle server started on port: {}", port); } @Override diff --git a/common/network-common/src/main/java/org/apache/spark/network/util/LevelDBProvider.java b/common/network-common/src/main/java/org/apache/spark/network/util/LevelDBProvider.java new file mode 100644 index 000000000000..f96d068cf3d5 --- /dev/null +++ b/common/network-common/src/main/java/org/apache/spark/network/util/LevelDBProvider.java @@ -0,0 +1,152 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.network.util; + +import java.io.File; +import java.io.IOException; +import java.nio.charset.StandardCharsets; + +import com.fasterxml.jackson.annotation.JsonCreator; +import com.fasterxml.jackson.annotation.JsonProperty; +import com.fasterxml.jackson.databind.ObjectMapper; +import org.fusesource.leveldbjni.JniDBFactory; +import org.fusesource.leveldbjni.internal.NativeDB; +import org.iq80.leveldb.DB; +import org.iq80.leveldb.Options; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +/** + * LevelDB utility class available in the network package. + */ +public class LevelDBProvider { + private static final Logger logger = LoggerFactory.getLogger(LevelDBProvider.class); + + public static DB initLevelDB(File dbFile, StoreVersion version, ObjectMapper mapper) throws + IOException { + DB tmpDb = null; + if (dbFile != null) { + Options options = new Options(); + options.createIfMissing(false); + options.logger(new LevelDBLogger()); + try { + tmpDb = JniDBFactory.factory.open(dbFile, options); + } catch (NativeDB.DBException e) { + if (e.isNotFound() || e.getMessage().contains(" does not exist ")) { + logger.info("Creating state database at " + dbFile); + options.createIfMissing(true); + try { + tmpDb = JniDBFactory.factory.open(dbFile, options); + } catch (NativeDB.DBException dbExc) { + throw new IOException("Unable to create state store", dbExc); + } + } else { + // the leveldb file seems to be corrupt somehow. Lets just blow it away and create a new + // one, so we can keep processing new apps + logger.error("error opening leveldb file {}. Creating new file, will not be able to " + + "recover state for existing applications", dbFile, e); + if (dbFile.isDirectory()) { + for (File f : dbFile.listFiles()) { + if (!f.delete()) { + logger.warn("error deleting {}", f.getPath()); + } + } + } + if (!dbFile.delete()) { + logger.warn("error deleting {}", dbFile.getPath()); + } + options.createIfMissing(true); + try { + tmpDb = JniDBFactory.factory.open(dbFile, options); + } catch (NativeDB.DBException dbExc) { + throw new IOException("Unable to create state store", dbExc); + } + + } + } + // if there is a version mismatch, we throw an exception, which means the service is unusable + checkVersion(tmpDb, version, mapper); + } + return tmpDb; + } + + private static class LevelDBLogger implements org.iq80.leveldb.Logger { + private static final Logger LOG = LoggerFactory.getLogger(LevelDBLogger.class); + + @Override + public void log(String message) { + LOG.info(message); + } + } + + /** + * Simple major.minor versioning scheme. Any incompatible changes should be across major + * versions. Minor version differences are allowed -- meaning we should be able to read + * dbs that are either earlier *or* later on the minor version. + */ + public static void checkVersion(DB db, StoreVersion newversion, ObjectMapper mapper) throws + IOException { + byte[] bytes = db.get(StoreVersion.KEY); + if (bytes == null) { + storeVersion(db, newversion, mapper); + } else { + StoreVersion version = mapper.readValue(bytes, StoreVersion.class); + if (version.major != newversion.major) { + throw new IOException("cannot read state DB with version " + version + ", incompatible " + + "with current version " + newversion); + } + storeVersion(db, newversion, mapper); + } + } + + public static void storeVersion(DB db, StoreVersion version, ObjectMapper mapper) + throws IOException { + db.put(StoreVersion.KEY, mapper.writeValueAsBytes(version)); + } + + public static class StoreVersion { + + static final byte[] KEY = "StoreVersion".getBytes(StandardCharsets.UTF_8); + + public final int major; + public final int minor; + + @JsonCreator + public StoreVersion(@JsonProperty("major") int major, @JsonProperty("minor") int minor) { + this.major = major; + this.minor = minor; + } + + @Override + public boolean equals(Object o) { + if (this == o) return true; + if (o == null || getClass() != o.getClass()) return false; + + StoreVersion that = (StoreVersion) o; + + return major == that.major && minor == that.minor; + } + + @Override + public int hashCode() { + int result = major; + result = 31 * result + minor; + return result; + } + } +} diff --git a/common/network-shuffle/build.gradle b/common/network-shuffle/build.gradle new file mode 100644 index 000000000000..0ffbc3414ad1 --- /dev/null +++ b/common/network-shuffle/build.gradle @@ -0,0 +1,30 @@ +/* + * Copyright (c) 2016 SnappyData, Inc. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you + * may not use this file except in compliance with the License. You + * may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or + * implied. See the License for the specific language governing + * permissions and limitations under the License. See accompanying + * LICENSE file. + */ + +description = 'Spark Project Shuffle Streaming Service' + +dependencies { + compile project(subprojectBase + 'snappy-spark-network-common_' + scalaBinaryVersion) + compile project(subprojectBase + 'snappy-spark-tags_' + scalaBinaryVersion) + + compile group: 'org.fusesource.leveldbjni', name: 'leveldbjni-all', version: '1.8' + compile group: 'com.fasterxml.jackson.core', name: 'jackson-databind', version: fasterXmlVersion + compile group: 'com.fasterxml.jackson.core', name: 'jackson-annotations', version: fasterXmlVersion + compile group: 'com.google.guava', name: 'guava', version: guavaVersion + + testCompile project(path: subprojectBase + 'snappy-spark-network-common_' + scalaBinaryVersion, configuration: 'testOutput') +} diff --git a/common/network-shuffle/pom.xml b/common/network-shuffle/pom.xml index 2fb5835305a2..20cf29efffc7 100644 --- a/common/network-shuffle/pom.xml +++ b/common/network-shuffle/pom.xml @@ -22,7 +22,7 @@ org.apache.spark spark-parent_2.11 - 2.0.1-SNAPSHOT + 2.0.1 ../../pom.xml @@ -43,22 +43,6 @@ ${project.version} - - org.fusesource.leveldbjni - leveldbjni-all - 1.8 - - - - com.fasterxml.jackson.core - jackson-databind - - - - com.fasterxml.jackson.core - jackson-annotations - - org.slf4j diff --git a/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/ExternalShuffleBlockResolver.java b/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/ExternalShuffleBlockResolver.java index 54e870a9b56a..e34dc1f5b1de 100644 --- a/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/ExternalShuffleBlockResolver.java +++ b/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/ExternalShuffleBlockResolver.java @@ -30,17 +30,16 @@ import com.google.common.annotations.VisibleForTesting; import com.google.common.base.Objects; import com.google.common.collect.Maps; -import org.fusesource.leveldbjni.JniDBFactory; -import org.fusesource.leveldbjni.internal.NativeDB; import org.iq80.leveldb.DB; import org.iq80.leveldb.DBIterator; -import org.iq80.leveldb.Options; import org.slf4j.Logger; import org.slf4j.LoggerFactory; import org.apache.spark.network.buffer.FileSegmentManagedBuffer; import org.apache.spark.network.buffer.ManagedBuffer; import org.apache.spark.network.shuffle.protocol.ExecutorShuffleInfo; +import org.apache.spark.network.util.LevelDBProvider; +import org.apache.spark.network.util.LevelDBProvider.StoreVersion; import org.apache.spark.network.util.JavaUtils; import org.apache.spark.network.util.NettyUtils; import org.apache.spark.network.util.TransportConf; @@ -95,52 +94,10 @@ public ExternalShuffleBlockResolver(TransportConf conf, File registeredExecutorF Executor directoryCleaner) throws IOException { this.conf = conf; this.registeredExecutorFile = registeredExecutorFile; - if (registeredExecutorFile != null) { - Options options = new Options(); - options.createIfMissing(false); - options.logger(new LevelDBLogger()); - DB tmpDb; - try { - tmpDb = JniDBFactory.factory.open(registeredExecutorFile, options); - } catch (NativeDB.DBException e) { - if (e.isNotFound() || e.getMessage().contains(" does not exist ")) { - logger.info("Creating state database at " + registeredExecutorFile); - options.createIfMissing(true); - try { - tmpDb = JniDBFactory.factory.open(registeredExecutorFile, options); - } catch (NativeDB.DBException dbExc) { - throw new IOException("Unable to create state store", dbExc); - } - } else { - // the leveldb file seems to be corrupt somehow. Lets just blow it away and create a new - // one, so we can keep processing new apps - logger.error("error opening leveldb file {}. Creating new file, will not be able to " + - "recover state for existing applications", registeredExecutorFile, e); - if (registeredExecutorFile.isDirectory()) { - for (File f : registeredExecutorFile.listFiles()) { - if (!f.delete()) { - logger.warn("error deleting {}", f.getPath()); - } - } - } - if (!registeredExecutorFile.delete()) { - logger.warn("error deleting {}", registeredExecutorFile.getPath()); - } - options.createIfMissing(true); - try { - tmpDb = JniDBFactory.factory.open(registeredExecutorFile, options); - } catch (NativeDB.DBException dbExc) { - throw new IOException("Unable to create state store", dbExc); - } - - } - } - // if there is a version mismatch, we throw an exception, which means the service is unusable - checkVersion(tmpDb); - executors = reloadRegisteredExecutors(tmpDb); - db = tmpDb; + db = LevelDBProvider.initLevelDB(this.registeredExecutorFile, CURRENT_VERSION, mapper); + if (db != null) { + executors = reloadRegisteredExecutors(db); } else { - db = null; executors = Maps.newConcurrentMap(); } this.directoryCleaner = directoryCleaner; @@ -244,7 +201,7 @@ private void deleteExecutorDirs(String[] dirs) { for (String localDir : dirs) { try { JavaUtils.deleteRecursively(new File(localDir)); - logger.debug("Successfully cleaned up directory: " + localDir); + logger.debug("Successfully cleaned up directory: {}", localDir); } catch (Exception e) { logger.error("Failed to delete directory: " + localDir, e); } @@ -368,76 +325,11 @@ static ConcurrentMap reloadRegisteredExecutors(D break; } AppExecId id = parseDbAppExecKey(key); + logger.info("Reloading registered executors: " + id.toString()); ExecutorShuffleInfo shuffleInfo = mapper.readValue(e.getValue(), ExecutorShuffleInfo.class); registeredExecutors.put(id, shuffleInfo); } } return registeredExecutors; } - - private static class LevelDBLogger implements org.iq80.leveldb.Logger { - private static final Logger LOG = LoggerFactory.getLogger(LevelDBLogger.class); - - @Override - public void log(String message) { - LOG.info(message); - } - } - - /** - * Simple major.minor versioning scheme. Any incompatible changes should be across major - * versions. Minor version differences are allowed -- meaning we should be able to read - * dbs that are either earlier *or* later on the minor version. - */ - private static void checkVersion(DB db) throws IOException { - byte[] bytes = db.get(StoreVersion.KEY); - if (bytes == null) { - storeVersion(db); - } else { - StoreVersion version = mapper.readValue(bytes, StoreVersion.class); - if (version.major != CURRENT_VERSION.major) { - throw new IOException("cannot read state DB with version " + version + ", incompatible " + - "with current version " + CURRENT_VERSION); - } - storeVersion(db); - } - } - - private static void storeVersion(DB db) throws IOException { - db.put(StoreVersion.KEY, mapper.writeValueAsBytes(CURRENT_VERSION)); - } - - - public static class StoreVersion { - - static final byte[] KEY = "StoreVersion".getBytes(StandardCharsets.UTF_8); - - public final int major; - public final int minor; - - @JsonCreator public StoreVersion( - @JsonProperty("major") int major, - @JsonProperty("minor") int minor) { - this.major = major; - this.minor = minor; - } - - @Override - public boolean equals(Object o) { - if (this == o) return true; - if (o == null || getClass() != o.getClass()) return false; - - StoreVersion that = (StoreVersion) o; - - return major == that.major && minor == that.minor; - } - - @Override - public int hashCode() { - int result = major; - result = 31 * result + minor; - return result; - } - } - } diff --git a/common/network-yarn/build.gradle b/common/network-yarn/build.gradle new file mode 100644 index 000000000000..bbb6d8c7f81a --- /dev/null +++ b/common/network-yarn/build.gradle @@ -0,0 +1,100 @@ +/* + * Copyright (c) 2016 SnappyData, Inc. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you + * may not use this file except in compliance with the License. You + * may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or + * implied. See the License for the specific language governing + * permissions and limitations under the License. See accompanying + * LICENSE file. + */ + +plugins { + id 'com.github.johnrengelman.shadow' version '1.2.3' +} + +description = 'Spark Project YARN Shuffle Service' + +dependencies { + compile project(subprojectBase + 'snappy-spark-network-shuffle_' + scalaBinaryVersion) + compile project(subprojectBase + 'snappy-spark-tags_' + scalaBinaryVersion) + + compile group: 'io.netty', name: 'netty-all', version: nettyAllVersion + provided (group: 'org.apache.hadoop', name: 'hadoop-client', version: hadoopVersion) { + exclude(group: 'asm', module: 'asm') + exclude(group: 'org.codehaus.jackson', module: 'jackson-core-asl') + exclude(group: 'org.codehaus.jackson', module: 'jackson-mapper-asl') + exclude(group: 'org.ow2.asm', module: 'asm') + exclude(group: 'org.jboss.netty', module: 'netty') + exclude(group: 'commons-logging', module: 'commons-logging') + exclude(group: 'org.mockito', module: 'mockito-all') + exclude(group: 'org.mortbay.jetty', module: 'servlet-api-2.5') + exclude(group: 'javax.servlet', module: 'servlet-api') + exclude(group: 'junit', module: 'junit') + exclude(group: 'com.google.guava', module: 'guava') + exclude(group: 'com.sun.jersey') + exclude(group: 'com.sun.jersey.jersey-test-framework') + exclude(group: 'com.sun.jersey.contribs') + } + + /* + runtimeJar project(subprojectBase + 'snappy-spark-network-common_' + scalaBinaryVersion) + runtimeJar project(subprojectBase + 'snappy-spark-network-shuffle_' + scalaBinaryVersion) + runtimeJar group: 'io.netty', name: 'netty-all', version: nettyAllVersion + runtimeJar group: 'com.fasterxml.jackson.core', name: 'jackson-databind', version: fasterXmlVersion + runtimeJar group: 'com.fasterxml.jackson.core', name: 'jackson-annotations', version: fasterXmlVersion + */ +} + +shadowJar { + baseName 'snappy-spark' + classifier 'yarn-shuffle' + + mergeServiceFiles { + exclude 'META-INF/*.SF' + exclude 'META-INF/*.DSA' + exclude 'META-INF/*.RSA' + } + + dependencies { + exclude(dependency('org.scala-lang:.*')) + exclude(dependency('org.scala-lang.modules:.*')) + exclude(dependency('org.slf4j:.*')) + exclude(dependency('log4j:.*')) + exclude(dependency('org.scalatest:.*')) + } + //configurations = [ project.configurations.runtimeJar ] + + relocate 'io.netty', "${shadePackageName}.io.netty" + relocate 'com.fasterxml.jackson', "${shadePackageName}.com.fasterxml.jackson" + relocate 'com.google.common', "${shadePackageName}.guava" + + String createdBy = '' + if (rootProject.hasProperty('enablePublish')) { + createdBy = 'SnappyData Build Team' + } else { + createdBy = System.getProperty('user.name') + } + manifest { + attributes( + 'Manifest-Version' : '1.0', + 'Created-By' : createdBy, + 'Title' : project.name, + 'Version' : version, + 'Vendor' : 'SnappyData, Inc.' + ) + } + + doLast { + copy { + from outputs + into "${buildDir}/jars" + } + } +} diff --git a/common/network-yarn/pom.xml b/common/network-yarn/pom.xml index 07d9f1c58f7a..25cc32889ef3 100644 --- a/common/network-yarn/pom.xml +++ b/common/network-yarn/pom.xml @@ -22,7 +22,7 @@ org.apache.spark spark-parent_2.11 - 2.0.1-SNAPSHOT + 2.0.1 ../../pom.xml diff --git a/common/network-yarn/src/main/java/org/apache/spark/network/yarn/YarnShuffleService.java b/common/network-yarn/src/main/java/org/apache/spark/network/yarn/YarnShuffleService.java index b6feb55e2192..7bdf73539671 100644 --- a/common/network-yarn/src/main/java/org/apache/spark/network/yarn/YarnShuffleService.java +++ b/common/network-yarn/src/main/java/org/apache/spark/network/yarn/YarnShuffleService.java @@ -18,15 +18,28 @@ package org.apache.spark.network.yarn; import java.io.File; +import java.io.IOException; +import java.nio.charset.StandardCharsets; import java.nio.ByteBuffer; import java.util.List; +import java.util.Map; +import com.fasterxml.jackson.annotation.JsonCreator; +import com.fasterxml.jackson.annotation.JsonProperty; +import com.fasterxml.jackson.databind.ObjectMapper; import com.google.common.annotations.VisibleForTesting; +import com.google.common.base.Objects; import com.google.common.collect.Lists; import org.apache.hadoop.conf.Configuration; +import org.apache.hadoop.fs.FileSystem; import org.apache.hadoop.fs.Path; +import org.apache.hadoop.fs.permission.FsPermission; import org.apache.hadoop.yarn.api.records.ContainerId; import org.apache.hadoop.yarn.server.api.*; +import org.apache.spark.network.util.LevelDBProvider; +import org.iq80.leveldb.DB; +import org.iq80.leveldb.DBIterator; + import org.slf4j.Logger; import org.slf4j.LoggerFactory; @@ -68,6 +81,22 @@ public class YarnShuffleService extends AuxiliaryService { private static final String SPARK_AUTHENTICATE_KEY = "spark.authenticate"; private static final boolean DEFAULT_SPARK_AUTHENTICATE = false; + private static final String RECOVERY_FILE_NAME = "registeredExecutors.ldb"; + private static final String SECRETS_RECOVERY_FILE_NAME = "sparkShuffleRecovery.ldb"; + + // just for testing when you want to find an open port + @VisibleForTesting + static int boundPort = -1; + private static final ObjectMapper mapper = new ObjectMapper(); + private static final String APP_CREDS_KEY_PREFIX = "AppCreds"; + private static final LevelDBProvider.StoreVersion CURRENT_VERSION = new LevelDBProvider + .StoreVersion(1, 0); + + // just for integration tests that want to look at this file -- in general not sensible as + // a static + @VisibleForTesting + static YarnShuffleService instance; + // An entity that manages the shuffle secret per application // This is used only if authentication is enabled private ShuffleSecretManager secretManager; @@ -75,6 +104,8 @@ public class YarnShuffleService extends AuxiliaryService { // The actual server that serves shuffle files private TransportServer shuffleServer = null; + private Configuration _conf = null; + // Handles registering executors and opening shuffle blocks @VisibleForTesting ExternalShuffleBlockHandler blockHandler; @@ -83,14 +114,11 @@ public class YarnShuffleService extends AuxiliaryService { @VisibleForTesting File registeredExecutorFile; - // just for testing when you want to find an open port + // Where to store & reload application secrets for recovering state after an NM restart @VisibleForTesting - static int boundPort = -1; + File secretsFile; - // just for integration tests that want to look at this file -- in general not sensible as - // a static - @VisibleForTesting - static YarnShuffleService instance; + private DB db; public YarnShuffleService() { super("spark_shuffle"); @@ -112,42 +140,86 @@ private boolean isAuthenticationEnabled() { */ @Override protected void serviceInit(Configuration conf) { + _conf = conf; - // In case this NM was killed while there were running spark applications, we need to restore - // lost state for the existing executors. We look for an existing file in the NM's local dirs. - // If we don't find one, then we choose a file to use to save the state next time. Even if - // an application was stopped while the NM was down, we expect yarn to call stopApplication() - // when it comes back - registeredExecutorFile = - findRegisteredExecutorFile(conf.getTrimmedStrings("yarn.nodemanager.local-dirs")); - - TransportConf transportConf = new TransportConf("shuffle", new HadoopConfigProvider(conf)); - // If authentication is enabled, set up the shuffle server to use a - // special RPC handler that filters out unauthenticated fetch requests - boolean authEnabled = conf.getBoolean(SPARK_AUTHENTICATE_KEY, DEFAULT_SPARK_AUTHENTICATE); try { - blockHandler = new ExternalShuffleBlockHandler(transportConf, registeredExecutorFile); + // In case this NM was killed while there were running spark applications, we need to restore + // lost state for the existing executors. We look for an existing file in the NM's local dirs. + // If we don't find one, then we choose a file to use to save the state next time. Even if + // an application was stopped while the NM was down, we expect yarn to call stopApplication() + // when it comes back + registeredExecutorFile = findRecoveryDb(RECOVERY_FILE_NAME); + + TransportConf transportConf = new TransportConf("shuffle", new HadoopConfigProvider(conf)); + blockHandler = new ExternalShuffleBlockHandler(transportConf, registeredExecutorFile); + + // If authentication is enabled, set up the shuffle server to use a + // special RPC handler that filters out unauthenticated fetch requests + boolean authEnabled = conf.getBoolean(SPARK_AUTHENTICATE_KEY, DEFAULT_SPARK_AUTHENTICATE); + List bootstraps = Lists.newArrayList(); + if (authEnabled) { + createSecretManager(); + bootstraps.add(new SaslServerBootstrap(transportConf, secretManager)); + } + + int port = conf.getInt( + SPARK_SHUFFLE_SERVICE_PORT_KEY, DEFAULT_SPARK_SHUFFLE_SERVICE_PORT); + TransportContext transportContext = new TransportContext(transportConf, blockHandler); + shuffleServer = transportContext.createServer(port, bootstraps); + // the port should normally be fixed, but for tests its useful to find an open port + port = shuffleServer.getPort(); + boundPort = port; + String authEnabledString = authEnabled ? "enabled" : "not enabled"; + logger.info("Started YARN shuffle service for Spark on port {}. " + + "Authentication is {}. Registered executor file is {}", port, authEnabledString, + registeredExecutorFile); } catch (Exception e) { logger.error("Failed to initialize external shuffle service", e); } + } - List bootstraps = Lists.newArrayList(); - if (authEnabled) { - secretManager = new ShuffleSecretManager(); - bootstraps.add(new SaslServerBootstrap(transportConf, secretManager)); + private void createSecretManager() throws IOException { + secretManager = new ShuffleSecretManager(); + secretsFile = findRecoveryDb(SECRETS_RECOVERY_FILE_NAME); + + // Make sure this is protected in case its not in the NM recovery dir + FileSystem fs = FileSystem.getLocal(_conf); + fs.mkdirs(new Path(secretsFile.getPath()), new FsPermission((short)0700)); + + db = LevelDBProvider.initLevelDB(secretsFile, CURRENT_VERSION, mapper); + logger.info("Recovery location is: " + secretsFile.getPath()); + if (db != null) { + logger.info("Going to reload spark shuffle data"); + DBIterator itr = db.iterator(); + itr.seek(APP_CREDS_KEY_PREFIX.getBytes(StandardCharsets.UTF_8)); + while (itr.hasNext()) { + Map.Entry e = itr.next(); + String key = new String(e.getKey(), StandardCharsets.UTF_8); + if (!key.startsWith(APP_CREDS_KEY_PREFIX)) { + break; + } + String id = parseDbAppKey(key); + ByteBuffer secret = mapper.readValue(e.getValue(), ByteBuffer.class); + logger.info("Reloading tokens for app: " + id); + secretManager.registerApp(id, secret); + } } + } - int port = conf.getInt( - SPARK_SHUFFLE_SERVICE_PORT_KEY, DEFAULT_SPARK_SHUFFLE_SERVICE_PORT); - TransportContext transportContext = new TransportContext(transportConf, blockHandler); - shuffleServer = transportContext.createServer(port, bootstraps); - // the port should normally be fixed, but for tests its useful to find an open port - port = shuffleServer.getPort(); - boundPort = port; - String authEnabledString = authEnabled ? "enabled" : "not enabled"; - logger.info("Started YARN shuffle service for Spark on port {}. " + - "Authentication is {}. Registered executor file is {}", port, authEnabledString, - registeredExecutorFile); + private static String parseDbAppKey(String s) throws IOException { + if (!s.startsWith(APP_CREDS_KEY_PREFIX)) { + throw new IllegalArgumentException("expected a string starting with " + APP_CREDS_KEY_PREFIX); + } + String json = s.substring(APP_CREDS_KEY_PREFIX.length() + 1); + AppId parsed = mapper.readValue(json, AppId.class); + return parsed.appId; + } + + private static byte[] dbAppKey(AppId appExecId) throws IOException { + // we stick a common prefix on all the keys so we can find them in the DB + String appExecJson = mapper.writeValueAsString(appExecId); + String key = (APP_CREDS_KEY_PREFIX + ";" + appExecJson); + return key.getBytes(StandardCharsets.UTF_8); } @Override @@ -157,6 +229,12 @@ public void initializeApplication(ApplicationInitializationContext context) { ByteBuffer shuffleSecret = context.getApplicationDataForService(); logger.info("Initializing application {}", appId); if (isAuthenticationEnabled()) { + AppId fullId = new AppId(appId); + if (db != null) { + byte[] key = dbAppKey(fullId); + byte[] value = mapper.writeValueAsString(shuffleSecret).getBytes(StandardCharsets.UTF_8); + db.put(key, value); + } secretManager.registerApp(appId, shuffleSecret); } } catch (Exception e) { @@ -170,6 +248,14 @@ public void stopApplication(ApplicationTerminationContext context) { try { logger.info("Stopping application {}", appId); if (isAuthenticationEnabled()) { + AppId fullId = new AppId(appId); + if (db != null) { + try { + db.delete(dbAppKey(fullId)); + } catch (IOException e) { + logger.error("Error deleting {} from executor state db", appId, e); + } + } secretManager.unregisterApp(appId); } blockHandler.applicationRemoved(appId, false /* clean up local dirs */); @@ -190,14 +276,15 @@ public void stopContainer(ContainerTerminationContext context) { logger.info("Stopping container {}", containerId); } - private File findRegisteredExecutorFile(String[] localDirs) { + private File findRecoveryDb(String fileName) { + String[] localDirs = _conf.getTrimmedStrings("yarn.nodemanager.local-dirs"); for (String dir: localDirs) { - File f = new File(new Path(dir).toUri().getPath(), "registeredExecutors.ldb"); + File f = new File(new Path(dir).toUri().getPath(), fileName); if (f.exists()) { return f; } } - return new File(new Path(localDirs[0]).toUri().getPath(), "registeredExecutors.ldb"); + return new File(new Path(localDirs[0]).toUri().getPath(), fileName); } /** @@ -212,6 +299,9 @@ protected void serviceStop() { if (blockHandler != null) { blockHandler.close(); } + if (db != null) { + db.close(); + } } catch (Exception e) { logger.error("Exception when stopping service", e); } @@ -222,4 +312,38 @@ protected void serviceStop() { public ByteBuffer getMetaData() { return ByteBuffer.allocate(0); } + + /** + * Simply encodes an application ID. + */ + public static class AppId { + public final String appId; + + @JsonCreator + public AppId(@JsonProperty("appId") String appId) { + this.appId = appId; + } + + @Override + public boolean equals(Object o) { + if (this == o) return true; + if (o == null || getClass() != o.getClass()) return false; + + AppId appExecId = (AppId) o; + return Objects.equal(appId, appExecId.appId); + } + + @Override + public int hashCode() { + return Objects.hashCode(appId); + } + + @Override + public String toString() { + return Objects.toStringHelper(this) + .add("appId", appId) + .toString(); + } + } + } diff --git a/common/sketch/build.gradle b/common/sketch/build.gradle new file mode 100644 index 000000000000..a5e5efff08b5 --- /dev/null +++ b/common/sketch/build.gradle @@ -0,0 +1,28 @@ +/* + * Copyright (c) 2016 SnappyData, Inc. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you + * may not use this file except in compliance with the License. You + * may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or + * implied. See the License for the specific language governing + * permissions and limitations under the License. See accompanying + * LICENSE file. + */ + +description = 'Spark Project Sketch' + +dependencies { + compile project(subprojectBase + 'snappy-spark-tags_' + scalaBinaryVersion) +} + +tasks.withType(JavaCompile) { + options.compilerArgs << '-XDignore.symbol.file' + options.fork = true + options.forkOptions.executable = 'javac' +} diff --git a/common/sketch/pom.xml b/common/sketch/pom.xml index 5e02efdc45e6..37a5d09a3ff0 100644 --- a/common/sketch/pom.xml +++ b/common/sketch/pom.xml @@ -22,7 +22,7 @@ org.apache.spark spark-parent_2.11 - 2.0.1-SNAPSHOT + 2.0.1 ../../pom.xml diff --git a/common/tags/build.gradle b/common/tags/build.gradle new file mode 100644 index 000000000000..e272cfbaa638 --- /dev/null +++ b/common/tags/build.gradle @@ -0,0 +1,22 @@ +/* + * Copyright (c) 2016 SnappyData, Inc. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you + * may not use this file except in compliance with the License. You + * may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or + * implied. See the License for the specific language governing + * permissions and limitations under the License. See accompanying + * LICENSE file. + */ + +description = 'Spark Project Tags' + +dependencies { + compile "org.scalatest:scalatest_${scalaBinaryVersion}:${scalatestVersion}" +} diff --git a/common/tags/pom.xml b/common/tags/pom.xml index e7fc6a2a0241..ab287f3368a4 100644 --- a/common/tags/pom.xml +++ b/common/tags/pom.xml @@ -22,7 +22,7 @@ org.apache.spark spark-parent_2.11 - 2.0.1-SNAPSHOT + 2.0.1 ../../pom.xml diff --git a/common/unsafe/build.gradle b/common/unsafe/build.gradle new file mode 100644 index 000000000000..69d29942f5f1 --- /dev/null +++ b/common/unsafe/build.gradle @@ -0,0 +1,38 @@ +/* + * Copyright (c) 2016 SnappyData, Inc. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you + * may not use this file except in compliance with the License. You + * may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or + * implied. See the License for the specific language governing + * permissions and limitations under the License. See accompanying + * LICENSE file. + */ + +description = 'Spark Project Unsafe' + +dependencies { + compile project(subprojectBase + 'snappy-spark-tags_' + scalaBinaryVersion) + + compile group: 'com.twitter', name: 'chill_' + scalaBinaryVersion, version: chillVersion + compile group: 'com.google.code.findbugs', name: 'jsr305', version: jsr305Version + compile group: 'com.google.guava', name: 'guava', version: guavaVersion + + testCompile group: 'org.apache.commons', name: 'commons-lang3', version: '3.3.2' +} + +// reset the srcDirs to allow javac compilation with specific args below +sourceSets.main.scala.srcDirs = [ 'src/main/scala' ] +sourceSets.main.java.srcDirs = [ 'src/main/java' ] + +tasks.withType(JavaCompile) { + options.compilerArgs << '-XDignore.symbol.file' + options.fork = true + options.forkOptions.executable = 'javac' +} diff --git a/common/unsafe/pom.xml b/common/unsafe/pom.xml index 24f0e75f2f04..45831ce98dbc 100644 --- a/common/unsafe/pom.xml +++ b/common/unsafe/pom.xml @@ -22,7 +22,7 @@ org.apache.spark spark-parent_2.11 - 2.0.1-SNAPSHOT + 2.0.1 ../../pom.xml diff --git a/common/unsafe/src/main/java/org/apache/spark/unsafe/types/UTF8String.java b/common/unsafe/src/main/java/org/apache/spark/unsafe/types/UTF8String.java index 54a54569240c..ce8351694308 100644 --- a/common/unsafe/src/main/java/org/apache/spark/unsafe/types/UTF8String.java +++ b/common/unsafe/src/main/java/org/apache/spark/unsafe/types/UTF8String.java @@ -465,9 +465,9 @@ public UTF8String trim() { int s = 0; int e = this.numBytes - 1; // skip all of the space (0x20) in the left side - while (s < this.numBytes && getByte(s) <= 0x20 && getByte(s) >= 0x00) s++; + while (s < this.numBytes && getByte(s) == 0x20) s++; // skip all of the space (0x20) in the right side - while (e >= 0 && getByte(e) <= 0x20 && getByte(e) >= 0x00) e--; + while (e >= 0 && getByte(e) == 0x20) e--; if (s > e) { // empty string return UTF8String.fromBytes(new byte[0]); @@ -479,7 +479,7 @@ public UTF8String trim() { public UTF8String trimLeft() { int s = 0; // skip all of the space (0x20) in the left side - while (s < this.numBytes && getByte(s) <= 0x20 && getByte(s) >= 0x00) s++; + while (s < this.numBytes && getByte(s) == 0x20) s++; if (s == this.numBytes) { // empty string return UTF8String.fromBytes(new byte[0]); @@ -491,7 +491,7 @@ public UTF8String trimLeft() { public UTF8String trimRight() { int e = numBytes - 1; // skip all of the space (0x20) in the right side - while (e >= 0 && getByte(e) <= 0x20 && getByte(e) >= 0x00) e--; + while (e >= 0 && getByte(e) == 0x20) e--; if (e < 0) { // empty string diff --git a/common/unsafe/src/test/java/org/apache/spark/unsafe/types/UTF8StringSuite.java b/common/unsafe/src/test/java/org/apache/spark/unsafe/types/UTF8StringSuite.java index d4160ad029eb..7f03686dcec4 100644 --- a/common/unsafe/src/test/java/org/apache/spark/unsafe/types/UTF8StringSuite.java +++ b/common/unsafe/src/test/java/org/apache/spark/unsafe/types/UTF8StringSuite.java @@ -232,6 +232,16 @@ public void trims() { assertEquals(fromString("数据砖头"), fromString("数据砖头").trim()); assertEquals(fromString("数据砖头"), fromString("数据砖头").trimLeft()); assertEquals(fromString("数据砖头"), fromString("数据砖头").trimRight()); + + char[] charsLessThan0x20 = new char[10]; + Arrays.fill(charsLessThan0x20, (char)(' ' - 1)); + String stringStartingWithSpace = + new String(charsLessThan0x20) + "hello" + new String(charsLessThan0x20); + assertEquals(fromString(stringStartingWithSpace), fromString(stringStartingWithSpace).trim()); + assertEquals(fromString(stringStartingWithSpace), + fromString(stringStartingWithSpace).trimLeft()); + assertEquals(fromString(stringStartingWithSpace), + fromString(stringStartingWithSpace).trimRight()); } @Test diff --git a/common/unsafe/src/test/scala/org/apache/spark/unsafe/types/UTF8StringPropertyCheckSuite.scala b/common/unsafe/src/test/scala/org/apache/spark/unsafe/types/UTF8StringPropertyCheckSuite.scala index 8a6b9e3e4536..62d4176d00f9 100644 --- a/common/unsafe/src/test/scala/org/apache/spark/unsafe/types/UTF8StringPropertyCheckSuite.scala +++ b/common/unsafe/src/test/scala/org/apache/spark/unsafe/types/UTF8StringPropertyCheckSuite.scala @@ -98,7 +98,7 @@ class UTF8StringPropertyCheckSuite extends FunSuite with GeneratorDrivenProperty } } - val whitespaceChar: Gen[Char] = Gen.choose(0x00, 0x20).map(_.toChar) + val whitespaceChar: Gen[Char] = Gen.const(0x20.toChar) val whitespaceString: Gen[String] = Gen.listOf(whitespaceChar).map(_.mkString) val randomString: Gen[String] = Arbitrary.arbString.arbitrary @@ -107,7 +107,7 @@ class UTF8StringPropertyCheckSuite extends FunSuite with GeneratorDrivenProperty def lTrim(s: String): String = { var st = 0 val array: Array[Char] = s.toCharArray - while ((st < s.length) && (array(st) <= ' ')) { + while ((st < s.length) && (array(st) == ' ')) { st += 1 } if (st > 0) s.substring(st, s.length) else s @@ -115,7 +115,7 @@ class UTF8StringPropertyCheckSuite extends FunSuite with GeneratorDrivenProperty def rTrim(s: String): String = { var len = s.length val array: Array[Char] = s.toCharArray - while ((len > 0) && (array(len - 1) <= ' ')) { + while ((len > 0) && (array(len - 1) == ' ')) { len -= 1 } if (len < s.length) s.substring(0, len) else s @@ -127,7 +127,7 @@ class UTF8StringPropertyCheckSuite extends FunSuite with GeneratorDrivenProperty whitespaceString ) { (start: String, middle: String, end: String) => val s = start + middle + end - assert(toUTF8(s).trim() === toUTF8(s.trim())) + assert(toUTF8(s).trim() === toUTF8(rTrim(lTrim(s)))) assert(toUTF8(s).trimLeft() === toUTF8(lTrim(s))) assert(toUTF8(s).trimRight() === toUTF8(rTrim(s))) } diff --git a/core/build.gradle b/core/build.gradle new file mode 100644 index 000000000000..9395a129dac3 --- /dev/null +++ b/core/build.gradle @@ -0,0 +1,178 @@ +/* + * Copyright (c) 2016 SnappyData, Inc. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you + * may not use this file except in compliance with the License. You + * may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or + * implied. See the License for the specific language governing + * permissions and limitations under the License. See accompanying + * LICENSE file. + */ + +description = 'Spark Project Core' + +dependencies { + compile project(subprojectBase + 'snappy-spark-launcher_' + scalaBinaryVersion) + compile project(subprojectBase + 'snappy-spark-network-common_' + scalaBinaryVersion) + compile project(subprojectBase + 'snappy-spark-network-shuffle_' + scalaBinaryVersion) + compile project(subprojectBase + 'snappy-spark-unsafe_' + scalaBinaryVersion) + compile project(subprojectBase + 'snappy-spark-tags_' + scalaBinaryVersion) + + compile(group: 'org.apache.avro', name: 'avro-ipc', version: avroVersion) { + exclude(group: 'io.netty', module: 'netty') + exclude(group: 'org.mortbay.jetty', module: 'jetty') + exclude(group: 'org.mortbay.jetty', module: 'jetty-util') + exclude(group: 'org.mortbay.jetty', module: 'servlet-api') + exclude(group: 'org.apache.velocity', module: 'velocity') + } + compile(group: 'org.apache.avro', name: 'avro-mapred', version: avroVersion, classifier: 'hadoop2') { + exclude(group: 'io.netty', module: 'netty') + exclude(group: 'org.mortbay.jetty', module: 'jetty') + exclude(group: 'org.mortbay.jetty', module: 'jetty-util') + exclude(group: 'org.mortbay.jetty', module: 'servlet-api') + exclude(group: 'org.apache.velocity', module: 'velocity') + exclude(group: 'org.apache.avro', module: 'avro-ipc') + } + compile group: 'com.google.guava', name: 'guava', version: guavaVersion + compile group: 'com.twitter', name: 'chill_' + scalaBinaryVersion, version: chillVersion + compile group: 'com.twitter', name: 'chill-java', version: chillVersion + compile group: 'org.apache.xbean', name: 'xbean-asm5-shaded', version: '4.4' + // explicitly include netty from akka-remote to not let zookeeper override it + compile group: 'io.netty', name: 'netty', version: nettyVersion + // explicitly exclude old netty from zookeeper + compile(group: 'org.apache.zookeeper', name: 'zookeeper', version: '3.4.8') { + exclude(group: 'org.jboss.netty', module: 'netty') + exclude(group: 'jline', module: 'jline') + } + compile group: 'com.google.protobuf', name: 'protobuf-java', version: protobufVersion + compile(group: 'org.apache.hadoop', name: 'hadoop-client', version: hadoopVersion) { + exclude(group: 'asm', module: 'asm') + exclude(group: 'org.codehaus.jackson', module: 'jackson-mapper-asl') + exclude(group: 'org.ow2.asm', module: 'asm') + exclude(group: 'org.apache.zookeeper', module: 'zookeeper') + exclude(group: 'org.jboss.netty', module: 'netty') + exclude(group: 'jline', module: 'jline') + exclude(group: 'commons-logging', module: 'commons-logging') + exclude(group: 'org.mockito', module: 'mockito-all') + exclude(group: 'org.mortbay.jetty', module: 'servlet-api-2.5') + exclude(group: 'javax.servlet', module: 'servlet-api') + exclude(group: 'junit', module: 'junit') + exclude(group: 'com.google.guava', module: 'guava') + exclude(group: 'com.sun.jersey') + exclude(group: 'com.sun.jersey.jersey-test-framework') + exclude(group: 'com.sun.jersey.contribs') + exclude(group: 'com.google.protobuf', module: 'protobuf-java') + } + compile(group: 'net.java.dev.jets3t', name: 'jets3t', version: '0.9.3') { + exclude(group: 'commons-logging', module: 'commons-logging') + } + compile(group: 'org.apache.curator', name: 'curator-recipes', version: curatorVersion) { + exclude(group: 'org.apache.zookeeper', module: 'zookeeper') + exclude(group: 'org.jboss.netty', module: 'netty') + exclude(group: 'jline', module: 'jline') + exclude(group: 'com.google.guava', module: 'guava') + } + + compile 'org.scala-lang:scalap:' + scalaVersion + compile group: 'org.apache.xbean', name: 'xbean-asm5-shaded' , version: '4.4' + compile group: 'org.roaringbitmap', name: 'RoaringBitmap' , version: '0.5.11' + + compile group: 'org.eclipse.jetty', name: 'jetty-server', version: jettyVersion + compile group: 'org.eclipse.jetty', name: 'jetty-plus', version: jettyVersion + compile group: 'org.eclipse.jetty', name: 'jetty-util', version: jettyVersion + compile group: 'org.eclipse.jetty', name: 'jetty-http', version: jettyVersion + compile group: 'org.eclipse.jetty', name: 'jetty-servlet', version: jettyVersion + compile group: 'org.eclipse.jetty', name: 'jetty-servlets', version: jettyVersion + compile group: 'org.eclipse.jetty', name: 'jetty-security', version: jettyVersion + compile group: 'org.eclipse.jetty', name: 'jetty-continuation', version: jettyVersion + compile group: 'javax.servlet', name: 'javax.servlet-api', version: javaxServletVersion + compile group: 'org.apache.commons', name: 'commons-lang3', version: '3.3.2' + compile group: 'org.apache.commons', name: 'commons-math3', version: '3.4.1' + compile group: 'com.google.code.findbugs', name: 'jsr305', version: jsr305Version + compile group: 'org.slf4j', name: 'jul-to-slf4j', version: slf4jVersion + compile group: 'org.slf4j', name: 'jcl-over-slf4j', version: slf4jVersion + compile group: 'com.ning', name: 'compress-lzf', version: '1.0.3' + compile group: 'org.xerial.snappy', name: 'snappy-java', version: snappyJavaVersion + compile group: 'net.jpountz.lz4', name: 'lz4', version: '1.3.0' + compile group: 'commons-net', name: 'commons-net', version: '2.2' + compile group: 'org.json4s', name: 'json4s-jackson_' + scalaBinaryVersion, version: '3.2.11' + compile group: 'org.glassfish.jersey.core', name: 'jersey-client', version: jerseyVersion + compile group: 'org.glassfish.jersey.core', name: 'jersey-common', version: jerseyVersion + compile group: 'org.glassfish.jersey.core', name: 'jersey-server', version: jerseyVersion + compile group: 'org.glassfish.jersey.containers', name: 'jersey-container-servlet', version: jerseyVersion + compile group: 'org.glassfish.jersey.containers', name: 'jersey-container-servlet-core', version: jerseyVersion + compile(group: 'org.apache.mesos', name: 'mesos', version: '0.21.1', classifier: 'shaded-protobuf') { + exclude(group: 'com.google.protobuf', module: 'protobuf-java') + } + compile group: 'io.netty', name: 'netty-all', version: nettyAllVersion + compile(group: 'com.clearspring.analytics', name: 'stream', version: '2.7.0') { + exclude(group: 'it.unimi.dsi', module: 'fastutil') + } + compile group: 'io.dropwizard.metrics', name: 'metrics-core', version: metricsVersion + compile group: 'io.dropwizard.metrics', name: 'metrics-jvm', version: metricsVersion + compile group: 'io.dropwizard.metrics', name: 'metrics-json', version: metricsVersion + compile group: 'io.dropwizard.metrics', name: 'metrics-graphite', version: metricsVersion + compile group: 'com.fasterxml.jackson.core', name: 'jackson-databind', version: fasterXmlVersion + compile(group: 'com.fasterxml.jackson.module', name: 'jackson-module-scala_' + scalaBinaryVersion, version: fasterXmlVersion) { + exclude(group: 'com.google.guava', module: 'guava') + } + compile group: 'org.apache.ivy', name: 'ivy', version: '2.4.0' + compile group: 'oro', name: 'oro', version: '2.0.8' + compile(group: 'net.razorvine', name: 'pyrolite', version: '4.9') { + exclude(group: 'net.razorvine', module: 'serpent') + } + compile group: 'net.sf.py4j', name: 'py4j', version: '0.10.1' + + testCompile group: 'org.apache.avro', name: 'avro-ipc', version: avroVersion, classifier: 'tests' + testCompile "org.apache.derby:derby:${derbyVersion}" + testCompile(group: 'org.seleniumhq.selenium', name: 'selenium-java', version: seleniumVersion) { + exclude(group: 'com.google.guava', module: 'guava') + } + testCompile(group: 'org.seleniumhq.selenium', name: 'selenium-htmlunit-driver', version: seleniumVersion) { + exclude(group: 'com.google.guava', module: 'guava') + } + testCompile group: 'xml-apis', name: 'xml-apis', version: '1.4.01' + testCompile group: 'org.hamcrest', name: 'hamcrest-core', version: '1.3' + testCompile group: 'org.hamcrest', name: 'hamcrest-library', version: '1.3' + testCompile(group: 'org.apache.curator', name: 'curator-test', version: curatorVersion) { + exclude(group: 'org.apache.zookeeper', module: 'zookeeper') + exclude(group: 'org.jboss.netty', module: 'netty') + exclude(group: 'jline', module: 'jline') + exclude(group: 'com.google.guava', module: 'guava') + } +} + +// TODO: sparkr profile, copy-dependencies target? + +// fix scala+java test ordering +sourceSets.test.scala.srcDir 'src/test/java' +sourceSets.test.java.srcDirs = [] + +// generate properties using spark-build-info and add to project resources +String extraResourceDir = "${buildDir}/extra-resources" + +task generateBuildInfo { + outputs.file "${extraResourceDir}/spark-version-info.properties" + inputs.dir compileScala.destinationDir + + doLast { + file(extraResourceDir).mkdirs() + exec { + executable 'bash' + workingDir = buildDir + args "${projectDir}/../build/spark-build-info", extraResourceDir, version + } + } +} +sourceSets { + main { + // register generated resources on the main SourceSet + output.dir(extraResourceDir, builtBy: 'generateBuildInfo') + } +} diff --git a/core/pom.xml b/core/pom.xml index bb27ec916c65..2d19e5b81cf5 100644 --- a/core/pom.xml +++ b/core/pom.xml @@ -21,7 +21,7 @@ org.apache.spark spark-parent_2.11 - 2.0.1-SNAPSHOT + 2.0.1 ../pom.xml @@ -327,7 +327,7 @@ net.sf.py4j py4j - 0.10.1 + 0.10.3 org.apache.spark diff --git a/core/src/main/java/org/apache/spark/shuffle/sort/BypassMergeSortShuffleWriter.java b/core/src/main/java/org/apache/spark/shuffle/sort/BypassMergeSortShuffleWriter.java index 0e9defe5b4a5..601dd6edfcf7 100644 --- a/core/src/main/java/org/apache/spark/shuffle/sort/BypassMergeSortShuffleWriter.java +++ b/core/src/main/java/org/apache/spark/shuffle/sort/BypassMergeSortShuffleWriter.java @@ -156,8 +156,14 @@ public void write(Iterator> records) throws IOException { File output = shuffleBlockResolver.getDataFile(shuffleId, mapId); File tmp = Utils.tempFileWith(output); - partitionLengths = writePartitionedFile(tmp); - shuffleBlockResolver.writeIndexFileAndCommit(shuffleId, mapId, partitionLengths, tmp); + try { + partitionLengths = writePartitionedFile(tmp); + shuffleBlockResolver.writeIndexFileAndCommit(shuffleId, mapId, partitionLengths, tmp); + } finally { + if (tmp.exists() && !tmp.delete()) { + logger.error("Error while deleting temp file {}", tmp.getAbsolutePath()); + } + } mapStatus = MapStatus$.MODULE$.apply(blockManager.shuffleServerId(), partitionLengths); } diff --git a/core/src/main/java/org/apache/spark/shuffle/sort/UnsafeShuffleWriter.java b/core/src/main/java/org/apache/spark/shuffle/sort/UnsafeShuffleWriter.java index 44e6aa73d975..c08a5d424ea6 100644 --- a/core/src/main/java/org/apache/spark/shuffle/sort/UnsafeShuffleWriter.java +++ b/core/src/main/java/org/apache/spark/shuffle/sort/UnsafeShuffleWriter.java @@ -207,15 +207,21 @@ void closeAndWriteOutput() throws IOException { final File output = shuffleBlockResolver.getDataFile(shuffleId, mapId); final File tmp = Utils.tempFileWith(output); try { - partitionLengths = mergeSpills(spills, tmp); - } finally { - for (SpillInfo spill : spills) { - if (spill.file.exists() && ! spill.file.delete()) { - logger.error("Error while deleting spill file {}", spill.file.getPath()); + try { + partitionLengths = mergeSpills(spills, tmp); + } finally { + for (SpillInfo spill : spills) { + if (spill.file.exists() && ! spill.file.delete()) { + logger.error("Error while deleting spill file {}", spill.file.getPath()); + } } } + shuffleBlockResolver.writeIndexFileAndCommit(shuffleId, mapId, partitionLengths, tmp); + } finally { + if (tmp.exists() && !tmp.delete()) { + logger.error("Error while deleting temp file {}", tmp.getAbsolutePath()); + } } - shuffleBlockResolver.writeIndexFileAndCommit(shuffleId, mapId, partitionLengths, tmp); mapStatus = MapStatus$.MODULE$.apply(blockManager.shuffleServerId(), partitionLengths); } diff --git a/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeExternalSorter.java b/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeExternalSorter.java index 50f5b068b276..999ded45f2e2 100644 --- a/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeExternalSorter.java +++ b/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeExternalSorter.java @@ -521,7 +521,8 @@ public long spill() throws IOException { // is accessing the current record. We free this page in that caller's next loadNext() // call. for (MemoryBlock page : allocatedPages) { - if (!loaded || page.getBaseObject() != upstream.getBaseObject()) { + if (!loaded || page.pageNumber != + ((UnsafeInMemorySorter.SortedIterator)upstream).getCurrentPageNumber()) { released += page.size(); freePage(page); } else { diff --git a/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeInMemorySorter.java b/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeInMemorySorter.java index c7b070f519f8..b51737158098 100644 --- a/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeInMemorySorter.java +++ b/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeInMemorySorter.java @@ -225,6 +225,7 @@ public final class SortedIterator extends UnsafeSorterIterator implements Clonea private long baseOffset; private long keyPrefix; private int recordLength; + private long currentPageNumber; private SortedIterator(int numRecords, int offset) { this.numRecords = numRecords; @@ -239,6 +240,7 @@ public SortedIterator clone() { iter.baseOffset = baseOffset; iter.keyPrefix = keyPrefix; iter.recordLength = recordLength; + iter.currentPageNumber = currentPageNumber; return iter; } @@ -256,6 +258,7 @@ public boolean hasNext() { public void loadNext() { // This pointer points to a 4-byte record length, followed by the record's bytes final long recordPointer = array.get(offset + position); + currentPageNumber = TaskMemoryManager.decodePageNumber(recordPointer); baseObject = memoryManager.getPage(recordPointer); baseOffset = memoryManager.getOffsetInPage(recordPointer) + 4; // Skip over record length recordLength = Platform.getInt(baseObject, baseOffset - 4); @@ -269,6 +272,10 @@ public void loadNext() { @Override public long getBaseOffset() { return baseOffset; } + public long getCurrentPageNumber() { + return currentPageNumber; + } + @Override public int getRecordLength() { return recordLength; } diff --git a/core/src/main/resources/org/apache/spark/ui/static/historypage.js b/core/src/main/resources/org/apache/spark/ui/static/historypage.js index d2161662d567..177120aaa6c1 100644 --- a/core/src/main/resources/org/apache/spark/ui/static/historypage.js +++ b/core/src/main/resources/org/apache/spark/ui/static/historypage.js @@ -15,6 +15,12 @@ * limitations under the License. */ +var appLimit = -1; + +function setAppLimit(val) { + appLimit = val; +} + // this function works exactly the same as UIUtils.formatDuration function formatDuration(milliseconds) { if (milliseconds < 100) { @@ -111,7 +117,7 @@ $(document).ready(function() { requestedIncomplete = getParameterByName("showIncomplete", searchString); requestedIncomplete = (requestedIncomplete == "true" ? true : false); - $.getJSON("api/v1/applications", function(response,status,jqXHR) { + $.getJSON("api/v1/applications?limit=" + appLimit, function(response,status,jqXHR) { var array = []; var hasMultipleAttempts = false; for (i in response) { diff --git a/core/src/main/resources/org/apache/spark/ui/static/timeline-view.css b/core/src/main/resources/org/apache/spark/ui/static/timeline-view.css index 0f400461c529..3bf3e8bfa1f3 100644 --- a/core/src/main/resources/org/apache/spark/ui/static/timeline-view.css +++ b/core/src/main/resources/org/apache/spark/ui/static/timeline-view.css @@ -33,12 +33,15 @@ div#application-timeline, div#job-timeline { height: 55px; } -#task-assignment-timeline div.item.range { - padding: 0px; +#task-assignment-timeline div.vis-item.vis-range { height: 26px; border-width: 0; } +#task-assignment-timeline .vis-item-content { + padding: 0px; +} + .task-assignment-timeline-content { width: 100%; } @@ -83,24 +86,24 @@ rect.getting-result-time-proportion { stroke: #75B0A6; } -.vis.timeline { +.vis-timeline { line-height: 14px; } -.vis.timeline div.content { +.vis-timeline div.vis-item-content { width: 100%; } -.vis.timeline .item.stage { +.vis-timeline .vis-item.stage { cursor: pointer; } -.vis.timeline .item.stage.succeeded { +.vis-timeline .vis-item.stage.succeeded { background-color: #A0DFFF; border-color: #3EC0FF; } -.vis.timeline .item.stage.succeeded.selected { +.vis-timeline .vis-item.stage.succeeded.vis-selected { background-color: #A0DFFF; border-color: #3EC0FF; z-index: auto; @@ -111,12 +114,12 @@ rect.getting-result-time-proportion { stroke: #3EC0FF; } -.vis.timeline .item.stage.failed { +.vis-timeline .vis-item.stage.failed { background-color: #FFA1B0; border-color: #FF4D6D; } -.vis.timeline .item.stage.failed.selected { +.vis-timeline .vis-item.stage.failed.vis-selected { background-color: #FFA1B0; border-color: #FF4D6D; z-index: auto; @@ -127,12 +130,12 @@ rect.getting-result-time-proportion { stroke: #FF4D6D; } -.vis.timeline .item.stage.running { +.vis-timeline .vis-item.stage.running { background-color: #A2FCC0; border-color: #36F572; } -.vis.timeline .item.stage.running.selected { +.vis-timeline .vis-item.stage.running.vis-selected { background-color: #A2FCC0; border-color: #36F572; z-index: auto; @@ -143,20 +146,20 @@ rect.getting-result-time-proportion { stroke: #36F572; } -.vis.timeline .foreground { +.vis-timeline .vis-foreground { cursor: move; } -.vis.timeline .item.job { +.vis-timeline .vis-item.job { cursor: pointer; } -.vis.timeline .item.job.succeeded { +.vis-timeline .vis-item.job.succeeded { background-color: #A0DFFF; border-color: #3EC0FF; } -.vis.timeline .item.job.succeeded.selected { +.vis-timeline .vis-item.job.succeeded.vis-selected { background-color: #A0DFFF; border-color: #3EC0FF; z-index: auto; @@ -167,12 +170,12 @@ rect.getting-result-time-proportion { stroke: #3EC0FF; } -.vis.timeline .item.job.failed { +.vis-timeline .vis-item.job.failed { background-color: #FFA1B0; border-color: #FF4D6D; } -.vis.timeline .item.job.failed.selected { +.vis-timeline .vis-item.job.failed.vis-selected { background-color: #FFA1B0; border-color: #FF4D6D; z-index: auto; @@ -183,12 +186,12 @@ rect.getting-result-time-proportion { stroke: #FF4D6D; } -.vis.timeline .item.job.running { +.vis-timeline .vis-item.job.running { background-color: #A2FCC0; border-color: #36F572; } -.vis.timeline .item.job.running.selected { +.vis-timeline .vis-item.job.running.vis-selected { background-color: #A2FCC0; border-color: #36F572; z-index: auto; @@ -199,7 +202,7 @@ rect.getting-result-time-proportion { stroke: #36F572; } -.vis.timeline .item.executor.added { +.vis-timeline .vis-item.executor.added { background-color: #A0DFFF; border-color: #3EC0FF; } @@ -209,7 +212,7 @@ rect.getting-result-time-proportion { stroke: #3EC0FF; } -.vis.timeline .item.executor.removed { +.vis-timeline .vis-item.executor.removed { background-color: #FFA1B0; border-color: #FF4D6D; } @@ -219,7 +222,7 @@ rect.getting-result-time-proportion { stroke: #FF4D6D; } -.vis.timeline .item.executor.selected { +.vis-timeline .vis-item.executor.vis-selected { background-color: #A2FCC0; border-color: #36F572; z-index: 2; @@ -258,15 +261,15 @@ span.expand-task-assignment-timeline { cursor: pointer; } -.vis.timeline .item.range .content { +.vis-timeline .vis-item.vis-range .vis-item-content { position: unset; } -.vis.timeline .item .tooltip-inner { +.vis-timeline .vis-item .tooltip-inner { max-width: unset !important; } -.vispanel.center { +.vis-panel.vis-center { font-size: 12px; line-height: 12px; } diff --git a/core/src/main/resources/org/apache/spark/ui/static/timeline-view.js b/core/src/main/resources/org/apache/spark/ui/static/timeline-view.js index 9ab5684d901f..a6153ceda75e 100644 --- a/core/src/main/resources/org/apache/spark/ui/static/timeline-view.js +++ b/core/src/main/resources/org/apache/spark/ui/static/timeline-view.js @@ -41,7 +41,7 @@ function drawApplicationTimeline(groupArray, eventObjArray, startTime, offset) { setupExecutorEventAction(); function setupJobEventAction() { - $(".item.range.job.application-timeline-object").each(function() { + $(".vis-item.vis-range.job.application-timeline-object").each(function() { var getSelectorForJobEntry = function(baseElem) { var jobIdText = $($(baseElem).find(".application-timeline-content")[0]).text(); var jobId = jobIdText.match("\\(Job (\\d+)\\)$")[1]; @@ -116,7 +116,7 @@ function drawJobTimeline(groupArray, eventObjArray, startTime, offset) { setupExecutorEventAction(); function setupStageEventAction() { - $(".item.range.stage.job-timeline-object").each(function() { + $(".vis-item.vis-range.stage.job-timeline-object").each(function() { var getSelectorForStageEntry = function(baseElem) { var stageIdText = $($(baseElem).find(".job-timeline-content")[0]).text(); var stageIdAndAttempt = stageIdText.match("\\(Stage (\\d+\\.\\d+)\\)$")[1].split("."); @@ -233,7 +233,7 @@ $(function (){ }); function setupExecutorEventAction() { - $(".item.box.executor").each(function () { + $(".vis-item.vis-box.executor").each(function () { $(this).hover( function() { $($(this).find(".executor-event-content")[0]).tooltip("show"); diff --git a/core/src/main/scala/org/apache/spark/ExecutorAllocationManager.scala b/core/src/main/scala/org/apache/spark/ExecutorAllocationManager.scala index 932ba16812bb..6f320c524201 100644 --- a/core/src/main/scala/org/apache/spark/ExecutorAllocationManager.scala +++ b/core/src/main/scala/org/apache/spark/ExecutorAllocationManager.scala @@ -230,7 +230,7 @@ private[spark] class ExecutorAllocationManager( } } } - executor.scheduleAtFixedRate(scheduleTask, 0, intervalMillis, TimeUnit.MILLISECONDS) + executor.scheduleWithFixedDelay(scheduleTask, 0, intervalMillis, TimeUnit.MILLISECONDS) client.requestTotalExecutors(numExecutorsTarget, localityAwareTasks, hostToLocalTaskCount) } diff --git a/core/src/main/scala/org/apache/spark/Partitioner.scala b/core/src/main/scala/org/apache/spark/Partitioner.scala index 98c3abe93b55..341b337a9f88 100644 --- a/core/src/main/scala/org/apache/spark/Partitioner.scala +++ b/core/src/main/scala/org/apache/spark/Partitioner.scala @@ -75,11 +75,18 @@ object Partitioner { * so attempting to partition an RDD[Array[_]] or RDD[(Array[_], _)] using a HashPartitioner will * produce an unexpected or incorrect result. */ -class HashPartitioner(partitions: Int) extends Partitioner { +class HashPartitioner(partitions: Int, buckets: Int = 0) extends Partitioner { require(partitions >= 0, s"Number of partitions ($partitions) cannot be negative.") + require(buckets >= 0, s"Number of buckets ($buckets) cannot be negative.") + + def this(partitions: Int) { + this(partitions, 0) + } def numPartitions: Int = partitions + def numBuckets: Int = buckets + def getPartition(key: Any): Int = key match { case null => 0 case _ => Utils.nonNegativeMod(key.hashCode, numPartitions) diff --git a/core/src/main/scala/org/apache/spark/SparkConf.scala b/core/src/main/scala/org/apache/spark/SparkConf.scala index 33ed0d5493e0..c161cf5a0e20 100644 --- a/core/src/main/scala/org/apache/spark/SparkConf.scala +++ b/core/src/main/scala/org/apache/spark/SparkConf.scala @@ -14,6 +14,24 @@ * See the License for the specific language governing permissions and * limitations under the License. */ +/* + * Changes for SnappyData data platform. + * + * Portions Copyright (c) 2016 SnappyData, Inc. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you + * may not use this file except in compliance with the License. You + * may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or + * implied. See the License for the specific language governing + * permissions and limitations under the License. See accompanying + * LICENSE file. + */ package org.apache.spark @@ -62,7 +80,8 @@ class SparkConf(loadDefaults: Boolean) extends Cloneable with Logging { private[spark] def loadFromSystemProperties(silent: Boolean): SparkConf = { // Load any spark.* system properties - for ((key, value) <- Utils.getSystemProperties if key.startsWith("spark.")) { + for ((key, value) <- Utils.getSystemProperties + if key.startsWith("spark.") || key.startsWith("snappydata.")) { set(key, value, silent) } this diff --git a/core/src/main/scala/org/apache/spark/SparkContext.scala b/core/src/main/scala/org/apache/spark/SparkContext.scala index fe15052b6247..0c1dc6ec4a4d 100644 --- a/core/src/main/scala/org/apache/spark/SparkContext.scala +++ b/core/src/main/scala/org/apache/spark/SparkContext.scala @@ -14,14 +14,32 @@ * See the License for the specific language governing permissions and * limitations under the License. */ +/* + * Changes for SnappyData data platform. + * + * Portions Copyright (c) 2016 SnappyData, Inc. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you + * may not use this file except in compliance with the License. You + * may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or + * implied. See the License for the specific language governing + * permissions and limitations under the License. See accompanying + * LICENSE file. + */ package org.apache.spark import java.io._ import java.lang.reflect.Constructor -import java.net.URI +import java.net.{MalformedURLException, URI} import java.util.{Arrays, Locale, Properties, ServiceLoader, UUID} -import java.util.concurrent.ConcurrentMap +import java.util.concurrent.{ConcurrentHashMap, ConcurrentMap} import java.util.concurrent.atomic.{AtomicBoolean, AtomicInteger, AtomicReference} import scala.collection.JavaConverters._ @@ -35,11 +53,9 @@ import scala.util.control.NonFatal import com.google.common.collect.MapMaker import org.apache.commons.lang3.SerializationUtils import org.apache.hadoop.conf.Configuration -import org.apache.hadoop.fs.Path -import org.apache.hadoop.io.{ArrayWritable, BooleanWritable, BytesWritable, DoubleWritable, - FloatWritable, IntWritable, LongWritable, NullWritable, Text, Writable} -import org.apache.hadoop.mapred.{FileInputFormat, InputFormat, JobConf, SequenceFileInputFormat, - TextInputFormat} +import org.apache.hadoop.fs.{FileSystem, Path} +import org.apache.hadoop.io.{ArrayWritable, BooleanWritable, BytesWritable, DoubleWritable, FloatWritable, IntWritable, LongWritable, NullWritable, Text, Writable} +import org.apache.hadoop.mapred.{FileInputFormat, InputFormat, JobConf, SequenceFileInputFormat, TextInputFormat} import org.apache.hadoop.mapreduce.{InputFormat => NewInputFormat, Job => NewHadoopJob} import org.apache.hadoop.mapreduce.lib.input.{FileInputFormat => NewFileInputFormat} import org.apache.mesos.MesosNativeLibrary @@ -47,8 +63,7 @@ import org.apache.mesos.MesosNativeLibrary import org.apache.spark.annotation.DeveloperApi import org.apache.spark.broadcast.Broadcast import org.apache.spark.deploy.{LocalSparkCluster, SparkHadoopUtil} -import org.apache.spark.input.{FixedLengthBinaryInputFormat, PortableDataStream, StreamInputFormat, - WholeTextFileInputFormat} +import org.apache.spark.input.{FixedLengthBinaryInputFormat, PortableDataStream, StreamInputFormat, WholeTextFileInputFormat} import org.apache.spark.internal.Logging import org.apache.spark.io.CompressionCodec import org.apache.spark.partial.{ApproximateEvaluator, PartialResult} @@ -249,7 +264,7 @@ class SparkContext(config: SparkConf) extends Logging with ExecutorAllocationCli def isStopped: Boolean = stopped.get() // An asynchronous listener bus for Spark events - private[spark] val listenerBus = new LiveListenerBus + private[spark] val listenerBus = new LiveListenerBus(this) // This function allows components created by SparkEnv to be mocked in unit tests: private[spark] def createSparkEnv( @@ -262,8 +277,8 @@ class SparkContext(config: SparkConf) extends Logging with ExecutorAllocationCli private[spark] def env: SparkEnv = _env // Used to store a URL for each static file/jar together with the file's local timestamp - private[spark] val addedFiles = HashMap[String, Long]() - private[spark] val addedJars = HashMap[String, Long]() + private[spark] val addedFiles = new ConcurrentHashMap[String, Long]().asScala + private[spark] val addedJars = new ConcurrentHashMap[String, Long]().asScala // Keeps track of all persisted RDDs private[spark] val persistentRdds = { @@ -788,7 +803,7 @@ class SparkContext(config: SparkConf) extends Logging with ExecutorAllocationCli def makeRDD[T: ClassTag](seq: Seq[(T, Seq[String])]): RDD[T] = withScope { assertNotStopped() val indexToPrefs = seq.zipWithIndex.map(t => (t._2, t._1._2)).toMap - new ParallelCollectionRDD[T](this, seq.map(_._1), seq.size, indexToPrefs) + new ParallelCollectionRDD[T](this, seq.map(_._1), math.max(seq.size, 1), indexToPrefs) } /** @@ -960,6 +975,11 @@ class SparkContext(config: SparkConf) extends Logging with ExecutorAllocationCli valueClass: Class[V], minPartitions: Int = defaultMinPartitions): RDD[(K, V)] = withScope { assertNotStopped() + + // This is a hack to enforce loading hdfs-site.xml. + // See SPARK-11227 for details. + FileSystem.getLocal(conf) + // Add necessary security credentials to the JobConf before broadcasting it. SparkHadoopUtil.get.addCredentials(conf) new HadoopRDD(this, conf, inputFormatClass, keyClass, valueClass, minPartitions) @@ -980,6 +1000,11 @@ class SparkContext(config: SparkConf) extends Logging with ExecutorAllocationCli valueClass: Class[V], minPartitions: Int = defaultMinPartitions): RDD[(K, V)] = withScope { assertNotStopped() + + // This is a hack to enforce loading hdfs-site.xml. + // See SPARK-11227 for details. + FileSystem.getLocal(hadoopConfiguration) + // A Hadoop configuration can be about 10 KB, which is pretty big, so broadcast it. val confBroadcast = broadcast(new SerializableConfiguration(hadoopConfiguration)) val setInputPathsFunc = (jobConf: JobConf) => FileInputFormat.setInputPaths(jobConf, path) @@ -1064,6 +1089,11 @@ class SparkContext(config: SparkConf) extends Logging with ExecutorAllocationCli vClass: Class[V], conf: Configuration = hadoopConfiguration): RDD[(K, V)] = withScope { assertNotStopped() + + // This is a hack to enforce loading hdfs-site.xml. + // See SPARK-11227 for details. + FileSystem.getLocal(hadoopConfiguration) + // The call to NewHadoopJob automatically adds security credentials to conf, // so we don't need to explicitly add them ourselves val job = NewHadoopJob.getInstance(conf) @@ -1098,6 +1128,11 @@ class SparkContext(config: SparkConf) extends Logging with ExecutorAllocationCli kClass: Class[K], vClass: Class[V]): RDD[(K, V)] = withScope { assertNotStopped() + + // This is a hack to enforce loading hdfs-site.xml. + // See SPARK-11227 for details. + FileSystem.getLocal(conf) + // Add necessary security credentials to the JobConf. Required to access secure HDFS. val jconf = new JobConf(conf) SparkHadoopUtil.get.addCredentials(jconf) @@ -1399,7 +1434,7 @@ class SparkContext(config: SparkConf) extends Logging with ExecutorAllocationCli * supported for Hadoop-supported filesystems. */ def addFile(path: String, recursive: Boolean): Unit = { - val uri = new URI(path) + val uri = new Path(path).toUri val schemeCorrectedPath = uri.getScheme match { case null | "local" => new File(path).getCanonicalFile.toURI.toString case _ => path @@ -1421,6 +1456,9 @@ class SparkContext(config: SparkConf) extends Logging with ExecutorAllocationCli throw new SparkException(s"Added file $hadoopPath is a directory and recursive is not " + "turned on.") } + } else { + // SPARK-17650: Make sure this is a valid URL before adding it to the list of dependencies + Utils.validateURL(uri) } val key = if (!isLocal && scheme == "file") { @@ -1429,14 +1467,14 @@ class SparkContext(config: SparkConf) extends Logging with ExecutorAllocationCli schemeCorrectedPath } val timestamp = System.currentTimeMillis - addedFiles(key) = timestamp - - // Fetch the file locally in case a job is executed using DAGScheduler.runLocally(). - Utils.fetchFile(path, new File(SparkFiles.getRootDirectory()), conf, env.securityManager, - hadoopConfiguration, timestamp, useCache = false) - - logInfo("Added file " + path + " at " + key + " with timestamp " + addedFiles(key)) - postEnvironmentUpdate() + if (addedFiles.putIfAbsent(key, timestamp).isEmpty) { + logInfo(s"Added file $path at $key with timestamp $timestamp") + // Fetch the file locally so that closures which are run on the driver can still use the + // SparkFiles API to access files. + Utils.fetchFile(uri.toString, new File(SparkFiles.getRootDirectory()), conf, + env.securityManager, hadoopConfiguration, timestamp, useCache = false) + postEnvironmentUpdate() + } } /** @@ -1472,7 +1510,8 @@ class SparkContext(config: SparkConf) extends Logging with ExecutorAllocationCli * This includes running, pending, and completed tasks. * @return whether the request is acknowledged by the cluster manager. */ - private[spark] override def requestTotalExecutors( + @DeveloperApi + override def requestTotalExecutors( numExecutors: Int, localityAwareTasks: Int, hostToLocalTaskCount: scala.collection.immutable.Map[String, Int] @@ -1679,6 +1718,8 @@ class SparkContext(config: SparkConf) extends Logging with ExecutorAllocationCli key = env.rpcEnv.fileServer.addJar(new File(path)) } else { val uri = new URI(path) + // SPARK-17650: Make sure this is a valid URL before adding it to the list of dependencies + Utils.validateURL(uri) key = uri.getScheme match { // A JAR file which exists only on the driver node case null | "file" => @@ -1704,12 +1745,6 @@ class SparkContext(config: SparkConf) extends Logging with ExecutorAllocationCli case exc: FileNotFoundException => logError(s"Jar not found at $path") null - case e: Exception => - // For now just log an error but allow to go through so spark examples work. - // The spark examples don't really need the jar distributed since its also - // the app jar. - logError("Error adding jar (" + e + "), was the --addJars option used?") - null } } // A JAR file which exists locally on every worker node @@ -1720,11 +1755,13 @@ class SparkContext(config: SparkConf) extends Logging with ExecutorAllocationCli } } if (key != null) { - addedJars(key) = System.currentTimeMillis - logInfo("Added JAR " + path + " at " + key + " with timestamp " + addedJars(key)) + val timestamp = System.currentTimeMillis + if (addedJars.putIfAbsent(key, timestamp).isEmpty) { + logInfo(s"Added JAR $path at $key with timestamp $timestamp") + postEnvironmentUpdate() + } } } - postEnvironmentUpdate() } /** @@ -2137,7 +2174,7 @@ class SparkContext(config: SparkConf) extends Logging with ExecutorAllocationCli } } - listenerBus.start(this) + listenerBus.start() _listenerBusStarted = true } @@ -2191,7 +2228,7 @@ object SparkContext extends Logging { * * Access to this field is guarded by SPARK_CONTEXT_CONSTRUCTOR_LOCK. */ - private val activeContext: AtomicReference[SparkContext] = + private[spark] val activeContext: AtomicReference[SparkContext] = new AtomicReference[SparkContext](null) /** diff --git a/core/src/main/scala/org/apache/spark/SparkEnv.scala b/core/src/main/scala/org/apache/spark/SparkEnv.scala index af50a6dc2d8d..5da0dfcf5ad9 100644 --- a/core/src/main/scala/org/apache/spark/SparkEnv.scala +++ b/core/src/main/scala/org/apache/spark/SparkEnv.scala @@ -14,6 +14,24 @@ * See the License for the specific language governing permissions and * limitations under the License. */ +/* + * Changes for SnappyData data platform. + * + * Portions Copyright (c) 2016 SnappyData, Inc. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you + * may not use this file except in compliance with the License. You + * may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or + * implied. See the License for the specific language governing + * permissions and limitations under the License. See accompanying + * LICENSE file. + */ package org.apache.spark @@ -302,10 +320,15 @@ object SparkEnv extends Logging { val useLegacyMemoryManager = conf.getBoolean("spark.memory.useLegacyMode", false) val memoryManager: MemoryManager = - if (useLegacyMemoryManager) { - new StaticMemoryManager(conf, numUsableCores) - } else { - UnifiedMemoryManager(conf, numUsableCores) + conf.getOption("spark.memory.manager").map(Utils.classForName(_) + .getConstructor(classOf[SparkConf], classOf[Int]) + .newInstance(conf, Int.box(numUsableCores)) + .asInstanceOf[MemoryManager]).getOrElse { + if (useLegacyMemoryManager) { + new StaticMemoryManager(conf, numUsableCores) + } else { + UnifiedMemoryManager(conf, numUsableCores) + } } val blockTransferService = @@ -402,7 +425,8 @@ object SparkEnv extends Logging { // System properties that are not java classpaths val systemProperties = Utils.getSystemProperties.toSeq val otherProperties = systemProperties.filter { case (k, _) => - k != "java.class.path" && !k.startsWith("spark.") + k != "java.class.path" && !k.startsWith("spark.") && + !k.startsWith("snappydata.") }.sorted // Class paths including all added jars and files diff --git a/core/src/main/scala/org/apache/spark/api/python/PythonUtils.scala b/core/src/main/scala/org/apache/spark/api/python/PythonUtils.scala index 64cf4981714c..701097ace897 100644 --- a/core/src/main/scala/org/apache/spark/api/python/PythonUtils.scala +++ b/core/src/main/scala/org/apache/spark/api/python/PythonUtils.scala @@ -32,7 +32,7 @@ private[spark] object PythonUtils { val pythonPath = new ArrayBuffer[String] for (sparkHome <- sys.env.get("SPARK_HOME")) { pythonPath += Seq(sparkHome, "python", "lib", "pyspark.zip").mkString(File.separator) - pythonPath += Seq(sparkHome, "python", "lib", "py4j-0.10.1-src.zip").mkString(File.separator) + pythonPath += Seq(sparkHome, "python", "lib", "py4j-0.10.3-src.zip").mkString(File.separator) } pythonPath ++= SparkContext.jarOfObject(this) pythonPath.mkString(File.pathSeparator) diff --git a/core/src/main/scala/org/apache/spark/deploy/DeployMessage.scala b/core/src/main/scala/org/apache/spark/deploy/DeployMessage.scala index 34c0696bfc4e..ac09c6c497f8 100644 --- a/core/src/main/scala/org/apache/spark/deploy/DeployMessage.scala +++ b/core/src/main/scala/org/apache/spark/deploy/DeployMessage.scala @@ -135,7 +135,7 @@ private[deploy] object DeployMessages { } case class ExecutorUpdated(id: Int, state: ExecutorState, message: Option[String], - exitStatus: Option[Int]) + exitStatus: Option[Int], workerLost: Boolean) case class ApplicationRemoved(message: String) diff --git a/core/src/main/scala/org/apache/spark/deploy/SparkSubmit.scala b/core/src/main/scala/org/apache/spark/deploy/SparkSubmit.scala index 9feafc99ac07..80611658a164 100644 --- a/core/src/main/scala/org/apache/spark/deploy/SparkSubmit.scala +++ b/core/src/main/scala/org/apache/spark/deploy/SparkSubmit.scala @@ -311,7 +311,7 @@ object SparkSubmit { // In Mesos cluster mode, non-local python files are automatically downloaded by Mesos. if (args.isPython && !isYarnCluster && !isMesosCluster) { if (Utils.nonLocalPaths(args.primaryResource).nonEmpty) { - printErrorAndExit(s"Only local python files are supported: $args.primaryResource") + printErrorAndExit(s"Only local python files are supported: ${args.primaryResource}") } val nonLocalPyFiles = Utils.nonLocalPaths(args.pyFiles).mkString(",") if (nonLocalPyFiles.nonEmpty) { @@ -322,7 +322,7 @@ object SparkSubmit { // Require all R files to be local if (args.isR && !isYarnCluster) { if (Utils.nonLocalPaths(args.primaryResource).nonEmpty) { - printErrorAndExit(s"Only local R files are supported: $args.primaryResource") + printErrorAndExit(s"Only local R files are supported: ${args.primaryResource}") } } @@ -633,7 +633,14 @@ object SparkSubmit { // explicitly sets `spark.submit.pyFiles` in his/her default properties file. sysProps.get("spark.submit.pyFiles").foreach { pyFiles => val resolvedPyFiles = Utils.resolveURIs(pyFiles) - val formattedPyFiles = PythonRunner.formatPaths(resolvedPyFiles).mkString(",") + val formattedPyFiles = if (!isYarnCluster && !isMesosCluster) { + PythonRunner.formatPaths(resolvedPyFiles).mkString(",") + } else { + // Ignoring formatting python path in yarn and mesos cluster mode, these two modes + // support dealing with remote python files, they could distribute and add python files + // locally. + resolvedPyFiles + } sysProps("spark.submit.pyFiles") = formattedPyFiles } @@ -897,9 +904,12 @@ private[spark] object SparkSubmitUtils { val localIvyRoot = new File(ivySettings.getDefaultIvyUserDir, "local") localIvy.setLocal(true) localIvy.setRepository(new FileRepository(localIvyRoot)) - val ivyPattern = Seq("[organisation]", "[module]", "[revision]", "[type]s", - "[artifact](-[classifier]).[ext]").mkString(File.separator) - localIvy.addIvyPattern(localIvyRoot.getAbsolutePath + File.separator + ivyPattern) + val ivyPattern = Seq(localIvyRoot.getAbsolutePath, "[organisation]", "[module]", "[revision]", + "ivys", "ivy.xml").mkString(File.separator) + localIvy.addIvyPattern(ivyPattern) + val artifactPattern = Seq(localIvyRoot.getAbsolutePath, "[organisation]", "[module]", + "[revision]", "[type]s", "[artifact](-[classifier]).[ext]").mkString(File.separator) + localIvy.addArtifactPattern(artifactPattern) localIvy.setName("local-ivy-cache") cr.add(localIvy) @@ -944,7 +954,7 @@ private[spark] object SparkSubmitUtils { artifacts.foreach { mvn => val ri = ModuleRevisionId.newInstance(mvn.groupId, mvn.artifactId, mvn.version) val dd = new DefaultDependencyDescriptor(ri, false, false) - dd.addDependencyConfiguration(ivyConfName, ivyConfName) + dd.addDependencyConfiguration(ivyConfName, ivyConfName + "(runtime)") // scalastyle:off println printStream.println(s"${dd.getDependencyId} added as a dependency") // scalastyle:on println diff --git a/core/src/main/scala/org/apache/spark/deploy/SparkSubmitArguments.scala b/core/src/main/scala/org/apache/spark/deploy/SparkSubmitArguments.scala index f1761e7c1ec9..80bfced167ef 100644 --- a/core/src/main/scala/org/apache/spark/deploy/SparkSubmitArguments.scala +++ b/core/src/main/scala/org/apache/spark/deploy/SparkSubmitArguments.scala @@ -14,6 +14,24 @@ * See the License for the specific language governing permissions and * limitations under the License. */ +/* + * Changes for SnappyData data platform. + * + * Portions Copyright (c) 2016 SnappyData, Inc. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you + * may not use this file except in compliance with the License. You + * may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or + * implied. See the License for the specific language governing + * permissions and limitations under the License. See accompanying + * LICENSE file. + */ package org.apache.spark.deploy @@ -129,7 +147,7 @@ private[deploy] class SparkSubmitArguments(args: Seq[String], env: Map[String, S */ private def ignoreNonSparkProperties(): Unit = { sparkProperties.foreach { case (k, v) => - if (!k.startsWith("spark.")) { + if (!k.startsWith("spark.") && !k.startsWith("snappydata.")) { sparkProperties -= k SparkSubmit.printWarning(s"Ignoring non-spark config property: $k=$v") } diff --git a/core/src/main/scala/org/apache/spark/deploy/client/StandaloneAppClient.scala b/core/src/main/scala/org/apache/spark/deploy/client/StandaloneAppClient.scala index a9df732df93c..93f58ce63799 100644 --- a/core/src/main/scala/org/apache/spark/deploy/client/StandaloneAppClient.scala +++ b/core/src/main/scala/org/apache/spark/deploy/client/StandaloneAppClient.scala @@ -21,6 +21,8 @@ import java.util.concurrent._ import java.util.concurrent.{Future => JFuture, ScheduledFuture => JScheduledFuture} import java.util.concurrent.atomic.{AtomicBoolean, AtomicReference} +import scala.concurrent.Future +import scala.util.{Failure, Success} import scala.util.control.NonFatal import org.apache.spark.SparkConf @@ -79,11 +81,6 @@ private[spark] class StandaloneAppClient( private val registrationRetryThread = ThreadUtils.newDaemonSingleThreadScheduledExecutor("appclient-registration-retry-thread") - // A thread pool to perform receive then reply actions in a thread so as not to block the - // event loop. - private val askAndReplyThreadPool = - ThreadUtils.newDaemonCachedThreadPool("appclient-receive-and-reply-threadpool") - override def onStart(): Unit = { try { registerWithMaster(1) @@ -177,12 +174,12 @@ private[spark] class StandaloneAppClient( cores)) listener.executorAdded(fullId, workerId, hostPort, cores, memory) - case ExecutorUpdated(id, state, message, exitStatus) => + case ExecutorUpdated(id, state, message, exitStatus, workerLost) => val fullId = appId + "/" + id val messageText = message.map(s => " (" + s + ")").getOrElse("") logInfo("Executor updated: %s is now %s%s".format(fullId, state, messageText)) if (ExecutorState.isFinished(state)) { - listener.executorRemoved(fullId, message.getOrElse(""), exitStatus) + listener.executorRemoved(fullId, message.getOrElse(""), exitStatus, workerLost) } case MasterChanged(masterRef, masterWebUiUrl) => @@ -220,19 +217,13 @@ private[spark] class StandaloneAppClient( endpointRef: RpcEndpointRef, context: RpcCallContext, msg: T): Unit = { - // Create a thread to ask a message and reply with the result. Allow thread to be + // Ask a message and create a thread to reply with the result. Allow thread to be // interrupted during shutdown, otherwise context must be notified of NonFatal errors. - askAndReplyThreadPool.execute(new Runnable { - override def run(): Unit = { - try { - context.reply(endpointRef.askWithRetry[Boolean](msg)) - } catch { - case ie: InterruptedException => // Cancelled - case NonFatal(t) => - context.sendFailure(t) - } - } - }) + endpointRef.ask[Boolean](msg).andThen { + case Success(b) => context.reply(b) + case Failure(ie: InterruptedException) => // Cancelled + case Failure(NonFatal(t)) => context.sendFailure(t) + }(ThreadUtils.sameThread) } override def onDisconnected(address: RpcAddress): Unit = { @@ -272,7 +263,6 @@ private[spark] class StandaloneAppClient( registrationRetryThread.shutdownNow() registerMasterFutures.get.foreach(_.cancel(true)) registerMasterThreadPool.shutdownNow() - askAndReplyThreadPool.shutdownNow() } } @@ -301,12 +291,12 @@ private[spark] class StandaloneAppClient( * * @return whether the request is acknowledged. */ - def requestTotalExecutors(requestedTotal: Int): Boolean = { + def requestTotalExecutors(requestedTotal: Int): Future[Boolean] = { if (endpoint.get != null && appId.get != null) { - endpoint.get.askWithRetry[Boolean](RequestExecutors(appId.get, requestedTotal)) + endpoint.get.ask[Boolean](RequestExecutors(appId.get, requestedTotal)) } else { logWarning("Attempted to request executors before driver fully initialized.") - false + Future.successful(false) } } @@ -314,12 +304,12 @@ private[spark] class StandaloneAppClient( * Kill the given list of executors through the Master. * @return whether the kill request is acknowledged. */ - def killExecutors(executorIds: Seq[String]): Boolean = { + def killExecutors(executorIds: Seq[String]): Future[Boolean] = { if (endpoint.get != null && appId.get != null) { - endpoint.get.askWithRetry[Boolean](KillExecutors(appId.get, executorIds)) + endpoint.get.ask[Boolean](KillExecutors(appId.get, executorIds)) } else { logWarning("Attempted to kill executors before driver fully initialized.") - false + Future.successful(false) } } diff --git a/core/src/main/scala/org/apache/spark/deploy/client/StandaloneAppClientListener.scala b/core/src/main/scala/org/apache/spark/deploy/client/StandaloneAppClientListener.scala index 370b16ce4213..64255ec92b72 100644 --- a/core/src/main/scala/org/apache/spark/deploy/client/StandaloneAppClientListener.scala +++ b/core/src/main/scala/org/apache/spark/deploy/client/StandaloneAppClientListener.scala @@ -36,5 +36,6 @@ private[spark] trait StandaloneAppClientListener { def executorAdded( fullId: String, workerId: String, hostPort: String, cores: Int, memory: Int): Unit - def executorRemoved(fullId: String, message: String, exitStatus: Option[Int]): Unit + def executorRemoved( + fullId: String, message: String, exitStatus: Option[Int], workerLost: Boolean): Unit } diff --git a/core/src/main/scala/org/apache/spark/deploy/history/HistoryPage.scala b/core/src/main/scala/org/apache/spark/deploy/history/HistoryPage.scala index 2fad1120cdc8..a120b6c5fcdf 100644 --- a/core/src/main/scala/org/apache/spark/deploy/history/HistoryPage.scala +++ b/core/src/main/scala/org/apache/spark/deploy/history/HistoryPage.scala @@ -44,7 +44,8 @@ private[history] class HistoryPage(parent: HistoryServer) extends WebUIPage("") if (allAppsSize > 0) { ++ ++ - + ++ + } else if (requestedIncomplete) {

No incomplete applications found!

} else { diff --git a/core/src/main/scala/org/apache/spark/deploy/history/HistoryServer.scala b/core/src/main/scala/org/apache/spark/deploy/history/HistoryServer.scala index d821474bdb59..c178917d8da3 100644 --- a/core/src/main/scala/org/apache/spark/deploy/history/HistoryServer.scala +++ b/core/src/main/scala/org/apache/spark/deploy/history/HistoryServer.scala @@ -28,6 +28,7 @@ import org.eclipse.jetty.servlet.{ServletContextHandler, ServletHolder} import org.apache.spark.{SecurityManager, SparkConf} import org.apache.spark.deploy.SparkHadoopUtil import org.apache.spark.internal.Logging +import org.apache.spark.internal.config._ import org.apache.spark.status.api.v1.{ApiRootResource, ApplicationInfo, ApplicationsListResource, UIRoot} import org.apache.spark.ui.{SparkUI, UIUtils, WebUI} import org.apache.spark.ui.JettyUtils._ @@ -55,6 +56,9 @@ class HistoryServer( // How many applications to retain private val retainedApplications = conf.getInt("spark.history.retainedApplications", 50) + // How many applications the summary ui displays + private[history] val maxApplications = conf.get(HISTORY_UI_MAX_APPS); + // application private val appCache = new ApplicationCache(this, retainedApplications, new SystemClock()) diff --git a/core/src/main/scala/org/apache/spark/deploy/master/ApplicationState.scala b/core/src/main/scala/org/apache/spark/deploy/master/ApplicationState.scala index 37bfcdfdf477..097728c82157 100644 --- a/core/src/main/scala/org/apache/spark/deploy/master/ApplicationState.scala +++ b/core/src/main/scala/org/apache/spark/deploy/master/ApplicationState.scala @@ -22,6 +22,4 @@ private[master] object ApplicationState extends Enumeration { type ApplicationState = Value val WAITING, RUNNING, FINISHED, FAILED, KILLED, UNKNOWN = Value - - val MAX_NUM_RETRY = 10 } diff --git a/core/src/main/scala/org/apache/spark/deploy/master/Master.scala b/core/src/main/scala/org/apache/spark/deploy/master/Master.scala index f8aac3008cef..dcf41638e799 100644 --- a/core/src/main/scala/org/apache/spark/deploy/master/Master.scala +++ b/core/src/main/scala/org/apache/spark/deploy/master/Master.scala @@ -58,6 +58,7 @@ private[deploy] class Master( private val RETAINED_DRIVERS = conf.getInt("spark.deploy.retainedDrivers", 200) private val REAPER_ITERATIONS = conf.getInt("spark.dead.worker.persistence", 15) private val RECOVERY_MODE = conf.get("spark.deploy.recoveryMode", "NONE") + private val MAX_EXECUTOR_RETRIES = conf.getInt("spark.deploy.maxExecutorRetries", 10) val workers = new HashSet[WorkerInfo] val idToApp = new HashMap[String, ApplicationInfo] @@ -251,7 +252,7 @@ private[deploy] class Master( appInfo.resetRetryCount() } - exec.application.driver.send(ExecutorUpdated(execId, state, message, exitStatus)) + exec.application.driver.send(ExecutorUpdated(execId, state, message, exitStatus, false)) if (ExecutorState.isFinished(state)) { // Remove this executor from the worker and app @@ -265,19 +266,20 @@ private[deploy] class Master( val normalExit = exitStatus == Some(0) // Only retry certain number of times so we don't go into an infinite loop. - if (!normalExit) { - if (appInfo.incrementRetryCount() < ApplicationState.MAX_NUM_RETRY) { - schedule() - } else { - val execs = appInfo.executors.values - if (!execs.exists(_.state == ExecutorState.RUNNING)) { - logError(s"Application ${appInfo.desc.name} with ID ${appInfo.id} failed " + - s"${appInfo.retryCount} times; removing it") - removeApplication(appInfo, ApplicationState.FAILED) - } + // Important note: this code path is not exercised by tests, so be very careful when + // changing this `if` condition. + if (!normalExit + && appInfo.incrementRetryCount() >= MAX_EXECUTOR_RETRIES + && MAX_EXECUTOR_RETRIES >= 0) { // < 0 disables this application-killing path + val execs = appInfo.executors.values + if (!execs.exists(_.state == ExecutorState.RUNNING)) { + logError(s"Application ${appInfo.desc.name} with ID ${appInfo.id} failed " + + s"${appInfo.retryCount} times; removing it") + removeApplication(appInfo, ApplicationState.FAILED) } } } + schedule() case None => logWarning(s"Got status update for unknown executor $appId/$execId") } @@ -764,7 +766,7 @@ private[deploy] class Master( for (exec <- worker.executors.values) { logInfo("Telling app of lost executor: " + exec.id) exec.application.driver.send(ExecutorUpdated( - exec.id, ExecutorState.LOST, Some("worker lost"), None)) + exec.id, ExecutorState.LOST, Some("worker lost"), None, workerLost = true)) exec.state = ExecutorState.LOST exec.application.removeExecutor(exec) } diff --git a/core/src/main/scala/org/apache/spark/deploy/master/ui/ApplicationPage.scala b/core/src/main/scala/org/apache/spark/deploy/master/ui/ApplicationPage.scala index 8875fc223250..18c5d0bd0194 100644 --- a/core/src/main/scala/org/apache/spark/deploy/master/ui/ApplicationPage.scala +++ b/core/src/main/scala/org/apache/spark/deploy/master/ui/ApplicationPage.scala @@ -24,7 +24,7 @@ import scala.xml.Node import org.apache.spark.deploy.DeployMessages.{MasterStateResponse, RequestMasterState} import org.apache.spark.deploy.ExecutorState import org.apache.spark.deploy.master.ExecutorDesc -import org.apache.spark.ui.{UIUtils, WebUIPage} +import org.apache.spark.ui.{ToolTips, UIUtils, WebUIPage} import org.apache.spark.util.Utils private[ui] class ApplicationPage(parent: MasterWebUI) extends WebUIPage("app") { @@ -69,6 +69,16 @@ private[ui] class ApplicationPage(parent: MasterWebUI) extends WebUIPage("app") } } +
  • + + Executor Limit: + { + if (app.executorLimit == Int.MaxValue) "Unlimited" else app.executorLimit + } + ({app.executors.size} granted) + +
  • Executor Memory: {Utils.megabytesToString(app.desc.memoryPerExecutorMB)} diff --git a/core/src/main/scala/org/apache/spark/executor/CoarseGrainedExecutorBackend.scala b/core/src/main/scala/org/apache/spark/executor/CoarseGrainedExecutorBackend.scala index e30839c49c04..1175fa347ea3 100644 --- a/core/src/main/scala/org/apache/spark/executor/CoarseGrainedExecutorBackend.scala +++ b/core/src/main/scala/org/apache/spark/executor/CoarseGrainedExecutorBackend.scala @@ -69,6 +69,9 @@ private[spark] class CoarseGrainedExecutorBackend( }(ThreadUtils.sameThread) } + protected def registerExecutor: Executor = + new Executor(executorId, hostname, env, userClassPath, isLocal = false) + def extractLogUrls: Map[String, String] = { val prefix = "SPARK_LOG_URL_" sys.env.filterKeys(_.startsWith(prefix)) @@ -79,7 +82,7 @@ private[spark] class CoarseGrainedExecutorBackend( case RegisteredExecutor => logInfo("Successfully registered with driver") try { - executor = new Executor(executorId, hostname, env, userClassPath, isLocal = false) + executor = registerExecutor } catch { case NonFatal(e) => exitExecutor(1, "Unable to create executor due to " + e.getMessage, e) diff --git a/core/src/main/scala/org/apache/spark/executor/Executor.scala b/core/src/main/scala/org/apache/spark/executor/Executor.scala index 9a017f29f7d2..1e3c6505d535 100644 --- a/core/src/main/scala/org/apache/spark/executor/Executor.scala +++ b/core/src/main/scala/org/apache/spark/executor/Executor.scala @@ -34,7 +34,7 @@ import org.apache.spark.deploy.SparkHadoopUtil import org.apache.spark.internal.Logging import org.apache.spark.memory.TaskMemoryManager import org.apache.spark.rpc.RpcTimeout -import org.apache.spark.scheduler.{AccumulableInfo, DirectTaskResult, IndirectTaskResult, Task} +import org.apache.spark.scheduler.{DirectTaskResult, IndirectTaskResult, Task} import org.apache.spark.shuffle.FetchFailedException import org.apache.spark.storage.{StorageLevel, TaskResultBlockId} import org.apache.spark.util._ @@ -60,7 +60,7 @@ private[spark] class Executor( // Application dependencies (added through SparkContext) that we've fetched so far on this node. // Each map holds the master's timestamp for the version of that file or JAR we got. private val currentFiles: HashMap[String, Long] = new HashMap[String, Long]() - private val currentJars: HashMap[String, Long] = new HashMap[String, Long]() + protected val currentJars: HashMap[String, Long] = new HashMap[String, Long]() private val EMPTY_BYTE_BUFFER = ByteBuffer.wrap(new Array[Byte](0)) @@ -402,7 +402,12 @@ private[spark] class Executor( // Don't forcibly exit unless the exception was inherently fatal, to avoid // stopping other tasks unnecessarily. if (Utils.isFatalError(t)) { - SparkUncaughtExceptionHandler.uncaughtException(t) + if (!isLocal) { + Thread.getDefaultUncaughtExceptionHandler. + uncaughtException(Thread.currentThread(), t) + } else { + SparkUncaughtExceptionHandler.uncaughtException(t) + } } } finally { @@ -415,7 +420,7 @@ private[spark] class Executor( * Create a ClassLoader for use in tasks, adding any JARs specified by the user or any classes * created by the interpreter to the search path */ - private def createClassLoader(): MutableURLClassLoader = { + protected def createClassLoader(): MutableURLClassLoader = { // Bootstrap the list of jars with the user class path. val now = System.currentTimeMillis() userClassPath.foreach { url => diff --git a/core/src/main/scala/org/apache/spark/executor/TaskMetrics.scala b/core/src/main/scala/org/apache/spark/executor/TaskMetrics.scala index 5bb505bf09f1..52a349919e33 100644 --- a/core/src/main/scala/org/apache/spark/executor/TaskMetrics.scala +++ b/core/src/main/scala/org/apache/spark/executor/TaskMetrics.scala @@ -17,6 +17,9 @@ package org.apache.spark.executor +import java.util.{ArrayList, Collections} + +import scala.collection.JavaConverters._ import scala.collection.mutable.{ArrayBuffer, LinkedHashMap} import org.apache.spark._ @@ -99,7 +102,11 @@ class TaskMetrics private[spark] () extends Serializable { /** * Storage statuses of any blocks that have been updated as a result of this task. */ - def updatedBlockStatuses: Seq[(BlockId, BlockStatus)] = _updatedBlockStatuses.value + def updatedBlockStatuses: Seq[(BlockId, BlockStatus)] = { + // This is called on driver. All accumulator updates have a fixed value. So it's safe to use + // `asScala` which accesses the internal values using `java.util.Iterator`. + _updatedBlockStatuses.value.asScala + } // Setters and increment-ers private[spark] def setExecutorDeserializeTime(v: Long): Unit = @@ -114,8 +121,10 @@ class TaskMetrics private[spark] () extends Serializable { private[spark] def incPeakExecutionMemory(v: Long): Unit = _peakExecutionMemory.add(v) private[spark] def incUpdatedBlockStatuses(v: (BlockId, BlockStatus)): Unit = _updatedBlockStatuses.add(v) - private[spark] def setUpdatedBlockStatuses(v: Seq[(BlockId, BlockStatus)]): Unit = + private[spark] def setUpdatedBlockStatuses(v: java.util.List[(BlockId, BlockStatus)]): Unit = _updatedBlockStatuses.setValue(v) + private[spark] def setUpdatedBlockStatuses(v: Seq[(BlockId, BlockStatus)]): Unit = + _updatedBlockStatuses.setValue(v.asJava) /** * Metrics related to reading data from a [[org.apache.spark.rdd.HadoopRDD]] or from persisted @@ -225,6 +234,15 @@ class TaskMetrics private[spark] () extends Serializable { } private[spark] def accumulators(): Seq[AccumulatorV2[_, _]] = internalAccums ++ externalAccums + + /** + * Looks for a registered accumulator by accumulator name. + */ + private[spark] def lookForAccumulatorByName(name: String): Option[AccumulatorV2[_, _]] = { + accumulators.find { acc => + acc.name.isDefined && acc.name.get == name + } + } } @@ -259,7 +277,7 @@ private[spark] object TaskMetrics extends Logging { val name = info.name.get val value = info.update.get if (name == UPDATED_BLOCK_STATUSES) { - tm.setUpdatedBlockStatuses(value.asInstanceOf[Seq[(BlockId, BlockStatus)]]) + tm.setUpdatedBlockStatuses(value.asInstanceOf[java.util.List[(BlockId, BlockStatus)]]) } else { tm.nameToAccums.get(name).foreach( _.asInstanceOf[LongAccumulator].setValue(value.asInstanceOf[Long]) @@ -290,8 +308,8 @@ private[spark] object TaskMetrics extends Logging { private[spark] class BlockStatusesAccumulator - extends AccumulatorV2[(BlockId, BlockStatus), Seq[(BlockId, BlockStatus)]] { - private var _seq = ArrayBuffer.empty[(BlockId, BlockStatus)] + extends AccumulatorV2[(BlockId, BlockStatus), java.util.List[(BlockId, BlockStatus)]] { + private val _seq = Collections.synchronizedList(new ArrayList[(BlockId, BlockStatus)]()) override def isZero(): Boolean = _seq.isEmpty @@ -299,25 +317,27 @@ private[spark] class BlockStatusesAccumulator override def copy(): BlockStatusesAccumulator = { val newAcc = new BlockStatusesAccumulator - newAcc._seq = _seq.clone() + newAcc._seq.addAll(_seq) newAcc } override def reset(): Unit = _seq.clear() - override def add(v: (BlockId, BlockStatus)): Unit = _seq += v + override def add(v: (BlockId, BlockStatus)): Unit = _seq.add(v) - override def merge(other: AccumulatorV2[(BlockId, BlockStatus), Seq[(BlockId, BlockStatus)]]) - : Unit = other match { - case o: BlockStatusesAccumulator => _seq ++= o.value - case _ => throw new UnsupportedOperationException( - s"Cannot merge ${this.getClass.getName} with ${other.getClass.getName}") + override def merge( + other: AccumulatorV2[(BlockId, BlockStatus), java.util.List[(BlockId, BlockStatus)]]): Unit = { + other match { + case o: BlockStatusesAccumulator => _seq.addAll(o.value) + case _ => throw new UnsupportedOperationException( + s"Cannot merge ${this.getClass.getName} with ${other.getClass.getName}") + } } - override def value: Seq[(BlockId, BlockStatus)] = _seq + override def value: java.util.List[(BlockId, BlockStatus)] = _seq - def setValue(newValue: Seq[(BlockId, BlockStatus)]): Unit = { + def setValue(newValue: java.util.List[(BlockId, BlockStatus)]): Unit = { _seq.clear() - _seq ++= newValue + _seq.addAll(newValue) } } diff --git a/core/src/main/scala/org/apache/spark/internal/config/package.scala b/core/src/main/scala/org/apache/spark/internal/config/package.scala index 05dd68300f89..29f812a2ce11 100644 --- a/core/src/main/scala/org/apache/spark/internal/config/package.scala +++ b/core/src/main/scala/org/apache/spark/internal/config/package.scala @@ -103,4 +103,18 @@ package object config { .stringConf .checkValues(Set("hive", "in-memory")) .createWithDefault("in-memory") + + // To limit memory usage, we only track information for a fixed number of tasks + private[spark] val UI_RETAINED_TASKS = ConfigBuilder("spark.ui.retainedTasks") + .intConf + .createWithDefault(100000) + + // To limit how many applications are shown in the History Server summary ui + private[spark] val HISTORY_UI_MAX_APPS = + ConfigBuilder("spark.history.ui.maxApplications").intConf.createWithDefault(Integer.MAX_VALUE) + + private[spark] val LISTENER_BUS_EVENT_QUEUE_SIZE = + ConfigBuilder("spark.scheduler.listenerbus.eventqueue.size") + .intConf + .createWithDefault(10000) } diff --git a/core/src/main/scala/org/apache/spark/rdd/BlockRDD.scala b/core/src/main/scala/org/apache/spark/rdd/BlockRDD.scala index 63d1d1767a8c..d47b75544fdb 100644 --- a/core/src/main/scala/org/apache/spark/rdd/BlockRDD.scala +++ b/core/src/main/scala/org/apache/spark/rdd/BlockRDD.scala @@ -44,7 +44,7 @@ class BlockRDD[T: ClassTag](sc: SparkContext, @transient val blockIds: Array[Blo assertValid() val blockManager = SparkEnv.get.blockManager val blockId = split.asInstanceOf[BlockRDDPartition].blockId - blockManager.get(blockId) match { + blockManager.get[T](blockId) match { case Some(block) => block.data.asInstanceOf[Iterator[T]] case None => throw new Exception("Could not compute split, block " + blockId + " not found") diff --git a/core/src/main/scala/org/apache/spark/rdd/HadoopRDD.scala b/core/src/main/scala/org/apache/spark/rdd/HadoopRDD.scala index 515fd6f4e278..297d95a73101 100644 --- a/core/src/main/scala/org/apache/spark/rdd/HadoopRDD.scala +++ b/core/src/main/scala/org/apache/spark/rdd/HadoopRDD.scala @@ -22,7 +22,6 @@ import java.text.SimpleDateFormat import java.util.Date import scala.collection.immutable.Map -import scala.collection.mutable.ListBuffer import scala.reflect.ClassTag import org.apache.hadoop.conf.{Configurable, Configuration} @@ -317,7 +316,7 @@ class HadoopRDD[K, V]( try { val lsplit = c.inputSplitWithLocationInfo.cast(hsplit) val infos = c.getLocationInfo.invoke(lsplit).asInstanceOf[Array[AnyRef]] - Some(HadoopRDD.convertSplitLocationInfo(infos)) + HadoopRDD.convertSplitLocationInfo(infos) } catch { case e: Exception => logDebug("Failed to use InputSplitWithLocations.", e) @@ -419,21 +418,20 @@ private[spark] object HadoopRDD extends Logging { None } - private[spark] def convertSplitLocationInfo(infos: Array[AnyRef]): Seq[String] = { - val out = ListBuffer[String]() - infos.foreach { loc => - val locationStr = HadoopRDD.SPLIT_INFO_REFLECTIONS.get. - getLocation.invoke(loc).asInstanceOf[String] + private[spark] def convertSplitLocationInfo(infos: Array[AnyRef]): Option[Seq[String]] = { + Option(infos).map(_.flatMap { loc => + val reflections = HadoopRDD.SPLIT_INFO_REFLECTIONS.get + val locationStr = reflections.getLocation.invoke(loc).asInstanceOf[String] if (locationStr != "localhost") { - if (HadoopRDD.SPLIT_INFO_REFLECTIONS.get.isInMemory. - invoke(loc).asInstanceOf[Boolean]) { - logDebug("Partition " + locationStr + " is cached by Hadoop.") - out += new HDFSCacheTaskLocation(locationStr).toString + if (reflections.isInMemory.invoke(loc).asInstanceOf[Boolean]) { + logDebug(s"Partition $locationStr is cached by Hadoop.") + Some(HDFSCacheTaskLocation(locationStr).toString) } else { - out += new HostTaskLocation(locationStr).toString + Some(HostTaskLocation(locationStr).toString) } + } else { + None } - } - out.seq + }) } } diff --git a/core/src/main/scala/org/apache/spark/rdd/NewHadoopRDD.scala b/core/src/main/scala/org/apache/spark/rdd/NewHadoopRDD.scala index b086baa08408..2d9d69dfb8fb 100644 --- a/core/src/main/scala/org/apache/spark/rdd/NewHadoopRDD.scala +++ b/core/src/main/scala/org/apache/spark/rdd/NewHadoopRDD.scala @@ -255,7 +255,7 @@ class NewHadoopRDD[K, V]( case Some(c) => try { val infos = c.newGetLocationInfo.invoke(split).asInstanceOf[Array[AnyRef]] - Some(HadoopRDD.convertSplitLocationInfo(infos)) + HadoopRDD.convertSplitLocationInfo(infos) } catch { case e : Exception => logDebug("Failed to use InputSplit#getLocationInfo.", e) diff --git a/core/src/main/scala/org/apache/spark/rdd/RDD.scala b/core/src/main/scala/org/apache/spark/rdd/RDD.scala index a4905dd51b94..34d32aacfb62 100644 --- a/core/src/main/scala/org/apache/spark/rdd/RDD.scala +++ b/core/src/main/scala/org/apache/spark/rdd/RDD.scala @@ -70,7 +70,7 @@ import org.apache.spark.util.random.{BernoulliCellSampler, BernoulliSampler, Poi * All of the scheduling and execution in Spark is done based on these methods, allowing each RDD * to implement its own way of computing itself. Indeed, users can implement custom RDDs (e.g. for * reading data from a new storage system) by overriding these functions. Please refer to the - * [[http://www.cs.berkeley.edu/~matei/papers/2012/nsdi_spark.pdf Spark paper]] for more details + * [[http://people.csail.mit.edu/matei/papers/2012/nsdi_spark.pdf Spark paper]] for more details * on RDD internals. */ abstract class RDD[T: ClassTag]( @@ -474,12 +474,17 @@ abstract class RDD[T: ClassTag]( def sample( withReplacement: Boolean, fraction: Double, - seed: Long = Utils.random.nextLong): RDD[T] = withScope { - require(fraction >= 0.0, "Negative fraction value: " + fraction) - if (withReplacement) { - new PartitionwiseSampledRDD[T, T](this, new PoissonSampler[T](fraction), true, seed) - } else { - new PartitionwiseSampledRDD[T, T](this, new BernoulliSampler[T](fraction), true, seed) + seed: Long = Utils.random.nextLong): RDD[T] = { + require(fraction >= 0, + s"Fraction must be nonnegative, but got ${fraction}") + + withScope { + require(fraction >= 0.0, "Negative fraction value: " + fraction) + if (withReplacement) { + new PartitionwiseSampledRDD[T, T](this, new PoissonSampler[T](fraction), true, seed) + } else { + new PartitionwiseSampledRDD[T, T](this, new BernoulliSampler[T](fraction), true, seed) + } } } @@ -493,14 +498,22 @@ abstract class RDD[T: ClassTag]( */ def randomSplit( weights: Array[Double], - seed: Long = Utils.random.nextLong): Array[RDD[T]] = withScope { - val sum = weights.sum - val normalizedCumWeights = weights.map(_ / sum).scanLeft(0.0d)(_ + _) - normalizedCumWeights.sliding(2).map { x => - randomSampleWithRange(x(0), x(1), seed) - }.toArray + seed: Long = Utils.random.nextLong): Array[RDD[T]] = { + require(weights.forall(_ >= 0), + s"Weights must be nonnegative, but got ${weights.mkString("[", ",", "]")}") + require(weights.sum > 0, + s"Sum of weights must be positive, but got ${weights.mkString("[", ",", "]")}") + + withScope { + val sum = weights.sum + val normalizedCumWeights = weights.map(_ / sum).scanLeft(0.0d)(_ + _) + normalizedCumWeights.sliding(2).map { x => + randomSampleWithRange(x(0), x(1), seed) + }.toArray + } } + /** * Internal method exposed for Random Splits in DataFrames. Samples an RDD given a probability * range. diff --git a/core/src/main/scala/org/apache/spark/rdd/UnionRDD.scala b/core/src/main/scala/org/apache/spark/rdd/UnionRDD.scala index 8171dcc04637..ad1fddbde7b0 100644 --- a/core/src/main/scala/org/apache/spark/rdd/UnionRDD.scala +++ b/core/src/main/scala/org/apache/spark/rdd/UnionRDD.scala @@ -20,7 +20,7 @@ package org.apache.spark.rdd import java.io.{IOException, ObjectOutputStream} import scala.collection.mutable.ArrayBuffer -import scala.collection.parallel.ForkJoinTaskSupport +import scala.collection.parallel.{ForkJoinTaskSupport, ThreadPoolTaskSupport} import scala.concurrent.forkjoin.ForkJoinPool import scala.reflect.ClassTag @@ -58,6 +58,11 @@ private[spark] class UnionPartition[T: ClassTag]( } } +object UnionRDD { + private[spark] lazy val partitionEvalTaskSupport = + new ForkJoinTaskSupport(new ForkJoinPool(8)) +} + @DeveloperApi class UnionRDD[T: ClassTag]( sc: SparkContext, @@ -68,13 +73,10 @@ class UnionRDD[T: ClassTag]( private[spark] val isPartitionListingParallel: Boolean = rdds.length > conf.getInt("spark.rdd.parallelListingThreshold", 10) - @transient private lazy val partitionEvalTaskSupport = - new ForkJoinTaskSupport(new ForkJoinPool(8)) - override def getPartitions: Array[Partition] = { val parRDDs = if (isPartitionListingParallel) { val parArray = rdds.par - parArray.tasksupport = partitionEvalTaskSupport + parArray.tasksupport = UnionRDD.partitionEvalTaskSupport parArray } else { rdds diff --git a/core/src/main/scala/org/apache/spark/rpc/netty/Dispatcher.scala b/core/src/main/scala/org/apache/spark/rpc/netty/Dispatcher.scala index d305de2e1340..a02cf30a5d83 100644 --- a/core/src/main/scala/org/apache/spark/rpc/netty/Dispatcher.scala +++ b/core/src/main/scala/org/apache/spark/rpc/netty/Dispatcher.scala @@ -17,7 +17,7 @@ package org.apache.spark.rpc.netty -import java.util.concurrent.{ConcurrentHashMap, LinkedBlockingQueue, ThreadPoolExecutor, TimeUnit} +import java.util.concurrent.{ConcurrentHashMap, ConcurrentMap, LinkedBlockingQueue, ThreadPoolExecutor, TimeUnit} import javax.annotation.concurrent.GuardedBy import scala.collection.JavaConverters._ @@ -42,8 +42,10 @@ private[netty] class Dispatcher(nettyEnv: NettyRpcEnv) extends Logging { val inbox = new Inbox(ref, endpoint) } - private val endpoints = new ConcurrentHashMap[String, EndpointData] - private val endpointRefs = new ConcurrentHashMap[RpcEndpoint, RpcEndpointRef] + private val endpoints: ConcurrentMap[String, EndpointData] = + new ConcurrentHashMap[String, EndpointData] + private val endpointRefs: ConcurrentMap[RpcEndpoint, RpcEndpointRef] = + new ConcurrentHashMap[RpcEndpoint, RpcEndpointRef] // Track the receivers whose inboxes may contain messages. private val receivers = new LinkedBlockingQueue[EndpointData] diff --git a/core/src/main/scala/org/apache/spark/rpc/netty/NettyStreamManager.scala b/core/src/main/scala/org/apache/spark/rpc/netty/NettyStreamManager.scala index afcb023a99da..780fadd5bda8 100644 --- a/core/src/main/scala/org/apache/spark/rpc/netty/NettyStreamManager.scala +++ b/core/src/main/scala/org/apache/spark/rpc/netty/NettyStreamManager.scala @@ -66,14 +66,18 @@ private[netty] class NettyStreamManager(rpcEnv: NettyRpcEnv) } override def addFile(file: File): String = { - require(files.putIfAbsent(file.getName(), file) == null, - s"File ${file.getName()} already registered.") + val existingPath = files.putIfAbsent(file.getName, file) + require(existingPath == null || existingPath == file, + s"File ${file.getName} was already registered with a different path " + + s"(old path = $existingPath, new path = $file") s"${rpcEnv.address.toSparkURL}/files/${Utils.encodeFileNameToURIRawPath(file.getName())}" } override def addJar(file: File): String = { - require(jars.putIfAbsent(file.getName(), file) == null, - s"JAR ${file.getName()} already registered.") + val existingPath = jars.putIfAbsent(file.getName, file) + require(existingPath == null || existingPath == file, + s"File ${file.getName} was already registered with a different path " + + s"(old path = $existingPath, new path = $file") s"${rpcEnv.address.toSparkURL}/jars/${Utils.encodeFileNameToURIRawPath(file.getName())}" } diff --git a/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala b/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala index 5291b663667e..e7e2ff1718f2 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala @@ -233,8 +233,8 @@ class DAGScheduler( /** * Called by TaskScheduler implementation when an executor fails. */ - def executorLost(execId: String): Unit = { - eventProcessLoop.post(ExecutorLost(execId)) + def executorLost(execId: String, reason: ExecutorLossReason): Unit = { + eventProcessLoop.post(ExecutorLost(execId, reason)) } /** @@ -1277,18 +1277,20 @@ class DAGScheduler( s"has failed the maximum allowable number of " + s"times: ${Stage.MAX_CONSECUTIVE_FETCH_FAILURES}. " + s"Most recent failure reason: ${failureMessage}", None) - } else if (failedStages.isEmpty) { - // Don't schedule an event to resubmit failed stages if failed isn't empty, because - // in that case the event will already have been scheduled. - // TODO: Cancel running tasks in the stage - logInfo(s"Resubmitting $mapStage (${mapStage.name}) and " + - s"$failedStage (${failedStage.name}) due to fetch failure") - messageScheduler.schedule(new Runnable { - override def run(): Unit = eventProcessLoop.post(ResubmitFailedStages) - }, DAGScheduler.RESUBMIT_TIMEOUT, TimeUnit.MILLISECONDS) + } else { + if (failedStages.isEmpty) { + // Don't schedule an event to resubmit failed stages if failed isn't empty, because + // in that case the event will already have been scheduled. + // TODO: Cancel running tasks in the stage + logInfo(s"Resubmitting $mapStage (${mapStage.name}) and " + + s"$failedStage (${failedStage.name}) due to fetch failure") + messageScheduler.schedule(new Runnable { + override def run(): Unit = eventProcessLoop.post(ResubmitFailedStages) + }, DAGScheduler.RESUBMIT_TIMEOUT, TimeUnit.MILLISECONDS) + } + failedStages += failedStage + failedStages += mapStage } - failedStages += failedStage - failedStages += mapStage // Mark the map whose fetch failed as broken in the map stage if (mapId != -1) { mapStage.removeOutputLoc(mapId, bmAddress) @@ -1297,7 +1299,7 @@ class DAGScheduler( // TODO: mark the executor as failed only if there were lots of fetch failures on it if (bmAddress != null) { - handleExecutorLost(bmAddress.executorId, fetchFailed = true, Some(task.epoch)) + handleExecutorLost(bmAddress.executorId, filesLost = true, Some(task.epoch)) } } @@ -1323,15 +1325,16 @@ class DAGScheduler( * modify the scheduler's internal state. Use executorLost() to post a loss event from outside. * * We will also assume that we've lost all shuffle blocks associated with the executor if the - * executor serves its own blocks (i.e., we're not using external shuffle) OR a FetchFailed - * occurred, in which case we presume all shuffle data related to this executor to be lost. + * executor serves its own blocks (i.e., we're not using external shuffle), the entire slave + * is lost (likely including the shuffle service), or a FetchFailed occurred, in which case we + * presume all shuffle data related to this executor to be lost. * * Optionally the epoch during which the failure was caught can be passed to avoid allowing * stray fetch failures from possibly retriggering the detection of a node as lost. */ private[scheduler] def handleExecutorLost( execId: String, - fetchFailed: Boolean, + filesLost: Boolean, maybeEpoch: Option[Long] = None) { val currentEpoch = maybeEpoch.getOrElse(mapOutputTracker.getEpoch) if (!failedEpoch.contains(execId) || failedEpoch(execId) < currentEpoch) { @@ -1339,7 +1342,8 @@ class DAGScheduler( logInfo("Executor lost: %s (epoch %d)".format(execId, currentEpoch)) blockManagerMaster.removeExecutor(execId) - if (!env.blockManager.externalShuffleServiceEnabled || fetchFailed) { + if (filesLost || !env.blockManager.externalShuffleServiceEnabled) { + logInfo("Shuffle files lost for executor: %s (epoch %d)".format(execId, currentEpoch)) // TODO: This will be really slow if we keep accumulating shuffle map stages for ((shuffleId, stage) <- shuffleToMapStage) { stage.removeOutputsOnExecutor(execId) @@ -1643,8 +1647,12 @@ private[scheduler] class DAGSchedulerEventProcessLoop(dagScheduler: DAGScheduler case ExecutorAdded(execId, host) => dagScheduler.handleExecutorAdded(execId, host) - case ExecutorLost(execId) => - dagScheduler.handleExecutorLost(execId, fetchFailed = false) + case ExecutorLost(execId, reason) => + val filesLost = reason match { + case SlaveLost(_, true) => true + case _ => false + } + dagScheduler.handleExecutorLost(execId, filesLost) case BeginEvent(task, taskInfo) => dagScheduler.handleBeginEvent(task, taskInfo) diff --git a/core/src/main/scala/org/apache/spark/scheduler/DAGSchedulerEvent.scala b/core/src/main/scala/org/apache/spark/scheduler/DAGSchedulerEvent.scala index 8c761124824a..03781a2a2b56 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/DAGSchedulerEvent.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/DAGSchedulerEvent.scala @@ -77,7 +77,8 @@ private[scheduler] case class CompletionEvent( private[scheduler] case class ExecutorAdded(execId: String, host: String) extends DAGSchedulerEvent -private[scheduler] case class ExecutorLost(execId: String) extends DAGSchedulerEvent +private[scheduler] case class ExecutorLost(execId: String, reason: ExecutorLossReason) + extends DAGSchedulerEvent private[scheduler] case class TaskSetFailed(taskSet: TaskSet, reason: String, exception: Option[Throwable]) diff --git a/core/src/main/scala/org/apache/spark/scheduler/ExecutorLossReason.scala b/core/src/main/scala/org/apache/spark/scheduler/ExecutorLossReason.scala index 642bf81ac087..46a35b6a2eaf 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/ExecutorLossReason.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/ExecutorLossReason.scala @@ -51,6 +51,10 @@ private[spark] object ExecutorKilled extends ExecutorLossReason("Executor killed */ private [spark] object LossReasonPending extends ExecutorLossReason("Pending loss reason.") +/** + * @param _message human readable loss reason + * @param workerLost whether the worker is confirmed lost too (i.e. including shuffle service) + */ private[spark] -case class SlaveLost(_message: String = "Slave lost") +case class SlaveLost(_message: String = "Slave lost", workerLost: Boolean = false) extends ExecutorLossReason(_message) diff --git a/core/src/main/scala/org/apache/spark/scheduler/LiveListenerBus.scala b/core/src/main/scala/org/apache/spark/scheduler/LiveListenerBus.scala index 1c21313d1cb1..5533f7b1f236 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/LiveListenerBus.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/LiveListenerBus.scala @@ -18,11 +18,12 @@ package org.apache.spark.scheduler import java.util.concurrent._ -import java.util.concurrent.atomic.AtomicBoolean +import java.util.concurrent.atomic.{AtomicBoolean, AtomicLong} import scala.util.DynamicVariable -import org.apache.spark.SparkContext +import org.apache.spark.{SparkContext, SparkException} +import org.apache.spark.internal.config._ import org.apache.spark.util.Utils /** @@ -32,24 +33,36 @@ import org.apache.spark.util.Utils * has started will events be actually propagated to all attached listeners. This listener bus * is stopped when `stop()` is called, and it will drop further events after stopping. */ -private[spark] class LiveListenerBus extends SparkListenerBus { +private[spark] class LiveListenerBus(val sparkContext: SparkContext) extends SparkListenerBus { self => import LiveListenerBus._ - private var sparkContext: SparkContext = null - // Cap the capacity of the event queue so we get an explicit error (rather than // an OOM exception) if it's perpetually being added to more quickly than it's being drained. - private val EVENT_QUEUE_CAPACITY = 10000 - private val eventQueue = new LinkedBlockingQueue[SparkListenerEvent](EVENT_QUEUE_CAPACITY) + private lazy val EVENT_QUEUE_CAPACITY = validateAndGetQueueSize() + private lazy val eventQueue = new LinkedBlockingQueue[SparkListenerEvent](EVENT_QUEUE_CAPACITY) + + private def validateAndGetQueueSize(): Int = { + val queueSize = sparkContext.conf.get(LISTENER_BUS_EVENT_QUEUE_SIZE) + if (queueSize <= 0) { + throw new SparkException("spark.scheduler.listenerbus.eventqueue.size must be > 0!") + } + queueSize + } // Indicate if `start()` is called private val started = new AtomicBoolean(false) // Indicate if `stop()` is called private val stopped = new AtomicBoolean(false) + /** A counter for dropped events. It will be reset every time we log it. */ + private val droppedEventsCounter = new AtomicLong(0L) + + /** When `droppedEventsCounter` was logged last time in milliseconds. */ + @volatile private var lastReportTimestamp = 0L + // Indicate if we are processing some event // Guarded by `self` private var processingEvent = false @@ -96,11 +109,9 @@ private[spark] class LiveListenerBus extends SparkListenerBus { * listens for any additional events asynchronously while the listener bus is still running. * This should only be called once. * - * @param sc Used to stop the SparkContext in case the listener thread dies. */ - def start(sc: SparkContext): Unit = { + def start(): Unit = { if (started.compareAndSet(false, true)) { - sparkContext = sc listenerThread.start() } else { throw new IllegalStateException(s"$name already started!") @@ -118,6 +129,24 @@ private[spark] class LiveListenerBus extends SparkListenerBus { eventLock.release() } else { onDropEvent(event) + droppedEventsCounter.incrementAndGet() + } + + val droppedEvents = droppedEventsCounter.get + if (droppedEvents > 0) { + // Don't log too frequently + if (System.currentTimeMillis() - lastReportTimestamp >= 60 * 1000) { + // There may be multiple threads trying to decrease droppedEventsCounter. + // Use "compareAndSet" to make sure only one thread can win. + // And if another thread is increasing droppedEventsCounter, "compareAndSet" will fail and + // then that thread will update it. + if (droppedEventsCounter.compareAndSet(droppedEvents, 0)) { + val prevLastReportTimestamp = lastReportTimestamp + lastReportTimestamp = System.currentTimeMillis() + logWarning(s"Dropped $droppedEvents SparkListenerEvents since " + + new java.util.Date(prevLastReportTimestamp)) + } + } } } diff --git a/core/src/main/scala/org/apache/spark/scheduler/Task.scala b/core/src/main/scala/org/apache/spark/scheduler/Task.scala index 15f863b66c6e..1ed36bf0692f 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/Task.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/Task.scala @@ -21,6 +21,7 @@ import java.io.{DataInputStream, DataOutputStream} import java.nio.ByteBuffer import java.util.Properties +import scala.collection.mutable import scala.collection.mutable.HashMap import org.apache.spark._ @@ -198,8 +199,8 @@ private[spark] object Task { */ def serializeWithDependencies( task: Task[_], - currentFiles: HashMap[String, Long], - currentJars: HashMap[String, Long], + currentFiles: mutable.Map[String, Long], + currentJars: mutable.Map[String, Long], serializer: SerializerInstance) : ByteBuffer = { @@ -229,6 +230,7 @@ private[spark] object Task { dataOut.flush() val taskBytes = serializer.serialize(task) Utils.writeByteBuffer(taskBytes, out) + out.close() out.toByteBuffer } diff --git a/core/src/main/scala/org/apache/spark/scheduler/TaskSchedulerImpl.scala b/core/src/main/scala/org/apache/spark/scheduler/TaskSchedulerImpl.scala index 7dd4f6e9d2d9..d22321b88fb8 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/TaskSchedulerImpl.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/TaskSchedulerImpl.scala @@ -332,6 +332,7 @@ private[spark] class TaskSchedulerImpl( def statusUpdate(tid: Long, state: TaskState, serializedData: ByteBuffer) { var failedExecutor: Option[String] = None + var reason: Option[ExecutorLossReason] = None synchronized { try { if (state == TaskState.LOST && taskIdToExecutorId.contains(tid)) { @@ -339,8 +340,9 @@ private[spark] class TaskSchedulerImpl( val execId = taskIdToExecutorId(tid) if (executorIdToTaskCount.contains(execId)) { - removeExecutor(execId, + reason = Some( SlaveLost(s"Task $tid was lost, so marking the executor as lost as well.")) + removeExecutor(execId, reason.get) failedExecutor = Some(execId) } } @@ -373,7 +375,8 @@ private[spark] class TaskSchedulerImpl( } // Update the DAGScheduler without holding a lock on this, since that can deadlock if (failedExecutor.isDefined) { - dagScheduler.executorLost(failedExecutor.get) + assert(reason.isDefined) + dagScheduler.executorLost(failedExecutor.get, reason.get) backend.reviveOffers() } } @@ -499,7 +502,7 @@ private[spark] class TaskSchedulerImpl( } // Call dagScheduler.executorLost without holding the lock on this to prevent deadlock if (failedExecutor.isDefined) { - dagScheduler.executorLost(failedExecutor.get) + dagScheduler.executorLost(failedExecutor.get, reason) backend.reviveOffers() } } diff --git a/core/src/main/scala/org/apache/spark/scheduler/TaskSetManager.scala b/core/src/main/scala/org/apache/spark/scheduler/TaskSetManager.scala index 08d33f688a16..24baaffbe0ce 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/TaskSetManager.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/TaskSetManager.scala @@ -14,6 +14,24 @@ * See the License for the specific language governing permissions and * limitations under the License. */ +/* + * Changes for SnappyData data platform. + * + * Portions Copyright (c) 2016 SnappyData, Inc. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you + * may not use this file except in compliance with the License. You + * may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or + * implied. See the License for the specific language governing + * permissions and limitations under the License. See accompanying + * LICENSE file. + */ package org.apache.spark.scheduler @@ -312,7 +330,9 @@ private[spark] class TaskSetManager( // Check for node-local tasks if (TaskLocality.isAllowed(locality, TaskLocality.NODE_LOCAL)) { - for (index <- speculatableTasks if canRunOnHost(index)) { + for (index <- speculatableTasks if canRunOnHost(index) && + // don't return executor-local tasks that are still alive + canRunOnExecutor(execId, index)) { val locations = tasks(index).preferredLocations.map(_.host) if (locations.contains(host)) { speculatableTasks -= index @@ -335,7 +355,9 @@ private[spark] class TaskSetManager( // Check for rack-local tasks if (TaskLocality.isAllowed(locality, TaskLocality.RACK_LOCAL)) { for (rack <- sched.getRackForHost(host)) { - for (index <- speculatableTasks if canRunOnHost(index)) { + for (index <- speculatableTasks if canRunOnHost(index) + // don't return executor-local tasks that are still alive + if canRunOnExecutor(execId, index)) { val racks = tasks(index).preferredLocations.map(_.host).flatMap(sched.getRackForHost) if (racks.contains(rack)) { speculatableTasks -= index @@ -347,7 +369,9 @@ private[spark] class TaskSetManager( // Check for non-local tasks if (TaskLocality.isAllowed(locality, TaskLocality.ANY)) { - for (index <- speculatableTasks if canRunOnHost(index)) { + for (index <- speculatableTasks if canRunOnHost(index) && + // don't return executor-local tasks that are still alive + canRunOnExecutor(execId, index)) { speculatableTasks -= index return Some((index, TaskLocality.ANY)) } @@ -357,6 +381,17 @@ private[spark] class TaskSetManager( None } + private def canRunOnExecutor(execId: String, taskId: Int): Boolean = { + val locations = tasks(taskId).preferredLocations + locations.isEmpty || locations.exists { + case e: ExecutorCacheTaskLocation => execId == e.executorId + case _ => false + } || locations.collectFirst { + case e: ExecutorCacheTaskLocation if sched.isExecutorAlive(e.executorId) + && !executorIsBlacklisted(e.executorId, taskId) => false + }.getOrElse(true) + } + /** * Dequeue a pending task for a given node and return its index and locality level. * Only search for tasks matching the given locality constraint. @@ -371,7 +406,9 @@ private[spark] class TaskSetManager( } if (TaskLocality.isAllowed(maxLocality, TaskLocality.NODE_LOCAL)) { - for (index <- dequeueTaskFromList(execId, getPendingTasksForHost(host))) { + for (index <- dequeueTaskFromList(execId, getPendingTasksForHost(host)) + // don't return executor-local tasks that are still alive + if canRunOnExecutor(execId, index)) { return Some((index, TaskLocality.NODE_LOCAL, false)) } } @@ -387,13 +424,17 @@ private[spark] class TaskSetManager( for { rack <- sched.getRackForHost(host) index <- dequeueTaskFromList(execId, getPendingTasksForRack(rack)) + // don't return executor-local tasks that are still alive + if canRunOnExecutor(execId, index) } { return Some((index, TaskLocality.RACK_LOCAL, false)) } } if (TaskLocality.isAllowed(maxLocality, TaskLocality.ANY)) { - for (index <- dequeueTaskFromList(execId, allPendingTasks)) { + for (index <- dequeueTaskFromList(execId, allPendingTasks) + // don't return executor-local tasks that are still alive + if canRunOnExecutor(execId, index)) { return Some((index, TaskLocality.ANY, false)) } } diff --git a/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedSchedulerBackend.scala b/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedSchedulerBackend.scala index 8259923ce31c..db8470a6e6ce 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedSchedulerBackend.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedSchedulerBackend.scala @@ -14,6 +14,24 @@ * See the License for the specific language governing permissions and * limitations under the License. */ +/* + * Changes for SnappyData data platform. + * + * Portions Copyright (c) 2016 SnappyData, Inc. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you + * may not use this file except in compliance with the License. You + * may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or + * implied. See the License for the specific language governing + * permissions and limitations under the License. See accompanying + * LICENSE file. + */ package org.apache.spark.scheduler.cluster @@ -22,6 +40,8 @@ import java.util.concurrent.atomic.AtomicInteger import javax.annotation.concurrent.GuardedBy import scala.collection.mutable.{ArrayBuffer, HashMap, HashSet} +import scala.concurrent.Future +import scala.concurrent.duration.Duration import org.apache.spark.{ExecutorAllocationClient, SparkEnv, SparkException, TaskState} import org.apache.spark.internal.Logging @@ -49,6 +69,7 @@ class CoarseGrainedSchedulerBackend(scheduler: TaskSchedulerImpl, val rpcEnv: Rp protected val totalRegisteredExecutors = new AtomicInteger(0) protected val conf = scheduler.sc.conf private val maxRpcMessageSize = RpcUtils.maxMessageSizeBytes(conf) + private val defaultAskTimeout = RpcUtils.askRpcTimeout(conf) // Submit tasks only after (registered resources / total expected resources) // is equal to at least this value, that is double between 0 and 1. private val _minRegisteredRatio = @@ -168,8 +189,14 @@ class CoarseGrainedSchedulerBackend(scheduler: TaskSchedulerImpl, val rpcEnv: Rp // in this block are read when requesting executors CoarseGrainedSchedulerBackend.this.synchronized { executorDataMap.put(executorId, data) - if (currentExecutorIdCounter < executorId.toInt) { - currentExecutorIdCounter = executorId.toInt + // [snappydata] skip toInt used for Yarn since snappydata's + // executorId is not an integer + try { + if (currentExecutorIdCounter < executorId.toInt) { + currentExecutorIdCounter = executorId.toInt + } + } catch { + case nfe: NumberFormatException => // ignore } if (numPendingExecutors > 0) { numPendingExecutors -= 1 @@ -272,6 +299,7 @@ class CoarseGrainedSchedulerBackend(scheduler: TaskSchedulerImpl, val rpcEnv: Rp // Remove a disconnected slave from the cluster private def removeExecutor(executorId: String, reason: ExecutorLossReason): Unit = { + logDebug(s"Asked to remove executor $executorId with reason $reason") executorDataMap.get(executorId) match { case Some(executorInfo) => // This must be synchronized because variables mutated @@ -406,14 +434,15 @@ class CoarseGrainedSchedulerBackend(scheduler: TaskSchedulerImpl, val rpcEnv: Rp conf.getInt("spark.default.parallelism", math.max(totalCoreCount.get(), 2)) } - // Called by subclasses when notified of a lost worker - def removeExecutor(executorId: String, reason: ExecutorLossReason) { - try { - driverEndpoint.askWithRetry[Boolean](RemoveExecutor(executorId, reason)) - } catch { - case e: Exception => - throw new SparkException("Error notifying standalone scheduler's driver endpoint", e) - } + /** + * Called by subclasses when notified of a lost worker. It just fires the message and returns + * at once. + */ + protected def removeExecutor(executorId: String, reason: ExecutorLossReason): Unit = { + // Only log the failure since we don't care about the result. + driverEndpoint.ask[Boolean](RemoveExecutor(executorId, reason)).onFailure { case t => + logError(t.getMessage, t) + }(ThreadUtils.sameThread) } def sufficientResourcesRegistered(): Boolean = true @@ -445,19 +474,24 @@ class CoarseGrainedSchedulerBackend(scheduler: TaskSchedulerImpl, val rpcEnv: Rp * Request an additional number of executors from the cluster manager. * @return whether the request is acknowledged. */ - final override def requestExecutors(numAdditionalExecutors: Int): Boolean = synchronized { + final override def requestExecutors(numAdditionalExecutors: Int): Boolean = { if (numAdditionalExecutors < 0) { throw new IllegalArgumentException( "Attempted to request a negative number of additional executor(s) " + s"$numAdditionalExecutors from the cluster manager. Please specify a positive number!") } logInfo(s"Requesting $numAdditionalExecutors additional executor(s) from the cluster manager") - logDebug(s"Number of pending executors is now $numPendingExecutors") - numPendingExecutors += numAdditionalExecutors - // Account for executors pending to be added or removed - val newTotal = numExistingExecutors + numPendingExecutors - executorsPendingToRemove.size - doRequestTotalExecutors(newTotal) + val response = synchronized { + numPendingExecutors += numAdditionalExecutors + logDebug(s"Number of pending executors is now $numPendingExecutors") + + // Account for executors pending to be added or removed + doRequestTotalExecutors( + numExistingExecutors + numPendingExecutors - executorsPendingToRemove.size) + } + + defaultAskTimeout.awaitResult(response) } /** @@ -478,19 +512,24 @@ class CoarseGrainedSchedulerBackend(scheduler: TaskSchedulerImpl, val rpcEnv: Rp numExecutors: Int, localityAwareTasks: Int, hostToLocalTaskCount: Map[String, Int] - ): Boolean = synchronized { + ): Boolean = { if (numExecutors < 0) { throw new IllegalArgumentException( "Attempted to request a negative number of executor(s) " + s"$numExecutors from the cluster manager. Please specify a positive number!") } - this.localityAwareTasks = localityAwareTasks - this.hostToLocalTaskCount = hostToLocalTaskCount + val response = synchronized { + this.localityAwareTasks = localityAwareTasks + this.hostToLocalTaskCount = hostToLocalTaskCount + + numPendingExecutors = + math.max(numExecutors - numExistingExecutors + executorsPendingToRemove.size, 0) + + doRequestTotalExecutors(numExecutors) + } - numPendingExecutors = - math.max(numExecutors - numExistingExecutors + executorsPendingToRemove.size, 0) - doRequestTotalExecutors(numExecutors) + defaultAskTimeout.awaitResult(response) } /** @@ -503,16 +542,17 @@ class CoarseGrainedSchedulerBackend(scheduler: TaskSchedulerImpl, val rpcEnv: Rp * insufficient resources to satisfy the first request. We make the assumption here that the * cluster manager will eventually fulfill all requests when resources free up. * - * @return whether the request is acknowledged. + * @return a future whose evaluation indicates whether the request is acknowledged. */ - protected def doRequestTotalExecutors(requestedTotal: Int): Boolean = false + protected def doRequestTotalExecutors(requestedTotal: Int): Future[Boolean] = + Future.successful(false) /** * Request that the cluster manager kill the specified executors. * @return whether the kill request is acknowledged. If list to kill is empty, it will return * false. */ - final override def killExecutors(executorIds: Seq[String]): Boolean = synchronized { + final override def killExecutors(executorIds: Seq[String]): Boolean = { killExecutors(executorIds, replace = false, force = false) } @@ -532,39 +572,53 @@ class CoarseGrainedSchedulerBackend(scheduler: TaskSchedulerImpl, val rpcEnv: Rp final def killExecutors( executorIds: Seq[String], replace: Boolean, - force: Boolean): Boolean = synchronized { + force: Boolean): Boolean = { logInfo(s"Requesting to kill executor(s) ${executorIds.mkString(", ")}") - val (knownExecutors, unknownExecutors) = executorIds.partition(executorDataMap.contains) - unknownExecutors.foreach { id => - logWarning(s"Executor to kill $id does not exist!") - } - // If an executor is already pending to be removed, do not kill it again (SPARK-9795) - // If this executor is busy, do not kill it unless we are told to force kill it (SPARK-9552) - val executorsToKill = knownExecutors - .filter { id => !executorsPendingToRemove.contains(id) } - .filter { id => force || !scheduler.isExecutorBusy(id) } - executorsToKill.foreach { id => executorsPendingToRemove(id) = !replace } - - // If we do not wish to replace the executors we kill, sync the target number of executors - // with the cluster manager to avoid allocating new ones. When computing the new target, - // take into account executors that are pending to be added or removed. - if (!replace) { - doRequestTotalExecutors( - numExistingExecutors + numPendingExecutors - executorsPendingToRemove.size) - } else { - numPendingExecutors += knownExecutors.size + val response = synchronized { + val (knownExecutors, unknownExecutors) = executorIds.partition(executorDataMap.contains) + unknownExecutors.foreach { id => + logWarning(s"Executor to kill $id does not exist!") + } + + // If an executor is already pending to be removed, do not kill it again (SPARK-9795) + // If this executor is busy, do not kill it unless we are told to force kill it (SPARK-9552) + val executorsToKill = knownExecutors + .filter { id => !executorsPendingToRemove.contains(id) } + .filter { id => force || !scheduler.isExecutorBusy(id) } + executorsToKill.foreach { id => executorsPendingToRemove(id) = !replace } + + // If we do not wish to replace the executors we kill, sync the target number of executors + // with the cluster manager to avoid allocating new ones. When computing the new target, + // take into account executors that are pending to be added or removed. + val adjustTotalExecutors = + if (!replace) { + doRequestTotalExecutors( + numExistingExecutors + numPendingExecutors - executorsPendingToRemove.size) + } else { + numPendingExecutors += knownExecutors.size + Future.successful(true) + } + + val killExecutors: Boolean => Future[Boolean] = + if (!executorsToKill.isEmpty) { + _ => doKillExecutors(executorsToKill) + } else { + _ => Future.successful(false) + } + + adjustTotalExecutors.flatMap(killExecutors)(ThreadUtils.sameThread) } - !executorsToKill.isEmpty && doKillExecutors(executorsToKill) + defaultAskTimeout.awaitResult(response) } /** * Kill the given list of executors through the cluster manager. * @return whether the kill request is acknowledged. */ - protected def doKillExecutors(executorIds: Seq[String]): Boolean = false - + protected def doKillExecutors(executorIds: Seq[String]): Future[Boolean] = + Future.successful(false) } private[spark] object CoarseGrainedSchedulerBackend { diff --git a/core/src/main/scala/org/apache/spark/scheduler/cluster/StandaloneSchedulerBackend.scala b/core/src/main/scala/org/apache/spark/scheduler/cluster/StandaloneSchedulerBackend.scala index 8382fbe9ddb8..04d40e2907cf 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/cluster/StandaloneSchedulerBackend.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/cluster/StandaloneSchedulerBackend.scala @@ -19,6 +19,8 @@ package org.apache.spark.scheduler.cluster import java.util.concurrent.Semaphore +import scala.concurrent.Future + import org.apache.spark.{SparkConf, SparkContext} import org.apache.spark.deploy.{ApplicationDescription, Command} import org.apache.spark.deploy.client.{StandaloneAppClient, StandaloneAppClientListener} @@ -148,10 +150,11 @@ private[spark] class StandaloneSchedulerBackend( fullId, hostPort, cores, Utils.megabytesToString(memory))) } - override def executorRemoved(fullId: String, message: String, exitStatus: Option[Int]) { + override def executorRemoved( + fullId: String, message: String, exitStatus: Option[Int], workerLost: Boolean) { val reason: ExecutorLossReason = exitStatus match { case Some(code) => ExecutorExited(code, exitCausedByApp = true, message) - case None => SlaveLost(message) + case None => SlaveLost(message, workerLost = workerLost) } logInfo("Executor %s removed: %s".format(fullId, message)) removeExecutor(fullId.split("/")(1), reason) @@ -173,12 +176,12 @@ private[spark] class StandaloneSchedulerBackend( * * @return whether the request is acknowledged. */ - protected override def doRequestTotalExecutors(requestedTotal: Int): Boolean = { + protected override def doRequestTotalExecutors(requestedTotal: Int): Future[Boolean] = { Option(client) match { case Some(c) => c.requestTotalExecutors(requestedTotal) case None => logWarning("Attempted to request executors before driver fully initialized.") - false + Future.successful(false) } } @@ -186,12 +189,12 @@ private[spark] class StandaloneSchedulerBackend( * Kill the given list of executors through the Master. * @return whether the kill request is acknowledged. */ - protected override def doKillExecutors(executorIds: Seq[String]): Boolean = { + protected override def doKillExecutors(executorIds: Seq[String]): Future[Boolean] = { Option(client) match { case Some(c) => c.killExecutors(executorIds) case None => logWarning("Attempted to kill executors before driver fully initialized.") - false + Future.successful(false) } } diff --git a/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosCoarseGrainedSchedulerBackend.scala b/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosCoarseGrainedSchedulerBackend.scala index 99e6d3958374..473b1be4e20e 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosCoarseGrainedSchedulerBackend.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosCoarseGrainedSchedulerBackend.scala @@ -24,6 +24,7 @@ import java.util.concurrent.locks.ReentrantLock import scala.collection.JavaConverters._ import scala.collection.mutable import scala.collection.mutable.{Buffer, HashMap, HashSet} +import scala.concurrent.Future import org.apache.mesos.Protos.{TaskInfo => MesosTaskInfo, _} @@ -552,7 +553,12 @@ private[spark] class MesosCoarseGrainedSchedulerBackend( taskId: String, reason: String): Unit = { stateLock.synchronized { - removeExecutor(taskId, SlaveLost(reason)) + // Do not call removeExecutor() after this scheduler backend was stopped because + // removeExecutor() internally will send a message to the driver endpoint but + // the driver endpoint is not available now, otherwise an exception will be thrown. + if (!stopCalled) { + removeExecutor(taskId, SlaveLost(reason)) + } slaves(slaveId).taskIDs.remove(taskId) } } @@ -572,7 +578,7 @@ private[spark] class MesosCoarseGrainedSchedulerBackend( super.applicationId } - override def doRequestTotalExecutors(requestedTotal: Int): Boolean = { + override def doRequestTotalExecutors(requestedTotal: Int): Future[Boolean] = Future.successful { // We don't truly know if we can fulfill the full amount of executors // since at coarse grain it depends on the amount of slaves available. logInfo("Capping the total amount of executors to " + requestedTotal) @@ -580,7 +586,7 @@ private[spark] class MesosCoarseGrainedSchedulerBackend( true } - override def doKillExecutors(executorIds: Seq[String]): Boolean = { + override def doKillExecutors(executorIds: Seq[String]): Future[Boolean] = Future.successful { if (mesosDriver == null) { logWarning("Asked to kill executors before the Mesos driver was started.") false diff --git a/core/src/main/scala/org/apache/spark/serializer/SerializerManager.scala b/core/src/main/scala/org/apache/spark/serializer/SerializerManager.scala index 9dc274c9fe28..59bdc88464a8 100644 --- a/core/src/main/scala/org/apache/spark/serializer/SerializerManager.scala +++ b/core/src/main/scala/org/apache/spark/serializer/SerializerManager.scala @@ -68,7 +68,7 @@ private[spark] class SerializerManager(defaultSerializer: Serializer, conf: Spar * loaded yet. */ private lazy val compressionCodec: CompressionCodec = CompressionCodec.createCodec(conf) - private def canUseKryo(ct: ClassTag[_]): Boolean = { + def canUseKryo(ct: ClassTag[_]): Boolean = { primitiveAndPrimitiveArrayClassTags.contains(ct) || ct == stringClassTag } @@ -128,8 +128,18 @@ private[spark] class SerializerManager(defaultSerializer: Serializer, conf: Spar /** Serializes into a chunked byte buffer. */ def dataSerialize[T: ClassTag](blockId: BlockId, values: Iterator[T]): ChunkedByteBuffer = { + dataSerializeWithExplicitClassTag(blockId, values, implicitly[ClassTag[T]]) + } + + /** Serializes into a chunked byte buffer. */ + def dataSerializeWithExplicitClassTag( + blockId: BlockId, + values: Iterator[_], + classTag: ClassTag[_]): ChunkedByteBuffer = { val bbos = new ChunkedByteBufferOutputStream(1024 * 1024 * 4, ByteBuffer.allocate) - dataSerializeStream(blockId, bbos, values) + val byteStream = new BufferedOutputStream(bbos) + val ser = getSerializer(classTag).newInstance() + ser.serializeStream(wrapForCompression(blockId, byteStream)).writeAll(values).close() bbos.toChunkedByteBuffer } @@ -137,11 +147,12 @@ private[spark] class SerializerManager(defaultSerializer: Serializer, conf: Spar * Deserializes an InputStream into an iterator of values and disposes of it when the end of * the iterator is reached. */ - def dataDeserializeStream[T: ClassTag]( + def dataDeserializeStream[T]( blockId: BlockId, - inputStream: InputStream): Iterator[T] = { + inputStream: InputStream) + (classTag: ClassTag[T]): Iterator[T] = { val stream = new BufferedInputStream(inputStream) - getSerializer(implicitly[ClassTag[T]]) + getSerializer(classTag) .newInstance() .deserializeStream(wrapForCompression(blockId, stream)) .asIterator.asInstanceOf[Iterator[T]] diff --git a/core/src/main/scala/org/apache/spark/shuffle/IndexShuffleBlockResolver.scala b/core/src/main/scala/org/apache/spark/shuffle/IndexShuffleBlockResolver.scala index 94d8c0d0fd3e..8d6396bededa 100644 --- a/core/src/main/scala/org/apache/spark/shuffle/IndexShuffleBlockResolver.scala +++ b/core/src/main/scala/org/apache/spark/shuffle/IndexShuffleBlockResolver.scala @@ -139,48 +139,54 @@ private[spark] class IndexShuffleBlockResolver( dataTmp: File): Unit = { val indexFile = getIndexFile(shuffleId, mapId) val indexTmp = Utils.tempFileWith(indexFile) - val out = new DataOutputStream(new BufferedOutputStream(new FileOutputStream(indexTmp))) - Utils.tryWithSafeFinally { - // We take in lengths of each block, need to convert it to offsets. - var offset = 0L - out.writeLong(offset) - for (length <- lengths) { - offset += length + try { + val out = new DataOutputStream(new BufferedOutputStream(new FileOutputStream(indexTmp))) + Utils.tryWithSafeFinally { + // We take in lengths of each block, need to convert it to offsets. + var offset = 0L out.writeLong(offset) + for (length <- lengths) { + offset += length + out.writeLong(offset) + } + } { + out.close() } - } { - out.close() - } - val dataFile = getDataFile(shuffleId, mapId) - // There is only one IndexShuffleBlockResolver per executor, this synchronization make sure - // the following check and rename are atomic. - synchronized { - val existingLengths = checkIndexAndDataFile(indexFile, dataFile, lengths.length) - if (existingLengths != null) { - // Another attempt for the same task has already written our map outputs successfully, - // so just use the existing partition lengths and delete our temporary map outputs. - System.arraycopy(existingLengths, 0, lengths, 0, lengths.length) - if (dataTmp != null && dataTmp.exists()) { - dataTmp.delete() - } - indexTmp.delete() - } else { - // This is the first successful attempt in writing the map outputs for this task, - // so override any existing index and data files with the ones we wrote. - if (indexFile.exists()) { - indexFile.delete() - } - if (dataFile.exists()) { - dataFile.delete() - } - if (!indexTmp.renameTo(indexFile)) { - throw new IOException("fail to rename file " + indexTmp + " to " + indexFile) - } - if (dataTmp != null && dataTmp.exists() && !dataTmp.renameTo(dataFile)) { - throw new IOException("fail to rename file " + dataTmp + " to " + dataFile) + val dataFile = getDataFile(shuffleId, mapId) + // There is only one IndexShuffleBlockResolver per executor, this synchronization make sure + // the following check and rename are atomic. + synchronized { + val existingLengths = checkIndexAndDataFile(indexFile, dataFile, lengths.length) + if (existingLengths != null) { + // Another attempt for the same task has already written our map outputs successfully, + // so just use the existing partition lengths and delete our temporary map outputs. + System.arraycopy(existingLengths, 0, lengths, 0, lengths.length) + if (dataTmp != null && dataTmp.exists()) { + dataTmp.delete() + } + indexTmp.delete() + } else { + // This is the first successful attempt in writing the map outputs for this task, + // so override any existing index and data files with the ones we wrote. + if (indexFile.exists()) { + indexFile.delete() + } + if (dataFile.exists()) { + dataFile.delete() + } + if (!indexTmp.renameTo(indexFile)) { + throw new IOException("fail to rename file " + indexTmp + " to " + indexFile) + } + if (dataTmp != null && dataTmp.exists() && !dataTmp.renameTo(dataFile)) { + throw new IOException("fail to rename file " + dataTmp + " to " + dataFile) + } } } + } finally { + if (indexTmp.exists() && !indexTmp.delete()) { + logError(s"Failed to delete temporary index file at ${indexTmp.getAbsolutePath}") + } } } diff --git a/core/src/main/scala/org/apache/spark/shuffle/sort/SortShuffleWriter.scala b/core/src/main/scala/org/apache/spark/shuffle/sort/SortShuffleWriter.scala index 1adacabc86c0..e677270b8460 100644 --- a/core/src/main/scala/org/apache/spark/shuffle/sort/SortShuffleWriter.scala +++ b/core/src/main/scala/org/apache/spark/shuffle/sort/SortShuffleWriter.scala @@ -67,10 +67,16 @@ private[spark] class SortShuffleWriter[K, V, C]( // (see SPARK-3570). val output = shuffleBlockResolver.getDataFile(dep.shuffleId, mapId) val tmp = Utils.tempFileWith(output) - val blockId = ShuffleBlockId(dep.shuffleId, mapId, IndexShuffleBlockResolver.NOOP_REDUCE_ID) - val partitionLengths = sorter.writePartitionedFile(blockId, tmp) - shuffleBlockResolver.writeIndexFileAndCommit(dep.shuffleId, mapId, partitionLengths, tmp) - mapStatus = MapStatus(blockManager.shuffleServerId, partitionLengths) + try { + val blockId = ShuffleBlockId(dep.shuffleId, mapId, IndexShuffleBlockResolver.NOOP_REDUCE_ID) + val partitionLengths = sorter.writePartitionedFile(blockId, tmp) + shuffleBlockResolver.writeIndexFileAndCommit(dep.shuffleId, mapId, partitionLengths, tmp) + mapStatus = MapStatus(blockManager.shuffleServerId, partitionLengths) + } finally { + if (tmp.exists() && !tmp.delete()) { + logError(s"Error while deleting temp file ${tmp.getAbsolutePath}") + } + } } /** Close this writer, passing along whether the map completed */ diff --git a/core/src/main/scala/org/apache/spark/status/api/v1/ApplicationListResource.scala b/core/src/main/scala/org/apache/spark/status/api/v1/ApplicationListResource.scala index 02fd2985fa20..075b9ba37dc8 100644 --- a/core/src/main/scala/org/apache/spark/status/api/v1/ApplicationListResource.scala +++ b/core/src/main/scala/org/apache/spark/status/api/v1/ApplicationListResource.scala @@ -29,7 +29,8 @@ private[v1] class ApplicationListResource(uiRoot: UIRoot) { def appList( @QueryParam("status") status: JList[ApplicationStatus], @DefaultValue("2010-01-01") @QueryParam("minDate") minDate: SimpleDateParam, - @DefaultValue("3000-01-01") @QueryParam("maxDate") maxDate: SimpleDateParam) + @DefaultValue("3000-01-01") @QueryParam("maxDate") maxDate: SimpleDateParam, + @QueryParam("limit") limit: Integer) : Iterator[ApplicationInfo] = { val allApps = uiRoot.getApplicationInfoList val adjStatus = { @@ -41,7 +42,7 @@ private[v1] class ApplicationListResource(uiRoot: UIRoot) { } val includeCompleted = adjStatus.contains(ApplicationStatus.COMPLETED) val includeRunning = adjStatus.contains(ApplicationStatus.RUNNING) - allApps.filter { app => + val appList = allApps.filter { app => val anyRunning = app.attempts.exists(!_.completed) // if any attempt is still running, we consider the app to also still be running val statusOk = (!anyRunning && includeCompleted) || @@ -53,6 +54,11 @@ private[v1] class ApplicationListResource(uiRoot: UIRoot) { } statusOk && dateOk } + if (limit != null) { + appList.take(limit) + } else { + appList + } } } diff --git a/core/src/main/scala/org/apache/spark/storage/BlockManager.scala b/core/src/main/scala/org/apache/spark/storage/BlockManager.scala index 83a9cbd63d39..37dfbd6818ef 100644 --- a/core/src/main/scala/org/apache/spark/storage/BlockManager.scala +++ b/core/src/main/scala/org/apache/spark/storage/BlockManager.scala @@ -216,7 +216,7 @@ private[spark] class BlockManager( logInfo(s"Reporting ${blockInfoManager.size} blocks to the master.") for ((blockId, info) <- blockInfoManager.entries) { val status = getCurrentBlockStatus(blockId, info) - if (!tryToReportBlockStatus(blockId, info, status)) { + if (info.tellMaster && !tryToReportBlockStatus(blockId, status)) { logError(s"Failed to report $blockId to master; giving up.") return } @@ -279,7 +279,12 @@ private[spark] class BlockManager( } else { getLocalBytes(blockId) match { case Some(buffer) => new BlockManagerManagedBuffer(blockInfoManager, blockId, buffer) - case None => throw new BlockNotFoundException(blockId.toString) + case None => + // If this block manager receives a request for a block that it doesn't have then it's + // likely that the master has outdated block statuses for this block. Therefore, we send + // an RPC so that this block is marked as being unavailable from this block manager. + reportBlockStatus(blockId, BlockStatus.empty) + throw new BlockNotFoundException(blockId.toString) } } } @@ -297,7 +302,7 @@ private[spark] class BlockManager( /** * Get the BlockStatus for the block identified by the given ID, if it exists. - * NOTE: This is mainly for testing, and it doesn't fetch information from external block store. + * NOTE: This is mainly for testing. */ def getStatus(blockId: BlockId): Option[BlockStatus] = { blockInfoManager.get(blockId).map { info => @@ -332,10 +337,9 @@ private[spark] class BlockManager( */ private def reportBlockStatus( blockId: BlockId, - info: BlockInfo, status: BlockStatus, droppedMemorySize: Long = 0L): Unit = { - val needReregister = !tryToReportBlockStatus(blockId, info, status, droppedMemorySize) + val needReregister = !tryToReportBlockStatus(blockId, status, droppedMemorySize) if (needReregister) { logInfo(s"Got told to re-register updating block $blockId") // Re-registering will report our new block for free. @@ -351,17 +355,12 @@ private[spark] class BlockManager( */ private def tryToReportBlockStatus( blockId: BlockId, - info: BlockInfo, status: BlockStatus, droppedMemorySize: Long = 0L): Boolean = { - if (info.tellMaster) { - val storageLevel = status.storageLevel - val inMemSize = Math.max(status.memSize, droppedMemorySize) - val onDiskSize = status.diskSize - master.updateBlockInfo(blockManagerId, blockId, storageLevel, inMemSize, onDiskSize) - } else { - true - } + val storageLevel = status.storageLevel + val inMemSize = Math.max(status.memSize, droppedMemorySize) + val onDiskSize = status.diskSize + master.updateBlockInfo(blockManagerId, blockId, storageLevel, inMemSize, onDiskSize) } /** @@ -373,7 +372,7 @@ private[spark] class BlockManager( info.synchronized { info.level match { case null => - BlockStatus(StorageLevel.NONE, memSize = 0L, diskSize = 0L) + BlockStatus.empty case level => val inMem = level.useMemory && memoryStore.contains(blockId) val onDisk = level.useDisk && diskStore.contains(blockId) @@ -497,7 +496,8 @@ private[spark] class BlockManager( diskStore.getBytes(blockId) } else if (level.useMemory && memoryStore.contains(blockId)) { // The block was not found on disk, so serialize an in-memory copy: - serializerManager.dataSerialize(blockId, memoryStore.getValues(blockId).get) + serializerManager.dataSerializeWithExplicitClassTag( + blockId, memoryStore.getValues(blockId).get, info.classTag) } else { handleLocalReadFailure(blockId) } @@ -518,10 +518,11 @@ private[spark] class BlockManager( * * This does not acquire a lock on this block in this JVM. */ - private def getRemoteValues(blockId: BlockId): Option[BlockResult] = { + private def getRemoteValues[T: ClassTag](blockId: BlockId): Option[BlockResult] = { + val ct = implicitly[ClassTag[T]] getRemoteBytes(blockId).map { data => val values = - serializerManager.dataDeserializeStream(blockId, data.toInputStream(dispose = true)) + serializerManager.dataDeserializeStream(blockId, data.toInputStream(dispose = true))(ct) new BlockResult(values, DataReadMethod.Network, data.size) } } @@ -562,8 +563,9 @@ private[spark] class BlockManager( // Give up trying anymore locations. Either we've tried all of the original locations, // or we've refreshed the list of locations from the master, and have still // hit failures after trying locations from the refreshed list. - throw new BlockFetchException(s"Failed to fetch block after" + - s" ${totalFailureCount} fetch failures. Most recent failure cause:", e) + logWarning(s"Failed to fetch block after $totalFailureCount fetch failures. " + + s"Most recent failure cause:", e) + return None } logWarning(s"Failed to fetch remote block $blockId " + @@ -600,13 +602,13 @@ private[spark] class BlockManager( * any locks if the block was fetched from a remote block manager. The read lock will * automatically be freed once the result's `data` iterator is fully consumed. */ - def get(blockId: BlockId): Option[BlockResult] = { + def get[T: ClassTag](blockId: BlockId): Option[BlockResult] = { val local = getLocalValues(blockId) if (local.isDefined) { logInfo(s"Found block $blockId locally") return local } - val remote = getRemoteValues(blockId) + val remote = getRemoteValues[T](blockId) if (remote.isDefined) { logInfo(s"Found block $blockId remotely") return remote @@ -658,7 +660,7 @@ private[spark] class BlockManager( makeIterator: () => Iterator[T]): Either[BlockResult, Iterator[T]] = { // Attempt to read the block from local or remote storage. If it's present, then we don't need // to go through the local-get-or-put path. - get(blockId) match { + get[T](blockId)(classTag) match { case Some(block) => return Left(block) case _ => @@ -805,12 +807,10 @@ private[spark] class BlockManager( // Now that the block is in either the memory, externalBlockStore, or disk store, // tell the master about it. info.size = size - if (tellMaster) { - reportBlockStatus(blockId, info, putBlockStatus) - } - Option(TaskContext.get()).foreach { c => - c.taskMetrics().incUpdatedBlockStatuses(blockId -> putBlockStatus) + if (tellMaster && info.tellMaster) { + reportBlockStatus(blockId, putBlockStatus) } + addUpdatedBlockStatusToTaskMetrics(blockId, putBlockStatus) } logDebug("Put block %s locally took %s".format(blockId, Utils.getUsedTimeMs(startTimeMs))) if (level.replication > 1) { @@ -861,22 +861,38 @@ private[spark] class BlockManager( } val startTimeMs = System.currentTimeMillis - var blockWasSuccessfullyStored: Boolean = false + var exceptionWasThrown: Boolean = true val result: Option[T] = try { val res = putBody(putBlockInfo) - blockWasSuccessfullyStored = res.isEmpty - res - } finally { - if (blockWasSuccessfullyStored) { + exceptionWasThrown = false + if (res.isEmpty) { + // the block was successfully stored if (keepReadLock) { blockInfoManager.downgradeLock(blockId) } else { blockInfoManager.unlock(blockId) } } else { - blockInfoManager.removeBlock(blockId) + removeBlockInternal(blockId, tellMaster = false) logWarning(s"Putting block $blockId failed") } + res + } finally { + // This cleanup is performed in a finally block rather than a `catch` to avoid having to + // catch and properly re-throw InterruptedException. + if (exceptionWasThrown) { + logWarning(s"Putting block $blockId failed due to an exception") + // If an exception was thrown then it's possible that the code in `putBody` has already + // notified the master about the availability of this block, so we need to send an update + // to remove this block location. + removeBlockInternal(blockId, tellMaster = tellMaster) + // The `putBody` code may have also added a new block status to TaskMetrics, so we need + // to cancel that out by overwriting it with an empty block status. We only do this if + // the finally block was entered via an exception because doing this unconditionally would + // cause us to send empty block statuses for every block that failed to be cached due to + // a memory shortage (which is an expected failure, unlike an uncaught exception). + addUpdatedBlockStatusToTaskMetrics(blockId, BlockStatus.empty) + } } if (level.replication > 1) { logDebug("Putting block %s with replication took %s" @@ -959,21 +975,26 @@ private[spark] class BlockManager( val putBlockStatus = getCurrentBlockStatus(blockId, info) val blockWasSuccessfullyStored = putBlockStatus.storageLevel.isValid if (blockWasSuccessfullyStored) { - // Now that the block is in either the memory, externalBlockStore, or disk store, - // tell the master about it. + // Now that the block is in either the memory or disk store, tell the master about it. info.size = size - if (tellMaster) { - reportBlockStatus(blockId, info, putBlockStatus) - } - Option(TaskContext.get()).foreach { c => - c.taskMetrics().incUpdatedBlockStatuses(blockId -> putBlockStatus) + if (tellMaster && info.tellMaster) { + reportBlockStatus(blockId, putBlockStatus) } + addUpdatedBlockStatusToTaskMetrics(blockId, putBlockStatus) logDebug("Put block %s locally took %s".format(blockId, Utils.getUsedTimeMs(startTimeMs))) if (level.replication > 1) { val remoteStartTime = System.currentTimeMillis val bytesToReplicate = doGetLocalBytes(blockId, info) + // [SPARK-16550] Erase the typed classTag when using default serialization, since + // NettyBlockRpcServer crashes when deserializing repl-defined classes. + // TODO(ekl) remove this once the classloader issue on the remote end is fixed. + val remoteClassTag = if (!serializerManager.canUseKryo(classTag)) { + scala.reflect.classTag[Any] + } else { + classTag + } try { - replicate(blockId, bytesToReplicate, level, classTag) + replicate(blockId, bytesToReplicate, level, remoteClassTag) } finally { bytesToReplicate.dispose() } @@ -1170,7 +1191,7 @@ private[spark] class BlockManager( done = true // specified number of peers have been replicated to } } catch { - case e: Exception => + case NonFatal(e) => logWarning(s"Failed to replicate $blockId to $peer, failure #$failures", e) failures += 1 replicationFailed = true @@ -1195,8 +1216,8 @@ private[spark] class BlockManager( /** * Read a block consisting of a single object. */ - def getSingle(blockId: BlockId): Option[Any] = { - get(blockId).map(_.data.next()) + def getSingle[T: ClassTag](blockId: BlockId): Option[T] = { + get[T](blockId).map(_.data.next().asInstanceOf[T]) } /** @@ -1261,12 +1282,10 @@ private[spark] class BlockManager( val status = getCurrentBlockStatus(blockId, info) if (info.tellMaster) { - reportBlockStatus(blockId, info, status, droppedMemorySize) + reportBlockStatus(blockId, status, droppedMemorySize) } if (blockIsUpdated) { - Option(TaskContext.get()).foreach { c => - c.taskMetrics().incUpdatedBlockStatuses(blockId -> status) - } + addUpdatedBlockStatusToTaskMetrics(blockId, status) } status.storageLevel } @@ -1306,21 +1325,31 @@ private[spark] class BlockManager( // The block has already been removed; do nothing. logWarning(s"Asked to remove block $blockId, which does not exist") case Some(info) => - // Removals are idempotent in disk store and memory store. At worst, we get a warning. - val removedFromMemory = memoryStore.remove(blockId) - val removedFromDisk = diskStore.remove(blockId) - if (!removedFromMemory && !removedFromDisk) { - logWarning(s"Block $blockId could not be removed as it was not found in either " + - "the disk, memory, or external block store") - } - blockInfoManager.removeBlock(blockId) - val removeBlockStatus = getCurrentBlockStatus(blockId, info) - if (tellMaster && info.tellMaster) { - reportBlockStatus(blockId, info, removeBlockStatus) - } - Option(TaskContext.get()).foreach { c => - c.taskMetrics().incUpdatedBlockStatuses(blockId -> removeBlockStatus) - } + removeBlockInternal(blockId, tellMaster = tellMaster && info.tellMaster) + addUpdatedBlockStatusToTaskMetrics(blockId, BlockStatus.empty) + } + } + + /** + * Internal version of [[removeBlock()]] which assumes that the caller already holds a write + * lock on the block. + */ + private def removeBlockInternal(blockId: BlockId, tellMaster: Boolean): Unit = { + // Removals are idempotent in disk store and memory store. At worst, we get a warning. + val removedFromMemory = memoryStore.remove(blockId) + val removedFromDisk = diskStore.remove(blockId) + if (!removedFromMemory && !removedFromDisk) { + logWarning(s"Block $blockId could not be removed as it was not found on disk or in memory") + } + blockInfoManager.removeBlock(blockId) + if (tellMaster) { + reportBlockStatus(blockId, BlockStatus.empty) + } + } + + private def addUpdatedBlockStatusToTaskMetrics(blockId: BlockId, status: BlockStatus): Unit = { + Option(TaskContext.get()).foreach { c => + c.taskMetrics().incUpdatedBlockStatuses(blockId -> status) } } diff --git a/core/src/main/scala/org/apache/spark/storage/DiskBlockManager.scala b/core/src/main/scala/org/apache/spark/storage/DiskBlockManager.scala index 0666be2dcb01..35cfdb794129 100644 --- a/core/src/main/scala/org/apache/spark/storage/DiskBlockManager.scala +++ b/core/src/main/scala/org/apache/spark/storage/DiskBlockManager.scala @@ -14,11 +14,28 @@ * See the License for the specific language governing permissions and * limitations under the License. */ +/* + * Changes for SnappyData data platform. + * + * Portions Copyright (c) 2016 SnappyData, Inc. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you + * may not use this file except in compliance with the License. You + * may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or + * implied. See the License for the specific language governing + * permissions and limitations under the License. See accompanying + * LICENSE file. + */ package org.apache.spark.storage import java.io.{File, IOException} -import java.util.UUID import org.apache.spark.SparkConf import org.apache.spark.executor.ExecutorExitCode @@ -105,18 +122,18 @@ private[spark] class DiskBlockManager(conf: SparkConf, deleteFilesOnStop: Boolea /** Produces a unique block id and File suitable for storing local intermediate results. */ def createTempLocalBlock(): (TempLocalBlockId, File) = { - var blockId = new TempLocalBlockId(UUID.randomUUID()) + var blockId = new TempLocalBlockId(StorageUtils.newNonSecureRandomUUID()) while (getFile(blockId).exists()) { - blockId = new TempLocalBlockId(UUID.randomUUID()) + blockId = new TempLocalBlockId(StorageUtils.newNonSecureRandomUUID()) } (blockId, getFile(blockId)) } /** Produces a unique block id and File suitable for storing shuffled intermediate results. */ def createTempShuffleBlock(): (TempShuffleBlockId, File) = { - var blockId = new TempShuffleBlockId(UUID.randomUUID()) + var blockId = new TempShuffleBlockId(StorageUtils.newNonSecureRandomUUID()) while (getFile(blockId).exists()) { - blockId = new TempShuffleBlockId(UUID.randomUUID()) + blockId = new TempShuffleBlockId(StorageUtils.newNonSecureRandomUUID()) } (blockId, getFile(blockId)) } diff --git a/core/src/main/scala/org/apache/spark/storage/StorageUtils.scala b/core/src/main/scala/org/apache/spark/storage/StorageUtils.scala index fb9941bbd9e0..ea3882624cb0 100644 --- a/core/src/main/scala/org/apache/spark/storage/StorageUtils.scala +++ b/core/src/main/scala/org/apache/spark/storage/StorageUtils.scala @@ -14,10 +14,29 @@ * See the License for the specific language governing permissions and * limitations under the License. */ +/* + * Changes for SnappyData data platform. + * + * Portions Copyright (c) 2016 SnappyData, Inc. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you + * may not use this file except in compliance with the License. You + * may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or + * implied. See the License for the specific language governing + * permissions and limitations under the License. See accompanying + * LICENSE file. + */ package org.apache.spark.storage import java.nio.{ByteBuffer, MappedByteBuffer} +import java.util.UUID import scala.collection.Map import scala.collection.mutable @@ -282,4 +301,52 @@ private[spark] object StorageUtils extends Logging { blockLocations } + /** static random number generator for UUIDs */ + private val uuidRnd = new java.util.Random + + /** + * Generate a random UUID for file names etc. Uses non-secure version + * of random number generator to be more efficient given that its not + * critical to have this unique. + * + * Adapted from Android's java.util.UUID source. + */ + final def newNonSecureRandomUUID(): UUID = { + val randomBytes: Array[Byte] = new Array[Byte](16) + uuidRnd.nextBytes(randomBytes) + + var msb = getLong(randomBytes, 0) + var lsb = getLong(randomBytes, 8) + // Set the version field to 4. + msb &= ~(0xfL << 12) + msb |= (4L << 12) + // Set the variant field to 2. Note that the variant field is + // variable-width, so supporting other variants is not just a matter + // of changing the constant 2 below! + lsb &= ~(0x3L << 62) + lsb |= 2L << 62 + new UUID(msb, lsb) + } + + final def getLong(src: Array[Byte], offset: Int): Long = { + var index = offset + var h: Int = (src(index) & 0xff) << 24 + index += 1 + h |= (src(index) & 0xff) << 16 + index += 1 + h |= (src(index) & 0xff) << 8 + index += 1 + h |= (src(index) & 0xff) + index += 1 + + var l = (src(index) & 0xff) << 24 + index += 1 + l |= (src(index) & 0xff) << 16 + index += 1 + l |= (src(index) & 0xff) << 8 + index += 1 + l |= (src(index) & 0xff) + + (h.toLong << 32L) | (l.toLong & 0xffffffffL) + } } diff --git a/core/src/main/scala/org/apache/spark/storage/memory/MemoryStore.scala b/core/src/main/scala/org/apache/spark/storage/memory/MemoryStore.scala index 0349da0d8aa0..3243b9421cc5 100644 --- a/core/src/main/scala/org/apache/spark/storage/memory/MemoryStore.scala +++ b/core/src/main/scala/org/apache/spark/storage/memory/MemoryStore.scala @@ -14,6 +14,24 @@ * See the License for the specific language governing permissions and * limitations under the License. */ +/* + * Changes for SnappyData data platform. + * + * Portions Copyright (c) 2016 SnappyData, Inc. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you + * may not use this file except in compliance with the License. You + * may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or + * implied. See the License for the specific language governing + * permissions and limitations under the License. See accompanying + * LICENSE file. + */ package org.apache.spark.storage.memory @@ -33,7 +51,7 @@ import org.apache.spark.memory.{MemoryManager, MemoryMode} import org.apache.spark.serializer.{SerializationStream, SerializerManager} import org.apache.spark.storage.{BlockId, BlockInfoManager, StorageLevel} import org.apache.spark.unsafe.Platform -import org.apache.spark.util.{CompletionIterator, SizeEstimator, Utils} +import org.apache.spark.util.{SizeEstimator, Utils} import org.apache.spark.util.collection.SizeTrackingVector import org.apache.spark.util.io.{ChunkedByteBuffer, ChunkedByteBufferOutputStream} @@ -271,10 +289,11 @@ private[spark] class MemoryStore( blockId, Utils.bytesToString(size), Utils.bytesToString(maxMemory - blocksMemoryUsed))) Right(size) } else { - assert(currentUnrollMemoryForThisTask >= currentUnrollMemoryForThisTask, + assert(currentUnrollMemoryForThisTask >= unrollMemoryUsedByThisBlock, "released too much unroll memory") Left(new PartiallyUnrolledIterator( this, + MemoryMode.ON_HEAP, unrollMemoryUsedByThisBlock, unrolled = arrayValues.toIterator, rest = Iterator.empty)) @@ -283,7 +302,11 @@ private[spark] class MemoryStore( // We ran out of space while unrolling the values for this block logUnrollFailureMessage(blockId, vector.estimateSize()) Left(new PartiallyUnrolledIterator( - this, unrollMemoryUsedByThisBlock, unrolled = vector.iterator, rest = values)) + this, + MemoryMode.ON_HEAP, + unrollMemoryUsedByThisBlock, + unrolled = vector.iterator, + rest = values)) } } @@ -392,7 +415,7 @@ private[spark] class MemoryStore( redirectableStream, unrollMemoryUsedByThisBlock, memoryMode, - bbos.toChunkedByteBuffer, + bbos, values, classTag)) } @@ -591,11 +614,11 @@ private[spark] class MemoryStore( val memoryToRelease = math.min(memory, unrollMemoryMap(taskAttemptId)) if (memoryToRelease > 0) { unrollMemoryMap(taskAttemptId) -= memoryToRelease - if (unrollMemoryMap(taskAttemptId) == 0) { - unrollMemoryMap.remove(taskAttemptId) - } memoryManager.releaseUnrollMemory(memoryToRelease, memoryMode) } + if (unrollMemoryMap(taskAttemptId) == 0) { + unrollMemoryMap.remove(taskAttemptId) + } } } } @@ -653,6 +676,7 @@ private[spark] class MemoryStore( * The result of a failed [[MemoryStore.putIteratorAsValues()]] call. * * @param memoryStore the memoryStore, used for freeing memory. + * @param memoryMode the memory mode (on- or off-heap). * @param unrollMemory the amount of unroll memory used by the values in `unrolled`. * @param unrolled an iterator for the partially-unrolled values. * @param rest the rest of the original iterator passed to @@ -660,39 +684,55 @@ private[spark] class MemoryStore( */ private[storage] class PartiallyUnrolledIterator[T]( memoryStore: MemoryStore, + memoryMode: MemoryMode, unrollMemory: Long, - unrolled: Iterator[T], + private[this] var unrolled: Iterator[T], rest: Iterator[T]) extends Iterator[T] { - private[this] var unrolledIteratorIsConsumed: Boolean = false - private[this] var iter: Iterator[T] = { - val completionIterator = CompletionIterator[T, Iterator[T]](unrolled, { - unrolledIteratorIsConsumed = true - memoryStore.releaseUnrollMemoryForThisTask(MemoryMode.ON_HEAP, unrollMemory) - }) - completionIterator ++ rest + private def releaseUnrollMemory(): Unit = { + memoryStore.releaseUnrollMemoryForThisTask(memoryMode, unrollMemory) + // SPARK-17503: Garbage collects the unrolling memory before the life end of + // PartiallyUnrolledIterator. + unrolled = null + } + + override def hasNext: Boolean = { + if (unrolled == null) { + rest.hasNext + } else if (!unrolled.hasNext) { + releaseUnrollMemory() + rest.hasNext + } else { + true + } } - override def hasNext: Boolean = iter.hasNext - override def next(): T = iter.next() + override def next(): T = { + if (unrolled == null) { + rest.next() + } else if (!unrolled.hasNext) { + releaseUnrollMemory() + rest.next + } else { + unrolled.next() + } + } /** * Called to dispose of this iterator and free its memory. */ def close(): Unit = { - if (!unrolledIteratorIsConsumed) { - memoryStore.releaseUnrollMemoryForThisTask(MemoryMode.ON_HEAP, unrollMemory) - unrolledIteratorIsConsumed = true + if (unrolled != null) { + releaseUnrollMemory() } - iter = null } } /** * A wrapper which allows an open [[OutputStream]] to be redirected to a different sink. */ -private class RedirectableOutputStream extends OutputStream { +private[storage] class RedirectableOutputStream extends OutputStream { private[this] var os: OutputStream = _ def setOutputStream(s: OutputStream): Unit = { os = s } override def write(b: Int): Unit = os.write(b) @@ -712,7 +752,8 @@ private class RedirectableOutputStream extends OutputStream { * @param redirectableOutputStream an OutputStream which can be redirected to a different sink. * @param unrollMemory the amount of unroll memory used by the values in `unrolled`. * @param memoryMode whether the unroll memory is on- or off-heap - * @param unrolled a byte buffer containing the partially-serialized values. + * @param bbos byte buffer output stream containing the partially-serialized values. + * [[redirectableOutputStream]] initially points to this output stream. * @param rest the rest of the original iterator passed to * [[MemoryStore.putIteratorAsValues()]]. * @param classTag the [[ClassTag]] for the block. @@ -721,14 +762,19 @@ private[storage] class PartiallySerializedBlock[T]( memoryStore: MemoryStore, serializerManager: SerializerManager, blockId: BlockId, - serializationStream: SerializationStream, - redirectableOutputStream: RedirectableOutputStream, - unrollMemory: Long, + private val serializationStream: SerializationStream, + private val redirectableOutputStream: RedirectableOutputStream, + val unrollMemory: Long, memoryMode: MemoryMode, - unrolled: ChunkedByteBuffer, + bbos: ChunkedByteBufferOutputStream, rest: Iterator[T], classTag: ClassTag[T]) { + private lazy val unrolledBuffer: ChunkedByteBuffer = { + bbos.close() + bbos.toChunkedByteBuffer + } + // If the task does not fully consume `valuesIterator` or otherwise fails to consume or dispose of // this PartiallySerializedBlock then we risk leaking of direct buffers, so we use a task // completion listener here in order to ensure that `unrolled.dispose()` is called at least once. @@ -737,7 +783,23 @@ private[storage] class PartiallySerializedBlock[T]( taskContext.addTaskCompletionListener { _ => // When a task completes, its unroll memory will automatically be freed. Thus we do not call // releaseUnrollMemoryForThisTask() here because we want to avoid double-freeing. - unrolled.dispose() + unrolledBuffer.dispose() + } + } + + // Exposed for testing + private[storage] def getUnrolledChunkedByteBuffer: ChunkedByteBuffer = unrolledBuffer + + private[this] var discarded = false + private[this] var consumed = false + + private def verifyNotConsumedAndNotDiscarded(): Unit = { + if (consumed) { + throw new IllegalStateException( + "Can only call one of finishWritingToStream() or valuesIterator() and can only call once.") + } + if (discarded) { + throw new IllegalStateException("Cannot call methods on a discarded PartiallySerializedBlock") } } @@ -745,15 +807,18 @@ private[storage] class PartiallySerializedBlock[T]( * Called to dispose of this block and free its memory. */ def discard(): Unit = { - try { - // We want to close the output stream in order to free any resources associated with the - // serializer itself (such as Kryo's internal buffers). close() might cause data to be - // written, so redirect the output stream to discard that data. - redirectableOutputStream.setOutputStream(ByteStreams.nullOutputStream()) - serializationStream.close() - } finally { - unrolled.dispose() - memoryStore.releaseUnrollMemoryForThisTask(memoryMode, unrollMemory) + if (!discarded) { + try { + // We want to close the output stream in order to free any resources associated with the + // serializer itself (such as Kryo's internal buffers). close() might cause data to be + // written, so redirect the output stream to discard that data. + redirectableOutputStream.setOutputStream(ByteStreams.nullOutputStream()) + serializationStream.close() + } finally { + discarded = true + unrolledBuffer.dispose() + memoryStore.releaseUnrollMemoryForThisTask(memoryMode, unrollMemory) + } } } @@ -762,8 +827,10 @@ private[storage] class PartiallySerializedBlock[T]( * and then serializing the values from the original input iterator. */ def finishWritingToStream(os: OutputStream): Unit = { + verifyNotConsumedAndNotDiscarded() + consumed = true // `unrolled`'s underlying buffers will be freed once this input stream is fully read: - ByteStreams.copy(unrolled.toInputStream(dispose = true), os) + ByteStreams.copy(unrolledBuffer.toInputStream(dispose = true), os) memoryStore.releaseUnrollMemoryForThisTask(memoryMode, unrollMemory) redirectableOutputStream.setOutputStream(os) while (rest.hasNext) { @@ -780,13 +847,22 @@ private[storage] class PartiallySerializedBlock[T]( * `close()` on it to free its resources. */ def valuesIterator: PartiallyUnrolledIterator[T] = { + verifyNotConsumedAndNotDiscarded() + consumed = true + // Close the serialization stream so that the serializer's internal buffers are freed and any + // "end-of-stream" markers can be written out so that `unrolled` is a valid serialized stream. + serializationStream.close() // `unrolled`'s underlying buffers will be freed once this input stream is fully read: val unrolledIter = serializerManager.dataDeserializeStream( - blockId, unrolled.toInputStream(dispose = true))(classTag) + blockId, unrolledBuffer.toInputStream(dispose = true))(classTag) + // The unroll memory will be freed once `unrolledIter` is fully consumed in + // PartiallyUnrolledIterator. If the iterator is not consumed by the end of the task then any + // extra unroll memory will automatically be freed by a `finally` block in `Task`. new PartiallyUnrolledIterator( memoryStore, + memoryMode, unrollMemory, - unrolled = CompletionIterator[T, Iterator[T]](unrolledIter, discard()), + unrolled = unrolledIter, rest = rest) } } diff --git a/core/src/main/scala/org/apache/spark/ui/ToolTips.scala b/core/src/main/scala/org/apache/spark/ui/ToolTips.scala index 2d2d80be4aab..3cc5353f475f 100644 --- a/core/src/main/scala/org/apache/spark/ui/ToolTips.scala +++ b/core/src/main/scala/org/apache/spark/ui/ToolTips.scala @@ -90,4 +90,10 @@ private[spark] object ToolTips { val TASK_TIME = "Shaded red when garbage collection (GC) time is over 10% of task time" + + val APPLICATION_EXECUTOR_LIMIT = + """Maximum number of executors that this application will use. This limit is finite only when + dynamic allocation is enabled. The number of granted executors may exceed the limit + ephemerally when executors are being killed. + """ } diff --git a/core/src/main/scala/org/apache/spark/ui/env/EnvironmentPage.scala b/core/src/main/scala/org/apache/spark/ui/env/EnvironmentPage.scala index f0a1174a71d3..22136a6f1074 100644 --- a/core/src/main/scala/org/apache/spark/ui/env/EnvironmentPage.scala +++ b/core/src/main/scala/org/apache/spark/ui/env/EnvironmentPage.scala @@ -26,11 +26,15 @@ import org.apache.spark.ui.{UIUtils, WebUIPage} private[ui] class EnvironmentPage(parent: EnvironmentTab) extends WebUIPage("") { private val listener = parent.listener + private def removePass(kv: (String, String)): (String, String) = { + if (kv._1.toLowerCase.contains("password")) (kv._1, "******") else kv + } + def render(request: HttpServletRequest): Seq[Node] = { val runtimeInformationTable = UIUtils.listingTable( propertyHeader, jvmRow, listener.jvmInformation, fixedWidth = true) val sparkPropertiesTable = UIUtils.listingTable( - propertyHeader, propertyRow, listener.sparkProperties, fixedWidth = true) + propertyHeader, propertyRow, listener.sparkProperties.map(removePass), fixedWidth = true) val systemPropertiesTable = UIUtils.listingTable( propertyHeader, propertyRow, listener.systemProperties, fixedWidth = true) val classpathEntriesTable = UIUtils.listingTable( diff --git a/core/src/main/scala/org/apache/spark/ui/jobs/JobProgressListener.scala b/core/src/main/scala/org/apache/spark/ui/jobs/JobProgressListener.scala index 842f42b4c98d..38ad6e985c4b 100644 --- a/core/src/main/scala/org/apache/spark/ui/jobs/JobProgressListener.scala +++ b/core/src/main/scala/org/apache/spark/ui/jobs/JobProgressListener.scala @@ -19,12 +19,13 @@ package org.apache.spark.ui.jobs import java.util.concurrent.TimeoutException -import scala.collection.mutable.{HashMap, HashSet, ListBuffer} +import scala.collection.mutable.{HashMap, HashSet, LinkedHashMap, ListBuffer} import org.apache.spark._ import org.apache.spark.annotation.DeveloperApi import org.apache.spark.executor.TaskMetrics import org.apache.spark.internal.Logging +import org.apache.spark.internal.config._ import org.apache.spark.scheduler._ import org.apache.spark.scheduler.SchedulingMode.SchedulingMode import org.apache.spark.storage.BlockManagerId @@ -93,6 +94,7 @@ class JobProgressListener(conf: SparkConf) extends SparkListener with Logging { val retainedStages = conf.getInt("spark.ui.retainedStages", SparkUI.DEFAULT_RETAINED_STAGES) val retainedJobs = conf.getInt("spark.ui.retainedJobs", SparkUI.DEFAULT_RETAINED_JOBS) + val retainedTasks = conf.get(UI_RETAINED_TASKS) // We can test for memory leaks by ensuring that collections that track non-active jobs and // stages do not grow without bound and that collections for active jobs/stages eventually become @@ -400,6 +402,11 @@ class JobProgressListener(conf: SparkConf) extends SparkListener with Logging { taskData.updateTaskMetrics(taskMetrics) taskData.errorMessage = errorMessage + // If Tasks is too large, remove and garbage collect old tasks + if (stageData.taskData.size > retainedTasks) { + stageData.taskData = stageData.taskData.drop(stageData.taskData.size - retainedTasks) + } + for ( activeJobsDependentOnStage <- stageIdToActiveJobIds.get(taskEnd.stageId); jobId <- activeJobsDependentOnStage; diff --git a/core/src/main/scala/org/apache/spark/ui/jobs/StagePage.scala b/core/src/main/scala/org/apache/spark/ui/jobs/StagePage.scala index d986a55959b8..d93a660d8555 100644 --- a/core/src/main/scala/org/apache/spark/ui/jobs/StagePage.scala +++ b/core/src/main/scala/org/apache/spark/ui/jobs/StagePage.scala @@ -131,7 +131,14 @@ private[ui] class StagePage(parent: StagesTab) extends WebUIPage("stage") { val stageData = stageDataOption.get val tasks = stageData.taskData.values.toSeq.sortBy(_.taskInfo.launchTime) - val numCompleted = tasks.count(_.taskInfo.finished) + val numCompleted = stageData.numCompleteTasks + val totalTasks = stageData.numActiveTasks + + stageData.numCompleteTasks + stageData.numFailedTasks + val totalTasksNumStr = if (totalTasks == tasks.size) { + s"$totalTasks" + } else { + s"$totalTasks, showing ${tasks.size}" + } val allAccumulables = progressListener.stageIdToData((stageId, stageAttemptId)).accumulables val externalAccumulables = allAccumulables.values.filter { acc => !acc.internal } @@ -576,7 +583,8 @@ private[ui] class StagePage(parent: StagesTab) extends WebUIPage("stage") {
    {summaryTable.getOrElse("No tasks have reported metrics yet.")}
    ++

    Aggregated Metrics by Executor

    ++ executorTable.toNodeSeq ++ maybeAccumulableTable ++ -

    Tasks

    ++ taskTableHTML ++ jsForScrollingDownToTaskTable +

    Tasks ({totalTasksNumStr})

    ++ + taskTableHTML ++ jsForScrollingDownToTaskTable UIUtils.headerSparkPage(stageHeader, content, parent, showVisualization = true) } } @@ -628,9 +636,9 @@ private[ui] class StagePage(parent: StagesTab) extends WebUIPage("stage") { } val executorComputingTime = executorRunTime - shuffleReadTime - shuffleWriteTime val executorComputingTimeProportion = - (100 - schedulerDelayProportion - shuffleReadTimeProportion - + math.max(100 - schedulerDelayProportion - shuffleReadTimeProportion - shuffleWriteTimeProportion - serializationTimeProportion - - deserializationTimeProportion - gettingResultTimeProportion) + deserializationTimeProportion - gettingResultTimeProportion, 0) val schedulerDelayProportionPos = 0 val deserializationTimeProportionPos = diff --git a/core/src/main/scala/org/apache/spark/ui/jobs/UIData.scala b/core/src/main/scala/org/apache/spark/ui/jobs/UIData.scala index d76a0e657c28..818605003eaf 100644 --- a/core/src/main/scala/org/apache/spark/ui/jobs/UIData.scala +++ b/core/src/main/scala/org/apache/spark/ui/jobs/UIData.scala @@ -18,12 +18,11 @@ package org.apache.spark.ui.jobs import scala.collection.mutable -import scala.collection.mutable.HashMap +import scala.collection.mutable.{HashMap, LinkedHashMap} import org.apache.spark.JobExecutionStatus import org.apache.spark.executor.{ShuffleReadMetrics, ShuffleWriteMetrics, TaskMetrics} import org.apache.spark.scheduler.{AccumulableInfo, TaskInfo} -import org.apache.spark.storage.{BlockId, BlockStatus} import org.apache.spark.util.AccumulatorContext import org.apache.spark.util.collection.OpenHashSet @@ -94,7 +93,7 @@ private[spark] object UIData { var description: Option[String] = None var accumulables = new HashMap[Long, AccumulableInfo] - var taskData = new HashMap[Long, TaskUIData] + var taskData = new LinkedHashMap[Long, TaskUIData] var executorSummary = new HashMap[String, ExecutorSummary] def hasInput: Boolean = inputBytes > 0 @@ -142,7 +141,6 @@ private[spark] object UIData { memoryBytesSpilled = m.memoryBytesSpilled, diskBytesSpilled = m.diskBytesSpilled, peakExecutionMemory = m.peakExecutionMemory, - updatedBlockStatuses = m.updatedBlockStatuses.toList, inputMetrics = InputMetricsUIData(m.inputMetrics.bytesRead, m.inputMetrics.recordsRead), outputMetrics = OutputMetricsUIData(m.outputMetrics.bytesWritten, m.outputMetrics.recordsWritten), @@ -190,7 +188,6 @@ private[spark] object UIData { memoryBytesSpilled: Long, diskBytesSpilled: Long, peakExecutionMemory: Long, - updatedBlockStatuses: Seq[(BlockId, BlockStatus)], inputMetrics: InputMetricsUIData, outputMetrics: OutputMetricsUIData, shuffleReadMetrics: ShuffleReadMetricsUIData, diff --git a/core/src/main/scala/org/apache/spark/util/AccumulatorV2.scala b/core/src/main/scala/org/apache/spark/util/AccumulatorV2.scala index 044dd69cc92c..470d912ecff1 100644 --- a/core/src/main/scala/org/apache/spark/util/AccumulatorV2.scala +++ b/core/src/main/scala/org/apache/spark/util/AccumulatorV2.scala @@ -19,10 +19,12 @@ package org.apache.spark.util import java.{lang => jl} import java.io.ObjectInputStream -import java.util.ArrayList +import java.util.{ArrayList, Collections} import java.util.concurrent.ConcurrentHashMap import java.util.concurrent.atomic.AtomicLong +import scala.collection.JavaConverters._ + import org.apache.spark.{InternalAccumulator, SparkContext, TaskContext} import org.apache.spark.scheduler.AccumulableInfo @@ -36,6 +38,9 @@ private[spark] case class AccumulatorMetadata( /** * The base class for accumulators, that can accumulate inputs of type `IN`, and produce output of * type `OUT`. + * + * `OUT` should be a type that can be read atomically (e.g., Int, Long), or thread-safely + * (e.g., synchronized collections) because it will be read from other threads. */ abstract class AccumulatorV2[IN, OUT] extends Serializable { private[spark] var metadata: AccumulatorMetadata = _ @@ -131,7 +136,7 @@ abstract class AccumulatorV2[IN, OUT] extends Serializable { def reset(): Unit /** - * Takes the inputs and accumulates. e.g. it can be a simple `+=` for counter accumulator. + * Takes the inputs and accumulates. */ def add(v: IN): Unit @@ -257,6 +262,16 @@ private[spark] object AccumulatorContext { originals.clear() } + /** + * Looks for a registered accumulator by accumulator name. + */ + private[spark] def lookForAccumulatorByName(name: String): Option[AccumulatorV2[_, _]] = { + originals.values().asScala.find { ref => + val acc = ref.get + acc != null && acc.name.isDefined && acc.name.get == name + }.map(_.get) + } + // Identifier for distinguishing SQL metrics from other accumulators private[spark] val SQL_ACCUM_IDENTIFIER = "sql" } @@ -421,7 +436,7 @@ class DoubleAccumulator extends AccumulatorV2[jl.Double, jl.Double] { * @since 2.0.0 */ class CollectionAccumulator[T] extends AccumulatorV2[T, java.util.List[T]] { - private val _list: java.util.List[T] = new ArrayList[T] + private val _list: java.util.List[T] = Collections.synchronizedList(new ArrayList[T]()) override def isZero: Boolean = _list.isEmpty diff --git a/core/src/main/scala/org/apache/spark/util/Benchmark.scala b/core/src/main/scala/org/apache/spark/util/Benchmark.scala index 7def44bd2a2b..7576faa99c96 100644 --- a/core/src/main/scala/org/apache/spark/util/Benchmark.scala +++ b/core/src/main/scala/org/apache/spark/util/Benchmark.scala @@ -69,12 +69,17 @@ private[spark] class Benchmark( * @param name of the benchmark case * @param numIters if non-zero, forces exactly this many iterations to be run */ - def addCase(name: String, numIters: Int = 0)(f: Int => Unit): Unit = { - addTimerCase(name, numIters) { timer => + def addCase( + name: String, + numIters: Int = 0, + prepare: () => Unit = () => { }, + cleanup: () => Unit = () => { })(f: Int => Unit): Unit = { + val timedF = (timer: Benchmark.Timer) => { timer.startTiming() f(timer.iteration) timer.stopTiming() } + benchmarks += Benchmark.Case(name, timedF, numIters, prepare, cleanup) } /** @@ -101,7 +106,12 @@ private[spark] class Benchmark( val results = benchmarks.map { c => println(" Running case: " + c.name) - measure(valuesPerIteration, c.numIters)(c.fn) + try { + c.prepare() + measure(valuesPerIteration, c.numIters)(c.fn) + } finally { + c.cleanup() + } } println @@ -188,7 +198,12 @@ private[spark] object Benchmark { } } - case class Case(name: String, fn: Timer => Unit, numIters: Int) + case class Case( + name: String, + fn: Timer => Unit, + numIters: Int, + prepare: () => Unit = () => { }, + cleanup: () => Unit = () => { }) case class Result(avgMs: Double, bestRate: Double, bestMs: Double) /** diff --git a/core/src/main/scala/org/apache/spark/util/ByteBufferOutputStream.scala b/core/src/main/scala/org/apache/spark/util/ByteBufferOutputStream.scala index 09e7579ae960..9077b86f9ba1 100644 --- a/core/src/main/scala/org/apache/spark/util/ByteBufferOutputStream.scala +++ b/core/src/main/scala/org/apache/spark/util/ByteBufferOutputStream.scala @@ -29,7 +29,32 @@ private[spark] class ByteBufferOutputStream(capacity: Int) extends ByteArrayOutp def getCount(): Int = count + private[this] var closed: Boolean = false + + override def write(b: Int): Unit = { + require(!closed, "cannot write to a closed ByteBufferOutputStream") + super.write(b) + } + + override def write(b: Array[Byte], off: Int, len: Int): Unit = { + require(!closed, "cannot write to a closed ByteBufferOutputStream") + super.write(b, off, len) + } + + override def reset(): Unit = { + require(!closed, "cannot reset a closed ByteBufferOutputStream") + super.reset() + } + + override def close(): Unit = { + if (!closed) { + super.close() + closed = true + } + } + def toByteBuffer: ByteBuffer = { - return ByteBuffer.wrap(buf, 0, count) + require(closed, "can only call toByteBuffer() after ByteBufferOutputStream has been closed") + ByteBuffer.wrap(buf, 0, count) } } diff --git a/core/src/main/scala/org/apache/spark/util/JsonProtocol.scala b/core/src/main/scala/org/apache/spark/util/JsonProtocol.scala index 18547d459eb5..148635f5d241 100644 --- a/core/src/main/scala/org/apache/spark/util/JsonProtocol.scala +++ b/core/src/main/scala/org/apache/spark/util/JsonProtocol.scala @@ -309,11 +309,12 @@ private[spark] object JsonProtocol { case v: Int => JInt(v) case v: Long => JInt(v) // We only have 3 kind of internal accumulator types, so if it's not int or long, it must be - // the blocks accumulator, whose type is `Seq[(BlockId, BlockStatus)]` + // the blocks accumulator, whose type is `java.util.List[(BlockId, BlockStatus)]` case v => - JArray(v.asInstanceOf[Seq[(BlockId, BlockStatus)]].toList.map { case (id, status) => - ("Block ID" -> id.toString) ~ - ("Status" -> blockStatusToJson(status)) + JArray(v.asInstanceOf[java.util.List[(BlockId, BlockStatus)]].asScala.toList.map { + case (id, status) => + ("Block ID" -> id.toString) ~ + ("Status" -> blockStatusToJson(status)) }) } } else { @@ -740,7 +741,7 @@ private[spark] object JsonProtocol { val id = BlockId((blockJson \ "Block ID").extract[String]) val status = blockStatusFromJson(blockJson \ "Status") (id, status) - } + }.asJava case _ => throw new IllegalArgumentException(s"unexpected json value $value for " + "accumulator " + name.get) } diff --git a/core/src/main/scala/org/apache/spark/util/Utils.scala b/core/src/main/scala/org/apache/spark/util/Utils.scala index be1ae401d950..7d41458fdfb3 100644 --- a/core/src/main/scala/org/apache/spark/util/Utils.scala +++ b/core/src/main/scala/org/apache/spark/util/Utils.scala @@ -14,6 +14,24 @@ * See the License for the specific language governing permissions and * limitations under the License. */ +/* + * Changes for SnappyData data platform. + * + * Portions Copyright (c) 2016 SnappyData, Inc. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you + * may not use this file except in compliance with the License. You + * may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or + * implied. See the License for the specific language governing + * permissions and limitations under the License. See accompanying + * LICENSE file. + */ package org.apache.spark.util @@ -55,6 +73,7 @@ import org.apache.spark.internal.Logging import org.apache.spark.internal.config.{DYN_ALLOCATION_INITIAL_EXECUTORS, DYN_ALLOCATION_MIN_EXECUTORS, EXECUTOR_INSTANCES} import org.apache.spark.network.util.JavaUtils import org.apache.spark.serializer.{DeserializationStream, SerializationStream, SerializerInstance} +import org.apache.spark.storage.StorageUtils /** CallSite represents a place in user code. It can have a short and a long form. */ private[spark] case class CallSite(shortForm: String, longForm: String) @@ -282,7 +301,8 @@ private[spark] object Utils extends Logging { maxAttempts + " attempts!") } try { - dir = new File(root, namePrefix + "-" + UUID.randomUUID.toString) + dir = new File(root, namePrefix + "-" + + StorageUtils.newNonSecureRandomUUID().toString) if (dir.exists() || !dir.mkdirs()) { dir = null } @@ -697,6 +717,26 @@ private[spark] object Utils extends Logging { } } + /** + * Validate that a given URI is actually a valid URL as well. + * @param uri The URI to validate + */ + @throws[MalformedURLException]("when the URI is an invalid URL") + def validateURL(uri: URI): Unit = { + Option(uri.getScheme).getOrElse("file") match { + case "http" | "https" | "ftp" => + try { + uri.toURL + } catch { + case e: MalformedURLException => + val ex = new MalformedURLException(s"URI (${uri.toString}) is not a valid URL.") + ex.initCause(e) + throw ex + } + case _ => // will not be turned into a URL anyway + } + } + /** * Get the path of a temporary directory. Spark's local directories can be configured through * multiple settings, which are used with the following precedence: @@ -824,7 +864,7 @@ private[spark] object Utils extends Logging { */ def randomizeInPlace[T](arr: Array[T], rand: Random = new Random): Array[T] = { for (i <- (arr.length - 1) to 1 by -1) { - val j = rand.nextInt(i) + val j = rand.nextInt(i + 1) val tmp = arr(j) arr(j) = arr(i) arr(i) = tmp @@ -1949,7 +1989,7 @@ private[spark] object Utils extends Logging { val path = Option(filePath).getOrElse(getDefaultPropertiesFile()) Option(path).foreach { confFile => getPropertiesFromFile(confFile).filter { case (k, v) => - k.startsWith("spark.") + k.startsWith("spark.") || k.startsWith("snappydata.") }.foreach { case (k, v) => conf.setIfMissing(k, v) sys.props.getOrElseUpdate(k, v) @@ -2370,7 +2410,28 @@ private[spark] object Utils extends Logging { * Returns a path of temporary file which is in the same directory with `path`. */ def tempFileWith(path: File): File = { - new File(path.getAbsolutePath + "." + UUID.randomUUID()) + var temp: File = null + do { + temp = new File(path.getAbsolutePath + "." + + StorageUtils.newNonSecureRandomUUID()) + } while (temp.exists()) + temp + } + + /** + * Returns a path of temporary file which is in the same directory with `path`. + */ + def tempFileWith(parent: String, prefix: String): File = { + var temp: File = null + do { + val name = if (prefix == null) { + StorageUtils.newNonSecureRandomUUID().toString + } else { + prefix + '.' + StorageUtils.newNonSecureRandomUUID().toString + } + temp = new File(parent, name) + } while (temp.exists()) + temp } /** diff --git a/core/src/main/scala/org/apache/spark/util/VersionUtils.scala b/core/src/main/scala/org/apache/spark/util/VersionUtils.scala new file mode 100644 index 000000000000..828153b86842 --- /dev/null +++ b/core/src/main/scala/org/apache/spark/util/VersionUtils.scala @@ -0,0 +1,52 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.util + +/** + * Utilities for working with Spark version strings + */ +private[spark] object VersionUtils { + + private val majorMinorRegex = """^(\d+)\.(\d+)(\..*)?$""".r + + /** + * Given a Spark version string, return the major version number. + * E.g., for 2.0.1-SNAPSHOT, return 2. + */ + def majorVersion(sparkVersion: String): Int = majorMinorVersion(sparkVersion)._1 + + /** + * Given a Spark version string, return the minor version number. + * E.g., for 2.0.1-SNAPSHOT, return 0. + */ + def minorVersion(sparkVersion: String): Int = majorMinorVersion(sparkVersion)._2 + + /** + * Given a Spark version string, return the (major version number, minor version number). + * E.g., for 2.0.1-SNAPSHOT, return (2, 0). + */ + def majorMinorVersion(sparkVersion: String): (Int, Int) = { + majorMinorRegex.findFirstMatchIn(sparkVersion) match { + case Some(m) => + (m.group(1).toInt, m.group(2).toInt) + case None => + throw new IllegalArgumentException(s"Spark tried to parse '$sparkVersion' as a Spark" + + s" version string, but it could not find the major and minor version numbers.") + } + } +} diff --git a/core/src/main/scala/org/apache/spark/util/collection/ExternalSorter.scala b/core/src/main/scala/org/apache/spark/util/collection/ExternalSorter.scala index 4067acee738e..6ea7307c3c6e 100644 --- a/core/src/main/scala/org/apache/spark/util/collection/ExternalSorter.scala +++ b/core/src/main/scala/org/apache/spark/util/collection/ExternalSorter.scala @@ -622,7 +622,9 @@ private[spark] class ExternalSorter[K, V, C]( val ds = deserializeStream deserializeStream = null fileStream = null - ds.close() + if (ds != null) { + ds.close() + } // NOTE: We don't do file.delete() here because that is done in ExternalSorter.stop(). // This should also be fixed in ExternalAppendOnlyMap. } diff --git a/core/src/main/scala/org/apache/spark/util/collection/OpenHashMap.scala b/core/src/main/scala/org/apache/spark/util/collection/OpenHashMap.scala index 10ab0b3f8996..00cccd33daf9 100644 --- a/core/src/main/scala/org/apache/spark/util/collection/OpenHashMap.scala +++ b/core/src/main/scala/org/apache/spark/util/collection/OpenHashMap.scala @@ -149,6 +149,23 @@ class OpenHashMap[K : ClassTag, @specialized(Long, Int, Double) V: ClassTag]( } } + def clear() { + // first clear the values array and value for null key + val bitSet = _keySet.getBitSet + val nullV = null.asInstanceOf[V] + val values = _values + var pos = bitSet.nextSetBit(0) + while (pos >= 0) { + values(pos) = nullV + pos = bitSet.nextSetBit(pos + 1) + } + haveNullValue = false + nullValue = nullV + _oldValues = null + // next clear the key set + _keySet.clear() + } + // The following member variables are declared as protected instead of private for the // specialization to work (specialized class extends the non-specialized one and needs access // to the "private" variables). diff --git a/core/src/main/scala/org/apache/spark/util/collection/OpenHashSet.scala b/core/src/main/scala/org/apache/spark/util/collection/OpenHashSet.scala index 0f6a425e3db9..c9d577f45049 100644 --- a/core/src/main/scala/org/apache/spark/util/collection/OpenHashSet.scala +++ b/core/src/main/scala/org/apache/spark/util/collection/OpenHashSet.scala @@ -212,6 +212,12 @@ class OpenHashSet[@specialized(Long, Int) T: ClassTag]( */ def nextPos(fromPos: Int): Int = _bitset.nextSetBit(fromPos) + def clear() { + _data = new Array[T](_capacity) + _bitset.clear() + _size = 0 + } + /** * Double the table's size and re-hash everything. We are not really using k, but it is declared * so Scala compiler can specialize this method (which leads to calling the specialized version diff --git a/core/src/main/scala/org/apache/spark/util/io/ChunkedByteBufferOutputStream.scala b/core/src/main/scala/org/apache/spark/util/io/ChunkedByteBufferOutputStream.scala index 67b50d1e7043..a625b3289538 100644 --- a/core/src/main/scala/org/apache/spark/util/io/ChunkedByteBufferOutputStream.scala +++ b/core/src/main/scala/org/apache/spark/util/io/ChunkedByteBufferOutputStream.scala @@ -49,10 +49,19 @@ private[spark] class ChunkedByteBufferOutputStream( */ private[this] var position = chunkSize private[this] var _size = 0 + private[this] var closed: Boolean = false def size: Long = _size + override def close(): Unit = { + if (!closed) { + super.close() + closed = true + } + } + override def write(b: Int): Unit = { + require(!closed, "cannot write to a closed ChunkedByteBufferOutputStream") allocateNewChunkIfNeeded() chunks(lastChunkIndex).put(b.toByte) position += 1 @@ -60,6 +69,7 @@ private[spark] class ChunkedByteBufferOutputStream( } override def write(bytes: Array[Byte], off: Int, len: Int): Unit = { + require(!closed, "cannot write to a closed ChunkedByteBufferOutputStream") var written = 0 while (written < len) { allocateNewChunkIfNeeded() @@ -73,7 +83,6 @@ private[spark] class ChunkedByteBufferOutputStream( @inline private def allocateNewChunkIfNeeded(): Unit = { - require(!toChunkedByteBufferWasCalled, "cannot write after toChunkedByteBuffer() is called") if (position == chunkSize) { chunks += allocator(chunkSize) lastChunkIndex += 1 @@ -82,6 +91,7 @@ private[spark] class ChunkedByteBufferOutputStream( } def toChunkedByteBuffer: ChunkedByteBuffer = { + require(closed, "cannot call toChunkedByteBuffer() unless close() has been called") require(!toChunkedByteBufferWasCalled, "toChunkedByteBuffer() can only be called once") toChunkedByteBufferWasCalled = true if (lastChunkIndex == -1) { diff --git a/core/src/test/resources/HistoryServerExpectations/limit_app_list_json_expectation.json b/core/src/test/resources/HistoryServerExpectations/limit_app_list_json_expectation.json new file mode 100644 index 000000000000..9165f549d7d2 --- /dev/null +++ b/core/src/test/resources/HistoryServerExpectations/limit_app_list_json_expectation.json @@ -0,0 +1,67 @@ +[ { + "id" : "local-1430917381534", + "name" : "Spark shell", + "attempts" : [ { + "startTime" : "2015-05-06T13:03:00.893GMT", + "endTime" : "2015-05-06T13:03:11.398GMT", + "lastUpdated" : "", + "duration" : 10505, + "sparkUser" : "irashid", + "completed" : true, + "startTimeEpoch" : 1430917380893, + "endTimeEpoch" : 1430917391398, + "lastUpdatedEpoch" : 0 + } ] +}, { + "id" : "local-1430917381535", + "name" : "Spark shell", + "attempts" : [ { + "attemptId" : "2", + "startTime" : "2015-05-06T13:03:00.893GMT", + "endTime" : "2015-05-06T13:03:00.950GMT", + "lastUpdated" : "", + "duration" : 57, + "sparkUser" : "irashid", + "completed" : true, + "startTimeEpoch" : 1430917380893, + "endTimeEpoch" : 1430917380950, + "lastUpdatedEpoch" : 0 + }, { + "attemptId" : "1", + "startTime" : "2015-05-06T13:03:00.880GMT", + "endTime" : "2015-05-06T13:03:00.890GMT", + "lastUpdated" : "", + "duration" : 10, + "sparkUser" : "irashid", + "completed" : true, + "startTimeEpoch" : 1430917380880, + "endTimeEpoch" : 1430917380890, + "lastUpdatedEpoch" : 0 + } ] +}, { + "id" : "local-1426533911241", + "name" : "Spark shell", + "attempts" : [ { + "attemptId" : "2", + "startTime" : "2015-03-17T23:11:50.242GMT", + "endTime" : "2015-03-17T23:12:25.177GMT", + "lastUpdated" : "", + "duration" : 34935, + "sparkUser" : "irashid", + "completed" : true, + "startTimeEpoch" : 1426633910242, + "endTimeEpoch" : 1426633945177, + "lastUpdatedEpoch" : 0 + }, { + "attemptId" : "1", + "startTime" : "2015-03-16T19:25:10.242GMT", + "endTime" : "2015-03-16T19:25:45.177GMT", + "lastUpdated" : "", + "duration" : 34935, + "sparkUser" : "irashid", + "completed" : true, + "startTimeEpoch" : 1426533910242, + "endTimeEpoch" : 1426533945177, + "lastUpdatedEpoch" : 0 + } ] +} ] diff --git a/core/src/test/resources/HistoryServerExpectations/stage_task_list_w__sortBy_expectation.json b/core/src/test/resources/HistoryServerExpectations/stage_task_list_w__sortBy_expectation.json index 11eec0b49c40..96d86b7278ff 100644 --- a/core/src/test/resources/HistoryServerExpectations/stage_task_list_w__sortBy_expectation.json +++ b/core/src/test/resources/HistoryServerExpectations/stage_task_list_w__sortBy_expectation.json @@ -39,21 +39,21 @@ } } }, { - "taskId" : 5, - "index" : 5, + "taskId" : 1, + "index" : 1, "attempt" : 0, - "launchTime" : "2015-05-06T13:03:06.505GMT", + "launchTime" : "2015-05-06T13:03:06.502GMT", "executorId" : "driver", "host" : "localhost", "taskLocality" : "PROCESS_LOCAL", "speculative" : false, "accumulatorUpdates" : [ ], "taskMetrics" : { - "executorDeserializeTime" : 30, + "executorDeserializeTime" : 31, "executorRunTime" : 350, "resultSize" : 2010, "jvmGcTime" : 7, - "resultSerializationTime" : 1, + "resultSerializationTime" : 0, "memoryBytesSpilled" : 0, "diskBytesSpilled" : 0, "inputMetrics" : { @@ -74,26 +74,26 @@ }, "shuffleWriteMetrics" : { "bytesWritten" : 1710, - "writeTime" : 3675510, + "writeTime" : 3934399, "recordsWritten" : 10 } } }, { - "taskId" : 1, - "index" : 1, + "taskId" : 5, + "index" : 5, "attempt" : 0, - "launchTime" : "2015-05-06T13:03:06.502GMT", + "launchTime" : "2015-05-06T13:03:06.505GMT", "executorId" : "driver", "host" : "localhost", "taskLocality" : "PROCESS_LOCAL", "speculative" : false, "accumulatorUpdates" : [ ], "taskMetrics" : { - "executorDeserializeTime" : 31, + "executorDeserializeTime" : 30, "executorRunTime" : 350, "resultSize" : 2010, "jvmGcTime" : 7, - "resultSerializationTime" : 0, + "resultSerializationTime" : 1, "memoryBytesSpilled" : 0, "diskBytesSpilled" : 0, "inputMetrics" : { @@ -114,22 +114,22 @@ }, "shuffleWriteMetrics" : { "bytesWritten" : 1710, - "writeTime" : 3934399, + "writeTime" : 3675510, "recordsWritten" : 10 } } }, { - "taskId" : 4, - "index" : 4, + "taskId" : 0, + "index" : 0, "attempt" : 0, - "launchTime" : "2015-05-06T13:03:06.504GMT", + "launchTime" : "2015-05-06T13:03:06.494GMT", "executorId" : "driver", "host" : "localhost", "taskLocality" : "PROCESS_LOCAL", "speculative" : false, "accumulatorUpdates" : [ ], "taskMetrics" : { - "executorDeserializeTime" : 31, + "executorDeserializeTime" : 32, "executorRunTime" : 349, "resultSize" : 2010, "jvmGcTime" : 7, @@ -137,7 +137,7 @@ "memoryBytesSpilled" : 0, "diskBytesSpilled" : 0, "inputMetrics" : { - "bytesRead" : 60488, + "bytesRead" : 49294, "recordsRead" : 10000 }, "outputMetrics" : { @@ -154,15 +154,15 @@ }, "shuffleWriteMetrics" : { "bytesWritten" : 1710, - "writeTime" : 83022, + "writeTime" : 3842811, "recordsWritten" : 10 } } }, { - "taskId" : 7, - "index" : 7, + "taskId" : 3, + "index" : 3, "attempt" : 0, - "launchTime" : "2015-05-06T13:03:06.506GMT", + "launchTime" : "2015-05-06T13:03:06.504GMT", "executorId" : "driver", "host" : "localhost", "taskLocality" : "PROCESS_LOCAL", @@ -173,7 +173,7 @@ "executorRunTime" : 349, "resultSize" : 2010, "jvmGcTime" : 7, - "resultSerializationTime" : 0, + "resultSerializationTime" : 2, "memoryBytesSpilled" : 0, "diskBytesSpilled" : 0, "inputMetrics" : { @@ -194,13 +194,13 @@ }, "shuffleWriteMetrics" : { "bytesWritten" : 1710, - "writeTime" : 2579051, + "writeTime" : 1311694, "recordsWritten" : 10 } } }, { - "taskId" : 3, - "index" : 3, + "taskId" : 4, + "index" : 4, "attempt" : 0, "launchTime" : "2015-05-06T13:03:06.504GMT", "executorId" : "driver", @@ -213,7 +213,7 @@ "executorRunTime" : 349, "resultSize" : 2010, "jvmGcTime" : 7, - "resultSerializationTime" : 2, + "resultSerializationTime" : 1, "memoryBytesSpilled" : 0, "diskBytesSpilled" : 0, "inputMetrics" : { @@ -234,30 +234,30 @@ }, "shuffleWriteMetrics" : { "bytesWritten" : 1710, - "writeTime" : 1311694, + "writeTime" : 83022, "recordsWritten" : 10 } } }, { - "taskId" : 0, - "index" : 0, + "taskId" : 7, + "index" : 7, "attempt" : 0, - "launchTime" : "2015-05-06T13:03:06.494GMT", + "launchTime" : "2015-05-06T13:03:06.506GMT", "executorId" : "driver", "host" : "localhost", "taskLocality" : "PROCESS_LOCAL", "speculative" : false, "accumulatorUpdates" : [ ], "taskMetrics" : { - "executorDeserializeTime" : 32, + "executorDeserializeTime" : 31, "executorRunTime" : 349, "resultSize" : 2010, "jvmGcTime" : 7, - "resultSerializationTime" : 1, + "resultSerializationTime" : 0, "memoryBytesSpilled" : 0, "diskBytesSpilled" : 0, "inputMetrics" : { - "bytesRead" : 49294, + "bytesRead" : 60488, "recordsRead" : 10000 }, "outputMetrics" : { @@ -274,7 +274,7 @@ }, "shuffleWriteMetrics" : { "bytesWritten" : 1710, - "writeTime" : 3842811, + "writeTime" : 2579051, "recordsWritten" : 10 } } @@ -479,25 +479,25 @@ } } }, { - "taskId" : 16, - "index" : 16, + "taskId" : 9, + "index" : 9, "attempt" : 0, - "launchTime" : "2015-05-06T13:03:07.001GMT", + "launchTime" : "2015-05-06T13:03:06.915GMT", "executorId" : "driver", "host" : "localhost", "taskLocality" : "PROCESS_LOCAL", "speculative" : false, "accumulatorUpdates" : [ ], "taskMetrics" : { - "executorDeserializeTime" : 10, + "executorDeserializeTime" : 9, "executorRunTime" : 84, "resultSize" : 2010, - "jvmGcTime" : 5, + "jvmGcTime" : 0, "resultSerializationTime" : 0, "memoryBytesSpilled" : 0, "diskBytesSpilled" : 0, "inputMetrics" : { - "bytesRead" : 70564, + "bytesRead" : 60489, "recordsRead" : 10000 }, "outputMetrics" : { @@ -514,22 +514,22 @@ }, "shuffleWriteMetrics" : { "bytesWritten" : 1710, - "writeTime" : 108320, + "writeTime" : 101664, "recordsWritten" : 10 } } }, { - "taskId" : 19, - "index" : 19, + "taskId" : 16, + "index" : 16, "attempt" : 0, - "launchTime" : "2015-05-06T13:03:07.012GMT", + "launchTime" : "2015-05-06T13:03:07.001GMT", "executorId" : "driver", "host" : "localhost", "taskLocality" : "PROCESS_LOCAL", "speculative" : false, "accumulatorUpdates" : [ ], "taskMetrics" : { - "executorDeserializeTime" : 5, + "executorDeserializeTime" : 10, "executorRunTime" : 84, "resultSize" : 2010, "jvmGcTime" : 5, @@ -554,30 +554,30 @@ }, "shuffleWriteMetrics" : { "bytesWritten" : 1710, - "writeTime" : 95788, + "writeTime" : 108320, "recordsWritten" : 10 } } }, { - "taskId" : 9, - "index" : 9, + "taskId" : 19, + "index" : 19, "attempt" : 0, - "launchTime" : "2015-05-06T13:03:06.915GMT", + "launchTime" : "2015-05-06T13:03:07.012GMT", "executorId" : "driver", "host" : "localhost", "taskLocality" : "PROCESS_LOCAL", "speculative" : false, "accumulatorUpdates" : [ ], "taskMetrics" : { - "executorDeserializeTime" : 9, + "executorDeserializeTime" : 5, "executorRunTime" : 84, "resultSize" : 2010, - "jvmGcTime" : 0, + "jvmGcTime" : 5, "resultSerializationTime" : 0, "memoryBytesSpilled" : 0, "diskBytesSpilled" : 0, "inputMetrics" : { - "bytesRead" : 60489, + "bytesRead" : 70564, "recordsRead" : 10000 }, "outputMetrics" : { @@ -594,25 +594,25 @@ }, "shuffleWriteMetrics" : { "bytesWritten" : 1710, - "writeTime" : 101664, + "writeTime" : 95788, "recordsWritten" : 10 } } }, { - "taskId" : 20, - "index" : 20, + "taskId" : 14, + "index" : 14, "attempt" : 0, - "launchTime" : "2015-05-06T13:03:07.014GMT", + "launchTime" : "2015-05-06T13:03:06.925GMT", "executorId" : "driver", "host" : "localhost", "taskLocality" : "PROCESS_LOCAL", "speculative" : false, "accumulatorUpdates" : [ ], "taskMetrics" : { - "executorDeserializeTime" : 3, + "executorDeserializeTime" : 6, "executorRunTime" : 83, "resultSize" : 2010, - "jvmGcTime" : 5, + "jvmGcTime" : 0, "resultSerializationTime" : 0, "memoryBytesSpilled" : 0, "diskBytesSpilled" : 0, @@ -634,25 +634,25 @@ }, "shuffleWriteMetrics" : { "bytesWritten" : 1710, - "writeTime" : 97716, + "writeTime" : 95646, "recordsWritten" : 10 } } }, { - "taskId" : 14, - "index" : 14, + "taskId" : 20, + "index" : 20, "attempt" : 0, - "launchTime" : "2015-05-06T13:03:06.925GMT", + "launchTime" : "2015-05-06T13:03:07.014GMT", "executorId" : "driver", "host" : "localhost", "taskLocality" : "PROCESS_LOCAL", "speculative" : false, "accumulatorUpdates" : [ ], "taskMetrics" : { - "executorDeserializeTime" : 6, + "executorDeserializeTime" : 3, "executorRunTime" : 83, "resultSize" : 2010, - "jvmGcTime" : 0, + "jvmGcTime" : 5, "resultSerializationTime" : 0, "memoryBytesSpilled" : 0, "diskBytesSpilled" : 0, @@ -674,7 +674,7 @@ }, "shuffleWriteMetrics" : { "bytesWritten" : 1710, - "writeTime" : 95646, + "writeTime" : 97716, "recordsWritten" : 10 } } diff --git a/core/src/test/resources/HistoryServerExpectations/stage_task_list_w__sortBy_short_names___runtime_expectation.json b/core/src/test/resources/HistoryServerExpectations/stage_task_list_w__sortBy_short_names___runtime_expectation.json index 11eec0b49c40..96d86b7278ff 100644 --- a/core/src/test/resources/HistoryServerExpectations/stage_task_list_w__sortBy_short_names___runtime_expectation.json +++ b/core/src/test/resources/HistoryServerExpectations/stage_task_list_w__sortBy_short_names___runtime_expectation.json @@ -39,21 +39,21 @@ } } }, { - "taskId" : 5, - "index" : 5, + "taskId" : 1, + "index" : 1, "attempt" : 0, - "launchTime" : "2015-05-06T13:03:06.505GMT", + "launchTime" : "2015-05-06T13:03:06.502GMT", "executorId" : "driver", "host" : "localhost", "taskLocality" : "PROCESS_LOCAL", "speculative" : false, "accumulatorUpdates" : [ ], "taskMetrics" : { - "executorDeserializeTime" : 30, + "executorDeserializeTime" : 31, "executorRunTime" : 350, "resultSize" : 2010, "jvmGcTime" : 7, - "resultSerializationTime" : 1, + "resultSerializationTime" : 0, "memoryBytesSpilled" : 0, "diskBytesSpilled" : 0, "inputMetrics" : { @@ -74,26 +74,26 @@ }, "shuffleWriteMetrics" : { "bytesWritten" : 1710, - "writeTime" : 3675510, + "writeTime" : 3934399, "recordsWritten" : 10 } } }, { - "taskId" : 1, - "index" : 1, + "taskId" : 5, + "index" : 5, "attempt" : 0, - "launchTime" : "2015-05-06T13:03:06.502GMT", + "launchTime" : "2015-05-06T13:03:06.505GMT", "executorId" : "driver", "host" : "localhost", "taskLocality" : "PROCESS_LOCAL", "speculative" : false, "accumulatorUpdates" : [ ], "taskMetrics" : { - "executorDeserializeTime" : 31, + "executorDeserializeTime" : 30, "executorRunTime" : 350, "resultSize" : 2010, "jvmGcTime" : 7, - "resultSerializationTime" : 0, + "resultSerializationTime" : 1, "memoryBytesSpilled" : 0, "diskBytesSpilled" : 0, "inputMetrics" : { @@ -114,22 +114,22 @@ }, "shuffleWriteMetrics" : { "bytesWritten" : 1710, - "writeTime" : 3934399, + "writeTime" : 3675510, "recordsWritten" : 10 } } }, { - "taskId" : 4, - "index" : 4, + "taskId" : 0, + "index" : 0, "attempt" : 0, - "launchTime" : "2015-05-06T13:03:06.504GMT", + "launchTime" : "2015-05-06T13:03:06.494GMT", "executorId" : "driver", "host" : "localhost", "taskLocality" : "PROCESS_LOCAL", "speculative" : false, "accumulatorUpdates" : [ ], "taskMetrics" : { - "executorDeserializeTime" : 31, + "executorDeserializeTime" : 32, "executorRunTime" : 349, "resultSize" : 2010, "jvmGcTime" : 7, @@ -137,7 +137,7 @@ "memoryBytesSpilled" : 0, "diskBytesSpilled" : 0, "inputMetrics" : { - "bytesRead" : 60488, + "bytesRead" : 49294, "recordsRead" : 10000 }, "outputMetrics" : { @@ -154,15 +154,15 @@ }, "shuffleWriteMetrics" : { "bytesWritten" : 1710, - "writeTime" : 83022, + "writeTime" : 3842811, "recordsWritten" : 10 } } }, { - "taskId" : 7, - "index" : 7, + "taskId" : 3, + "index" : 3, "attempt" : 0, - "launchTime" : "2015-05-06T13:03:06.506GMT", + "launchTime" : "2015-05-06T13:03:06.504GMT", "executorId" : "driver", "host" : "localhost", "taskLocality" : "PROCESS_LOCAL", @@ -173,7 +173,7 @@ "executorRunTime" : 349, "resultSize" : 2010, "jvmGcTime" : 7, - "resultSerializationTime" : 0, + "resultSerializationTime" : 2, "memoryBytesSpilled" : 0, "diskBytesSpilled" : 0, "inputMetrics" : { @@ -194,13 +194,13 @@ }, "shuffleWriteMetrics" : { "bytesWritten" : 1710, - "writeTime" : 2579051, + "writeTime" : 1311694, "recordsWritten" : 10 } } }, { - "taskId" : 3, - "index" : 3, + "taskId" : 4, + "index" : 4, "attempt" : 0, "launchTime" : "2015-05-06T13:03:06.504GMT", "executorId" : "driver", @@ -213,7 +213,7 @@ "executorRunTime" : 349, "resultSize" : 2010, "jvmGcTime" : 7, - "resultSerializationTime" : 2, + "resultSerializationTime" : 1, "memoryBytesSpilled" : 0, "diskBytesSpilled" : 0, "inputMetrics" : { @@ -234,30 +234,30 @@ }, "shuffleWriteMetrics" : { "bytesWritten" : 1710, - "writeTime" : 1311694, + "writeTime" : 83022, "recordsWritten" : 10 } } }, { - "taskId" : 0, - "index" : 0, + "taskId" : 7, + "index" : 7, "attempt" : 0, - "launchTime" : "2015-05-06T13:03:06.494GMT", + "launchTime" : "2015-05-06T13:03:06.506GMT", "executorId" : "driver", "host" : "localhost", "taskLocality" : "PROCESS_LOCAL", "speculative" : false, "accumulatorUpdates" : [ ], "taskMetrics" : { - "executorDeserializeTime" : 32, + "executorDeserializeTime" : 31, "executorRunTime" : 349, "resultSize" : 2010, "jvmGcTime" : 7, - "resultSerializationTime" : 1, + "resultSerializationTime" : 0, "memoryBytesSpilled" : 0, "diskBytesSpilled" : 0, "inputMetrics" : { - "bytesRead" : 49294, + "bytesRead" : 60488, "recordsRead" : 10000 }, "outputMetrics" : { @@ -274,7 +274,7 @@ }, "shuffleWriteMetrics" : { "bytesWritten" : 1710, - "writeTime" : 3842811, + "writeTime" : 2579051, "recordsWritten" : 10 } } @@ -479,25 +479,25 @@ } } }, { - "taskId" : 16, - "index" : 16, + "taskId" : 9, + "index" : 9, "attempt" : 0, - "launchTime" : "2015-05-06T13:03:07.001GMT", + "launchTime" : "2015-05-06T13:03:06.915GMT", "executorId" : "driver", "host" : "localhost", "taskLocality" : "PROCESS_LOCAL", "speculative" : false, "accumulatorUpdates" : [ ], "taskMetrics" : { - "executorDeserializeTime" : 10, + "executorDeserializeTime" : 9, "executorRunTime" : 84, "resultSize" : 2010, - "jvmGcTime" : 5, + "jvmGcTime" : 0, "resultSerializationTime" : 0, "memoryBytesSpilled" : 0, "diskBytesSpilled" : 0, "inputMetrics" : { - "bytesRead" : 70564, + "bytesRead" : 60489, "recordsRead" : 10000 }, "outputMetrics" : { @@ -514,22 +514,22 @@ }, "shuffleWriteMetrics" : { "bytesWritten" : 1710, - "writeTime" : 108320, + "writeTime" : 101664, "recordsWritten" : 10 } } }, { - "taskId" : 19, - "index" : 19, + "taskId" : 16, + "index" : 16, "attempt" : 0, - "launchTime" : "2015-05-06T13:03:07.012GMT", + "launchTime" : "2015-05-06T13:03:07.001GMT", "executorId" : "driver", "host" : "localhost", "taskLocality" : "PROCESS_LOCAL", "speculative" : false, "accumulatorUpdates" : [ ], "taskMetrics" : { - "executorDeserializeTime" : 5, + "executorDeserializeTime" : 10, "executorRunTime" : 84, "resultSize" : 2010, "jvmGcTime" : 5, @@ -554,30 +554,30 @@ }, "shuffleWriteMetrics" : { "bytesWritten" : 1710, - "writeTime" : 95788, + "writeTime" : 108320, "recordsWritten" : 10 } } }, { - "taskId" : 9, - "index" : 9, + "taskId" : 19, + "index" : 19, "attempt" : 0, - "launchTime" : "2015-05-06T13:03:06.915GMT", + "launchTime" : "2015-05-06T13:03:07.012GMT", "executorId" : "driver", "host" : "localhost", "taskLocality" : "PROCESS_LOCAL", "speculative" : false, "accumulatorUpdates" : [ ], "taskMetrics" : { - "executorDeserializeTime" : 9, + "executorDeserializeTime" : 5, "executorRunTime" : 84, "resultSize" : 2010, - "jvmGcTime" : 0, + "jvmGcTime" : 5, "resultSerializationTime" : 0, "memoryBytesSpilled" : 0, "diskBytesSpilled" : 0, "inputMetrics" : { - "bytesRead" : 60489, + "bytesRead" : 70564, "recordsRead" : 10000 }, "outputMetrics" : { @@ -594,25 +594,25 @@ }, "shuffleWriteMetrics" : { "bytesWritten" : 1710, - "writeTime" : 101664, + "writeTime" : 95788, "recordsWritten" : 10 } } }, { - "taskId" : 20, - "index" : 20, + "taskId" : 14, + "index" : 14, "attempt" : 0, - "launchTime" : "2015-05-06T13:03:07.014GMT", + "launchTime" : "2015-05-06T13:03:06.925GMT", "executorId" : "driver", "host" : "localhost", "taskLocality" : "PROCESS_LOCAL", "speculative" : false, "accumulatorUpdates" : [ ], "taskMetrics" : { - "executorDeserializeTime" : 3, + "executorDeserializeTime" : 6, "executorRunTime" : 83, "resultSize" : 2010, - "jvmGcTime" : 5, + "jvmGcTime" : 0, "resultSerializationTime" : 0, "memoryBytesSpilled" : 0, "diskBytesSpilled" : 0, @@ -634,25 +634,25 @@ }, "shuffleWriteMetrics" : { "bytesWritten" : 1710, - "writeTime" : 97716, + "writeTime" : 95646, "recordsWritten" : 10 } } }, { - "taskId" : 14, - "index" : 14, + "taskId" : 20, + "index" : 20, "attempt" : 0, - "launchTime" : "2015-05-06T13:03:06.925GMT", + "launchTime" : "2015-05-06T13:03:07.014GMT", "executorId" : "driver", "host" : "localhost", "taskLocality" : "PROCESS_LOCAL", "speculative" : false, "accumulatorUpdates" : [ ], "taskMetrics" : { - "executorDeserializeTime" : 6, + "executorDeserializeTime" : 3, "executorRunTime" : 83, "resultSize" : 2010, - "jvmGcTime" : 0, + "jvmGcTime" : 5, "resultSerializationTime" : 0, "memoryBytesSpilled" : 0, "diskBytesSpilled" : 0, @@ -674,7 +674,7 @@ }, "shuffleWriteMetrics" : { "bytesWritten" : 1710, - "writeTime" : 95646, + "writeTime" : 97716, "recordsWritten" : 10 } } diff --git a/core/src/test/resources/HistoryServerExpectations/stage_task_list_w__sortBy_short_names__runtime_expectation.json b/core/src/test/resources/HistoryServerExpectations/stage_task_list_w__sortBy_short_names__runtime_expectation.json index 9528d872ef73..e0e9e8140c71 100644 --- a/core/src/test/resources/HistoryServerExpectations/stage_task_list_w__sortBy_short_names__runtime_expectation.json +++ b/core/src/test/resources/HistoryServerExpectations/stage_task_list_w__sortBy_short_names__runtime_expectation.json @@ -39,21 +39,21 @@ } } }, { - "taskId" : 86, - "index" : 86, + "taskId" : 41, + "index" : 41, "attempt" : 0, - "launchTime" : "2015-05-06T13:03:07.374GMT", + "launchTime" : "2015-05-06T13:03:07.200GMT", "executorId" : "driver", "host" : "localhost", "taskLocality" : "PROCESS_LOCAL", "speculative" : false, "accumulatorUpdates" : [ ], "taskMetrics" : { - "executorDeserializeTime" : 3, + "executorDeserializeTime" : 2, "executorRunTime" : 16, "resultSize" : 2065, "jvmGcTime" : 0, - "resultSerializationTime" : 1, + "resultSerializationTime" : 0, "memoryBytesSpilled" : 0, "diskBytesSpilled" : 0, "inputMetrics" : { @@ -74,15 +74,15 @@ }, "shuffleWriteMetrics" : { "bytesWritten" : 1710, - "writeTime" : 95848, + "writeTime" : 90765, "recordsWritten" : 10 } } }, { - "taskId" : 41, - "index" : 41, + "taskId" : 43, + "index" : 43, "attempt" : 0, - "launchTime" : "2015-05-06T13:03:07.200GMT", + "launchTime" : "2015-05-06T13:03:07.204GMT", "executorId" : "driver", "host" : "localhost", "taskLocality" : "PROCESS_LOCAL", @@ -114,22 +114,22 @@ }, "shuffleWriteMetrics" : { "bytesWritten" : 1710, - "writeTime" : 90765, + "writeTime" : 171516, "recordsWritten" : 10 } } }, { - "taskId" : 68, - "index" : 68, + "taskId" : 57, + "index" : 57, "attempt" : 0, - "launchTime" : "2015-05-06T13:03:07.306GMT", + "launchTime" : "2015-05-06T13:03:07.257GMT", "executorId" : "driver", "host" : "localhost", "taskLocality" : "PROCESS_LOCAL", "speculative" : false, "accumulatorUpdates" : [ ], "taskMetrics" : { - "executorDeserializeTime" : 2, + "executorDeserializeTime" : 3, "executorRunTime" : 16, "resultSize" : 2065, "jvmGcTime" : 0, @@ -154,7 +154,7 @@ }, "shuffleWriteMetrics" : { "bytesWritten" : 1710, - "writeTime" : 101750, + "writeTime" : 96849, "recordsWritten" : 10 } } @@ -199,10 +199,10 @@ } } }, { - "taskId" : 43, - "index" : 43, + "taskId" : 68, + "index" : 68, "attempt" : 0, - "launchTime" : "2015-05-06T13:03:07.204GMT", + "launchTime" : "2015-05-06T13:03:07.306GMT", "executorId" : "driver", "host" : "localhost", "taskLocality" : "PROCESS_LOCAL", @@ -234,15 +234,15 @@ }, "shuffleWriteMetrics" : { "bytesWritten" : 1710, - "writeTime" : 171516, + "writeTime" : 101750, "recordsWritten" : 10 } } }, { - "taskId" : 57, - "index" : 57, + "taskId" : 86, + "index" : 86, "attempt" : 0, - "launchTime" : "2015-05-06T13:03:07.257GMT", + "launchTime" : "2015-05-06T13:03:07.374GMT", "executorId" : "driver", "host" : "localhost", "taskLocality" : "PROCESS_LOCAL", @@ -253,7 +253,7 @@ "executorRunTime" : 16, "resultSize" : 2065, "jvmGcTime" : 0, - "resultSerializationTime" : 0, + "resultSerializationTime" : 1, "memoryBytesSpilled" : 0, "diskBytesSpilled" : 0, "inputMetrics" : { @@ -274,15 +274,15 @@ }, "shuffleWriteMetrics" : { "bytesWritten" : 1710, - "writeTime" : 96849, + "writeTime" : 95848, "recordsWritten" : 10 } } }, { - "taskId" : 59, - "index" : 59, + "taskId" : 32, + "index" : 32, "attempt" : 0, - "launchTime" : "2015-05-06T13:03:07.265GMT", + "launchTime" : "2015-05-06T13:03:07.148GMT", "executorId" : "driver", "host" : "localhost", "taskLocality" : "PROCESS_LOCAL", @@ -314,22 +314,22 @@ }, "shuffleWriteMetrics" : { "bytesWritten" : 1710, - "writeTime" : 100753, + "writeTime" : 89603, "recordsWritten" : 10 } } }, { - "taskId" : 32, - "index" : 32, + "taskId" : 39, + "index" : 39, "attempt" : 0, - "launchTime" : "2015-05-06T13:03:07.148GMT", + "launchTime" : "2015-05-06T13:03:07.180GMT", "executorId" : "driver", "host" : "localhost", "taskLocality" : "PROCESS_LOCAL", "speculative" : false, "accumulatorUpdates" : [ ], "taskMetrics" : { - "executorDeserializeTime" : 3, + "executorDeserializeTime" : 2, "executorRunTime" : 17, "resultSize" : 2065, "jvmGcTime" : 0, @@ -354,22 +354,22 @@ }, "shuffleWriteMetrics" : { "bytesWritten" : 1710, - "writeTime" : 89603, + "writeTime" : 98748, "recordsWritten" : 10 } } }, { - "taskId" : 87, - "index" : 87, + "taskId" : 42, + "index" : 42, "attempt" : 0, - "launchTime" : "2015-05-06T13:03:07.374GMT", + "launchTime" : "2015-05-06T13:03:07.203GMT", "executorId" : "driver", "host" : "localhost", "taskLocality" : "PROCESS_LOCAL", "speculative" : false, "accumulatorUpdates" : [ ], "taskMetrics" : { - "executorDeserializeTime" : 12, + "executorDeserializeTime" : 10, "executorRunTime" : 17, "resultSize" : 2065, "jvmGcTime" : 0, @@ -394,15 +394,15 @@ }, "shuffleWriteMetrics" : { "bytesWritten" : 1710, - "writeTime" : 102159, + "writeTime" : 103713, "recordsWritten" : 10 } } }, { - "taskId" : 99, - "index" : 99, + "taskId" : 51, + "index" : 51, "attempt" : 0, - "launchTime" : "2015-05-06T13:03:07.426GMT", + "launchTime" : "2015-05-06T13:03:07.242GMT", "executorId" : "driver", "host" : "localhost", "taskLocality" : "PROCESS_LOCAL", @@ -417,7 +417,7 @@ "memoryBytesSpilled" : 0, "diskBytesSpilled" : 0, "inputMetrics" : { - "bytesRead" : 70565, + "bytesRead" : 70564, "recordsRead" : 10000 }, "outputMetrics" : { @@ -434,25 +434,25 @@ }, "shuffleWriteMetrics" : { "bytesWritten" : 1710, - "writeTime" : 133964, + "writeTime" : 96013, "recordsWritten" : 10 } } }, { - "taskId" : 63, - "index" : 63, + "taskId" : 59, + "index" : 59, "attempt" : 0, - "launchTime" : "2015-05-06T13:03:07.276GMT", + "launchTime" : "2015-05-06T13:03:07.265GMT", "executorId" : "driver", "host" : "localhost", "taskLocality" : "PROCESS_LOCAL", "speculative" : false, "accumulatorUpdates" : [ ], "taskMetrics" : { - "executorDeserializeTime" : 20, + "executorDeserializeTime" : 3, "executorRunTime" : 17, "resultSize" : 2065, - "jvmGcTime" : 5, + "jvmGcTime" : 0, "resultSerializationTime" : 0, "memoryBytesSpilled" : 0, "diskBytesSpilled" : 0, @@ -474,25 +474,25 @@ }, "shuffleWriteMetrics" : { "bytesWritten" : 1710, - "writeTime" : 102779, + "writeTime" : 100753, "recordsWritten" : 10 } } }, { - "taskId" : 90, - "index" : 90, + "taskId" : 63, + "index" : 63, "attempt" : 0, - "launchTime" : "2015-05-06T13:03:07.385GMT", + "launchTime" : "2015-05-06T13:03:07.276GMT", "executorId" : "driver", "host" : "localhost", "taskLocality" : "PROCESS_LOCAL", "speculative" : false, "accumulatorUpdates" : [ ], "taskMetrics" : { - "executorDeserializeTime" : 2, + "executorDeserializeTime" : 20, "executorRunTime" : 17, "resultSize" : 2065, - "jvmGcTime" : 0, + "jvmGcTime" : 5, "resultSerializationTime" : 0, "memoryBytesSpilled" : 0, "diskBytesSpilled" : 0, @@ -514,22 +514,22 @@ }, "shuffleWriteMetrics" : { "bytesWritten" : 1710, - "writeTime" : 98472, + "writeTime" : 102779, "recordsWritten" : 10 } } }, { - "taskId" : 39, - "index" : 39, + "taskId" : 87, + "index" : 87, "attempt" : 0, - "launchTime" : "2015-05-06T13:03:07.180GMT", + "launchTime" : "2015-05-06T13:03:07.374GMT", "executorId" : "driver", "host" : "localhost", "taskLocality" : "PROCESS_LOCAL", "speculative" : false, "accumulatorUpdates" : [ ], "taskMetrics" : { - "executorDeserializeTime" : 2, + "executorDeserializeTime" : 12, "executorRunTime" : 17, "resultSize" : 2065, "jvmGcTime" : 0, @@ -554,22 +554,22 @@ }, "shuffleWriteMetrics" : { "bytesWritten" : 1710, - "writeTime" : 98748, + "writeTime" : 102159, "recordsWritten" : 10 } } }, { - "taskId" : 42, - "index" : 42, + "taskId" : 90, + "index" : 90, "attempt" : 0, - "launchTime" : "2015-05-06T13:03:07.203GMT", + "launchTime" : "2015-05-06T13:03:07.385GMT", "executorId" : "driver", "host" : "localhost", "taskLocality" : "PROCESS_LOCAL", "speculative" : false, "accumulatorUpdates" : [ ], "taskMetrics" : { - "executorDeserializeTime" : 10, + "executorDeserializeTime" : 2, "executorRunTime" : 17, "resultSize" : 2065, "jvmGcTime" : 0, @@ -594,15 +594,15 @@ }, "shuffleWriteMetrics" : { "bytesWritten" : 1710, - "writeTime" : 103713, + "writeTime" : 98472, "recordsWritten" : 10 } } }, { - "taskId" : 51, - "index" : 51, + "taskId" : 99, + "index" : 99, "attempt" : 0, - "launchTime" : "2015-05-06T13:03:07.242GMT", + "launchTime" : "2015-05-06T13:03:07.426GMT", "executorId" : "driver", "host" : "localhost", "taskLocality" : "PROCESS_LOCAL", @@ -617,7 +617,7 @@ "memoryBytesSpilled" : 0, "diskBytesSpilled" : 0, "inputMetrics" : { - "bytesRead" : 70564, + "bytesRead" : 70565, "recordsRead" : 10000 }, "outputMetrics" : { @@ -634,22 +634,22 @@ }, "shuffleWriteMetrics" : { "bytesWritten" : 1710, - "writeTime" : 96013, + "writeTime" : 133964, "recordsWritten" : 10 } } }, { - "taskId" : 50, - "index" : 50, + "taskId" : 44, + "index" : 44, "attempt" : 0, - "launchTime" : "2015-05-06T13:03:07.240GMT", + "launchTime" : "2015-05-06T13:03:07.205GMT", "executorId" : "driver", "host" : "localhost", "taskLocality" : "PROCESS_LOCAL", "speculative" : false, "accumulatorUpdates" : [ ], "taskMetrics" : { - "executorDeserializeTime" : 4, + "executorDeserializeTime" : 3, "executorRunTime" : 18, "resultSize" : 2065, "jvmGcTime" : 0, @@ -674,22 +674,22 @@ }, "shuffleWriteMetrics" : { "bytesWritten" : 1710, - "writeTime" : 90836, + "writeTime" : 98293, "recordsWritten" : 10 } } }, { - "taskId" : 53, - "index" : 53, + "taskId" : 47, + "index" : 47, "attempt" : 0, - "launchTime" : "2015-05-06T13:03:07.244GMT", + "launchTime" : "2015-05-06T13:03:07.212GMT", "executorId" : "driver", "host" : "localhost", "taskLocality" : "PROCESS_LOCAL", "speculative" : false, "accumulatorUpdates" : [ ], "taskMetrics" : { - "executorDeserializeTime" : 6, + "executorDeserializeTime" : 2, "executorRunTime" : 18, "resultSize" : 2065, "jvmGcTime" : 0, @@ -714,22 +714,22 @@ }, "shuffleWriteMetrics" : { "bytesWritten" : 1710, - "writeTime" : 92835, + "writeTime" : 103015, "recordsWritten" : 10 } } }, { - "taskId" : 44, - "index" : 44, + "taskId" : 50, + "index" : 50, "attempt" : 0, - "launchTime" : "2015-05-06T13:03:07.205GMT", + "launchTime" : "2015-05-06T13:03:07.240GMT", "executorId" : "driver", "host" : "localhost", "taskLocality" : "PROCESS_LOCAL", "speculative" : false, "accumulatorUpdates" : [ ], "taskMetrics" : { - "executorDeserializeTime" : 3, + "executorDeserializeTime" : 4, "executorRunTime" : 18, "resultSize" : 2065, "jvmGcTime" : 0, @@ -754,25 +754,25 @@ }, "shuffleWriteMetrics" : { "bytesWritten" : 1710, - "writeTime" : 98293, + "writeTime" : 90836, "recordsWritten" : 10 } } }, { - "taskId" : 80, - "index" : 80, + "taskId" : 52, + "index" : 52, "attempt" : 0, - "launchTime" : "2015-05-06T13:03:07.341GMT", + "launchTime" : "2015-05-06T13:03:07.243GMT", "executorId" : "driver", "host" : "localhost", "taskLocality" : "PROCESS_LOCAL", "speculative" : false, "accumulatorUpdates" : [ ], "taskMetrics" : { - "executorDeserializeTime" : 13, + "executorDeserializeTime" : 5, "executorRunTime" : 18, "resultSize" : 2065, - "jvmGcTime" : 5, + "jvmGcTime" : 0, "resultSerializationTime" : 0, "memoryBytesSpilled" : 0, "diskBytesSpilled" : 0, @@ -794,7 +794,7 @@ }, "shuffleWriteMetrics" : { "bytesWritten" : 1710, - "writeTime" : 98069, + "writeTime" : 89664, "recordsWritten" : 10 } } diff --git a/core/src/test/scala/org/apache/spark/DistributedSuite.scala b/core/src/test/scala/org/apache/spark/DistributedSuite.scala index 0515e6e3a631..4e36adc8baf3 100644 --- a/core/src/test/scala/org/apache/spark/DistributedSuite.scala +++ b/core/src/test/scala/org/apache/spark/DistributedSuite.scala @@ -134,61 +134,31 @@ class DistributedSuite extends SparkFunSuite with Matchers with LocalSparkContex } } - test("caching") { + test("repeatedly failing task that crashes JVM with a zero exit code (SPARK-16925)") { + // Ensures that if a task which causes the JVM to exit with a zero exit code will cause the + // Spark job to eventually fail. sc = new SparkContext(clusterUrl, "test") - val data = sc.parallelize(1 to 1000, 10).cache() - assert(data.count() === 1000) - assert(data.count() === 1000) - assert(data.count() === 1000) - } - - test("caching on disk") { - sc = new SparkContext(clusterUrl, "test") - val data = sc.parallelize(1 to 1000, 10).persist(StorageLevel.DISK_ONLY) - assert(data.count() === 1000) - assert(data.count() === 1000) - assert(data.count() === 1000) - } - - test("caching in memory, replicated") { - sc = new SparkContext(clusterUrl, "test") - val data = sc.parallelize(1 to 1000, 10).persist(StorageLevel.MEMORY_ONLY_2) - assert(data.count() === 1000) - assert(data.count() === 1000) - assert(data.count() === 1000) - } - - test("caching in memory, serialized, replicated") { - sc = new SparkContext(clusterUrl, "test") - val data = sc.parallelize(1 to 1000, 10).persist(StorageLevel.MEMORY_ONLY_SER_2) - assert(data.count() === 1000) - assert(data.count() === 1000) - assert(data.count() === 1000) - } - - test("caching on disk, replicated") { - sc = new SparkContext(clusterUrl, "test") - val data = sc.parallelize(1 to 1000, 10).persist(StorageLevel.DISK_ONLY_2) - assert(data.count() === 1000) - assert(data.count() === 1000) - assert(data.count() === 1000) - } - - test("caching in memory and disk, replicated") { - sc = new SparkContext(clusterUrl, "test") - val data = sc.parallelize(1 to 1000, 10).persist(StorageLevel.MEMORY_AND_DISK_2) - assert(data.count() === 1000) - assert(data.count() === 1000) - assert(data.count() === 1000) + failAfter(Span(100000, Millis)) { + val thrown = intercept[SparkException] { + sc.parallelize(1 to 1, 1).foreachPartition { _ => System.exit(0) } + } + assert(thrown.getClass === classOf[SparkException]) + assert(thrown.getMessage.contains("failed 4 times")) + } + // Check that the cluster is still usable: + sc.parallelize(1 to 10).count() } - test("caching in memory and disk, serialized, replicated") { + private def testCaching(storageLevel: StorageLevel): Unit = { sc = new SparkContext(clusterUrl, "test") - val data = sc.parallelize(1 to 1000, 10).persist(StorageLevel.MEMORY_AND_DISK_SER_2) - - assert(data.count() === 1000) - assert(data.count() === 1000) - assert(data.count() === 1000) + sc.jobProgressListener.waitUntilExecutorsUp(2, 30000) + val data = sc.parallelize(1 to 1000, 10) + val cachedData = data.persist(storageLevel) + assert(cachedData.count === 1000) + assert(sc.getExecutorStorageStatus.map(_.rddBlocksById(cachedData.id).size).sum === + storageLevel.replication * data.getNumPartitions) + assert(cachedData.count === 1000) + assert(cachedData.count === 1000) // Get all the locations of the first partition and try to fetch the partitions // from those locations. @@ -200,10 +170,26 @@ class DistributedSuite extends SparkFunSuite with Matchers with LocalSparkContex blockManager.master.getLocations(blockId).foreach { cmId => val bytes = blockTransfer.fetchBlockSync(cmId.host, cmId.port, cmId.executorId, blockId.toString) - val deserialized = serializerManager.dataDeserializeStream[Int](blockId, - new ChunkedByteBuffer(bytes.nioByteBuffer()).toInputStream()).toList + val deserialized = serializerManager.dataDeserializeStream(blockId, + new ChunkedByteBuffer(bytes.nioByteBuffer()).toInputStream())(data.elementClassTag).toList assert(deserialized === (1 to 100).toList) } + // This will exercise the getRemoteBytes / getRemoteValues code paths: + assert(blockIds.flatMap(id => blockManager.get[Int](id).get.data).toSet === (1 to 1000).toSet) + } + + Seq( + "caching" -> StorageLevel.MEMORY_ONLY, + "caching on disk" -> StorageLevel.DISK_ONLY, + "caching in memory, replicated" -> StorageLevel.MEMORY_ONLY_2, + "caching in memory, serialized, replicated" -> StorageLevel.MEMORY_ONLY_SER_2, + "caching on disk, replicated" -> StorageLevel.DISK_ONLY_2, + "caching in memory and disk, replicated" -> StorageLevel.MEMORY_AND_DISK_2, + "caching in memory and disk, serialized, replicated" -> StorageLevel.MEMORY_AND_DISK_SER_2 + ).foreach { case (testName, storageLevel) => + test(testName) { + testCaching(storageLevel) + } } test("compute without caching when no partitions fit in memory") { diff --git a/core/src/test/scala/org/apache/spark/HeartbeatReceiverSuite.scala b/core/src/test/scala/org/apache/spark/HeartbeatReceiverSuite.scala index 5e2ba311ee77..e30349570b7e 100644 --- a/core/src/test/scala/org/apache/spark/HeartbeatReceiverSuite.scala +++ b/core/src/test/scala/org/apache/spark/HeartbeatReceiverSuite.scala @@ -21,6 +21,7 @@ import java.util.concurrent.{ExecutorService, TimeUnit} import scala.collection.Map import scala.collection.mutable +import scala.concurrent.Future import scala.concurrent.duration._ import scala.language.postfixOps @@ -270,13 +271,13 @@ private class FakeSchedulerBackend( clusterManagerEndpoint: RpcEndpointRef) extends CoarseGrainedSchedulerBackend(scheduler, rpcEnv) { - protected override def doRequestTotalExecutors(requestedTotal: Int): Boolean = { - clusterManagerEndpoint.askWithRetry[Boolean]( + protected override def doRequestTotalExecutors(requestedTotal: Int): Future[Boolean] = { + clusterManagerEndpoint.ask[Boolean]( RequestExecutors(requestedTotal, localityAwareTasks, hostToLocalTaskCount)) } - protected override def doKillExecutors(executorIds: Seq[String]): Boolean = { - clusterManagerEndpoint.askWithRetry[Boolean](KillExecutors(executorIds)) + protected override def doKillExecutors(executorIds: Seq[String]): Future[Boolean] = { + clusterManagerEndpoint.ask[Boolean](KillExecutors(executorIds)) } } diff --git a/core/src/test/scala/org/apache/spark/SparkContextSuite.scala b/core/src/test/scala/org/apache/spark/SparkContextSuite.scala index 4fa3cab18184..c451c596b069 100644 --- a/core/src/test/scala/org/apache/spark/SparkContextSuite.scala +++ b/core/src/test/scala/org/apache/spark/SparkContextSuite.scala @@ -18,6 +18,7 @@ package org.apache.spark import java.io.File +import java.net.MalformedURLException import java.nio.charset.StandardCharsets import java.util.concurrent.TimeUnit @@ -173,6 +174,27 @@ class SparkContextSuite extends SparkFunSuite with LocalSparkContext { } } + test("SPARK-17650: malformed url's throw exceptions before bricking Executors") { + try { + sc = new SparkContext(new SparkConf().setAppName("test").setMaster("local")) + Seq("http", "https", "ftp").foreach { scheme => + val badURL = s"$scheme://user:pwd/path" + val e1 = intercept[MalformedURLException] { + sc.addFile(badURL) + } + assert(e1.getMessage.contains(badURL)) + val e2 = intercept[MalformedURLException] { + sc.addJar(badURL) + } + assert(e2.getMessage.contains(badURL)) + assert(sc.addedFiles.isEmpty) + assert(sc.addedJars.isEmpty) + } + } finally { + sc.stop() + } + } + test("addFile recursive works") { val pluto = Utils.createTempDir() val neptune = Utils.createTempDir(pluto.getAbsolutePath) @@ -216,6 +238,57 @@ class SparkContextSuite extends SparkFunSuite with LocalSparkContext { } } + test("cannot call addFile with different paths that have the same filename") { + val dir = Utils.createTempDir() + try { + val subdir1 = new File(dir, "subdir1") + val subdir2 = new File(dir, "subdir2") + assert(subdir1.mkdir()) + assert(subdir2.mkdir()) + val file1 = new File(subdir1, "file") + val file2 = new File(subdir2, "file") + Files.write("old", file1, StandardCharsets.UTF_8) + Files.write("new", file2, StandardCharsets.UTF_8) + sc = new SparkContext("local-cluster[1,1,1024]", "test") + sc.addFile(file1.getAbsolutePath) + def getAddedFileContents(): String = { + sc.parallelize(Seq(0)).map { _ => + scala.io.Source.fromFile(SparkFiles.get("file")).mkString + }.first() + } + assert(getAddedFileContents() === "old") + intercept[IllegalArgumentException] { + sc.addFile(file2.getAbsolutePath) + } + assert(getAddedFileContents() === "old") + } finally { + Utils.deleteRecursively(dir) + } + } + + // Regression tests for SPARK-16787 + for ( + schedulingMode <- Seq("local-mode", "non-local-mode"); + method <- Seq("addJar", "addFile") + ) { + val jarPath = Thread.currentThread().getContextClassLoader.getResource("TestUDTF.jar").toString + val master = schedulingMode match { + case "local-mode" => "local" + case "non-local-mode" => "local-cluster[1,1,1024]" + } + test(s"$method can be called twice with same file in $schedulingMode (SPARK-16787)") { + sc = new SparkContext(master, "test") + method match { + case "addJar" => + sc.addJar(jarPath) + sc.addJar(jarPath) + case "addFile" => + sc.addFile(jarPath) + sc.addFile(jarPath) + } + } + } + test("Cancelling job group should not cause SparkContext to shutdown (SPARK-6414)") { try { sc = new SparkContext(new SparkConf().setAppName("test").setMaster("local")) diff --git a/core/src/test/scala/org/apache/spark/deploy/SparkSubmitSuite.scala b/core/src/test/scala/org/apache/spark/deploy/SparkSubmitSuite.scala index b2bc8861083b..54693c1bf81e 100644 --- a/core/src/test/scala/org/apache/spark/deploy/SparkSubmitSuite.scala +++ b/core/src/test/scala/org/apache/spark/deploy/SparkSubmitSuite.scala @@ -577,6 +577,25 @@ class SparkSubmitSuite val sysProps3 = SparkSubmit.prepareSubmitEnvironment(appArgs3)._3 sysProps3("spark.submit.pyFiles") should be( PythonRunner.formatPaths(Utils.resolveURIs(pyFiles)).mkString(",")) + + // Test remote python files + val f4 = File.createTempFile("test-submit-remote-python-files", "", tmpDir) + val writer4 = new PrintWriter(f4) + val remotePyFiles = "hdfs:///tmp/file1.py,hdfs:///tmp/file2.py" + writer4.println("spark.submit.pyFiles " + remotePyFiles) + writer4.close() + val clArgs4 = Seq( + "--master", "yarn", + "--deploy-mode", "cluster", + "--properties-file", f4.getPath, + "hdfs:///tmp/mister.py" + ) + val appArgs4 = new SparkSubmitArguments(clArgs4) + val sysProps4 = SparkSubmit.prepareSubmitEnvironment(appArgs4)._3 + // Should not format python path for yarn cluster mode + sysProps4("spark.submit.pyFiles") should be( + Utils.resolveURIs(remotePyFiles) + ) } test("user classpath first in driver") { diff --git a/core/src/test/scala/org/apache/spark/deploy/client/AppClientSuite.scala b/core/src/test/scala/org/apache/spark/deploy/client/AppClientSuite.scala index f6ef9d15ddee..bc58fb2a362a 100644 --- a/core/src/test/scala/org/apache/spark/deploy/client/AppClientSuite.scala +++ b/core/src/test/scala/org/apache/spark/deploy/client/AppClientSuite.scala @@ -22,7 +22,7 @@ import java.util.concurrent.ConcurrentLinkedQueue import scala.concurrent.duration._ import org.scalatest.BeforeAndAfterAll -import org.scalatest.concurrent.Eventually._ +import org.scalatest.concurrent.{Eventually, ScalaFutures} import org.apache.spark._ import org.apache.spark.deploy.{ApplicationDescription, Command} @@ -36,7 +36,12 @@ import org.apache.spark.util.Utils /** * End-to-end tests for application client in standalone mode. */ -class AppClientSuite extends SparkFunSuite with LocalSparkContext with BeforeAndAfterAll { +class AppClientSuite + extends SparkFunSuite + with LocalSparkContext + with BeforeAndAfterAll + with Eventually + with ScalaFutures { private val numWorkers = 2 private val conf = new SparkConf() private val securityManager = new SecurityManager(conf) @@ -93,7 +98,12 @@ class AppClientSuite extends SparkFunSuite with LocalSparkContext with BeforeAnd // Send message to Master to request Executors, verify request by change in executor limit val numExecutorsRequested = 1 - assert(ci.client.requestTotalExecutors(numExecutorsRequested)) + whenReady( + ci.client.requestTotalExecutors(numExecutorsRequested), + timeout(10.seconds), + interval(10.millis)) { acknowledged => + assert(acknowledged) + } eventually(timeout(10.seconds), interval(10.millis)) { val apps = getApplications() @@ -101,10 +111,12 @@ class AppClientSuite extends SparkFunSuite with LocalSparkContext with BeforeAnd } // Send request to kill executor, verify request was made - assert { - val apps = getApplications() - val executorId: String = apps.head.executors.head._2.fullId - ci.client.killExecutors(Seq(executorId)) + val executorId: String = getApplications().head.executors.head._2.fullId + whenReady( + ci.client.killExecutors(Seq(executorId)), + timeout(10.seconds), + interval(10.millis)) { acknowledged => + assert(acknowledged) } // Issue stop command for Client to disconnect from Master @@ -122,7 +134,9 @@ class AppClientSuite extends SparkFunSuite with LocalSparkContext with BeforeAnd val ci = new AppClientInst(masterRpcEnv.address.toSparkURL) // requests to master should fail immediately - assert(ci.client.requestTotalExecutors(3) === false) + whenReady(ci.client.requestTotalExecutors(3), timeout(1.seconds)) { success => + assert(success === false) + } } // =============================== @@ -196,7 +210,8 @@ class AppClientSuite extends SparkFunSuite with LocalSparkContext with BeforeAnd execAddedList.add(id) } - def executorRemoved(id: String, message: String, exitStatus: Option[Int]): Unit = { + def executorRemoved( + id: String, message: String, exitStatus: Option[Int], workerLost: Boolean): Unit = { execRemovedList.add(id) } } diff --git a/core/src/test/scala/org/apache/spark/deploy/history/HistoryServerSuite.scala b/core/src/test/scala/org/apache/spark/deploy/history/HistoryServerSuite.scala index 631a7cd9d5d7..ae3f5d9c012e 100644 --- a/core/src/test/scala/org/apache/spark/deploy/history/HistoryServerSuite.scala +++ b/core/src/test/scala/org/apache/spark/deploy/history/HistoryServerSuite.scala @@ -100,6 +100,7 @@ class HistoryServerSuite extends SparkFunSuite with BeforeAndAfter with Matchers "minDate app list json" -> "applications?minDate=2015-02-10", "maxDate app list json" -> "applications?maxDate=2015-02-10", "maxDate2 app list json" -> "applications?maxDate=2015-02-03T16:42:40.000GMT", + "limit app list json" -> "applications?limit=3", "one app json" -> "applications/local-1422981780767", "one app multi-attempt json" -> "applications/local-1426533911241", "job list json" -> "applications/local-1422981780767/jobs", diff --git a/core/src/test/scala/org/apache/spark/rdd/SortingSuite.scala b/core/src/test/scala/org/apache/spark/rdd/SortingSuite.scala index f9a7f151823a..7f20206202cb 100644 --- a/core/src/test/scala/org/apache/spark/rdd/SortingSuite.scala +++ b/core/src/test/scala/org/apache/spark/rdd/SortingSuite.scala @@ -135,7 +135,7 @@ class SortingSuite extends SparkFunSuite with SharedSparkContext with Matchers w } test("get a range of elements in an array not partitioned by a range partitioner") { - val pairArr = util.Random.shuffle((1 to 1000).toList).map(x => (x, x)) + val pairArr = scala.util.Random.shuffle((1 to 1000).toList).map(x => (x, x)) val pairs = sc.parallelize(pairArr, 10) val range = pairs.filterByRange(200, 800).collect() assert((800 to 200 by -1).toArray.sorted === range.map(_._1).sorted) diff --git a/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala index ce4e7a237e9f..5c353021677e 100644 --- a/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala +++ b/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala @@ -18,6 +18,7 @@ package org.apache.spark.scheduler import java.util.Properties +import java.util.concurrent.atomic.AtomicBoolean import scala.annotation.meta.param import scala.collection.mutable.{ArrayBuffer, HashMap, HashSet, Map} @@ -31,6 +32,7 @@ import org.apache.spark._ import org.apache.spark.broadcast.BroadcastManager import org.apache.spark.rdd.RDD import org.apache.spark.scheduler.SchedulingMode.SchedulingMode +import org.apache.spark.shuffle.{FetchFailedException, MetadataFetchFailedException} import org.apache.spark.storage.{BlockId, BlockManagerId, BlockManagerMaster} import org.apache.spark.util.{AccumulatorContext, AccumulatorV2, CallSite, LongAccumulator, Utils} @@ -199,7 +201,11 @@ class DAGSchedulerSuite extends SparkFunSuite with LocalSparkContext with Timeou override def beforeEach(): Unit = { super.beforeEach() - sc = new SparkContext("local", "DAGSchedulerSuite") + init(new SparkConf()) + } + + private def init(testConf: SparkConf): Unit = { + sc = new SparkContext("local", "DAGSchedulerSuite", testConf) sparkListener.submittedStageInfos.clear() sparkListener.successfulStages.clear() sparkListener.failedStages.clear() @@ -605,6 +611,46 @@ class DAGSchedulerSuite extends SparkFunSuite with LocalSparkContext with Timeou assertDataStructuresEmpty() } + private val shuffleFileLossTests = Seq( + ("slave lost with shuffle service", SlaveLost("", false), true, false), + ("worker lost with shuffle service", SlaveLost("", true), true, true), + ("worker lost without shuffle service", SlaveLost("", true), false, true), + ("executor failure with shuffle service", ExecutorKilled, true, false), + ("executor failure without shuffle service", ExecutorKilled, false, true)) + + for ((eventDescription, event, shuffleServiceOn, expectFileLoss) <- shuffleFileLossTests) { + val maybeLost = if (expectFileLoss) { + "lost" + } else { + "not lost" + } + test(s"shuffle files $maybeLost when $eventDescription") { + // reset the test context with the right shuffle service config + afterEach() + val conf = new SparkConf() + conf.set("spark.shuffle.service.enabled", shuffleServiceOn.toString) + init(conf) + assert(sc.env.blockManager.externalShuffleServiceEnabled == shuffleServiceOn) + + val shuffleMapRdd = new MyRDD(sc, 2, Nil) + val shuffleDep = new ShuffleDependency(shuffleMapRdd, new HashPartitioner(1)) + val shuffleId = shuffleDep.shuffleId + val reduceRdd = new MyRDD(sc, 1, List(shuffleDep), tracker = mapOutputTracker) + submit(reduceRdd, Array(0)) + complete(taskSets(0), Seq( + (Success, makeMapStatus("hostA", 1)), + (Success, makeMapStatus("hostB", 1)))) + runEvent(ExecutorLost("exec-hostA", event)) + if (expectFileLoss) { + intercept[MetadataFetchFailedException] { + mapOutputTracker.getMapSizesByExecutorId(shuffleId, 0) + } + } else { + assert(mapOutputTracker.getMapSizesByExecutorId(shuffleId, 0).map(_._1).toSet === + HashSet(makeBlockManagerId("hostA"), makeBlockManagerId("hostB"))) + } + } + } // Helper function to validate state when creating tests for task failures private def checkStageId(stageId: Int, attempt: Int, stageAttempt: TaskSet) { @@ -612,7 +658,6 @@ class DAGSchedulerSuite extends SparkFunSuite with LocalSparkContext with Timeou assert(stageAttempt.stageAttemptId == attempt) } - // Helper functions to extract commonly used code in Fetch Failure test cases private def setupStageAbortTest(sc: SparkContext) { sc.listenerBus.addListener(new EndListener()) @@ -1094,7 +1139,7 @@ class DAGSchedulerSuite extends SparkFunSuite with LocalSparkContext with Timeou // pretend we were told hostA went away val oldEpoch = mapOutputTracker.getEpoch - runEvent(ExecutorLost("exec-hostA")) + runEvent(ExecutorLost("exec-hostA", ExecutorKilled)) val newEpoch = mapOutputTracker.getEpoch assert(newEpoch > oldEpoch) @@ -1225,7 +1270,7 @@ class DAGSchedulerSuite extends SparkFunSuite with LocalSparkContext with Timeou )) // then one executor dies, and a task fails in stage 1 - runEvent(ExecutorLost("exec-hostA")) + runEvent(ExecutorLost("exec-hostA", ExecutorKilled)) runEvent(makeCompletionEvent( taskSets(1).tasks(0), FetchFailed(null, firstShuffleId, 2, 0, "Fetch failed"), @@ -1323,7 +1368,7 @@ class DAGSchedulerSuite extends SparkFunSuite with LocalSparkContext with Timeou makeMapStatus("hostA", reduceRdd.partitions.length))) // now that host goes down - runEvent(ExecutorLost("exec-hostA")) + runEvent(ExecutorLost("exec-hostA", ExecutorKilled)) // so we resubmit those tasks runEvent(makeCompletionEvent(taskSets(0).tasks(0), Resubmitted, null)) @@ -1516,7 +1561,7 @@ class DAGSchedulerSuite extends SparkFunSuite with LocalSparkContext with Timeou submit(reduceRdd, Array(0)) // blockManagerMaster.removeExecutor("exec-hostA") // pretend we were told hostA went away - runEvent(ExecutorLost("exec-hostA")) + runEvent(ExecutorLost("exec-hostA", ExecutorKilled)) // DAGScheduler will immediately resubmit the stage after it appears to have no pending tasks // rather than marking it is as failed and waiting. complete(taskSets(0), Seq( @@ -1983,7 +2028,7 @@ class DAGSchedulerSuite extends SparkFunSuite with LocalSparkContext with Timeou // Pretend host A was lost val oldEpoch = mapOutputTracker.getEpoch - runEvent(ExecutorLost("exec-hostA")) + runEvent(ExecutorLost("exec-hostA", ExecutorKilled)) val newEpoch = mapOutputTracker.getEpoch assert(newEpoch > oldEpoch) @@ -2014,6 +2059,61 @@ class DAGSchedulerSuite extends SparkFunSuite with LocalSparkContext with Timeou assertDataStructuresEmpty() } + test("SPARK-17644: After one stage is aborted for too many failed attempts, subsequent stages" + + "still behave correctly on fetch failures") { + // Runs a job that always encounters a fetch failure, so should eventually be aborted + def runJobWithPersistentFetchFailure: Unit = { + val rdd1 = sc.makeRDD(Array(1, 2, 3, 4), 2).map(x => (x, 1)).groupByKey() + val shuffleHandle = + rdd1.dependencies.head.asInstanceOf[ShuffleDependency[_, _, _]].shuffleHandle + rdd1.map { + case (x, _) if (x == 1) => + throw new FetchFailedException( + BlockManagerId("1", "1", 1), shuffleHandle.shuffleId, 0, 0, "test") + case (x, _) => x + }.count() + } + + // Runs a job that encounters a single fetch failure but succeeds on the second attempt + def runJobWithTemporaryFetchFailure: Unit = { + object FailThisAttempt { + val _fail = new AtomicBoolean(true) + } + val rdd1 = sc.makeRDD(Array(1, 2, 3, 4), 2).map(x => (x, 1)).groupByKey() + val shuffleHandle = + rdd1.dependencies.head.asInstanceOf[ShuffleDependency[_, _, _]].shuffleHandle + rdd1.map { + case (x, _) if (x == 1) && FailThisAttempt._fail.getAndSet(false) => + throw new FetchFailedException( + BlockManagerId("1", "1", 1), shuffleHandle.shuffleId, 0, 0, "test") + } + } + + failAfter(10.seconds) { + val e = intercept[SparkException] { + runJobWithPersistentFetchFailure + } + assert(e.getMessage.contains("org.apache.spark.shuffle.FetchFailedException")) + } + + // Run a second job that will fail due to a fetch failure. + // This job will hang without the fix for SPARK-17644. + failAfter(10.seconds) { + val e = intercept[SparkException] { + runJobWithPersistentFetchFailure + } + assert(e.getMessage.contains("org.apache.spark.shuffle.FetchFailedException")) + } + + failAfter(10.seconds) { + try { + runJobWithTemporaryFetchFailure + } catch { + case e: Throwable => fail("A job with one fetch failure should eventually succeed") + } + } + } + /** * Assert that the supplied TaskSet has exactly the given hosts as its preferred locations. * Note that this checks only the host and not the executor ID. diff --git a/core/src/test/scala/org/apache/spark/scheduler/EventLoggingListenerSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/EventLoggingListenerSuite.scala index c4c80b5b57da..7f4859206e25 100644 --- a/core/src/test/scala/org/apache/spark/scheduler/EventLoggingListenerSuite.scala +++ b/core/src/test/scala/org/apache/spark/scheduler/EventLoggingListenerSuite.scala @@ -142,14 +142,14 @@ class EventLoggingListenerSuite extends SparkFunSuite with LocalSparkContext wit extraConf.foreach { case (k, v) => conf.set(k, v) } val logName = compressionCodec.map("test-" + _).getOrElse("test") val eventLogger = new EventLoggingListener(logName, None, testDirPath.toUri(), conf) - val listenerBus = new LiveListenerBus + val listenerBus = new LiveListenerBus(sc) val applicationStart = SparkListenerApplicationStart("Greatest App (N)ever", None, 125L, "Mickey", None) val applicationEnd = SparkListenerApplicationEnd(1000L) // A comprehensive test on JSON de/serialization of all events is in JsonProtocolSuite eventLogger.start() - listenerBus.start(sc) + listenerBus.start() listenerBus.addListener(eventLogger) listenerBus.postToAll(applicationStart) listenerBus.postToAll(applicationEnd) diff --git a/core/src/test/scala/org/apache/spark/scheduler/SparkListenerSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/SparkListenerSuite.scala index 5ba67afc0cd6..e8a88d4909a8 100644 --- a/core/src/test/scala/org/apache/spark/scheduler/SparkListenerSuite.scala +++ b/core/src/test/scala/org/apache/spark/scheduler/SparkListenerSuite.scala @@ -37,13 +37,13 @@ class SparkListenerSuite extends SparkFunSuite with LocalSparkContext with Match val jobCompletionTime = 1421191296660L test("don't call sc.stop in listener") { - sc = new SparkContext("local", "SparkListenerSuite") + sc = new SparkContext("local", "SparkListenerSuite", new SparkConf()) val listener = new SparkContextStoppingListener(sc) - val bus = new LiveListenerBus + val bus = new LiveListenerBus(sc) bus.addListener(listener) // Starting listener bus should flush all buffered events - bus.start(sc) + bus.start() bus.post(SparkListenerJobEnd(0, jobCompletionTime, JobSucceeded)) bus.waitUntilEmpty(WAIT_TIMEOUT_MILLIS) @@ -52,8 +52,9 @@ class SparkListenerSuite extends SparkFunSuite with LocalSparkContext with Match } test("basic creation and shutdown of LiveListenerBus") { + sc = new SparkContext("local", "SparkListenerSuite", new SparkConf()) val counter = new BasicJobCounter - val bus = new LiveListenerBus + val bus = new LiveListenerBus(sc) bus.addListener(counter) // Listener bus hasn't started yet, so posting events should not increment counter @@ -61,7 +62,7 @@ class SparkListenerSuite extends SparkFunSuite with LocalSparkContext with Match assert(counter.count === 0) // Starting listener bus should flush all buffered events - bus.start(sc) + bus.start() bus.waitUntilEmpty(WAIT_TIMEOUT_MILLIS) assert(counter.count === 5) @@ -72,14 +73,14 @@ class SparkListenerSuite extends SparkFunSuite with LocalSparkContext with Match // Listener bus must not be started twice intercept[IllegalStateException] { - val bus = new LiveListenerBus - bus.start(sc) - bus.start(sc) + val bus = new LiveListenerBus(sc) + bus.start() + bus.start() } // ... or stopped before starting intercept[IllegalStateException] { - val bus = new LiveListenerBus + val bus = new LiveListenerBus(sc) bus.stop() } } @@ -106,12 +107,12 @@ class SparkListenerSuite extends SparkFunSuite with LocalSparkContext with Match drained = true } } - - val bus = new LiveListenerBus + sc = new SparkContext("local", "SparkListenerSuite", new SparkConf()) + val bus = new LiveListenerBus(sc) val blockingListener = new BlockingListener bus.addListener(blockingListener) - bus.start(sc) + bus.start() bus.post(SparkListenerJobEnd(0, jobCompletionTime, JobSucceeded)) listenerStarted.acquire() @@ -353,13 +354,14 @@ class SparkListenerSuite extends SparkFunSuite with LocalSparkContext with Match val badListener = new BadListener val jobCounter1 = new BasicJobCounter val jobCounter2 = new BasicJobCounter - val bus = new LiveListenerBus + sc = new SparkContext("local", "SparkListenerSuite", new SparkConf()) + val bus = new LiveListenerBus(sc) // Propagate events to bad listener first bus.addListener(badListener) bus.addListener(jobCounter1) bus.addListener(jobCounter2) - bus.start(sc) + bus.start() // Post events to all listeners, and wait until the queue is drained (1 to 5).foreach { _ => bus.post(SparkListenerJobEnd(0, jobCompletionTime, JobSucceeded)) } diff --git a/core/src/test/scala/org/apache/spark/scheduler/TaskSetManagerSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/TaskSetManagerSuite.scala index 862313326c93..b98f945bac25 100644 --- a/core/src/test/scala/org/apache/spark/scheduler/TaskSetManagerSuite.scala +++ b/core/src/test/scala/org/apache/spark/scheduler/TaskSetManagerSuite.scala @@ -44,7 +44,7 @@ class FakeDAGScheduler(sc: SparkContext, taskScheduler: FakeTaskScheduler) override def executorAdded(execId: String, host: String) {} - override def executorLost(execId: String) {} + override def executorLost(execId: String, reason: ExecutorLossReason) {} override def taskSetFailed( taskSet: TaskSet, diff --git a/core/src/test/scala/org/apache/spark/scheduler/cluster/mesos/MesosCoarseGrainedSchedulerBackendSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/cluster/mesos/MesosCoarseGrainedSchedulerBackendSuite.scala index 7f21d4c623af..3ffbe70e76bb 100644 --- a/core/src/test/scala/org/apache/spark/scheduler/cluster/mesos/MesosCoarseGrainedSchedulerBackendSuite.scala +++ b/core/src/test/scala/org/apache/spark/scheduler/cluster/mesos/MesosCoarseGrainedSchedulerBackendSuite.scala @@ -18,9 +18,13 @@ package org.apache.spark.scheduler.cluster.mesos import java.util.Collections +import java.util.concurrent.TimeUnit import scala.collection.JavaConverters._ import scala.collection.mutable.ArrayBuffer +import scala.concurrent.duration._ +import scala.concurrent.Promise +import scala.reflect.ClassTag import org.apache.mesos.{Protos, Scheduler, SchedulerDriver} import org.apache.mesos.Protos._ @@ -28,18 +32,21 @@ import org.apache.mesos.Protos.Value.Scalar import org.mockito.{ArgumentCaptor, Matchers} import org.mockito.Matchers._ import org.mockito.Mockito._ +import org.scalatest.concurrent.ScalaFutures import org.scalatest.mock.MockitoSugar import org.scalatest.BeforeAndAfter import org.apache.spark.{LocalSparkContext, SecurityManager, SparkConf, SparkContext, SparkFunSuite} import org.apache.spark.network.shuffle.mesos.MesosExternalShuffleClient import org.apache.spark.rpc.RpcEndpointRef +import org.apache.spark.scheduler.cluster.CoarseGrainedClusterMessages.RemoveExecutor import org.apache.spark.scheduler.TaskSchedulerImpl class MesosCoarseGrainedSchedulerBackendSuite extends SparkFunSuite with LocalSparkContext with MockitoSugar - with BeforeAndAfter { + with BeforeAndAfter + with ScalaFutures { private var sparkConf: SparkConf = _ private var driver: SchedulerDriver = _ @@ -47,6 +54,11 @@ class MesosCoarseGrainedSchedulerBackendSuite extends SparkFunSuite private var backend: MesosCoarseGrainedSchedulerBackend = _ private var externalShuffleClient: MesosExternalShuffleClient = _ private var driverEndpoint: RpcEndpointRef = _ + @volatile private var stopCalled = false + + // All 'requests' to the scheduler run immediately on the same thread, so + // demand that all futures have their value available immediately. + implicit override val patienceConfig = PatienceConfig(timeout = Duration(0, TimeUnit.SECONDS)) test("mesos supports killing and limiting executors") { setBackend() @@ -62,8 +74,8 @@ class MesosCoarseGrainedSchedulerBackendSuite extends SparkFunSuite verifyTaskLaunched("o1") // kills executors - backend.doRequestTotalExecutors(0) - assert(backend.doKillExecutors(Seq("0"))) + assert(backend.doRequestTotalExecutors(0).futureValue) + assert(backend.doKillExecutors(Seq("0")).futureValue) val taskID0 = createTaskId("0") verify(driver, times(1)).killTask(taskID0) @@ -252,6 +264,32 @@ class MesosCoarseGrainedSchedulerBackendSuite extends SparkFunSuite backend.start() } + test("Do not call removeExecutor() after backend is stopped") { + setBackend() + + // launches a task on a valid offer + val offers = List((backend.executorMemory(sc), 1)) + offerResources(offers) + verifyTaskLaunched("o1") + + // launches a thread simulating status update + val statusUpdateThread = new Thread { + override def run(): Unit = { + while (!stopCalled) { + Thread.sleep(100) + } + + val status = createTaskStatus("0", "s1", TaskState.TASK_FINISHED) + backend.statusUpdate(driver, status) + } + }.start + + backend.stop() + // Any method of the backend involving sending messages to the driver endpoint should not + // be called after the backend is stopped. + verify(driverEndpoint, never()).askWithRetry(isA(classOf[RemoveExecutor]))(any[ClassTag[_]]) + } + private def verifyDeclinedOffer(driver: SchedulerDriver, offerId: OfferID, filter: Boolean = false): Unit = { @@ -350,6 +388,10 @@ class MesosCoarseGrainedSchedulerBackendSuite extends SparkFunSuite mesosDriver = newDriver } + override def stopExecutors(): Unit = { + stopCalled = true + } + markRegistered() } backend.start() @@ -377,6 +419,7 @@ class MesosCoarseGrainedSchedulerBackendSuite extends SparkFunSuite when(taskScheduler.sc).thenReturn(sc) externalShuffleClient = mock[MesosExternalShuffleClient] driverEndpoint = mock[RpcEndpointRef] + when(driverEndpoint.ask(any())(any())).thenReturn(Promise().future) backend = createSchedulerBackend(taskScheduler, driver, externalShuffleClient, driverEndpoint) } diff --git a/core/src/test/scala/org/apache/spark/status/api/v1/AllStagesResourceSuite.scala b/core/src/test/scala/org/apache/spark/status/api/v1/AllStagesResourceSuite.scala index f684e16c25f7..1bfb0c1547ec 100644 --- a/core/src/test/scala/org/apache/spark/status/api/v1/AllStagesResourceSuite.scala +++ b/core/src/test/scala/org/apache/spark/status/api/v1/AllStagesResourceSuite.scala @@ -19,7 +19,7 @@ package org.apache.spark.status.api.v1 import java.util.Date -import scala.collection.mutable.HashMap +import scala.collection.mutable.LinkedHashMap import org.apache.spark.SparkFunSuite import org.apache.spark.scheduler.{StageInfo, TaskInfo, TaskLocality} @@ -28,7 +28,7 @@ import org.apache.spark.ui.jobs.UIData.{StageUIData, TaskUIData} class AllStagesResourceSuite extends SparkFunSuite { def getFirstTaskLaunchTime(taskLaunchTimes: Seq[Long]): Option[Date] = { - val tasks = new HashMap[Long, TaskUIData] + val tasks = new LinkedHashMap[Long, TaskUIData] taskLaunchTimes.zipWithIndex.foreach { case (time, idx) => tasks(idx.toLong) = TaskUIData( new TaskInfo(idx, idx, 1, time, "", "", TaskLocality.ANY, false), None) diff --git a/core/src/test/scala/org/apache/spark/storage/BlockManagerReplicationSuite.scala b/core/src/test/scala/org/apache/spark/storage/BlockManagerReplicationSuite.scala index 31687e614731..b9e3a364ee22 100644 --- a/core/src/test/scala/org/apache/spark/storage/BlockManagerReplicationSuite.scala +++ b/core/src/test/scala/org/apache/spark/storage/BlockManagerReplicationSuite.scala @@ -38,7 +38,10 @@ import org.apache.spark.shuffle.sort.SortShuffleManager import org.apache.spark.storage.StorageLevel._ /** Testsuite that tests block replication in BlockManager */ -class BlockManagerReplicationSuite extends SparkFunSuite with Matchers with BeforeAndAfter { +class BlockManagerReplicationSuite extends SparkFunSuite + with Matchers + with BeforeAndAfter + with LocalSparkContext { private val conf = new SparkConf(false).set("spark.app.id", "test") private var rpcEnv: RpcEnv = null @@ -91,8 +94,10 @@ class BlockManagerReplicationSuite extends SparkFunSuite with Matchers with Befo // to make cached peers refresh frequently conf.set("spark.storage.cachedPeersTtl", "10") + sc = new SparkContext("local", "test", conf) master = new BlockManagerMaster(rpcEnv.setupEndpoint("blockmanager", - new BlockManagerMasterEndpoint(rpcEnv, true, conf, new LiveListenerBus)), conf, true) + new BlockManagerMasterEndpoint(rpcEnv, true, conf, + new LiveListenerBus(sc))), conf, true) allStores.clear() } diff --git a/core/src/test/scala/org/apache/spark/storage/BlockManagerSuite.scala b/core/src/test/scala/org/apache/spark/storage/BlockManagerSuite.scala index 6821582254f5..e93eee273f16 100644 --- a/core/src/test/scala/org/apache/spark/storage/BlockManagerSuite.scala +++ b/core/src/test/scala/org/apache/spark/storage/BlockManagerSuite.scala @@ -49,7 +49,7 @@ import org.apache.spark.util._ import org.apache.spark.util.io.ChunkedByteBuffer class BlockManagerSuite extends SparkFunSuite with Matchers with BeforeAndAfterEach - with PrivateMethodTester with ResetSystemProperties { + with PrivateMethodTester with LocalSparkContext with ResetSystemProperties { import BlockManagerSuite._ @@ -107,8 +107,10 @@ class BlockManagerSuite extends SparkFunSuite with Matchers with BeforeAndAfterE rpcEnv = RpcEnv.create("test", "localhost", 0, conf, securityMgr) conf.set("spark.driver.port", rpcEnv.address.port.toString) + sc = new SparkContext("local", "test", conf) master = new BlockManagerMaster(rpcEnv.setupEndpoint("blockmanager", - new BlockManagerMasterEndpoint(rpcEnv, true, conf, new LiveListenerBus)), conf, true) + new BlockManagerMasterEndpoint(rpcEnv, true, conf, + new LiveListenerBus(sc))), conf, true) val initialize = PrivateMethod[Unit]('initialize) SizeEstimator invokePrivate initialize() @@ -511,10 +513,8 @@ class BlockManagerSuite extends SparkFunSuite with Matchers with BeforeAndAfterE assert(store.getRemoteBytes("list1").isDefined, "list1Get expected to be fetched") store3.stop() store3 = null - // exception throw because there is no locations - intercept[BlockFetchException] { - store.getRemoteBytes("list1") - } + // Should return None instead of throwing an exception: + assert(store.getRemoteBytes("list1").isEmpty) } test("SPARK-14252: getOrElseUpdate should still read from remote storage") { @@ -861,6 +861,7 @@ class BlockManagerSuite extends SparkFunSuite with Matchers with BeforeAndAfterE serializerManager, conf, memoryManager, mapOutputTracker, shuffleManager, transfer, securityMgr, 0) memoryManager.setMemoryStore(store.memoryStore) + store.initialize("app-id") // The put should fail since a1 is not serializable. class UnserializableClass @@ -1184,9 +1185,7 @@ class BlockManagerSuite extends SparkFunSuite with Matchers with BeforeAndAfterE new MockBlockTransferService(conf.getInt("spark.block.failures.beforeLocationRefresh", 5)) store = makeBlockManager(8000, "executor1", transferService = Option(mockBlockTransferService)) store.putSingle("item", 999L, StorageLevel.MEMORY_ONLY, tellMaster = true) - intercept[BlockFetchException] { - store.getRemoteBytes("item") - } + assert(store.getRemoteBytes("item").isEmpty) } test("SPARK-13328: refresh block locations (fetch should succeed after location refresh)") { @@ -1208,6 +1207,39 @@ class BlockManagerSuite extends SparkFunSuite with Matchers with BeforeAndAfterE verify(mockBlockManagerMaster, times(2)).getLocations("item") } + test("SPARK-17484: block status is properly updated following an exception in put()") { + val mockBlockTransferService = new MockBlockTransferService(maxFailures = 10) { + override def uploadBlock( + hostname: String, + port: Int, execId: String, + blockId: BlockId, + blockData: ManagedBuffer, + level: StorageLevel, + classTag: ClassTag[_]): Future[Unit] = { + throw new InterruptedException("Intentional interrupt") + } + } + store = makeBlockManager(8000, "executor1", transferService = Option(mockBlockTransferService)) + store2 = makeBlockManager(8000, "executor2", transferService = Option(mockBlockTransferService)) + intercept[InterruptedException] { + store.putSingle("item", "value", StorageLevel.MEMORY_ONLY_2, tellMaster = true) + } + assert(store.getLocalBytes("item").isEmpty) + assert(master.getLocations("item").isEmpty) + assert(store2.getRemoteBytes("item").isEmpty) + } + + test("SPARK-17484: master block locations are updated following an invalid remote block fetch") { + store = makeBlockManager(8000, "executor1") + store2 = makeBlockManager(8000, "executor2") + store.putSingle("item", "value", StorageLevel.MEMORY_ONLY, tellMaster = true) + assert(master.getLocations("item").nonEmpty) + store.removeBlock("item", tellMaster = false) + assert(master.getLocations("item").nonEmpty) + assert(store2.getRemoteBytes("item").isEmpty) + assert(master.getLocations("item").isEmpty) + } + class MockBlockTransferService(val maxFailures: Int) extends BlockTransferService { var numCalls = 0 diff --git a/core/src/test/scala/org/apache/spark/storage/MemoryStoreSuite.scala b/core/src/test/scala/org/apache/spark/storage/MemoryStoreSuite.scala index 145d432afe85..9e10ee560148 100644 --- a/core/src/test/scala/org/apache/spark/storage/MemoryStoreSuite.scala +++ b/core/src/test/scala/org/apache/spark/storage/MemoryStoreSuite.scala @@ -80,6 +80,13 @@ class MemoryStoreSuite (memoryStore, blockInfoManager) } + private def assertSameContents[T](expected: Seq[T], actual: Seq[T], hint: String): Unit = { + assert(actual.length === expected.length, s"wrong number of values returned in $hint") + expected.iterator.zip(actual.iterator).foreach { case (e, a) => + assert(e === a, s"$hint did not return original values!") + } + } + test("reserve/release unroll memory") { val (memoryStore, _) = makeMemoryStore(12000) assert(memoryStore.currentUnrollMemory === 0) @@ -138,9 +145,7 @@ class MemoryStoreSuite var putResult = putIteratorAsValues("unroll", smallList.iterator, ClassTag.Any) assert(putResult.isRight) assert(memoryStore.currentUnrollMemoryForThisTask === 0) - smallList.iterator.zip(memoryStore.getValues("unroll").get).foreach { case (e, a) => - assert(e === a, "getValues() did not return original values!") - } + assertSameContents(smallList, memoryStore.getValues("unroll").get.toSeq, "getValues") blockInfoManager.lockForWriting("unroll") assert(memoryStore.remove("unroll")) blockInfoManager.removeBlock("unroll") @@ -153,9 +158,7 @@ class MemoryStoreSuite assert(memoryStore.currentUnrollMemoryForThisTask === 0) assert(memoryStore.contains("someBlock2")) assert(!memoryStore.contains("someBlock1")) - smallList.iterator.zip(memoryStore.getValues("unroll").get).foreach { case (e, a) => - assert(e === a, "getValues() did not return original values!") - } + assertSameContents(smallList, memoryStore.getValues("unroll").get.toSeq, "getValues") blockInfoManager.lockForWriting("unroll") assert(memoryStore.remove("unroll")) blockInfoManager.removeBlock("unroll") @@ -168,9 +171,7 @@ class MemoryStoreSuite assert(memoryStore.currentUnrollMemoryForThisTask > 0) // we returned an iterator assert(!memoryStore.contains("someBlock2")) assert(putResult.isLeft) - bigList.iterator.zip(putResult.left.get).foreach { case (e, a) => - assert(e === a, "putIterator() did not return original values!") - } + assertSameContents(bigList, putResult.left.get.toSeq, "putIterator") // The unroll memory was freed once the iterator returned by putIterator() was fully traversed. assert(memoryStore.currentUnrollMemoryForThisTask === 0) } @@ -317,9 +318,8 @@ class MemoryStoreSuite assert(res.isLeft) assert(memoryStore.currentUnrollMemoryForThisTask > 0) val valuesReturnedFromFailedPut = res.left.get.valuesIterator.toSeq // force materialization - valuesReturnedFromFailedPut.zip(bigList).foreach { case (e, a) => - assert(e === a, "PartiallySerializedBlock.valuesIterator() did not return original values!") - } + assertSameContents( + bigList, valuesReturnedFromFailedPut, "PartiallySerializedBlock.valuesIterator()") // The unroll memory was freed once the iterator was fully traversed. assert(memoryStore.currentUnrollMemoryForThisTask === 0) } @@ -341,12 +341,10 @@ class MemoryStoreSuite res.left.get.finishWritingToStream(bos) // The unroll memory was freed once the block was fully written. assert(memoryStore.currentUnrollMemoryForThisTask === 0) - val deserializationStream = serializerManager.dataDeserializeStream[Any]( - "b1", new ByteBufferInputStream(bos.toByteBuffer))(ClassTag.Any) - deserializationStream.zip(bigList.iterator).foreach { case (e, a) => - assert(e === a, - "PartiallySerializedBlock.finishWritingtoStream() did not write original values!") - } + val deserializedValues = serializerManager.dataDeserializeStream[Any]( + "b1", new ByteBufferInputStream(bos.toByteBuffer))(ClassTag.Any).toSeq + assertSameContents( + bigList, deserializedValues, "PartiallySerializedBlock.finishWritingToStream()") } test("multiple unrolls by the same thread") { diff --git a/core/src/test/scala/org/apache/spark/storage/PartiallySerializedBlockSuite.scala b/core/src/test/scala/org/apache/spark/storage/PartiallySerializedBlockSuite.scala new file mode 100644 index 000000000000..ec4f2637fadd --- /dev/null +++ b/core/src/test/scala/org/apache/spark/storage/PartiallySerializedBlockSuite.scala @@ -0,0 +1,215 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.storage + +import java.nio.ByteBuffer + +import scala.reflect.ClassTag + +import org.mockito.Mockito +import org.mockito.Mockito.atLeastOnce +import org.mockito.invocation.InvocationOnMock +import org.mockito.stubbing.Answer +import org.scalatest.{BeforeAndAfterEach, PrivateMethodTester} + +import org.apache.spark.{SparkConf, SparkFunSuite, TaskContext, TaskContextImpl} +import org.apache.spark.memory.MemoryMode +import org.apache.spark.serializer.{JavaSerializer, SerializationStream, SerializerManager} +import org.apache.spark.storage.memory.{MemoryStore, PartiallySerializedBlock, RedirectableOutputStream} +import org.apache.spark.util.{ByteBufferInputStream, ByteBufferOutputStream} +import org.apache.spark.util.io.{ChunkedByteBuffer, ChunkedByteBufferOutputStream} + +class PartiallySerializedBlockSuite + extends SparkFunSuite + with BeforeAndAfterEach + with PrivateMethodTester { + + private val blockId = new TestBlockId("test") + private val conf = new SparkConf() + private val memoryStore = Mockito.mock(classOf[MemoryStore], Mockito.RETURNS_SMART_NULLS) + private val serializerManager = new SerializerManager(new JavaSerializer(conf), conf) + + private val getSerializationStream = PrivateMethod[SerializationStream]('serializationStream) + private val getRedirectableOutputStream = + PrivateMethod[RedirectableOutputStream]('redirectableOutputStream) + + override protected def beforeEach(): Unit = { + super.beforeEach() + Mockito.reset(memoryStore) + } + + private def partiallyUnroll[T: ClassTag]( + iter: Iterator[T], + numItemsToBuffer: Int): PartiallySerializedBlock[T] = { + + val bbos: ChunkedByteBufferOutputStream = { + val spy = Mockito.spy(new ChunkedByteBufferOutputStream(128, ByteBuffer.allocate)) + Mockito.doAnswer(new Answer[ChunkedByteBuffer] { + override def answer(invocationOnMock: InvocationOnMock): ChunkedByteBuffer = { + Mockito.spy(invocationOnMock.callRealMethod().asInstanceOf[ChunkedByteBuffer]) + } + }).when(spy).toChunkedByteBuffer + spy + } + + val serializer = serializerManager.getSerializer(implicitly[ClassTag[T]]).newInstance() + val redirectableOutputStream = Mockito.spy(new RedirectableOutputStream) + redirectableOutputStream.setOutputStream(bbos) + val serializationStream = Mockito.spy(serializer.serializeStream(redirectableOutputStream)) + + (1 to numItemsToBuffer).foreach { _ => + assert(iter.hasNext) + serializationStream.writeObject[T](iter.next()) + } + + val unrollMemory = bbos.size + new PartiallySerializedBlock[T]( + memoryStore, + serializerManager, + blockId, + serializationStream = serializationStream, + redirectableOutputStream, + unrollMemory = unrollMemory, + memoryMode = MemoryMode.ON_HEAP, + bbos, + rest = iter, + classTag = implicitly[ClassTag[T]]) + } + + test("valuesIterator() and finishWritingToStream() cannot be called after discard() is called") { + val partiallySerializedBlock = partiallyUnroll((1 to 10).iterator, 2) + partiallySerializedBlock.discard() + intercept[IllegalStateException] { + partiallySerializedBlock.finishWritingToStream(null) + } + intercept[IllegalStateException] { + partiallySerializedBlock.valuesIterator + } + } + + test("discard() can be called more than once") { + val partiallySerializedBlock = partiallyUnroll((1 to 10).iterator, 2) + partiallySerializedBlock.discard() + partiallySerializedBlock.discard() + } + + test("cannot call valuesIterator() more than once") { + val partiallySerializedBlock = partiallyUnroll((1 to 10).iterator, 2) + partiallySerializedBlock.valuesIterator + intercept[IllegalStateException] { + partiallySerializedBlock.valuesIterator + } + } + + test("cannot call finishWritingToStream() more than once") { + val partiallySerializedBlock = partiallyUnroll((1 to 10).iterator, 2) + partiallySerializedBlock.finishWritingToStream(new ByteBufferOutputStream()) + intercept[IllegalStateException] { + partiallySerializedBlock.finishWritingToStream(new ByteBufferOutputStream()) + } + } + + test("cannot call finishWritingToStream() after valuesIterator()") { + val partiallySerializedBlock = partiallyUnroll((1 to 10).iterator, 2) + partiallySerializedBlock.valuesIterator + intercept[IllegalStateException] { + partiallySerializedBlock.finishWritingToStream(new ByteBufferOutputStream()) + } + } + + test("cannot call valuesIterator() after finishWritingToStream()") { + val partiallySerializedBlock = partiallyUnroll((1 to 10).iterator, 2) + partiallySerializedBlock.finishWritingToStream(new ByteBufferOutputStream()) + intercept[IllegalStateException] { + partiallySerializedBlock.valuesIterator + } + } + + test("buffers are deallocated in a TaskCompletionListener") { + try { + TaskContext.setTaskContext(TaskContext.empty()) + val partiallySerializedBlock = partiallyUnroll((1 to 10).iterator, 2) + TaskContext.get().asInstanceOf[TaskContextImpl].markTaskCompleted() + Mockito.verify(partiallySerializedBlock.getUnrolledChunkedByteBuffer).dispose() + Mockito.verifyNoMoreInteractions(memoryStore) + } finally { + TaskContext.unset() + } + } + + private def testUnroll[T: ClassTag]( + testCaseName: String, + items: Seq[T], + numItemsToBuffer: Int): Unit = { + + test(s"$testCaseName with discard() and numBuffered = $numItemsToBuffer") { + val partiallySerializedBlock = partiallyUnroll(items.iterator, numItemsToBuffer) + partiallySerializedBlock.discard() + + Mockito.verify(memoryStore).releaseUnrollMemoryForThisTask( + MemoryMode.ON_HEAP, partiallySerializedBlock.unrollMemory) + Mockito.verify(partiallySerializedBlock.invokePrivate(getSerializationStream())).close() + Mockito.verify(partiallySerializedBlock.invokePrivate(getRedirectableOutputStream())).close() + Mockito.verifyNoMoreInteractions(memoryStore) + Mockito.verify(partiallySerializedBlock.getUnrolledChunkedByteBuffer, atLeastOnce).dispose() + } + + test(s"$testCaseName with finishWritingToStream() and numBuffered = $numItemsToBuffer") { + val partiallySerializedBlock = partiallyUnroll(items.iterator, numItemsToBuffer) + val bbos = Mockito.spy(new ByteBufferOutputStream()) + partiallySerializedBlock.finishWritingToStream(bbos) + + Mockito.verify(memoryStore).releaseUnrollMemoryForThisTask( + MemoryMode.ON_HEAP, partiallySerializedBlock.unrollMemory) + Mockito.verify(partiallySerializedBlock.invokePrivate(getSerializationStream())).close() + Mockito.verify(partiallySerializedBlock.invokePrivate(getRedirectableOutputStream())).close() + Mockito.verify(bbos).close() + Mockito.verifyNoMoreInteractions(memoryStore) + Mockito.verify(partiallySerializedBlock.getUnrolledChunkedByteBuffer, atLeastOnce).dispose() + + val serializer = serializerManager.getSerializer(implicitly[ClassTag[T]]).newInstance() + val deserialized = + serializer.deserializeStream(new ByteBufferInputStream(bbos.toByteBuffer)).asIterator.toSeq + assert(deserialized === items) + } + + test(s"$testCaseName with valuesIterator() and numBuffered = $numItemsToBuffer") { + val partiallySerializedBlock = partiallyUnroll(items.iterator, numItemsToBuffer) + val valuesIterator = partiallySerializedBlock.valuesIterator + Mockito.verify(partiallySerializedBlock.invokePrivate(getSerializationStream())).close() + Mockito.verify(partiallySerializedBlock.invokePrivate(getRedirectableOutputStream())).close() + + val deserializedItems = valuesIterator.toArray.toSeq + Mockito.verify(memoryStore).releaseUnrollMemoryForThisTask( + MemoryMode.ON_HEAP, partiallySerializedBlock.unrollMemory) + Mockito.verifyNoMoreInteractions(memoryStore) + Mockito.verify(partiallySerializedBlock.getUnrolledChunkedByteBuffer, atLeastOnce).dispose() + assert(deserializedItems === items) + } + } + + testUnroll("basic numbers", 1 to 1000, numItemsToBuffer = 50) + testUnroll("basic numbers", 1 to 1000, numItemsToBuffer = 0) + testUnroll("basic numbers", 1 to 1000, numItemsToBuffer = 1000) + testUnroll("case classes", (1 to 1000).map(x => MyCaseClass(x.toString)), numItemsToBuffer = 50) + testUnroll("case classes", (1 to 1000).map(x => MyCaseClass(x.toString)), numItemsToBuffer = 0) + testUnroll("case classes", (1 to 1000).map(x => MyCaseClass(x.toString)), numItemsToBuffer = 1000) + testUnroll("empty iterator", Seq.empty[String], numItemsToBuffer = 0) +} + +private case class MyCaseClass(str: String) diff --git a/core/src/test/scala/org/apache/spark/storage/PartiallyUnrolledIteratorSuite.scala b/core/src/test/scala/org/apache/spark/storage/PartiallyUnrolledIteratorSuite.scala new file mode 100644 index 000000000000..4253cc8ca4cd --- /dev/null +++ b/core/src/test/scala/org/apache/spark/storage/PartiallyUnrolledIteratorSuite.scala @@ -0,0 +1,61 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.storage + +import org.mockito.Matchers +import org.mockito.Mockito._ +import org.scalatest.mock.MockitoSugar + +import org.apache.spark.SparkFunSuite +import org.apache.spark.memory.MemoryMode.ON_HEAP +import org.apache.spark.storage.memory.{MemoryStore, PartiallyUnrolledIterator} + +class PartiallyUnrolledIteratorSuite extends SparkFunSuite with MockitoSugar { + test("join two iterators") { + val unrollSize = 1000 + val unroll = (0 until unrollSize).iterator + val restSize = 500 + val rest = (unrollSize until restSize + unrollSize).iterator + + val memoryStore = mock[MemoryStore] + val joinIterator = new PartiallyUnrolledIterator(memoryStore, ON_HEAP, unrollSize, unroll, rest) + + // Firstly iterate over unrolling memory iterator + (0 until unrollSize).foreach { value => + assert(joinIterator.hasNext) + assert(joinIterator.hasNext) + assert(joinIterator.next() == value) + } + + joinIterator.hasNext + joinIterator.hasNext + verify(memoryStore, times(1)) + .releaseUnrollMemoryForThisTask(Matchers.eq(ON_HEAP), Matchers.eq(unrollSize.toLong)) + + // Secondly, iterate over rest iterator + (unrollSize until unrollSize + restSize).foreach { value => + assert(joinIterator.hasNext) + assert(joinIterator.hasNext) + assert(joinIterator.next() == value) + } + + joinIterator.close() + // MemoryMode.releaseUnrollMemoryForThisTask is called only once + verifyNoMoreInteractions(memoryStore) + } +} diff --git a/core/src/test/scala/org/apache/spark/ui/storage/StorageTabSuite.scala b/core/src/test/scala/org/apache/spark/ui/storage/StorageTabSuite.scala index 411a0ddebeb7..f6c8418ba3ac 100644 --- a/core/src/test/scala/org/apache/spark/ui/storage/StorageTabSuite.scala +++ b/core/src/test/scala/org/apache/spark/ui/storage/StorageTabSuite.scala @@ -19,15 +19,14 @@ package org.apache.spark.ui.storage import org.scalatest.BeforeAndAfter -import org.apache.spark.{SparkConf, SparkFunSuite, Success} -import org.apache.spark.executor.TaskMetrics +import org.apache.spark._ import org.apache.spark.scheduler._ import org.apache.spark.storage._ /** * Test various functionality in the StorageListener that supports the StorageTab. */ -class StorageTabSuite extends SparkFunSuite with BeforeAndAfter { +class StorageTabSuite extends SparkFunSuite with LocalSparkContext with BeforeAndAfter { private var bus: LiveListenerBus = _ private var storageStatusListener: StorageStatusListener = _ private var storageListener: StorageListener = _ @@ -43,8 +42,10 @@ class StorageTabSuite extends SparkFunSuite with BeforeAndAfter { private val bm1 = BlockManagerId("big", "dog", 1) before { - bus = new LiveListenerBus - storageStatusListener = new StorageStatusListener(new SparkConf()) + val conf = new SparkConf() + sc = new SparkContext("local", "test", conf) + bus = new LiveListenerBus(sc) + storageStatusListener = new StorageStatusListener(conf) storageListener = new StorageListener(storageStatusListener) bus.addListener(storageStatusListener) bus.addListener(storageListener) diff --git a/core/src/test/scala/org/apache/spark/util/JsonProtocolSuite.scala b/core/src/test/scala/org/apache/spark/util/JsonProtocolSuite.scala index 7a1ee03e4ce5..cda345752393 100644 --- a/core/src/test/scala/org/apache/spark/util/JsonProtocolSuite.scala +++ b/core/src/test/scala/org/apache/spark/util/JsonProtocolSuite.scala @@ -19,6 +19,7 @@ package org.apache.spark.util import java.util.Properties +import scala.collection.JavaConverters._ import scala.collection.Map import org.json4s.jackson.JsonMethods._ @@ -415,7 +416,7 @@ class JsonProtocolSuite extends SparkFunSuite { }) testAccumValue(Some(RESULT_SIZE), 3L, JInt(3)) testAccumValue(Some(shuffleRead.REMOTE_BLOCKS_FETCHED), 2, JInt(2)) - testAccumValue(Some(UPDATED_BLOCK_STATUSES), blocks, blocksJson) + testAccumValue(Some(UPDATED_BLOCK_STATUSES), blocks.asJava, blocksJson) // For anything else, we just cast the value to a string testAccumValue(Some("anything"), blocks, JString(blocks.toString)) testAccumValue(Some("anything"), 123, JString("123")) diff --git a/core/src/test/scala/org/apache/spark/util/UtilsSuite.scala b/core/src/test/scala/org/apache/spark/util/UtilsSuite.scala index 30952a945834..4715fd29375d 100644 --- a/core/src/test/scala/org/apache/spark/util/UtilsSuite.scala +++ b/core/src/test/scala/org/apache/spark/util/UtilsSuite.scala @@ -31,6 +31,7 @@ import scala.util.Random import com.google.common.io.Files import org.apache.commons.lang3.SystemUtils +import org.apache.commons.math3.stat.inference.ChiSquareTest import org.apache.hadoop.conf.Configuration import org.apache.hadoop.fs.Path @@ -874,4 +875,38 @@ class UtilsSuite extends SparkFunSuite with ResetSystemProperties with Logging { } } } + + test("chi square test of randomizeInPlace") { + // Parameters + val arraySize = 10 + val numTrials = 1000 + val threshold = 0.05 + val seed = 1L + + // results(i)(j): how many times Utils.randomize moves an element from position j to position i + val results = Array.ofDim[Long](arraySize, arraySize) + + // This must be seeded because even a fair random process will fail this test with + // probability equal to the value of `threshold`, which is inconvenient for a unit test. + val rand = new java.util.Random(seed) + val range = 0 until arraySize + + for { + _ <- 0 until numTrials + trial = Utils.randomizeInPlace(range.toArray, rand) + i <- range + } results(i)(trial(i)) += 1L + + val chi = new ChiSquareTest() + + // We expect an even distribution; this array will be rescaled by `chiSquareTest` + val expected = Array.fill(arraySize * arraySize)(1.0) + val observed = results.flatten + + // Performs Pearson's chi-squared test. Using the sum-of-squares as the test statistic, gives + // the probability of a uniform distribution producing results as extreme as `observed` + val pValue = chi.chiSquareTest(expected, observed) + + assert(pValue > threshold) + } } diff --git a/core/src/test/scala/org/apache/spark/util/VersionUtilsSuite.scala b/core/src/test/scala/org/apache/spark/util/VersionUtilsSuite.scala new file mode 100644 index 000000000000..aaf79ebd4f9f --- /dev/null +++ b/core/src/test/scala/org/apache/spark/util/VersionUtilsSuite.scala @@ -0,0 +1,76 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.util + +import org.apache.spark.SparkFunSuite + +class VersionUtilsSuite extends SparkFunSuite { + + import org.apache.spark.util.VersionUtils._ + + test("Parse Spark major version") { + assert(majorVersion("2.0") === 2) + assert(majorVersion("12.10.11") === 12) + assert(majorVersion("2.0.1-SNAPSHOT") === 2) + assert(majorVersion("2.0.x") === 2) + withClue("majorVersion parsing should fail for invalid major version number") { + intercept[IllegalArgumentException] { + majorVersion("2z.0") + } + } + withClue("majorVersion parsing should fail for invalid minor version number") { + intercept[IllegalArgumentException] { + majorVersion("2.0z") + } + } + } + + test("Parse Spark minor version") { + assert(minorVersion("2.0") === 0) + assert(minorVersion("12.10.11") === 10) + assert(minorVersion("2.0.1-SNAPSHOT") === 0) + assert(minorVersion("2.0.x") === 0) + withClue("minorVersion parsing should fail for invalid major version number") { + intercept[IllegalArgumentException] { + minorVersion("2z.0") + } + } + withClue("minorVersion parsing should fail for invalid minor version number") { + intercept[IllegalArgumentException] { + minorVersion("2.0z") + } + } + } + + test("Parse Spark major and minor versions") { + assert(majorMinorVersion("2.0") === (2, 0)) + assert(majorMinorVersion("12.10.11") === (12, 10)) + assert(majorMinorVersion("2.0.1-SNAPSHOT") === (2, 0)) + assert(majorMinorVersion("2.0.x") === (2, 0)) + withClue("majorMinorVersion parsing should fail for invalid major version number") { + intercept[IllegalArgumentException] { + majorMinorVersion("2z.0") + } + } + withClue("majorMinorVersion parsing should fail for invalid minor version number") { + intercept[IllegalArgumentException] { + majorMinorVersion("2.0z") + } + } + } +} diff --git a/core/src/test/scala/org/apache/spark/util/io/ChunkedByteBufferOutputStreamSuite.scala b/core/src/test/scala/org/apache/spark/util/io/ChunkedByteBufferOutputStreamSuite.scala index 226622075a6c..86961745673c 100644 --- a/core/src/test/scala/org/apache/spark/util/io/ChunkedByteBufferOutputStreamSuite.scala +++ b/core/src/test/scala/org/apache/spark/util/io/ChunkedByteBufferOutputStreamSuite.scala @@ -28,12 +28,14 @@ class ChunkedByteBufferOutputStreamSuite extends SparkFunSuite { test("empty output") { val o = new ChunkedByteBufferOutputStream(1024, ByteBuffer.allocate) + o.close() assert(o.toChunkedByteBuffer.size === 0) } test("write a single byte") { val o = new ChunkedByteBufferOutputStream(1024, ByteBuffer.allocate) o.write(10) + o.close() val chunkedByteBuffer = o.toChunkedByteBuffer assert(chunkedByteBuffer.getChunks().length === 1) assert(chunkedByteBuffer.getChunks().head.array().toSeq === Seq(10.toByte)) @@ -43,6 +45,7 @@ class ChunkedByteBufferOutputStreamSuite extends SparkFunSuite { val o = new ChunkedByteBufferOutputStream(10, ByteBuffer.allocate) o.write(new Array[Byte](9)) o.write(99) + o.close() val chunkedByteBuffer = o.toChunkedByteBuffer assert(chunkedByteBuffer.getChunks().length === 1) assert(chunkedByteBuffer.getChunks().head.array()(9) === 99.toByte) @@ -52,6 +55,7 @@ class ChunkedByteBufferOutputStreamSuite extends SparkFunSuite { val o = new ChunkedByteBufferOutputStream(10, ByteBuffer.allocate) o.write(new Array[Byte](10)) o.write(99) + o.close() val arrays = o.toChunkedByteBuffer.getChunks().map(_.array()) assert(arrays.length === 2) assert(arrays(1).length === 1) @@ -63,6 +67,7 @@ class ChunkedByteBufferOutputStreamSuite extends SparkFunSuite { Random.nextBytes(ref) val o = new ChunkedByteBufferOutputStream(10, ByteBuffer.allocate) o.write(ref) + o.close() val arrays = o.toChunkedByteBuffer.getChunks().map(_.array()) assert(arrays.length === 1) assert(arrays.head.length === ref.length) @@ -74,6 +79,7 @@ class ChunkedByteBufferOutputStreamSuite extends SparkFunSuite { Random.nextBytes(ref) val o = new ChunkedByteBufferOutputStream(10, ByteBuffer.allocate) o.write(ref) + o.close() val arrays = o.toChunkedByteBuffer.getChunks().map(_.array()) assert(arrays.length === 1) assert(arrays.head.length === ref.length) @@ -85,6 +91,7 @@ class ChunkedByteBufferOutputStreamSuite extends SparkFunSuite { Random.nextBytes(ref) val o = new ChunkedByteBufferOutputStream(10, ByteBuffer.allocate) o.write(ref) + o.close() val arrays = o.toChunkedByteBuffer.getChunks().map(_.array()) assert(arrays.length === 3) assert(arrays(0).length === 10) @@ -101,6 +108,7 @@ class ChunkedByteBufferOutputStreamSuite extends SparkFunSuite { Random.nextBytes(ref) val o = new ChunkedByteBufferOutputStream(10, ByteBuffer.allocate) o.write(ref) + o.close() val arrays = o.toChunkedByteBuffer.getChunks().map(_.array()) assert(arrays.length === 3) assert(arrays(0).length === 10) diff --git a/dev/create-release/release-tag.sh b/dev/create-release/release-tag.sh index d404939d1cae..b7e5100ca740 100755 --- a/dev/create-release/release-tag.sh +++ b/dev/create-release/release-tag.sh @@ -60,12 +60,27 @@ git config user.email $GIT_EMAIL # Create release version $MVN versions:set -DnewVersion=$RELEASE_VERSION | grep -v "no value" # silence logs +# Set the release version in R/pkg/DESCRIPTION +sed -i".tmp1" 's/Version.*$/Version: '"$RELEASE_VERSION"'/g' R/pkg/DESCRIPTION +# Set the release version in docs +sed -i".tmp1" 's/SPARK_VERSION:.*$/SPARK_VERSION: '"$RELEASE_VERSION"'/g' docs/_config.yml +sed -i".tmp2" 's/SPARK_VERSION_SHORT:.*$/SPARK_VERSION_SHORT: '"$RELEASE_VERSION"'/g' docs/_config.yml + git commit -a -m "Preparing Spark release $RELEASE_TAG" echo "Creating tag $RELEASE_TAG at the head of $GIT_BRANCH" git tag $RELEASE_TAG # Create next version $MVN versions:set -DnewVersion=$NEXT_VERSION | grep -v "no value" # silence logs +# Remove -SNAPSHOT before setting the R version as R expects version strings to only have numbers +R_NEXT_VERSION=`echo $NEXT_VERSION | sed 's/-SNAPSHOT//g'` +sed -i".tmp2" 's/Version.*$/Version: '"$R_NEXT_VERSION"'/g' R/pkg/DESCRIPTION + +# Update docs with next version +sed -i".tmp3" 's/SPARK_VERSION:.*$/SPARK_VERSION: '"$NEXT_VERSION"'/g' docs/_config.yml +# Use R version for short version +sed -i".tmp4" 's/SPARK_VERSION_SHORT:.*$/SPARK_VERSION_SHORT: '"$R_NEXT_VERSION"'/g' docs/_config.yml + git commit -a -m "Preparing development version $NEXT_VERSION" # Push changes diff --git a/dev/deps/spark-deps-hadoop-2.2 b/dev/deps/spark-deps-hadoop-2.2 index 96001eade028..8c9e559790ba 100644 --- a/dev/deps/spark-deps-hadoop-2.2 +++ b/dev/deps/spark-deps-hadoop-2.2 @@ -46,7 +46,7 @@ curator-recipes-2.4.0.jar datanucleus-api-jdo-3.2.6.jar datanucleus-core-3.2.10.jar datanucleus-rdbms-3.2.9.jar -derby-10.11.1.1.jar +derby-10.12.1.1.jar eigenbase-properties-1.1.5.jar guava-14.0.1.jar guice-3.0.jar @@ -98,7 +98,7 @@ jersey-media-jaxb-2.22.2.jar jersey-server-2.22.2.jar jets3t-0.7.1.jar jetty-util-6.1.26.jar -jline-2.12.jar +jline-2.12.1.jar joda-time-2.9.3.jar jodd-core-3.5.2.jar jpam-1.1.jar @@ -140,7 +140,7 @@ parquet-jackson-1.7.0.jar pmml-model-1.2.15.jar pmml-schema-1.2.15.jar protobuf-java-2.5.0.jar -py4j-0.10.1.jar +py4j-0.10.3.jar pyrolite-4.9.jar scala-compiler-2.11.8.jar scala-library-2.11.8.jar @@ -151,7 +151,7 @@ scalap-2.11.8.jar slf4j-api-1.7.16.jar slf4j-log4j12-1.7.16.jar snappy-0.2.jar -snappy-java-1.1.2.4.jar +snappy-java-1.1.2.6.jar spire-macros_2.11-0.7.4.jar spire_2.11-0.7.4.jar stax-api-1.0.1.jar diff --git a/dev/deps/spark-deps-hadoop-2.3 b/dev/deps/spark-deps-hadoop-2.3 index 9f3d9ad97a9f..839e0840dba3 100644 --- a/dev/deps/spark-deps-hadoop-2.3 +++ b/dev/deps/spark-deps-hadoop-2.3 @@ -48,7 +48,7 @@ curator-recipes-2.4.0.jar datanucleus-api-jdo-3.2.6.jar datanucleus-core-3.2.10.jar datanucleus-rdbms-3.2.9.jar -derby-10.11.1.1.jar +derby-10.12.1.1.jar eigenbase-properties-1.1.5.jar guava-14.0.1.jar guice-3.0.jar @@ -103,7 +103,7 @@ jersey-server-2.22.2.jar jets3t-0.9.3.jar jetty-6.1.26.jar jetty-util-6.1.26.jar -jline-2.12.jar +jline-2.12.1.jar joda-time-2.9.3.jar jodd-core-3.5.2.jar jpam-1.1.jar @@ -147,7 +147,7 @@ parquet-jackson-1.7.0.jar pmml-model-1.2.15.jar pmml-schema-1.2.15.jar protobuf-java-2.5.0.jar -py4j-0.10.1.jar +py4j-0.10.3.jar pyrolite-4.9.jar scala-compiler-2.11.8.jar scala-library-2.11.8.jar @@ -158,7 +158,7 @@ scalap-2.11.8.jar slf4j-api-1.7.16.jar slf4j-log4j12-1.7.16.jar snappy-0.2.jar -snappy-java-1.1.2.4.jar +snappy-java-1.1.2.6.jar spire-macros_2.11-0.7.4.jar spire_2.11-0.7.4.jar stax-api-1.0-2.jar diff --git a/dev/deps/spark-deps-hadoop-2.4 b/dev/deps/spark-deps-hadoop-2.4 index 3df292ee9956..ed84de79b1fd 100644 --- a/dev/deps/spark-deps-hadoop-2.4 +++ b/dev/deps/spark-deps-hadoop-2.4 @@ -48,7 +48,7 @@ curator-recipes-2.4.0.jar datanucleus-api-jdo-3.2.6.jar datanucleus-core-3.2.10.jar datanucleus-rdbms-3.2.9.jar -derby-10.11.1.1.jar +derby-10.12.1.1.jar eigenbase-properties-1.1.5.jar guava-14.0.1.jar guice-3.0.jar @@ -103,7 +103,7 @@ jersey-server-2.22.2.jar jets3t-0.9.3.jar jetty-6.1.26.jar jetty-util-6.1.26.jar -jline-2.12.jar +jline-2.12.1.jar joda-time-2.9.3.jar jodd-core-3.5.2.jar jpam-1.1.jar @@ -147,7 +147,7 @@ parquet-jackson-1.7.0.jar pmml-model-1.2.15.jar pmml-schema-1.2.15.jar protobuf-java-2.5.0.jar -py4j-0.10.1.jar +py4j-0.10.3.jar pyrolite-4.9.jar scala-compiler-2.11.8.jar scala-library-2.11.8.jar @@ -158,7 +158,7 @@ scalap-2.11.8.jar slf4j-api-1.7.16.jar slf4j-log4j12-1.7.16.jar snappy-0.2.jar -snappy-java-1.1.2.4.jar +snappy-java-1.1.2.6.jar spire-macros_2.11-0.7.4.jar spire_2.11-0.7.4.jar stax-api-1.0-2.jar diff --git a/dev/deps/spark-deps-hadoop-2.6 b/dev/deps/spark-deps-hadoop-2.6 index 9540f5856bce..6e7c9cb5c791 100644 --- a/dev/deps/spark-deps-hadoop-2.6 +++ b/dev/deps/spark-deps-hadoop-2.6 @@ -52,7 +52,7 @@ curator-recipes-2.6.0.jar datanucleus-api-jdo-3.2.6.jar datanucleus-core-3.2.10.jar datanucleus-rdbms-3.2.9.jar -derby-10.11.1.1.jar +derby-10.12.1.1.jar eigenbase-properties-1.1.5.jar gson-2.2.4.jar guava-14.0.1.jar @@ -111,7 +111,7 @@ jersey-server-2.22.2.jar jets3t-0.9.3.jar jetty-6.1.26.jar jetty-util-6.1.26.jar -jline-2.12.jar +jline-2.12.1.jar joda-time-2.9.3.jar jodd-core-3.5.2.jar jpam-1.1.jar @@ -155,7 +155,7 @@ parquet-jackson-1.7.0.jar pmml-model-1.2.15.jar pmml-schema-1.2.15.jar protobuf-java-2.5.0.jar -py4j-0.10.1.jar +py4j-0.10.3.jar pyrolite-4.9.jar scala-compiler-2.11.8.jar scala-library-2.11.8.jar @@ -166,7 +166,7 @@ scalap-2.11.8.jar slf4j-api-1.7.16.jar slf4j-log4j12-1.7.16.jar snappy-0.2.jar -snappy-java-1.1.2.4.jar +snappy-java-1.1.2.6.jar spire-macros_2.11-0.7.4.jar spire_2.11-0.7.4.jar stax-api-1.0-2.jar diff --git a/dev/deps/spark-deps-hadoop-2.7 b/dev/deps/spark-deps-hadoop-2.7 index b5c3de75a9c8..a61f31eb8769 100644 --- a/dev/deps/spark-deps-hadoop-2.7 +++ b/dev/deps/spark-deps-hadoop-2.7 @@ -52,27 +52,27 @@ curator-recipes-2.6.0.jar datanucleus-api-jdo-3.2.6.jar datanucleus-core-3.2.10.jar datanucleus-rdbms-3.2.9.jar -derby-10.11.1.1.jar +derby-10.12.1.1.jar eigenbase-properties-1.1.5.jar gson-2.2.4.jar guava-14.0.1.jar guice-3.0.jar guice-servlet-3.0.jar -hadoop-annotations-2.7.2.jar -hadoop-auth-2.7.2.jar -hadoop-client-2.7.2.jar -hadoop-common-2.7.2.jar -hadoop-hdfs-2.7.2.jar -hadoop-mapreduce-client-app-2.7.2.jar -hadoop-mapreduce-client-common-2.7.2.jar -hadoop-mapreduce-client-core-2.7.2.jar -hadoop-mapreduce-client-jobclient-2.7.2.jar -hadoop-mapreduce-client-shuffle-2.7.2.jar -hadoop-yarn-api-2.7.2.jar -hadoop-yarn-client-2.7.2.jar -hadoop-yarn-common-2.7.2.jar -hadoop-yarn-server-common-2.7.2.jar -hadoop-yarn-server-web-proxy-2.7.2.jar +hadoop-annotations-2.7.3.jar +hadoop-auth-2.7.3.jar +hadoop-client-2.7.3.jar +hadoop-common-2.7.3.jar +hadoop-hdfs-2.7.3.jar +hadoop-mapreduce-client-app-2.7.3.jar +hadoop-mapreduce-client-common-2.7.3.jar +hadoop-mapreduce-client-core-2.7.3.jar +hadoop-mapreduce-client-jobclient-2.7.3.jar +hadoop-mapreduce-client-shuffle-2.7.3.jar +hadoop-yarn-api-2.7.3.jar +hadoop-yarn-client-2.7.3.jar +hadoop-yarn-common-2.7.3.jar +hadoop-yarn-server-common-2.7.3.jar +hadoop-yarn-server-web-proxy-2.7.3.jar hk2-api-2.4.0-b34.jar hk2-locator-2.4.0-b34.jar hk2-utils-2.4.0-b34.jar @@ -111,7 +111,7 @@ jersey-server-2.22.2.jar jets3t-0.9.3.jar jetty-6.1.26.jar jetty-util-6.1.26.jar -jline-2.12.jar +jline-2.12.1.jar joda-time-2.9.3.jar jodd-core-3.5.2.jar jpam-1.1.jar @@ -156,7 +156,7 @@ parquet-jackson-1.7.0.jar pmml-model-1.2.15.jar pmml-schema-1.2.15.jar protobuf-java-2.5.0.jar -py4j-0.10.1.jar +py4j-0.10.3.jar pyrolite-4.9.jar scala-compiler-2.11.8.jar scala-library-2.11.8.jar @@ -167,7 +167,7 @@ scalap-2.11.8.jar slf4j-api-1.7.16.jar slf4j-log4j12-1.7.16.jar snappy-0.2.jar -snappy-java-1.1.2.4.jar +snappy-java-1.1.2.6.jar spire-macros_2.11-0.7.4.jar spire_2.11-0.7.4.jar stax-api-1.0-2.jar diff --git a/docs/_config.yml b/docs/_config.yml index 3951cadb0ffd..75c89bd31898 100644 --- a/docs/_config.yml +++ b/docs/_config.yml @@ -14,8 +14,8 @@ include: # These allow the documentation to be updated with newer releases # of Spark, Scala, and Mesos. -SPARK_VERSION: 2.0.0 -SPARK_VERSION_SHORT: 2.0.0 +SPARK_VERSION: 2.0.1 +SPARK_VERSION_SHORT: 2.0.1 SCALA_BINARY_VERSION: "2.11" SCALA_VERSION: "2.11.7" MESOS_VERSION: 0.21.0 diff --git a/docs/_layouts/global.html b/docs/_layouts/global.html index d3bf082aa751..ad5b5c9adfac 100755 --- a/docs/_layouts/global.html +++ b/docs/_layouts/global.html @@ -114,7 +114,7 @@
  • Building Spark
  • Contributing to Spark
  • -
  • Supplemental Projects
  • +
  • Third Party Projects
  • diff --git a/docs/building-spark.md b/docs/building-spark.md index 2c987cf8346e..330df0054154 100644 --- a/docs/building-spark.md +++ b/docs/building-spark.md @@ -16,11 +16,13 @@ Building Spark using Maven requires Maven 3.3.9 or newer and Java 7+. ### Setting up Maven's Memory Usage -You'll need to configure Maven to use more memory than usual by setting `MAVEN_OPTS`. We recommend the following settings: +You'll need to configure Maven to use more memory than usual by setting `MAVEN_OPTS`: - export MAVEN_OPTS="-Xmx2g -XX:MaxPermSize=512M -XX:ReservedCodeCacheSize=512m" + export MAVEN_OPTS="-Xmx2g -XX:ReservedCodeCacheSize=512m" -If you don't run this, you may see errors like the following: +When compiling with Java 7, you will need to add the additional option "-XX:MaxPermSize=512M" to MAVEN_OPTS. + +If you don't add these parameters to `MAVEN_OPTS`, you may see errors and warnings like the following: [INFO] Compiling 203 Scala sources and 9 Java sources to /Users/me/Development/spark/core/target/scala-{{site.SCALA_BINARY_VERSION}}/classes... [ERROR] PermGen space -> [Help 1] @@ -28,12 +30,18 @@ If you don't run this, you may see errors like the following: [INFO] Compiling 203 Scala sources and 9 Java sources to /Users/me/Development/spark/core/target/scala-{{site.SCALA_BINARY_VERSION}}/classes... [ERROR] Java heap space -> [Help 1] -You can fix this by setting the `MAVEN_OPTS` variable as discussed before. + [INFO] Compiling 233 Scala sources and 41 Java sources to /Users/me/Development/spark/sql/core/target/scala-{site.SCALA_BINARY_VERSION}/classes... + OpenJDK 64-Bit Server VM warning: CodeCache is full. Compiler has been disabled. + OpenJDK 64-Bit Server VM warning: Try increasing the code cache size using -XX:ReservedCodeCacheSize= + +You can fix these problems by setting the `MAVEN_OPTS` variable as discussed before. **Note:** -* For Java 8 and above this step is not required. -* If using `build/mvn` with no `MAVEN_OPTS` set, the script will automate this for you. +* If using `build/mvn` with no `MAVEN_OPTS` set, the script will automatically add the above options to the `MAVEN_OPTS` environment variable. +* The `test` phase of the Spark build will automatically add these options to `MAVEN_OPTS`, even when not using `build/mvn`. +* You may see warnings like "ignoring option MaxPermSize=1g; support was removed in 8.0" when building or running tests with Java 8 and `build/mvn`. These warnings are harmless. + ### build/mvn diff --git a/docs/configuration.md b/docs/configuration.md index bf10b2481951..db088dde1191 100644 --- a/docs/configuration.md +++ b/docs/configuration.md @@ -123,6 +123,7 @@ of the most common options to set are: Number of cores to use for the driver process, only in cluster mode. + spark.driver.maxResultSize 1g @@ -217,7 +218,7 @@ Apart from these, the following properties are also available, and may be useful
    Note: In client mode, this config must not be set through the SparkConf directly in your application, because the driver JVM has already started at that point. Instead, please set this through the --driver-class-path command line option or in - your default properties file. + your default properties file. @@ -244,7 +245,7 @@ Apart from these, the following properties are also available, and may be useful
    Note: In client mode, this config must not be set through the SparkConf directly in your application, because the driver JVM has already started at that point. Instead, please set this through the --driver-library-path command line option or in - your default properties file. + your default properties file. @@ -597,6 +598,14 @@ Apart from these, the following properties are also available, and may be useful collecting. + + spark.ui.retainedTasks + 100000 + + How many tasks the Spark UI and status APIs remember before garbage + collecting. + + spark.worker.ui.retainedExecutors 1000 @@ -1204,7 +1213,7 @@ Apart from these, the following properties are also available, and may be useful false Whether to use dynamic resource allocation, which scales the number of executors registered - with this application up and down based on the workload. + with this application up and down based on the workload. For more detail, see the description here.

    @@ -1345,8 +1354,9 @@ Apart from these, the following properties are also available, and may be useful spark.authenticate.enableSaslEncryption false - Enable encrypted communication when authentication is enabled. This option is currently - only supported by the block transfer service. + Enable encrypted communication when authentication is + enabled. This is supported by the block transfer service and the + RPC endpoints. @@ -1448,8 +1458,10 @@ Apart from these, the following properties are also available, and may be useful the properties must be overwritten in the protocol-specific namespace.

    Use spark.ssl.YYY.XXX settings to overwrite the global configuration for - particular protocol denoted by YYY. Currently YYY can be - only fs for file server.

    + particular protocol denoted by YYY. Example values for YYY + include fs, ui, standalone, and + historyServer. See SSL + Configuration for details on hierarchical SSL configuration for services.

    diff --git a/docs/graphx-programming-guide.md b/docs/graphx-programming-guide.md index 2e9966c0a2b6..07b38d9cc9a8 100644 --- a/docs/graphx-programming-guide.md +++ b/docs/graphx-programming-guide.md @@ -24,7 +24,6 @@ description: GraphX graph processing library guide for Spark SPARK_VERSION_SHORT [Graph.outerJoinVertices]: api/scala/index.html#org.apache.spark.graphx.Graph@outerJoinVertices[U,VD2](RDD[(VertexId,U)])((VertexId,VD,Option[U])⇒VD2)(ClassTag[U],ClassTag[VD2]):Graph[VD2,ED] [Graph.aggregateMessages]: api/scala/index.html#org.apache.spark.graphx.Graph@aggregateMessages[A]((EdgeContext[VD,ED,A])⇒Unit,(A,A)⇒A,TripletFields)(ClassTag[A]):VertexRDD[A] [EdgeContext]: api/scala/index.html#org.apache.spark.graphx.EdgeContext -[Graph.mapReduceTriplets]: api/scala/index.html#org.apache.spark.graphx.Graph@mapReduceTriplets[A](mapFunc:org.apache.spark.graphx.EdgeTriplet[VD,ED]=>Iterator[(org.apache.spark.graphx.VertexId,A)],reduceFunc:(A,A)=>A,activeSetOpt:Option[(org.apache.spark.graphx.VertexRDD[_],org.apache.spark.graphx.EdgeDirection)])(implicitevidence$10:scala.reflect.ClassTag[A]):org.apache.spark.graphx.VertexRDD[A] [GraphOps.collectNeighborIds]: api/scala/index.html#org.apache.spark.graphx.GraphOps@collectNeighborIds(EdgeDirection):VertexRDD[Array[VertexId]] [GraphOps.collectNeighbors]: api/scala/index.html#org.apache.spark.graphx.GraphOps@collectNeighbors(EdgeDirection):VertexRDD[Array[(VertexId,VD)]] [RDD Persistence]: programming-guide.html#rdd-persistence @@ -67,23 +66,6 @@ operators (e.g., [subgraph](#structural_operators), [joinVertices](#join_operato [aggregateMessages](#aggregateMessages)) as well as an optimized variant of the [Pregel](#pregel) API. In addition, GraphX includes a growing collection of graph [algorithms](#graph_algorithms) and [builders](#graph_builders) to simplify graph analytics tasks. - -## Migrating from Spark 1.1 - -GraphX in Spark 1.2 contains a few user facing API changes: - -1. To improve performance we have introduced a new version of -[`mapReduceTriplets`][Graph.mapReduceTriplets] called -[`aggregateMessages`][Graph.aggregateMessages] which takes the messages previously returned from -[`mapReduceTriplets`][Graph.mapReduceTriplets] through a callback ([`EdgeContext`][EdgeContext]) -rather than by return value. -We are deprecating [`mapReduceTriplets`][Graph.mapReduceTriplets] and encourage users to consult -the [transition guide](#mrTripletsTransition). - -2. In Spark 1.0 and 1.1, the type signature of [`EdgeRDD`][EdgeRDD] switched from -`EdgeRDD[ED]` to `EdgeRDD[ED, VD]` to enable some caching optimizations. We have since discovered -a more elegant solution and have restored the type signature to the more natural `EdgeRDD[ED]` type. - # Getting Started To get started you first need to import Spark and GraphX into your project, as follows: @@ -613,7 +595,7 @@ compute the average age of the more senior followers of each user. ### Map Reduce Triplets Transition Guide (Legacy) In earlier versions of GraphX neighborhood aggregation was accomplished using the -[`mapReduceTriplets`][Graph.mapReduceTriplets] operator: +`mapReduceTriplets` operator: {% highlight scala %} class Graph[VD, ED] { @@ -624,7 +606,7 @@ class Graph[VD, ED] { } {% endhighlight %} -The [`mapReduceTriplets`][Graph.mapReduceTriplets] operator takes a user defined map function which +The `mapReduceTriplets` operator takes a user defined map function which is applied to each triplet and can yield *messages* which are aggregated using the user defined `reduce` function. However, we found the user of the returned iterator to be expensive and it inhibited our ability to diff --git a/docs/index.md b/docs/index.md index 0cb8803783a0..a7a92f6c4f6d 100644 --- a/docs/index.md +++ b/docs/index.md @@ -120,7 +120,7 @@ options for deployment: * [OpenStack Swift](storage-openstack-swift.html) * [Building Spark](building-spark.html): build Spark using the Maven system * [Contributing to Spark](https://cwiki.apache.org/confluence/display/SPARK/Contributing+to+Spark) -* [Supplemental Projects](https://cwiki.apache.org/confluence/display/SPARK/Supplemental+Spark+Projects): related third party Spark projects +* [Third Party Projects](https://cwiki.apache.org/confluence/display/SPARK/Third+Party+Projects): related third party Spark projects **External Resources:** diff --git a/docs/ml-advanced.md b/docs/ml-advanced.md index f5804fdeee5a..12a03d3c9198 100644 --- a/docs/ml-advanced.md +++ b/docs/ml-advanced.md @@ -49,7 +49,7 @@ MLlib L-BFGS solver calls the corresponding implementation in [breeze](https://g ## Normal equation solver for weighted least squares -MLlib implements normal equation solver for [weighted least squares](https://en.wikipedia.org/wiki/Least_squares#Weighted_least_squares) by [WeightedLeastSquares](https://github.com/apache/spark/blob/master/mllib/src/main/scala/org/apache/spark/ml/optim/WeightedLeastSquares.scala). +MLlib implements normal equation solver for [weighted least squares](https://en.wikipedia.org/wiki/Least_squares#Weighted_least_squares) by [WeightedLeastSquares]({{site.SPARK_GITHUB_URL}}/blob/v{{site.SPARK_VERSION_SHORT}}/mllib/src/main/scala/org/apache/spark/ml/optim/WeightedLeastSquares.scala). Given $n$ weighted observations $(w_i, a_i, b_i)$: @@ -73,7 +73,7 @@ In order to make the normal equation approach efficient, WeightedLeastSquares re ## Iteratively reweighted least squares (IRLS) -MLlib implements [iteratively reweighted least squares (IRLS)](https://en.wikipedia.org/wiki/Iteratively_reweighted_least_squares) by [IterativelyReweightedLeastSquares](https://github.com/apache/spark/blob/master/mllib/src/main/scala/org/apache/spark/ml/optim/IterativelyReweightedLeastSquares.scala). +MLlib implements [iteratively reweighted least squares (IRLS)](https://en.wikipedia.org/wiki/Iteratively_reweighted_least_squares) by [IterativelyReweightedLeastSquares]({{site.SPARK_GITHUB_URL}}/blob/v{{site.SPARK_VERSION_SHORT}}/mllib/src/main/scala/org/apache/spark/ml/optim/IterativelyReweightedLeastSquares.scala). It can be used to find the maximum likelihood estimates of a generalized linear model (GLM), find M-estimator in robust regression and other optimization problems. Refer to [Iteratively Reweighted Least Squares for Maximum Likelihood Estimation, and some Robust and Resistant Alternatives](http://www.jstor.org/stable/2345503) for more information. diff --git a/docs/ml-guide.md b/docs/ml-guide.md index 5abec63b7ab4..4607ad3ba681 100644 --- a/docs/ml-guide.md +++ b/docs/ml-guide.md @@ -33,7 +33,7 @@ The primary Machine Learning API for Spark is now the [DataFrame](sql-programmin * DataFrames provide a more user-friendly API than RDDs. The many benefits of DataFrames include Spark Datasources, SQL/DataFrame queries, Tungsten and Catalyst optimizations, and uniform APIs across languages. * The DataFrame-based API for MLlib provides a uniform API across ML algorithms and across multiple languages. -* DataFrames facilitate practical ML Pipelines, particularly feature transformations. See the [Pipelines guide](ml-pipeline.md) for details. +* DataFrames facilitate practical ML Pipelines, particularly feature transformations. See the [Pipelines guide](ml-pipeline.html) for details. # Dependencies diff --git a/docs/monitoring.md b/docs/monitoring.md index ee932cfc6d70..1bc3d266b66b 100644 --- a/docs/monitoring.md +++ b/docs/monitoring.md @@ -114,8 +114,17 @@ The history server can be configured as follows: spark.history.retainedApplications 50 - The number of application UIs to retain. If this cap is exceeded, then the oldest - applications will be removed. + The number of applications to retain UI data for in the cache. If this cap is exceeded, then + the oldest applications will be removed from the cache. If an application is not in the cache, + it will have to be loaded from disk if its accessed from the UI. + + + + spark.history.ui.maxApplications + Int.MaxValue + + The number of applications to display on the history summary page. Application UIs are still + available by accessing their URLs directly even if they are not displayed on the history summary page. @@ -242,7 +251,8 @@ can be identified by their `[attempt-id]`. In the API listed below, when running
    Examples:
    ?minDate=2015-02-10
    ?minDate=2015-02-03T16:42:40.000GMT -
    ?maxDate=[date] latest date/time to list; uses same format as minDate. +
    ?maxDate=[date] latest date/time to list; uses same format as minDate. +
    ?limit=[limit] limits the number of applications listed. /applications/[app-id]/jobs diff --git a/docs/programming-guide.md b/docs/programming-guide.md index 888c12f18635..f82832905ef4 100644 --- a/docs/programming-guide.md +++ b/docs/programming-guide.md @@ -1097,7 +1097,7 @@ for details. foreach(func) Run a function func on each element of the dataset. This is usually done for side effects such as updating an Accumulator or interacting with external storage systems. -
    Note: modifying variables other than Accumulators outside of the foreach() may result in undefined behavior. See Understanding closures for more details. +
    Note: modifying variables other than Accumulators outside of the foreach() may result in undefined behavior. See Understanding closures for more details. @@ -1348,17 +1348,17 @@ running stages (NOTE: this is not yet supported in Python). Accumulators in the Spark UI

    -An accumulator is created from an initial value `v` by calling `SparkContext.accumulator(v)`. Tasks -running on a cluster can then add to it using the `add` method or the `+=` operator (in Scala and Python). -However, they cannot read its value. -Only the driver program can read the accumulator's value, using its `value` method. - -The code below shows an accumulator being used to add up the elements of an array: -
    +A numeric accumulator can be created by calling `SparkContext.longAccumulator()` or `SparkContext.doubleAccumulator()` +to accumulate values of type Long or Double, respectively. Tasks running on a cluster can then add to it using +the `add` method. However, they cannot read its value. Only the driver program can read the accumulator's value, +using its `value` method. + +The code below shows an accumulator being used to add up the elements of an array: + {% highlight scala %} scala> val accum = sc.longAccumulator("My Accumulator") accum: org.apache.spark.util.LongAccumulator = LongAccumulator(id: 0, name: Some(My Accumulator), value: 0) @@ -1395,14 +1395,21 @@ val myVectorAcc = new VectorAccumulatorV2 sc.register(myVectorAcc, "MyVectorAcc1") {% endhighlight %} -Note that, when programmers define their own type of AccumulatorV2, the resulting type can be same or not same with the elements added. +Note that, when programmers define their own type of AccumulatorV2, the resulting type can be different than that of the elements added.
    +A numeric accumulator can be created by calling `SparkContext.longAccumulator()` or `SparkContext.doubleAccumulator()` +to accumulate values of type Long or Double, respectively. Tasks running on a cluster can then add to it using +the `add` method. However, they cannot read its value. Only the driver program can read the accumulator's value, +using its `value` method. + +The code below shows an accumulator being used to add up the elements of an array: + {% highlight java %} -LongAccumulator accum = sc.sc().longAccumulator(); +LongAccumulator accum = jsc.sc().longAccumulator(); sc.parallelize(Arrays.asList(1, 2, 3, 4)).foreach(x -> accum.add(x)); // ... @@ -1412,8 +1419,8 @@ accum.value(); // returns 10 {% endhighlight %} -While this code used the built-in support for accumulators of type Integer, programmers can also -create their own types by subclassing [AccumulatorParam](api/java/index.html?org/apache/spark/AccumulatorParam.html). +Programmers can also create their own types by subclassing +[AccumulatorParam](api/java/index.html?org/apache/spark/AccumulatorParam.html). The AccumulatorParam interface has two methods: `zero` for providing a "zero value" for your data type, and `addInPlace` for adding two values together. For example, supposing we had a `Vector` class representing mathematical vectors, we could write: @@ -1440,6 +1447,12 @@ a list by collecting together elements).
    +An accumulator is created from an initial value `v` by calling `SparkContext.accumulator(v)`. Tasks +running on a cluster can then add to it using the `add` method or the `+=` operator. However, they cannot read its value. +Only the driver program can read the accumulator's value, using its `value` method. + +The code below shows an accumulator being used to add up the elements of an array: + {% highlight python %} >>> accum = sc.accumulator(0) Accumulator @@ -1485,15 +1498,15 @@ Accumulators do not change the lazy evaluation model of Spark. If they are being
    {% highlight scala %} -val accum = sc.accumulator(0) -data.map { x => accum += x; x } +val accum = sc.longAccumulator +data.map { x => accum.add(x); x } // Here, accum is still 0 because no actions have caused the map operation to be computed. {% endhighlight %}
    {% highlight java %} -LongAccumulator accum = sc.sc().longAccumulator(); +LongAccumulator accum = jsc.sc().longAccumulator(); data.map(x -> { accum.add(x); return f(x); }); // Here, accum is still 0 because no actions have caused the `map` to be computed. {% endhighlight %} @@ -1531,49 +1544,6 @@ and then call `SparkContext.stop()` to tear it down. Make sure you stop the context within a `finally` block or the test framework's `tearDown` method, as Spark does not support two contexts running concurrently in the same program. -# Migrating from pre-1.0 Versions of Spark - -
    - -
    - -Spark 1.0 freezes the API of Spark Core for the 1.X series, in that any API available today that is -not marked "experimental" or "developer API" will be supported in future versions. -The only change for Scala users is that the grouping operations, e.g. `groupByKey`, `cogroup` and `join`, -have changed from returning `(Key, Seq[Value])` pairs to `(Key, Iterable[Value])`. - -
    - -
    - -Spark 1.0 freezes the API of Spark Core for the 1.X series, in that any API available today that is -not marked "experimental" or "developer API" will be supported in future versions. -Several changes were made to the Java API: - -* The Function classes in `org.apache.spark.api.java.function` became interfaces in 1.0, meaning that old - code that `extends Function` should `implement Function` instead. -* New variants of the `map` transformations, like `mapToPair` and `mapToDouble`, were added to create RDDs - of special data types. -* Grouping operations like `groupByKey`, `cogroup` and `join` have changed from returning - `(Key, List)` pairs to `(Key, Iterable)`. - -
    - -
    - -Spark 1.0 freezes the API of Spark Core for the 1.X series, in that any API available today that is -not marked "experimental" or "developer API" will be supported in future versions. -The only change for Python users is that the grouping operations, e.g. `groupByKey`, `cogroup` and `join`, -have changed from returning (key, list of values) pairs to (key, iterable of values). - -
    - -
    - -Migration guides are also available for [Spark Streaming](streaming-programming-guide.html#migration-guide-from-091-or-below-to-1x), -[MLlib](ml-guide.html#migration-guide) and [GraphX](graphx-programming-guide.html#migrating-from-spark-091). - - # Where to Go from Here You can see some [example Spark programs](http://spark.apache.org/examples.html) on the Spark website. diff --git a/docs/security.md b/docs/security.md index d2708a80703e..baadfefbec82 100644 --- a/docs/security.md +++ b/docs/security.md @@ -27,7 +27,8 @@ If your applications are using event logging, the directory where the event logs ## Encryption -Spark supports SSL for HTTP protocols. SASL encryption is supported for the block transfer service. +Spark supports SSL for HTTP protocols. SASL encryption is supported for the block transfer service +and the RPC endpoints. Encryption is not yet supported for data stored by Spark in temporary local storage, such as shuffle files, cached data, and other application files. If encrypting this data is desired, a workaround is diff --git a/docs/spark-standalone.md b/docs/spark-standalone.md index c864c9030835..6f0f665c82e1 100644 --- a/docs/spark-standalone.md +++ b/docs/spark-standalone.md @@ -195,6 +195,21 @@ SPARK_MASTER_OPTS supports the following system properties: the whole cluster by default.
    + + spark.deploy.maxExecutorRetries + 10 + + Limit on the maximum number of back-to-back executor failures that can occur before the + standalone cluster manager removes a faulty application. An application will never be removed + if it has any running executors. If an application experiences more than + spark.deploy.maxExecutorRetries failures in a row, no executors + successfully start running in between those failures, and the application has no running + executors then the standalone cluster manager will remove the application and mark it as failed. + To disable this automatic removal, set spark.deploy.maxExecutorRetries to + -1. +
    + + spark.worker.timeout 60 @@ -333,7 +348,7 @@ Learn more about getting started with ZooKeeper [here](http://zookeeper.apache.o **Configuration** In order to enable this recovery mode, you can set SPARK_DAEMON_JAVA_OPTS in spark-env by configuring `spark.deploy.recoveryMode` and related spark.deploy.zookeeper.* configurations. -For more information about these configurations please refer to the configurations (doc)[configurations.html#deploy] +For more information about these configurations please refer to the [configuration doc](configuration.html#deploy) Possible gotcha: If you have multiple Masters in your cluster but fail to correctly configure the Masters to use ZooKeeper, the Masters will fail to discover each other and think they're all leaders. This will not lead to a healthy cluster state (as all Masters will schedule independently). diff --git a/docs/sparkr.md b/docs/sparkr.md index 4bbc362c5208..340e7f7cb1a0 100644 --- a/docs/sparkr.md +++ b/docs/sparkr.md @@ -62,6 +62,21 @@ The following Spark driver properties can be set in `sparkConfig` with `sparkR.s + + + + + + + + + + + + + + + @@ -110,7 +125,8 @@ head(df) SparkR supports operating on a variety of data sources through the `SparkDataFrame` interface. This section describes the general methods for loading and saving data using Data Sources. You can check the Spark SQL programming guide for more [specific options](sql-programming-guide.html#manually-specifying-options) that are available for the built-in data sources. -The general method for creating SparkDataFrames from data sources is `read.df`. This method takes in the path for the file to load and the type of data source, and the currently active SparkSession will be used automatically. SparkR supports reading JSON, CSV and Parquet files natively and through [Spark Packages](http://spark-packages.org/) you can find data source connectors for popular file formats like [Avro](http://spark-packages.org/package/databricks/spark-avro). These packages can either be added by +The general method for creating SparkDataFrames from data sources is `read.df`. This method takes in the path for the file to load and the type of data source, and the currently active SparkSession will be used automatically. +SparkR supports reading JSON, CSV and Parquet files natively, and through packages available from sources like [Third Party Projects](https://cwiki.apache.org/confluence/display/SPARK/Third+Party+Projects), you can find data source connectors for popular file formats like Avro. These packages can either be added by specifying `--packages` with `spark-submit` or `sparkR` commands, or if initializing SparkSession with `sparkPackages` parameter when in an interactive R shell or from RStudio.
    diff --git a/docs/sql-programming-guide.md b/docs/sql-programming-guide.md index 33b170e50a00..0bd0093620a3 100644 --- a/docs/sql-programming-guide.md +++ b/docs/sql-programming-guide.md @@ -132,7 +132,7 @@ from a Hive table, or from [Spark data sources](#data-sources). As an example, the following creates a DataFrame based on the content of a JSON file: -{% include_example create_DataFrames r/RSparkSQLExample.R %} +{% include_example create_df r/RSparkSQLExample.R %}
    @@ -180,7 +180,7 @@ In addition to simple column references and expressions, DataFrames also have a
    -{% include_example dataframe_operations r/RSparkSQLExample.R %} +{% include_example untyped_ops r/RSparkSQLExample.R %} For a complete list of the types of operations that can be performed on a DataFrame refer to the [API Documentation](api/R/index.html). @@ -214,7 +214,7 @@ The `sql` function on a `SparkSession` enables applications to run SQL queries p
    The `sql` function enables applications to run SQL queries programmatically and returns the result as a `SparkDataFrame`. -{% include_example sql_query r/RSparkSQLExample.R %} +{% include_example run_sql r/RSparkSQLExample.R %}
    @@ -377,7 +377,7 @@ In the simplest form, the default data source (`parquet` unless otherwise config
    -{% include_example source_parquet r/RSparkSQLExample.R %} +{% include_example generic_load_save_functions r/RSparkSQLExample.R %}
    @@ -400,13 +400,11 @@ using this syntax.
    - {% include_example manual_load_options python/sql/datasource.py %}
    -
    - -{% include_example source_json r/RSparkSQLExample.R %} +
    +{% include_example manual_load_options r/RSparkSQLExample.R %}
    @@ -425,13 +423,11 @@ file directly with SQL.
    - {% include_example direct_sql python/sql/datasource.py %}
    - -{% include_example direct_query r/RSparkSQLExample.R %} +{% include_example direct_sql r/RSparkSQLExample.R %}
    @@ -523,7 +519,7 @@ Using the data from the above example:
    -{% include_example load_programmatically r/RSparkSQLExample.R %} +{% include_example basic_parquet_example r/RSparkSQLExample.R %}
    @@ -827,7 +823,7 @@ Note that the file that is offered as _a json file_ is not a typical JSON file. line must contain a separate, self-contained valid JSON object. As a consequence, a regular multi-line JSON file will most often fail. -{% include_example load_json_file r/RSparkSQLExample.R %} +{% include_example json_dataset r/RSparkSQLExample.R %} @@ -913,7 +909,7 @@ You may need to grant write privilege to the user who starts the spark applicati When working with Hive one must instantiate `SparkSession` with Hive support. This adds support for finding tables in the MetaStore and writing queries using HiveQL. -{% include_example hive_table r/RSparkSQLExample.R %} +{% include_example spark_hive r/RSparkSQLExample.R %} @@ -1045,7 +1041,7 @@ the Data Sources API. The following options are supported: - + @@ -1055,43 +1051,19 @@ the Data Sources API. The following options are supported:
    - -{% highlight scala %} -val jdbcDF = spark.read.format("jdbc").options( - Map("url" -> "jdbc:postgresql:dbserver", - "dbtable" -> "schema.tablename")).load() -{% endhighlight %} - +{% include_example jdbc_dataset scala/org/apache/spark/examples/sql/SQLDataSourceExample.scala %}
    - -{% highlight java %} - -Map options = new HashMap<>(); -options.put("url", "jdbc:postgresql:dbserver"); -options.put("dbtable", "schema.tablename"); - -Dataset jdbcDF = spark.read().format("jdbc"). options(options).load(); -{% endhighlight %} - - +{% include_example jdbc_dataset java/org/apache/spark/examples/sql/JavaSQLDataSourceExample.java %}
    - -{% highlight python %} - -df = spark.read.format('jdbc').options(url='jdbc:postgresql:dbserver', dbtable='schema.tablename').load() - -{% endhighlight %} - +{% include_example jdbc_dataset python/sql/datasource.py %}
    - -{% include_example jdbc r/RSparkSQLExample.R %} - +{% include_example jdbc_dataset r/RSparkSQLExample.R %}
    @@ -1175,6 +1147,15 @@ that these options will be deprecated in future release as more optimizations ar scheduled first).
    + + + + + diff --git a/docs/streaming-custom-receivers.md b/docs/streaming-custom-receivers.md index 479140f51910..f52bf348fcc9 100644 --- a/docs/streaming-custom-receivers.md +++ b/docs/streaming-custom-receivers.md @@ -181,7 +181,7 @@ val words = lines.flatMap(_.split(" ")) ... {% endhighlight %} -The full source code is in the example [CustomReceiver.scala](https://github.com/apache/spark/blob/master/examples/src/main/scala/org/apache/spark/examples/streaming/CustomReceiver.scala). +The full source code is in the example [CustomReceiver.scala]({{site.SPARK_GITHUB_URL}}/blob/v{{site.SPARK_VERSION_SHORT}}/examples/src/main/scala/org/apache/spark/examples/streaming/CustomReceiver.scala).
    @@ -193,7 +193,7 @@ JavaDStream words = lines.flatMap(new FlatMapFunction() ... {% endhighlight %} -The full source code is in the example [JavaCustomReceiver.java](https://github.com/apache/spark/blob/master/examples/src/main/java/org/apache/spark/examples/streaming/JavaCustomReceiver.java). +The full source code is in the example [JavaCustomReceiver.java]({{site.SPARK_GITHUB_URL}}/blob/v{{site.SPARK_VERSION_SHORT}}/examples/src/main/java/org/apache/spark/examples/streaming/JavaCustomReceiver.java).
    diff --git a/docs/streaming-flume-integration.md b/docs/streaming-flume-integration.md index 8eeeee75dbf4..767e1f9402e0 100644 --- a/docs/streaming-flume-integration.md +++ b/docs/streaming-flume-integration.md @@ -63,7 +63,7 @@ configuring Flume agents. By default, the Python API will decode Flume event body as UTF8 encoded strings. You can specify your custom decoding function to decode the body byte arrays in Flume events to any arbitrary data type. See the [API docs](api/python/pyspark.streaming.html#pyspark.streaming.flume.FlumeUtils) - and the [example]({{site.SPARK_GITHUB_URL}}/blob/master/examples/src/main/python/streaming/flume_wordcount.py). + and the [example]({{site.SPARK_GITHUB_URL}}/blob/v{{site.SPARK_VERSION_SHORT}}/examples/src/main/python/streaming/flume_wordcount.py). diff --git a/docs/streaming-kafka-0-10-integration.md b/docs/streaming-kafka-0-10-integration.md new file mode 100644 index 000000000000..44c39e39446d --- /dev/null +++ b/docs/streaming-kafka-0-10-integration.md @@ -0,0 +1,192 @@ +--- +layout: global +title: Spark Streaming + Kafka Integration Guide (Kafka broker version 0.10.0 or higher) +--- + +The Spark Streaming integration for Kafka 0.10 is similar in design to the 0.8 [Direct Stream approach](streaming-kafka-0-8-integration.html#approach-2-direct-approach-no-receivers). It provides simple parallelism, 1:1 correspondence between Kafka partitions and Spark partitions, and access to offsets and metadata. However, because the newer integration uses the [new Kafka consumer API](http://kafka.apache.org/documentation.html#newconsumerapi) instead of the simple API, there are notable differences in usage. This version of the integration is marked as experimental, so the API is potentially subject to change. + +### Linking +For Scala/Java applications using SBT/Maven project definitions, link your streaming application with the following artifact (see [Linking section](streaming-programming-guide.html#linking) in the main programming guide for further information). + + groupId = org.apache.spark + artifactId = spark-streaming-kafka-0-10_{{site.SCALA_BINARY_VERSION}} + version = {{site.SPARK_VERSION_SHORT}} + +### Creating a Direct Stream + Note that the namespace for the import includes the version, org.apache.spark.streaming.kafka010 + +
    +
    + import org.apache.kafka.clients.consumer.ConsumerRecord + import org.apache.kafka.common.serialization.StringDeserializer + import org.apache.spark.streaming.kafka010._ + import org.apache.spark.streaming.kafka010.LocationStrategies.PreferConsistent + import org.apache.spark.streaming.kafka010.ConsumerStrategies.Subscribe + + val kafkaParams = Map[String, Object]( + "bootstrap.servers" -> "localhost:9092,anotherhost:9092", + "key.deserializer" -> classOf[StringDeserializer], + "value.deserializer" -> classOf[StringDeserializer], + "group.id" -> "example", + "auto.offset.reset" -> "latest", + "enable.auto.commit" -> (false: java.lang.Boolean) + ) + + val topics = Array("topicA", "topicB") + val stream = KafkaUtils.createDirectStream[String, String]( + streamingContext, + PreferConsistent, + Subscribe[String, String](topics, kafkaParams) + ) + + stream.map(record => (record.key, record.value)) + +Each item in the stream is a [ConsumerRecord](http://kafka.apache.org/0100/javadoc/org/apache/kafka/clients/consumer/ConsumerRecord.html) +
    +
    +
    +
    + +For possible kafkaParams, see [Kafka consumer config docs](http://kafka.apache.org/documentation.html#newconsumerconfigs). +Note that enable.auto.commit is disabled, for discussion see [Storing Offsets](streaming-kafka-0-10-integration.html#storing-offsets) below. + +### LocationStrategies +The new Kafka consumer API will pre-fetch messages into buffers. Therefore it is important for performance reasons that the Spark integration keep cached consumers on executors (rather than recreating them for each batch), and prefer to schedule partitions on the host locations that have the appropriate consumers. + +In most cases, you should use `LocationStrategies.PreferConsistent` as shown above. This will distribute partitions evenly across available executors. If your executors are on the same hosts as your Kafka brokers, use `PreferBrokers`, which will prefer to schedule partitions on the Kafka leader for that partition. Finally, if you have a significant skew in load among partitions, use `PreferFixed`. This allows you to specify an explicit mapping of partitions to hosts (any unspecified partitions will use a consistent location). + +The cache for consumers has a default maximum size of 64. If you expect to be handling more than (64 * number of executors) Kafka partitions, you can change this setting via `spark.streaming.kafka.consumer.cache.maxCapacity` + +### ConsumerStrategies +The new Kafka consumer API has a number of different ways to specify topics, some of which require considerable post-object-instantiation setup. `ConsumerStrategies` provides an abstraction that allows Spark to obtain properly configured consumers even after restart from checkpoint. + +`ConsumerStrategies.Subscribe`, as shown above, allows you to subscribe to a fixed collection of topics. `SubscribePattern` allows you to use a regex to specify topics of interest. Note that unlike the 0.8 integration, using `Subscribe` or `SubscribePattern` should respond to adding partitions during a running stream. Finally, `Assign` allows you to specify a fixed collection of partitions. All three strategies have overloaded constructors that allow you to specify the starting offset for a particular partition. + +If you have specific consumer setup needs that are not met by the options above, `ConsumerStrategy` is a public class that you can extend. + +### Creating an RDD +If you have a use case that is better suited to batch processing, you can create an RDD for a defined range of offsets. + +
    +
    + // Import dependencies and create kafka params as in Create Direct Stream above + + val offsetRanges = Array( + // topic, partition, inclusive starting offset, exclusive ending offset + OffsetRange("test", 0, 0, 100), + OffsetRange("test", 1, 0, 100) + ) + + val rdd = KafkaUtils.createRDD[String, String](sparkContext, kafkaParams, offsetRanges, PreferConsistent) + +
    +
    +
    +
    + +Note that you cannot use `PreferBrokers`, because without the stream there is not a driver-side consumer to automatically look up broker metadata for you. Use `PreferFixed` with your own metadata lookups if necessary. + +### Obtaining Offsets + +
    +
    + stream.foreachRDD { rdd => + val offsetRanges = rdd.asInstanceOf[HasOffsetRanges].offsetRanges + rdd.foreachPartition { iter => + val o: OffsetRange = offsetRanges(TaskContext.get.partitionId) + println(s"${o.topic} ${o.partition} ${o.fromOffset} ${o.untilOffset}") + } + } +
    +
    +
    +
    + +Note that the typecast to `HasOffsetRanges` will only succeed if it is done in the first method called on the result of `createDirectStream`, not later down a chain of methods. Be aware that the one-to-one mapping between RDD partition and Kafka partition does not remain after any methods that shuffle or repartition, e.g. reduceByKey() or window(). + +### Storing Offsets +Kafka delivery semantics in the case of failure depend on how and when offsets are stored. Spark output operations are [at-least-once](streaming-programming-guide.html#semantics-of-output-operations). So if you want the equivalent of exactly-once semantics, you must either store offsets after an idempotent output, or store offsets in an atomic transaction alongside output. With this integration, you have 3 options, in order of increasing reliablity (and code complexity), for how to store offsets. + +#### Checkpoints +If you enable Spark [checkpointing](streaming-programming-guide.html#checkpointing), offsets will be stored in the checkpoint. This is easy to enable, but there are drawbacks. Your output operation must be idempotent, since you will get repeated outputs; transactions are not an option. Furthermore, you cannot recover from a checkpoint if your application code has changed. For planned upgrades, you can mitigate this by running the new code at the same time as the old code (since outputs need to be idempotent anyway, they should not clash). But for unplanned failures that require code changes, you will lose data unless you have another way to identify known good starting offsets. + +#### Kafka itself +Kafka has an offset commit API that stores offsets in a special Kafka topic. By default, the new consumer will periodically auto-commit offsets. This is almost certainly not what you want, because messages successfully polled by the consumer may not yet have resulted in a Spark output operation, resulting in undefined semantics. This is why the stream example above sets "enable.auto.commit" to false. However, you can commit offsets to Kafka after you know your output has been stored, using the `commitAsync` API. The benefit as compared to checkpoints is that Kafka is a durable store regardless of changes to your application code. However, Kafka is not transactional, so your outputs must still be idempotent. + +
    +
    + stream.foreachRDD { rdd => + val offsets = rdd.asInstanceOf[HasOffsetRanges].offsetRanges + + // some time later, after outputs have completed + stream.asInstanceOf[CanCommitOffsets].commitAsync(offsets) + } + +As with HasOffsetRanges, the cast to CanCommitOffsets will only succeed if called on the result of createDirectStream, not after transformations. The commitAsync call is threadsafe, but must occur after outputs if you want meaningful semantics. +
    +
    +
    +
    + +#### Your own data store +For data stores that support transactions, saving offsets in the same transaction as the results can keep the two in sync, even in failure situations. If you're careful about detecting repeated or skipped offset ranges, rolling back the transaction prevents duplicated or lost messages from affecting results. This gives the equivalent of exactly-once semantics. It is also possible to use this tactic even for outputs that result from aggregations, which are typically hard to make idempotent. + +
    +
    + // The details depend on your data store, but the general idea looks like this + + // begin from the the offsets committed to the database + val fromOffsets = selectOffsetsFromYourDatabase.map { resultSet => + new TopicPartition(resultSet.string("topic")), resultSet.int("partition")) -> resultSet.long("offset") + }.toMap + + val stream = KafkaUtils.createDirectStream[String, String]( + streamingContext, + PreferConsistent, + Assign[String, String](fromOffsets.keys.toList, kafkaParams, fromOffsets) + ) + + stream.foreachRDD { rdd => + val offsetRanges = rdd.asInstanceOf[HasOffsetRanges].offsetRanges + + val results = yourCalculation(rdd) + + yourTransactionBlock { + // update results + + // update offsets where the end of existing offsets matches the beginning of this batch of offsets + + // assert that offsets were updated correctly + } + } +
    +
    +
    +
    + +### SSL / TLS +The new Kafka consumer [supports SSL](http://kafka.apache.org/documentation.html#security_ssl). To enable it, set kafkaParams appropriately before passing to `createDirectStream` / `createRDD`. Note that this only applies to communication between Spark and Kafka brokers; you are still responsible for separately [securing](security.html) Spark inter-node communication. + + +
    +
    + val kafkaParams = Map[String, Object]( + // the usual params, make sure to change the port in bootstrap.servers if 9092 is not TLS + "security.protocol" -> "SSL", + "ssl.truststore.location" -> "/some-directory/kafka.client.truststore.jks", + "ssl.truststore.password" -> "test1234", + "ssl.keystore.location" -> "/some-directory/kafka.client.keystore.jks", + "ssl.keystore.password" -> "test1234", + "ssl.key.password" -> "test1234" + ) +
    +
    +
    +
    + +### Deploying + +As with any Spark applications, `spark-submit` is used to launch your application. + +For Scala and Java applications, if you are using SBT or Maven for project management, then package `spark-streaming-kafka-0-10_{{site.SCALA_BINARY_VERSION}}` and its dependencies into the application JAR. Make sure `spark-core_{{site.SCALA_BINARY_VERSION}}` and `spark-streaming_{{site.SCALA_BINARY_VERSION}}` are marked as `provided` dependencies as those are already present in a Spark installation. Then use `spark-submit` to launch your application (see [Deploying section](streaming-programming-guide.html#deploying-applications) in the main programming guide). + diff --git a/docs/streaming-kafka-0-8-integration.md b/docs/streaming-kafka-0-8-integration.md new file mode 100644 index 000000000000..f8f7b95cf745 --- /dev/null +++ b/docs/streaming-kafka-0-8-integration.md @@ -0,0 +1,210 @@ +--- +layout: global +title: Spark Streaming + Kafka Integration Guide (Kafka broker version 0.8.2.1 or higher) +--- +Here we explain how to configure Spark Streaming to receive data from Kafka. There are two approaches to this - the old approach using Receivers and Kafka's high-level API, and a new approach (introduced in Spark 1.3) without using Receivers. They have different programming models, performance characteristics, and semantics guarantees, so read on for more details. Both approaches are considered stable APIs as of the current version of Spark. + +## Approach 1: Receiver-based Approach +This approach uses a Receiver to receive the data. The Receiver is implemented using the Kafka high-level consumer API. As with all receivers, the data received from Kafka through a Receiver is stored in Spark executors, and then jobs launched by Spark Streaming processes the data. + +However, under default configuration, this approach can lose data under failures (see [receiver reliability](streaming-programming-guide.html#receiver-reliability). To ensure zero-data loss, you have to additionally enable Write Ahead Logs in Spark Streaming (introduced in Spark 1.2). This synchronously saves all the received Kafka data into write ahead logs on a distributed file system (e.g HDFS), so that all the data can be recovered on failure. See [Deploying section](streaming-programming-guide.html#deploying-applications) in the streaming programming guide for more details on Write Ahead Logs. + +Next, we discuss how to use this approach in your streaming application. + +1. **Linking:** For Scala/Java applications using SBT/Maven project definitions, link your streaming application with the following artifact (see [Linking section](streaming-programming-guide.html#linking) in the main programming guide for further information). + + groupId = org.apache.spark + artifactId = spark-streaming-kafka-0-8_{{site.SCALA_BINARY_VERSION}} + version = {{site.SPARK_VERSION_SHORT}} + + For Python applications, you will have to add this above library and its dependencies when deploying your application. See the *Deploying* subsection below. + +2. **Programming:** In the streaming application code, import `KafkaUtils` and create an input DStream as follows. + +
    +
    + import org.apache.spark.streaming.kafka._ + + val kafkaStream = KafkaUtils.createStream(streamingContext, + [ZK quorum], [consumer group id], [per-topic number of Kafka partitions to consume]) + + You can also specify the key and value classes and their corresponding decoder classes using variations of `createStream`. See the [API docs](api/scala/index.html#org.apache.spark.streaming.kafka.KafkaUtils$) + and the [example]({{site.SPARK_GITHUB_URL}}/blob/v{{site.SPARK_VERSION_SHORT}}/examples/src/main/scala/org/apache/spark/examples/streaming/KafkaWordCount.scala). +
    +
    + import org.apache.spark.streaming.kafka.*; + + JavaPairReceiverInputDStream kafkaStream = + KafkaUtils.createStream(streamingContext, + [ZK quorum], [consumer group id], [per-topic number of Kafka partitions to consume]); + + You can also specify the key and value classes and their corresponding decoder classes using variations of `createStream`. See the [API docs](api/java/index.html?org/apache/spark/streaming/kafka/KafkaUtils.html) + and the [example]({{site.SPARK_GITHUB_URL}}/blob/v{{site.SPARK_VERSION_SHORT}}/examples/src/main/java/org/apache/spark/examples/streaming/JavaKafkaWordCount.java). + +
    +
    + from pyspark.streaming.kafka import KafkaUtils + + kafkaStream = KafkaUtils.createStream(streamingContext, \ + [ZK quorum], [consumer group id], [per-topic number of Kafka partitions to consume]) + + By default, the Python API will decode Kafka data as UTF8 encoded strings. You can specify your custom decoding function to decode the byte arrays in Kafka records to any arbitrary data type. See the [API docs](api/python/pyspark.streaming.html#pyspark.streaming.kafka.KafkaUtils) + and the [example]({{site.SPARK_GITHUB_URL}}/blob/v{{site.SPARK_VERSION_SHORT}}/examples/src/main/python/streaming/kafka_wordcount.py). +
    +
    + + **Points to remember:** + + - Topic partitions in Kafka does not correlate to partitions of RDDs generated in Spark Streaming. So increasing the number of topic-specific partitions in the `KafkaUtils.createStream()` only increases the number of threads using which topics that are consumed within a single receiver. It does not increase the parallelism of Spark in processing the data. Refer to the main document for more information on that. + + - Multiple Kafka input DStreams can be created with different groups and topics for parallel receiving of data using multiple receivers. + + - If you have enabled Write Ahead Logs with a replicated file system like HDFS, the received data is already being replicated in the log. Hence, the storage level in storage level for the input stream to `StorageLevel.MEMORY_AND_DISK_SER` (that is, use +`KafkaUtils.createStream(..., StorageLevel.MEMORY_AND_DISK_SER)`). + +3. **Deploying:** As with any Spark applications, `spark-submit` is used to launch your application. However, the details are slightly different for Scala/Java applications and Python applications. + + For Scala and Java applications, if you are using SBT or Maven for project management, then package `spark-streaming-kafka-0-8_{{site.SCALA_BINARY_VERSION}}` and its dependencies into the application JAR. Make sure `spark-core_{{site.SCALA_BINARY_VERSION}}` and `spark-streaming_{{site.SCALA_BINARY_VERSION}}` are marked as `provided` dependencies as those are already present in a Spark installation. Then use `spark-submit` to launch your application (see [Deploying section](streaming-programming-guide.html#deploying-applications) in the main programming guide). + + For Python applications which lack SBT/Maven project management, `spark-streaming-kafka-0-8_{{site.SCALA_BINARY_VERSION}}` and its dependencies can be directly added to `spark-submit` using `--packages` (see [Application Submission Guide](submitting-applications.html)). That is, + + ./bin/spark-submit --packages org.apache.spark:spark-streaming-kafka-0-8_{{site.SCALA_BINARY_VERSION}}:{{site.SPARK_VERSION_SHORT}} ... + + Alternatively, you can also download the JAR of the Maven artifact `spark-streaming-kafka-0-8-assembly` from the + [Maven repository](http://search.maven.org/#search|ga|1|a%3A%22spark-streaming-kafka-0-8-assembly_{{site.SCALA_BINARY_VERSION}}%22%20AND%20v%3A%22{{site.SPARK_VERSION_SHORT}}%22) and add it to `spark-submit` with `--jars`. + +## Approach 2: Direct Approach (No Receivers) +This new receiver-less "direct" approach has been introduced in Spark 1.3 to ensure stronger end-to-end guarantees. Instead of using receivers to receive data, this approach periodically queries Kafka for the latest offsets in each topic+partition, and accordingly defines the offset ranges to process in each batch. When the jobs to process the data are launched, Kafka's simple consumer API is used to read the defined ranges of offsets from Kafka (similar to read files from a file system). Note that this feature was introduced in Spark 1.3 for the Scala and Java API, in Spark 1.4 for the Python API. + +This approach has the following advantages over the receiver-based approach (i.e. Approach 1). + +- *Simplified Parallelism:* No need to create multiple input Kafka streams and union them. With `directStream`, Spark Streaming will create as many RDD partitions as there are Kafka partitions to consume, which will all read data from Kafka in parallel. So there is a one-to-one mapping between Kafka and RDD partitions, which is easier to understand and tune. + +- *Efficiency:* Achieving zero-data loss in the first approach required the data to be stored in a Write Ahead Log, which further replicated the data. This is actually inefficient as the data effectively gets replicated twice - once by Kafka, and a second time by the Write Ahead Log. This second approach eliminates the problem as there is no receiver, and hence no need for Write Ahead Logs. As long as you have sufficient Kafka retention, messages can be recovered from Kafka. + +- *Exactly-once semantics:* The first approach uses Kafka's high level API to store consumed offsets in Zookeeper. This is traditionally the way to consume data from Kafka. While this approach (in combination with write ahead logs) can ensure zero data loss (i.e. at-least once semantics), there is a small chance some records may get consumed twice under some failures. This occurs because of inconsistencies between data reliably received by Spark Streaming and offsets tracked by Zookeeper. Hence, in this second approach, we use simple Kafka API that does not use Zookeeper. Offsets are tracked by Spark Streaming within its checkpoints. This eliminates inconsistencies between Spark Streaming and Zookeeper/Kafka, and so each record is received by Spark Streaming effectively exactly once despite failures. In order to achieve exactly-once semantics for output of your results, your output operation that saves the data to an external data store must be either idempotent, or an atomic transaction that saves results and offsets (see [Semantics of output operations](streaming-programming-guide.html#semantics-of-output-operations) in the main programming guide for further information). + +Note that one disadvantage of this approach is that it does not update offsets in Zookeeper, hence Zookeeper-based Kafka monitoring tools will not show progress. However, you can access the offsets processed by this approach in each batch and update Zookeeper yourself (see below). + +Next, we discuss how to use this approach in your streaming application. + +1. **Linking:** This approach is supported only in Scala/Java application. Link your SBT/Maven project with the following artifact (see [Linking section](streaming-programming-guide.html#linking) in the main programming guide for further information). + + groupId = org.apache.spark + artifactId = spark-streaming-kafka-0-8_{{site.SCALA_BINARY_VERSION}} + version = {{site.SPARK_VERSION_SHORT}} + +2. **Programming:** In the streaming application code, import `KafkaUtils` and create an input DStream as follows. + +
    +
    + import org.apache.spark.streaming.kafka._ + + val directKafkaStream = KafkaUtils.createDirectStream[ + [key class], [value class], [key decoder class], [value decoder class] ]( + streamingContext, [map of Kafka parameters], [set of topics to consume]) + + You can also pass a `messageHandler` to `createDirectStream` to access `MessageAndMetadata` that contains metadata about the current message and transform it to any desired type. + See the [API docs](api/scala/index.html#org.apache.spark.streaming.kafka.KafkaUtils$) + and the [example]({{site.SPARK_GITHUB_URL}}/blob/v{{site.SPARK_VERSION_SHORT}}/examples/src/main/scala/org/apache/spark/examples/streaming/DirectKafkaWordCount.scala). +
    +
    + import org.apache.spark.streaming.kafka.*; + + JavaPairInputDStream directKafkaStream = + KafkaUtils.createDirectStream(streamingContext, + [key class], [value class], [key decoder class], [value decoder class], + [map of Kafka parameters], [set of topics to consume]); + + You can also pass a `messageHandler` to `createDirectStream` to access `MessageAndMetadata` that contains metadata about the current message and transform it to any desired type. + See the [API docs](api/java/index.html?org/apache/spark/streaming/kafka/KafkaUtils.html) + and the [example]({{site.SPARK_GITHUB_URL}}/blob/v{{site.SPARK_VERSION_SHORT}}/examples/src/main/java/org/apache/spark/examples/streaming/JavaDirectKafkaWordCount.java). + +
    +
    + from pyspark.streaming.kafka import KafkaUtils + directKafkaStream = KafkaUtils.createDirectStream(ssc, [topic], {"metadata.broker.list": brokers}) + + You can also pass a `messageHandler` to `createDirectStream` to access `KafkaMessageAndMetadata` that contains metadata about the current message and transform it to any desired type. + By default, the Python API will decode Kafka data as UTF8 encoded strings. You can specify your custom decoding function to decode the byte arrays in Kafka records to any arbitrary data type. See the [API docs](api/python/pyspark.streaming.html#pyspark.streaming.kafka.KafkaUtils) + and the [example]({{site.SPARK_GITHUB_URL}}/blob/v{{site.SPARK_VERSION_SHORT}}/examples/src/main/python/streaming/direct_kafka_wordcount.py). +
    +
    + + In the Kafka parameters, you must specify either `metadata.broker.list` or `bootstrap.servers`. + By default, it will start consuming from the latest offset of each Kafka partition. If you set configuration `auto.offset.reset` in Kafka parameters to `smallest`, then it will start consuming from the smallest offset. + + You can also start consuming from any arbitrary offset using other variations of `KafkaUtils.createDirectStream`. Furthermore, if you want to access the Kafka offsets consumed in each batch, you can do the following. + +
    +
    + // Hold a reference to the current offset ranges, so it can be used downstream + var offsetRanges = Array[OffsetRange]() + + directKafkaStream.transform { rdd => + offsetRanges = rdd.asInstanceOf[HasOffsetRanges].offsetRanges + rdd + }.map { + ... + }.foreachRDD { rdd => + for (o <- offsetRanges) { + println(s"${o.topic} ${o.partition} ${o.fromOffset} ${o.untilOffset}") + } + ... + } +
    +
    + // Hold a reference to the current offset ranges, so it can be used downstream + final AtomicReference offsetRanges = new AtomicReference<>(); + + directKafkaStream.transformToPair( + new Function, JavaPairRDD>() { + @Override + public JavaPairRDD call(JavaPairRDD rdd) throws Exception { + OffsetRange[] offsets = ((HasOffsetRanges) rdd.rdd()).offsetRanges(); + offsetRanges.set(offsets); + return rdd; + } + } + ).map( + ... + ).foreachRDD( + new Function, Void>() { + @Override + public Void call(JavaPairRDD rdd) throws IOException { + for (OffsetRange o : offsetRanges.get()) { + System.out.println( + o.topic() + " " + o.partition() + " " + o.fromOffset() + " " + o.untilOffset() + ); + } + ... + return null; + } + } + ); +
    +
    + offsetRanges = [] + + def storeOffsetRanges(rdd): + global offsetRanges + offsetRanges = rdd.offsetRanges() + return rdd + + def printOffsetRanges(rdd): + for o in offsetRanges: + print "%s %s %s %s" % (o.topic, o.partition, o.fromOffset, o.untilOffset) + + directKafkaStream\ + .transform(storeOffsetRanges)\ + .foreachRDD(printOffsetRanges) +
    +
    + + You can use this to update Zookeeper yourself if you want Zookeeper-based Kafka monitoring tools to show progress of the streaming application. + + Note that the typecast to HasOffsetRanges will only succeed if it is done in the first method called on the directKafkaStream, not later down a chain of methods. You can use transform() instead of foreachRDD() as your first method call in order to access offsets, then call further Spark methods. However, be aware that the one-to-one mapping between RDD partition and Kafka partition does not remain after any methods that shuffle or repartition, e.g. reduceByKey() or window(). + + Another thing to note is that since this approach does not use Receivers, the standard receiver-related (that is, [configurations](configuration.html) of the form `spark.streaming.receiver.*` ) will not apply to the input DStreams created by this approach (will apply to other input DStreams though). Instead, use the [configurations](configuration.html) `spark.streaming.kafka.*`. An important one is `spark.streaming.kafka.maxRatePerPartition` which is the maximum rate (in messages per second) at which each Kafka partition will be read by this direct API. + +3. **Deploying:** This is same as the first approach. diff --git a/docs/streaming-kafka-integration.md b/docs/streaming-kafka-integration.md index e0d3f4f69be8..a8f3667a4985 100644 --- a/docs/streaming-kafka-integration.md +++ b/docs/streaming-kafka-integration.md @@ -2,209 +2,52 @@ layout: global title: Spark Streaming + Kafka Integration Guide --- -[Apache Kafka](http://kafka.apache.org/) is publish-subscribe messaging rethought as a distributed, partitioned, replicated commit log service. Here we explain how to configure Spark Streaming to receive data from Kafka. There are two approaches to this - the old approach using Receivers and Kafka's high-level API, and a new experimental approach (introduced in Spark 1.3) without using Receivers. They have different programming models, performance characteristics, and semantics guarantees, so read on for more details. -## Approach 1: Receiver-based Approach -This approach uses a Receiver to receive the data. The Receiver is implemented using the Kafka high-level consumer API. As with all receivers, the data received from Kafka through a Receiver is stored in Spark executors, and then jobs launched by Spark Streaming processes the data. - -However, under default configuration, this approach can lose data under failures (see [receiver reliability](streaming-programming-guide.html#receiver-reliability). To ensure zero-data loss, you have to additionally enable Write Ahead Logs in Spark Streaming (introduced in Spark 1.2). This synchronously saves all the received Kafka data into write ahead logs on a distributed file system (e.g HDFS), so that all the data can be recovered on failure. See [Deploying section](streaming-programming-guide.html#deploying-applications) in the streaming programming guide for more details on Write Ahead Logs. - -Next, we discuss how to use this approach in your streaming application. - -1. **Linking:** For Scala/Java applications using SBT/Maven project definitions, link your streaming application with the following artifact (see [Linking section](streaming-programming-guide.html#linking) in the main programming guide for further information). - - groupId = org.apache.spark - artifactId = spark-streaming-kafka-0-8_{{site.SCALA_BINARY_VERSION}} - version = {{site.SPARK_VERSION_SHORT}} - - For Python applications, you will have to add this above library and its dependencies when deploying your application. See the *Deploying* subsection below. - -2. **Programming:** In the streaming application code, import `KafkaUtils` and create an input DStream as follows. - -
    -
    - import org.apache.spark.streaming.kafka._ - - val kafkaStream = KafkaUtils.createStream(streamingContext, - [ZK quorum], [consumer group id], [per-topic number of Kafka partitions to consume]) - - You can also specify the key and value classes and their corresponding decoder classes using variations of `createStream`. See the [API docs](api/scala/index.html#org.apache.spark.streaming.kafka.KafkaUtils$) - and the [example]({{site.SPARK_GITHUB_URL}}/blob/master/examples/src/main/scala/org/apache/spark/examples/streaming/KafkaWordCount.scala). -
    -
    - import org.apache.spark.streaming.kafka.*; - - JavaPairReceiverInputDStream kafkaStream = - KafkaUtils.createStream(streamingContext, - [ZK quorum], [consumer group id], [per-topic number of Kafka partitions to consume]); - - You can also specify the key and value classes and their corresponding decoder classes using variations of `createStream`. See the [API docs](api/java/index.html?org/apache/spark/streaming/kafka/KafkaUtils.html) - and the [example]({{site.SPARK_GITHUB_URL}}/blob/master/examples/src/main/java/org/apache/spark/examples/streaming/JavaKafkaWordCount.java). - -
    -
    - from pyspark.streaming.kafka import KafkaUtils - - kafkaStream = KafkaUtils.createStream(streamingContext, \ - [ZK quorum], [consumer group id], [per-topic number of Kafka partitions to consume]) - - By default, the Python API will decode Kafka data as UTF8 encoded strings. You can specify your custom decoding function to decode the byte arrays in Kafka records to any arbitrary data type. See the [API docs](api/python/pyspark.streaming.html#pyspark.streaming.kafka.KafkaUtils) - and the [example]({{site.SPARK_GITHUB_URL}}/blob/master/examples/src/main/python/streaming/kafka_wordcount.py). -
    -
    - - **Points to remember:** - - - Topic partitions in Kafka does not correlate to partitions of RDDs generated in Spark Streaming. So increasing the number of topic-specific partitions in the `KafkaUtils.createStream()` only increases the number of threads using which topics that are consumed within a single receiver. It does not increase the parallelism of Spark in processing the data. Refer to the main document for more information on that. - - - Multiple Kafka input DStreams can be created with different groups and topics for parallel receiving of data using multiple receivers. - - - If you have enabled Write Ahead Logs with a replicated file system like HDFS, the received data is already being replicated in the log. Hence, the storage level in storage level for the input stream to `StorageLevel.MEMORY_AND_DISK_SER` (that is, use -`KafkaUtils.createStream(..., StorageLevel.MEMORY_AND_DISK_SER)`). - -3. **Deploying:** As with any Spark applications, `spark-submit` is used to launch your application. However, the details are slightly different for Scala/Java applications and Python applications. - - For Scala and Java applications, if you are using SBT or Maven for project management, then package `spark-streaming-kafka-0-8_{{site.SCALA_BINARY_VERSION}}` and its dependencies into the application JAR. Make sure `spark-core_{{site.SCALA_BINARY_VERSION}}` and `spark-streaming_{{site.SCALA_BINARY_VERSION}}` are marked as `provided` dependencies as those are already present in a Spark installation. Then use `spark-submit` to launch your application (see [Deploying section](streaming-programming-guide.html#deploying-applications) in the main programming guide). - - For Python applications which lack SBT/Maven project management, `spark-streaming-kafka-0-8_{{site.SCALA_BINARY_VERSION}}` and its dependencies can be directly added to `spark-submit` using `--packages` (see [Application Submission Guide](submitting-applications.html)). That is, - - ./bin/spark-submit --packages org.apache.spark:spark-streaming-kafka-0-8_{{site.SCALA_BINARY_VERSION}}:{{site.SPARK_VERSION_SHORT}} ... - - Alternatively, you can also download the JAR of the Maven artifact `spark-streaming-kafka-0-8-assembly` from the - [Maven repository](http://search.maven.org/#search|ga|1|a%3A%22spark-streaming-kafka-0-8-assembly_{{site.SCALA_BINARY_VERSION}}%22%20AND%20v%3A%22{{site.SPARK_VERSION_SHORT}}%22) and add it to `spark-submit` with `--jars`. - -## Approach 2: Direct Approach (No Receivers) -This new receiver-less "direct" approach has been introduced in Spark 1.3 to ensure stronger end-to-end guarantees. Instead of using receivers to receive data, this approach periodically queries Kafka for the latest offsets in each topic+partition, and accordingly defines the offset ranges to process in each batch. When the jobs to process the data are launched, Kafka's simple consumer API is used to read the defined ranges of offsets from Kafka (similar to read files from a file system). Note that this is an experimental feature introduced in Spark 1.3 for the Scala and Java API, in Spark 1.4 for the Python API. - -This approach has the following advantages over the receiver-based approach (i.e. Approach 1). - -- *Simplified Parallelism:* No need to create multiple input Kafka streams and union them. With `directStream`, Spark Streaming will create as many RDD partitions as there are Kafka partitions to consume, which will all read data from Kafka in parallel. So there is a one-to-one mapping between Kafka and RDD partitions, which is easier to understand and tune. - -- *Efficiency:* Achieving zero-data loss in the first approach required the data to be stored in a Write Ahead Log, which further replicated the data. This is actually inefficient as the data effectively gets replicated twice - once by Kafka, and a second time by the Write Ahead Log. This second approach eliminates the problem as there is no receiver, and hence no need for Write Ahead Logs. As long as you have sufficient Kafka retention, messages can be recovered from Kafka. - -- *Exactly-once semantics:* The first approach uses Kafka's high level API to store consumed offsets in Zookeeper. This is traditionally the way to consume data from Kafka. While this approach (in combination with write ahead logs) can ensure zero data loss (i.e. at-least once semantics), there is a small chance some records may get consumed twice under some failures. This occurs because of inconsistencies between data reliably received by Spark Streaming and offsets tracked by Zookeeper. Hence, in this second approach, we use simple Kafka API that does not use Zookeeper. Offsets are tracked by Spark Streaming within its checkpoints. This eliminates inconsistencies between Spark Streaming and Zookeeper/Kafka, and so each record is received by Spark Streaming effectively exactly once despite failures. In order to achieve exactly-once semantics for output of your results, your output operation that saves the data to an external data store must be either idempotent, or an atomic transaction that saves results and offsets (see [Semantics of output operations](streaming-programming-guide.html#semantics-of-output-operations) in the main programming guide for further information). - -Note that one disadvantage of this approach is that it does not update offsets in Zookeeper, hence Zookeeper-based Kafka monitoring tools will not show progress. However, you can access the offsets processed by this approach in each batch and update Zookeeper yourself (see below). - -Next, we discuss how to use this approach in your streaming application. - -1. **Linking:** This approach is supported only in Scala/Java application. Link your SBT/Maven project with the following artifact (see [Linking section](streaming-programming-guide.html#linking) in the main programming guide for further information). - - groupId = org.apache.spark - artifactId = spark-streaming-kafka-0-8_{{site.SCALA_BINARY_VERSION}} - version = {{site.SPARK_VERSION_SHORT}} - -2. **Programming:** In the streaming application code, import `KafkaUtils` and create an input DStream as follows. - -
    -
    - import org.apache.spark.streaming.kafka._ - - val directKafkaStream = KafkaUtils.createDirectStream[ - [key class], [value class], [key decoder class], [value decoder class] ]( - streamingContext, [map of Kafka parameters], [set of topics to consume]) - - You can also pass a `messageHandler` to `createDirectStream` to access `MessageAndMetadata` that contains metadata about the current message and transform it to any desired type. - See the [API docs](api/scala/index.html#org.apache.spark.streaming.kafka.KafkaUtils$) - and the [example]({{site.SPARK_GITHUB_URL}}/blob/master/examples/src/main/scala/org/apache/spark/examples/streaming/DirectKafkaWordCount.scala). -
    -
    - import org.apache.spark.streaming.kafka.*; - - JavaPairInputDStream directKafkaStream = - KafkaUtils.createDirectStream(streamingContext, - [key class], [value class], [key decoder class], [value decoder class], - [map of Kafka parameters], [set of topics to consume]); - - You can also pass a `messageHandler` to `createDirectStream` to access `MessageAndMetadata` that contains metadata about the current message and transform it to any desired type. - See the [API docs](api/java/index.html?org/apache/spark/streaming/kafka/KafkaUtils.html) - and the [example]({{site.SPARK_GITHUB_URL}}/blob/master/examples/src/main/java/org/apache/spark/examples/streaming/JavaDirectKafkaWordCount.java). - -
    -
    - from pyspark.streaming.kafka import KafkaUtils - directKafkaStream = KafkaUtils.createDirectStream(ssc, [topic], {"metadata.broker.list": brokers}) - - You can also pass a `messageHandler` to `createDirectStream` to access `KafkaMessageAndMetadata` that contains metadata about the current message and transform it to any desired type. - By default, the Python API will decode Kafka data as UTF8 encoded strings. You can specify your custom decoding function to decode the byte arrays in Kafka records to any arbitrary data type. See the [API docs](api/python/pyspark.streaming.html#pyspark.streaming.kafka.KafkaUtils) - and the [example]({{site.SPARK_GITHUB_URL}}/blob/master/examples/src/main/python/streaming/direct_kafka_wordcount.py). -
    -
    - - In the Kafka parameters, you must specify either `metadata.broker.list` or `bootstrap.servers`. - By default, it will start consuming from the latest offset of each Kafka partition. If you set configuration `auto.offset.reset` in Kafka parameters to `smallest`, then it will start consuming from the smallest offset. - - You can also start consuming from any arbitrary offset using other variations of `KafkaUtils.createDirectStream`. Furthermore, if you want to access the Kafka offsets consumed in each batch, you can do the following. - -
    -
    - // Hold a reference to the current offset ranges, so it can be used downstream - var offsetRanges = Array[OffsetRange]() - - directKafkaStream.transform { rdd => - offsetRanges = rdd.asInstanceOf[HasOffsetRanges].offsetRanges - rdd - }.map { - ... - }.foreachRDD { rdd => - for (o <- offsetRanges) { - println(s"${o.topic} ${o.partition} ${o.fromOffset} ${o.untilOffset}") - } - ... - } -
    -
    - // Hold a reference to the current offset ranges, so it can be used downstream - final AtomicReference offsetRanges = new AtomicReference<>(); - - directKafkaStream.transformToPair( - new Function, JavaPairRDD>() { - @Override - public JavaPairRDD call(JavaPairRDD rdd) throws Exception { - OffsetRange[] offsets = ((HasOffsetRanges) rdd.rdd()).offsetRanges(); - offsetRanges.set(offsets); - return rdd; - } - } - ).map( - ... - ).foreachRDD( - new Function, Void>() { - @Override - public Void call(JavaPairRDD rdd) throws IOException { - for (OffsetRange o : offsetRanges.get()) { - System.out.println( - o.topic() + " " + o.partition() + " " + o.fromOffset() + " " + o.untilOffset() - ); - } - ... - return null; - } - } - ); -
    -
    - offsetRanges = [] - - def storeOffsetRanges(rdd): - global offsetRanges - offsetRanges = rdd.offsetRanges() - return rdd - - def printOffsetRanges(rdd): - for o in offsetRanges: - print "%s %s %s %s" % (o.topic, o.partition, o.fromOffset, o.untilOffset) - - directKafkaStream\ - .transform(storeOffsetRanges)\ - .foreachRDD(printOffsetRanges) -
    -
    - - You can use this to update Zookeeper yourself if you want Zookeeper-based Kafka monitoring tools to show progress of the streaming application. - - Note that the typecast to HasOffsetRanges will only succeed if it is done in the first method called on the directKafkaStream, not later down a chain of methods. You can use transform() instead of foreachRDD() as your first method call in order to access offsets, then call further Spark methods. However, be aware that the one-to-one mapping between RDD partition and Kafka partition does not remain after any methods that shuffle or repartition, e.g. reduceByKey() or window(). - - Another thing to note is that since this approach does not use Receivers, the standard receiver-related (that is, [configurations](configuration.html) of the form `spark.streaming.receiver.*` ) will not apply to the input DStreams created by this approach (will apply to other input DStreams though). Instead, use the [configurations](configuration.html) `spark.streaming.kafka.*`. An important one is `spark.streaming.kafka.maxRatePerPartition` which is the maximum rate (in messages per second) at which each Kafka partition will be read by this direct API. - -3. **Deploying:** This is same as the first approach. +[Apache Kafka](http://kafka.apache.org/) is publish-subscribe messaging rethought as a distributed, partitioned, replicated commit log service. Please read the [Kafka documentation](http://kafka.apache.org/documentation.html) thoroughly before starting an integration using Spark. + +The Kafka project introduced a new consumer api between versions 0.8 and 0.10, so there are 2 separate corresponding Spark Streaming packages available. Please choose the correct package for your brokers and desired features; note that the 0.8 integration is compatible with later 0.9 and 0.10 brokers, but the 0.10 integration is not compatible with earlier brokers. + + +
    Property NameProperty groupspark-submit equivalent
    spark.masterApplication Properties--master
    spark.yarn.keytabApplication Properties--keytab
    spark.yarn.principalApplication Properties--principal
    spark.driver.memory Application Properties
    fetchSizefetchsize The JDBC fetch size, which determines how many rows to fetch per round trip. This can help performance on JDBC drivers which default to low fetch size (eg. Oracle with 10 rows).
    spark.sql.broadcastTimeout300 +

    + Timeout in seconds for the broadcast wait time in broadcast joins +

    +
    spark.sql.autoBroadcastJoinThreshold 10485760 (10 MB)
    + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + +
    spark-streaming-kafka-0-8spark-streaming-kafka-0-10
    Broker Version0.8.2.1 or higher0.10.0 or higher
    Api StabilityStableExperimental
    Language SupportScala, Java, PythonScala, Java
    Receiver DStreamYesNo
    Direct DStreamYesYes
    SSL / TLS SupportNoYes
    Offset Commit ApiNoYes
    Dynamic Topic SubscriptionNoYes
    diff --git a/docs/streaming-programming-guide.md b/docs/streaming-programming-guide.md index e80f1c94ff1b..236ae5d649c4 100644 --- a/docs/streaming-programming-guide.md +++ b/docs/streaming-programming-guide.md @@ -126,7 +126,7 @@ ssc.awaitTermination() // Wait for the computation to terminate {% endhighlight %} The complete code can be found in the Spark Streaming example -[NetworkWordCount]({{site.SPARK_GITHUB_URL}}/blob/master/examples/src/main/scala/org/apache/spark/examples/streaming/NetworkWordCount.scala). +[NetworkWordCount]({{site.SPARK_GITHUB_URL}}/blob/v{{site.SPARK_VERSION_SHORT}}/examples/src/main/scala/org/apache/spark/examples/streaming/NetworkWordCount.scala).
    @@ -216,7 +216,7 @@ jssc.awaitTermination(); // Wait for the computation to terminate {% endhighlight %} The complete code can be found in the Spark Streaming example -[JavaNetworkWordCount]({{site.SPARK_GITHUB_URL}}/blob/master/examples/src/main/java/org/apache/spark/examples/streaming/JavaNetworkWordCount.java). +[JavaNetworkWordCount]({{site.SPARK_GITHUB_URL}}/blob/v{{site.SPARK_VERSION_SHORT}}/examples/src/main/java/org/apache/spark/examples/streaming/JavaNetworkWordCount.java).
    @@ -277,7 +277,7 @@ ssc.awaitTermination() # Wait for the computation to terminate {% endhighlight %} The complete code can be found in the Spark Streaming example -[NetworkWordCount]({{site.SPARK_GITHUB_URL}}/blob/master/examples/src/main/python/streaming/network_wordcount.py). +[NetworkWordCount]({{site.SPARK_GITHUB_URL}}/blob/v{{site.SPARK_VERSION_SHORT}}/examples/src/main/python/streaming/network_wordcount.py).
    @@ -656,7 +656,7 @@ methods for creating DStreams from files as input sources. Python API `fileStream` is not available in the Python API, only `textFileStream` is available. - **Streams based on Custom Receivers:** DStreams can be created with data streams received through custom receivers. See the [Custom Receiver - Guide](streaming-custom-receivers.html) and [DStream Akka](https://github.com/spark-packages/dstream-akka) for more details. + Guide](streaming-custom-receivers.html) for more details. - **Queue of RDDs as a Stream:** For testing a Spark Streaming application with test data, one can also create a DStream based on a queue of RDDs, using `streamingContext.queueStream(queueOfRDDs)`. Each RDD pushed into the queue will be treated as a batch of data in the DStream, and processed like a stream. @@ -683,7 +683,7 @@ and add it to the classpath. Some of these advanced sources are as follows. -- **Kafka:** Spark Streaming {{site.SPARK_VERSION_SHORT}} is compatible with Kafka 0.8.2.1. See the [Kafka Integration Guide](streaming-kafka-integration.html) for more details. +- **Kafka:** Spark Streaming {{site.SPARK_VERSION_SHORT}} is compatible with Kafka broker versions 0.8.2.1 or higher. See the [Kafka Integration Guide](streaming-kafka-integration.html) for more details. - **Flume:** Spark Streaming {{site.SPARK_VERSION_SHORT}} is compatible with Flume 1.6.0. See the [Flume Integration Guide](streaming-flume-integration.html) for more details. @@ -854,7 +854,7 @@ JavaPairDStream runningCounts = pairs.updateStateByKey(updateFu The update function will be called for each word, with `newValues` having a sequence of 1's (from the `(word, 1)` pairs) and the `runningCount` having the previous count. For the complete Java code, take a look at the example -[JavaStatefulNetworkWordCount.java]({{site.SPARK_GITHUB_URL}}/blob/master/examples/src/main/java/org/apache/spark/examples/streaming +[JavaStatefulNetworkWordCount.java]({{site.SPARK_GITHUB_URL}}/blob/v{{site.SPARK_VERSION_SHORT}}/examples/src/main/java/org/apache/spark/examples/streaming /JavaStatefulNetworkWordCount.java).
    @@ -877,7 +877,7 @@ runningCounts = pairs.updateStateByKey(updateFunction) The update function will be called for each word, with `newValues` having a sequence of 1's (from the `(word, 1)` pairs) and the `runningCount` having the previous count. For the complete Python code, take a look at the example -[stateful_network_wordcount.py]({{site.SPARK_GITHUB_URL}}/blob/master/examples/src/main/python/streaming/stateful_network_wordcount.py). +[stateful_network_wordcount.py]({{site.SPARK_GITHUB_URL}}/blob/v{{site.SPARK_VERSION_SHORT}}/examples/src/main/python/streaming/stateful_network_wordcount.py). @@ -1428,7 +1428,7 @@ wordCounts.foreachRDD { (rdd: RDD[(String, Int)], time: Time) => {% endhighlight %} -See the full [source code]({{site.SPARK_GITHUB_URL}}/blob/master/examples/src/main/scala/org/apache/spark/examples/streaming/RecoverableNetworkWordCount.scala). +See the full [source code]({{site.SPARK_GITHUB_URL}}/blob/v{{site.SPARK_VERSION_SHORT}}/examples/src/main/scala/org/apache/spark/examples/streaming/RecoverableNetworkWordCount.scala).
    {% highlight java %} @@ -1491,7 +1491,7 @@ wordCounts.foreachRDD(new Function2, Time, Void>() {% endhighlight %} -See the full [source code]({{site.SPARK_GITHUB_URL}}/blob/master/examples/src/main/java/org/apache/spark/examples/streaming/JavaRecoverableNetworkWordCount.java). +See the full [source code]({{site.SPARK_GITHUB_URL}}/blob/v{{site.SPARK_VERSION_SHORT}}/examples/src/main/java/org/apache/spark/examples/streaming/JavaRecoverableNetworkWordCount.java).
    {% highlight python %} @@ -1526,7 +1526,7 @@ wordCounts.foreachRDD(echo) {% endhighlight %} -See the full [source code]({{site.SPARK_GITHUB_URL}}/blob/master/examples/src/main/python/streaming/recoverable_network_wordcount.py). +See the full [source code]({{site.SPARK_GITHUB_URL}}/blob/v{{site.SPARK_VERSION_SHORT}}/examples/src/main/python/streaming/recoverable_network_wordcount.py).
    @@ -1564,7 +1564,7 @@ words.foreachRDD { rdd => {% endhighlight %} -See the full [source code]({{site.SPARK_GITHUB_URL}}/blob/master/examples/src/main/scala/org/apache/spark/examples/streaming/SqlNetworkWordCount.scala). +See the full [source code]({{site.SPARK_GITHUB_URL}}/blob/v{{site.SPARK_VERSION_SHORT}}/examples/src/main/scala/org/apache/spark/examples/streaming/SqlNetworkWordCount.scala).
    {% highlight java %} @@ -1619,7 +1619,7 @@ words.foreachRDD( ); {% endhighlight %} -See the full [source code]({{site.SPARK_GITHUB_URL}}/blob/master/examples/src/main/java/org/apache/spark/examples/streaming/JavaSqlNetworkWordCount.java). +See the full [source code]({{site.SPARK_GITHUB_URL}}/blob/v{{site.SPARK_VERSION_SHORT}}/examples/src/main/java/org/apache/spark/examples/streaming/JavaSqlNetworkWordCount.java).
    {% highlight python %} @@ -1661,7 +1661,7 @@ def process(time, rdd): words.foreachRDD(process) {% endhighlight %} -See the full [source code]({{site.SPARK_GITHUB_URL}}/blob/master/examples/src/main/python/streaming/sql_network_wordcount.py). +See the full [source code]({{site.SPARK_GITHUB_URL}}/blob/v{{site.SPARK_VERSION_SHORT}}/examples/src/main/python/streaming/sql_network_wordcount.py).
    @@ -2350,7 +2350,7 @@ The following table summarizes the semantics under failures: ### With Kafka Direct API {:.no_toc} -In Spark 1.3, we have introduced a new Kafka Direct API, which can ensure that all the Kafka data is received by Spark Streaming exactly once. Along with this, if you implement exactly-once output operation, you can achieve end-to-end exactly-once guarantees. This approach (experimental as of Spark {{site.SPARK_VERSION_SHORT}}) is further discussed in the [Kafka Integration Guide](streaming-kafka-integration.html). +In Spark 1.3, we have introduced a new Kafka Direct API, which can ensure that all the Kafka data is received by Spark Streaming exactly once. Along with this, if you implement exactly-once output operation, you can achieve end-to-end exactly-once guarantees. This approach is further discussed in the [Kafka Integration Guide](streaming-kafka-integration.html). ## Semantics of output operations {:.no_toc} @@ -2378,61 +2378,12 @@ additional effort may be necessary to achieve exactly-once semantics. There are *************************************************************************************************** *************************************************************************************************** -# Migration Guide from 0.9.1 or below to 1.x -Between Spark 0.9.1 and Spark 1.0, there were a few API changes made to ensure future API stability. -This section elaborates the steps required to migrate your existing code to 1.0. - -**Input DStreams**: All operations that create an input stream (e.g., `StreamingContext.socketStream`, `FlumeUtils.createStream`, etc.) now returns -[InputDStream](api/scala/index.html#org.apache.spark.streaming.dstream.InputDStream) / -[ReceiverInputDStream](api/scala/index.html#org.apache.spark.streaming.dstream.ReceiverInputDStream) -(instead of DStream) for Scala, and [JavaInputDStream](api/java/index.html?org/apache/spark/streaming/api/java/JavaInputDStream.html) / -[JavaPairInputDStream](api/java/index.html?org/apache/spark/streaming/api/java/JavaPairInputDStream.html) / -[JavaReceiverInputDStream](api/java/index.html?org/apache/spark/streaming/api/java/JavaReceiverInputDStream.html) / -[JavaPairReceiverInputDStream](api/java/index.html?org/apache/spark/streaming/api/java/JavaPairReceiverInputDStream.html) -(instead of JavaDStream) for Java. This ensures that functionality specific to input streams can -be added to these classes in the future without breaking binary compatibility. -Note that your existing Spark Streaming applications should not require any change -(as these new classes are subclasses of DStream/JavaDStream) but may require recompilation with Spark 1.0. - -**Custom Network Receivers**: Since the release to Spark Streaming, custom network receivers could be defined -in Scala using the class NetworkReceiver. However, the API was limited in terms of error handling -and reporting, and could not be used from Java. Starting Spark 1.0, this class has been -replaced by [Receiver](api/scala/index.html#org.apache.spark.streaming.receiver.Receiver) which has -the following advantages. - -* Methods like `stop` and `restart` have been added to for better control of the lifecycle of a receiver. See -the [custom receiver guide](streaming-custom-receivers.html) for more details. -* Custom receivers can be implemented using both Scala and Java. - -To migrate your existing custom receivers from the earlier NetworkReceiver to the new Receiver, you have -to do the following. - -* Make your custom receiver class extend -[`org.apache.spark.streaming.receiver.Receiver`](api/scala/index.html#org.apache.spark.streaming.receiver.Receiver) -instead of `org.apache.spark.streaming.dstream.NetworkReceiver`. -* Earlier, a BlockGenerator object had to be created by the custom receiver, to which received data was -added for being stored in Spark. It had to be explicitly started and stopped from `onStart()` and `onStop()` -methods. The new Receiver class makes this unnecessary as it adds a set of methods named `store()` -that can be called to store the data in Spark. So, to migrate your custom network receiver, remove any -BlockGenerator object (does not exist any more in Spark 1.0 anyway), and use `store(...)` methods on -received data. - -**Actor-based Receivers**: The Actor-based Receiver APIs have been moved to [DStream Akka](https://github.com/spark-packages/dstream-akka). -Please refer to the project for more details. - -*************************************************************************************************** -*************************************************************************************************** - # Where to Go from Here * Additional guides - [Kafka Integration Guide](streaming-kafka-integration.html) - [Kinesis Integration Guide](streaming-kinesis-integration.html) - [Custom Receiver Guide](streaming-custom-receivers.html) -* External DStream data sources: - - [DStream MQTT](https://github.com/spark-packages/dstream-mqtt) - - [DStream Twitter](https://github.com/spark-packages/dstream-twitter) - - [DStream Akka](https://github.com/spark-packages/dstream-akka) - - [DStream ZeroMQ](https://github.com/spark-packages/dstream-zeromq) +* Third-party DStream data sources can be found in [Third Party Projects](https://cwiki.apache.org/confluence/display/SPARK/Third+Party+Projects) * API documentation - Scala docs * [StreamingContext](api/scala/index.html#org.apache.spark.streaming.StreamingContext) and diff --git a/docs/structured-streaming-programming-guide.md b/docs/structured-streaming-programming-guide.md index 8c14c3d220a2..94b7854449a3 100644 --- a/docs/structured-streaming-programming-guide.md +++ b/docs/structured-streaming-programming-guide.md @@ -14,9 +14,9 @@ Structured Streaming is a scalable and fault-tolerant stream processing engine b # Quick Example Let’s say you want to maintain a running word count of text data received from a data server listening on a TCP socket. Let’s see how you can express this using Structured Streaming. You can see the full code in -[Scala]({{site.SPARK_GITHUB_URL}}/blob/master/examples/src/main/scala/org/apache/spark/examples/sql/streaming/StructuredNetworkWordCount.scala)/ -[Java]({{site.SPARK_GITHUB_URL}}/blob/master/examples/src/main/java/org/apache/spark/examples/sql/streaming/JavaStructuredNetworkWordCount.java)/ -[Python]({{site.SPARK_GITHUB_URL}}/blob/master/examples/src/main/python/sql/streaming/structured_network_wordcount.py). And if you +[Scala]({{site.SPARK_GITHUB_URL}}/blob/v{{site.SPARK_VERSION_SHORT}}/examples/src/main/scala/org/apache/spark/examples/sql/streaming/StructuredNetworkWordCount.scala)/ +[Java]({{site.SPARK_GITHUB_URL}}/blob/v{{site.SPARK_VERSION_SHORT}}/examples/src/main/java/org/apache/spark/examples/sql/streaming/JavaStructuredNetworkWordCount.java)/ +[Python]({{site.SPARK_GITHUB_URL}}/blob/v{{site.SPARK_VERSION_SHORT}}/examples/src/main/python/sql/streaming/structured_network_wordcount.py). And if you [download Spark](http://spark.apache.org/downloads.html), you can directly run the example. In any case, let’s walk through the example step-by-step and understand how it works. First, we have to import the necessary classes and create a local SparkSession, the starting point of all functionalities related to Spark.
    @@ -88,7 +88,7 @@ val words = lines.as[String].flatMap(_.split(" ")) val wordCounts = words.groupBy("value").count() {% endhighlight %} -This `lines` DataFrame represents an unbounded table containing the streaming text data. This table contains one column of strings named “value”, and each line in the streaming text data becomes a row in the table. Note, that this is not currently receiving any data as we are just setting up the transformation, and have not yet started it. Next, we have converted the DataFrame to a Dataset of String using `.as[String]`, so that we can apply the `flatMap` operation to split each line into multiple words. The resultant `words` Dataset contains all the words. Finally, we have defined the `wordCounts` DataFrame by grouping by the unique values in the Dataset and counting them. Note that this is a streaming DataFrame which represents the running word counts of the stream. +This `lines` DataFrame represents an unbounded table containing the streaming text data. This table contains one column of strings named "value", and each line in the streaming text data becomes a row in the table. Note, that this is not currently receiving any data as we are just setting up the transformation, and have not yet started it. Next, we have converted the DataFrame to a Dataset of String using `.as[String]`, so that we can apply the `flatMap` operation to split each line into multiple words. The resultant `words` Dataset contains all the words. Finally, we have defined the `wordCounts` DataFrame by grouping by the unique values in the Dataset and counting them. Note that this is a streaming DataFrame which represents the running word counts of the stream.
    @@ -117,7 +117,7 @@ Dataset words = lines Dataset wordCounts = words.groupBy("value").count(); {% endhighlight %} -This `lines` DataFrame represents an unbounded table containing the streaming text data. This table contains one column of strings named “value”, and each line in the streaming text data becomes a row in the table. Note, that this is not currently receiving any data as we are just setting up the transformation, and have not yet started it. Next, we have converted the DataFrame to a Dataset of String using `.as(Encoders.STRING())`, so that we can apply the `flatMap` operation to split each line into multiple words. The resultant `words` Dataset contains all the words. Finally, we have defined the `wordCounts` DataFrame by grouping by the unique values in the Dataset and counting them. Note that this is a streaming DataFrame which represents the running word counts of the stream. +This `lines` DataFrame represents an unbounded table containing the streaming text data. This table contains one column of strings named "value", and each line in the streaming text data becomes a row in the table. Note, that this is not currently receiving any data as we are just setting up the transformation, and have not yet started it. Next, we have converted the DataFrame to a Dataset of String using `.as(Encoders.STRING())`, so that we can apply the `flatMap` operation to split each line into multiple words. The resultant `words` Dataset contains all the words. Finally, we have defined the `wordCounts` DataFrame by grouping by the unique values in the Dataset and counting them. Note that this is a streaming DataFrame which represents the running word counts of the stream.
    @@ -142,12 +142,12 @@ words = lines.select( wordCounts = words.groupBy('word').count() {% endhighlight %} -This `lines` DataFrame represents an unbounded table containing the streaming text data. This table contains one column of strings named “value”, and each line in the streaming text data becomes a row in the table. Note, that this is not currently receiving any data as we are just setting up the transformation, and have not yet started it. Next, we have used two built-in SQL functions - split and explode, to split each line into multiple rows with a word each. In addition, we use the function `alias` to name the new column as “word”. Finally, we have defined the `wordCounts` DataFrame by grouping by the unique values in the Dataset and counting them. Note that this is a streaming DataFrame which represents the running word counts of the stream. +This `lines` DataFrame represents an unbounded table containing the streaming text data. This table contains one column of strings named "value", and each line in the streaming text data becomes a row in the table. Note, that this is not currently receiving any data as we are just setting up the transformation, and have not yet started it. Next, we have used two built-in SQL functions - split and explode, to split each line into multiple rows with a word each. In addition, we use the function `alias` to name the new column as "word". Finally, we have defined the `wordCounts` DataFrame by grouping by the unique values in the Dataset and counting them. Note that this is a streaming DataFrame which represents the running word counts of the stream.
    -We have now set up the query on the streaming data. All that is left is to actually start receiving data and computing the counts. To do this, we set it up to print the complete set of counts (specified by `outputMode(“complete”)`) to the console every time they are updated. And then start the streaming computation using `start()`. +We have now set up the query on the streaming data. All that is left is to actually start receiving data and computing the counts. To do this, we set it up to print the complete set of counts (specified by `outputMode("complete")`) to the console every time they are updated. And then start the streaming computation using `start()`.
    @@ -361,16 +361,16 @@ table, and Spark runs it as an *incremental* query on the *unbounded* input table. Let’s understand this model in more detail. ## Basic Concepts -Consider the input data stream as the “Input Table”. Every data item that is +Consider the input data stream as the "Input Table". Every data item that is arriving on the stream is like a new row being appended to the Input Table. ![Stream as a Table](img/structured-streaming-stream-as-a-table.png "Stream as a Table") -A query on the input will generate the “Result Table”. Every trigger interval (say, every 1 second), new rows get appended to the Input Table, which eventually updates the Result Table. Whenever the result table gets updated, we would want to write the changed result rows to an external sink. +A query on the input will generate the "Result Table". Every trigger interval (say, every 1 second), new rows get appended to the Input Table, which eventually updates the Result Table. Whenever the result table gets updated, we would want to write the changed result rows to an external sink. ![Model](img/structured-streaming-model.png) -The “Output” is defined as what gets written out to the external storage. The output can be defined in different modes +The "Output" is defined as what gets written out to the external storage. The output can be defined in different modes - *Complete Mode* - The entire updated Result Table will be written to the external storage. It is up to the storage connector to decide how to handle writing of the entire table. @@ -386,7 +386,7 @@ the final `wordCounts` DataFrame is the result table. Note that the query on streaming `lines` DataFrame to generate `wordCounts` is *exactly the same* as it would be a static DataFrame. However, when this query is started, Spark will continuously check for new data from the socket connection. If there is -new data, Spark will run an “incremental” query that combines the previous +new data, Spark will run an "incremental" query that combines the previous running counts with the new data to compute updated counts, as shown below. ![Model](img/structured-streaming-example-model.png) @@ -618,9 +618,9 @@ The result tables would look something like the following. ![Window Operations](img/structured-streaming-window.png) Since this windowing is similar to grouping, in code, you can use `groupBy()` and `window()` operations to express windowed aggregations. You can see the full code for the below examples in -[Scala]({{site.SPARK_GITHUB_URL}}/blob/master/examples/src/main/scala/org/apache/spark/examples/sql/streaming/StructuredNetworkWordCountWindowed.scala)/ -[Java]({{site.SPARK_GITHUB_URL}}/blob/master/examples/src/main/java/org/apache/spark/examples/sql/streaming/JavaStructuredNetworkWordCountWindowed.java)/ -[Python]({{site.SPARK_GITHUB_URL}}/blob/master/examples/src/main/python/sql/streaming/structured_network_wordcount_windowed.py). +[Scala]({{site.SPARK_GITHUB_URL}}/blob/v{{site.SPARK_VERSION_SHORT}}/examples/src/main/scala/org/apache/spark/examples/sql/streaming/StructuredNetworkWordCountWindowed.scala)/ +[Java]({{site.SPARK_GITHUB_URL}}/blob/v{{site.SPARK_VERSION_SHORT}}/examples/src/main/java/org/apache/spark/examples/sql/streaming/JavaStructuredNetworkWordCountWindowed.java)/ +[Python]({{site.SPARK_GITHUB_URL}}/blob/v{{site.SPARK_VERSION_SHORT}}/examples/src/main/python/sql/streaming/structured_network_wordcount_windowed.py).
    @@ -682,8 +682,8 @@ Streaming DataFrames can be joined with static DataFrames to create new streamin val staticDf = spark.read. ... val streamingDf = spark.readStream. ... -streamingDf.join(staticDf, “type”) // inner equi-join with a static DF -streamingDf.join(staticDf, “type”, “right_join”) // right outer join with a static DF +streamingDf.join(staticDf, "type") // inner equi-join with a static DF +streamingDf.join(staticDf, "type", "right_join") // right outer join with a static DF {% endhighlight %} @@ -726,9 +726,9 @@ However, note that all of the operations applicable on static DataFrames/Dataset + Full outer join with a streaming Dataset is not supported - + Left outer join with a streaming Dataset on the left is not supported + + Left outer join with a streaming Dataset on the right is not supported - + Right outer join with a streaming Dataset on the right is not supported + + Right outer join with a streaming Dataset on the left is not supported - Any kind of joins between two streaming Datasets are not yet supported. @@ -789,7 +789,7 @@ Here is a table of all the sinks, and the corresponding settings. File Sink
    (only parquet in Spark 2.0) Append -
    writeStream
    .format(“parquet”)
    .start()
    +
    writeStream
    .format("parquet")
    .start()
    Yes Supports writes to partitioned tables. Partitioning by time may be useful. @@ -803,14 +803,14 @@ Here is a table of all the sinks, and the corresponding settings. Console Sink Append, Complete -
    writeStream
    .format(“console”)
    .start()
    +
    writeStream
    .format("console")
    .start()
    No Memory Sink Append, Complete -
    writeStream
    .format(“memory”)
    .queryName(“table”)
    .start()
    +
    writeStream
    .format("memory")
    .queryName("table")
    .start()
    No Saves the output data as a table, for interactive querying. Table name is the query name. @@ -839,7 +839,7 @@ noAggDF .start() // ========== DF with aggregation ========== -val aggDF = df.groupBy(“device”).count() +val aggDF = df.groupBy("device").count() // Print updated aggregations to console aggDF @@ -879,7 +879,7 @@ noAggDF .start(); // ========== DF with aggregation ========== -Dataset aggDF = df.groupBy(“device”).count(); +Dataset aggDF = df.groupBy("device").count(); // Print updated aggregations to console aggDF @@ -919,7 +919,7 @@ noAggDF\ .start() # ========== DF with aggregation ========== -aggDF = df.groupBy(“device”).count() +aggDF = df.groupBy("device").count() # Print updated aggregations to console aggDF\ @@ -1093,11 +1093,11 @@ In case of a failure or intentional shutdown, you can recover the previous progr {% highlight scala %} aggDF - .writeStream - .outputMode("complete") - .option(“checkpointLocation”, “path/to/HDFS/dir”) - .format("memory") - .start() + .writeStream + .outputMode("complete") + .option("checkpointLocation", "path/to/HDFS/dir") + .format("memory") + .start() {% endhighlight %}
    @@ -1105,11 +1105,11 @@ aggDF {% highlight java %} aggDF - .writeStream() - .outputMode("complete") - .option(“checkpointLocation”, “path/to/HDFS/dir”) - .format("memory") - .start(); + .writeStream() + .outputMode("complete") + .option("checkpointLocation", "path/to/HDFS/dir") + .format("memory") + .start(); {% endhighlight %}
    @@ -1117,11 +1117,11 @@ aggDF {% highlight python %} aggDF\ - .writeStream()\ - .outputMode("complete")\ - .option(“checkpointLocation”, “path/to/HDFS/dir”)\ - .format("memory")\ - .start() + .writeStream()\ + .outputMode("complete")\ + .option("checkpointLocation", "path/to/HDFS/dir")\ + .format("memory")\ + .start() {% endhighlight %}
    diff --git a/docs/submitting-applications.md b/docs/submitting-applications.md index 100ff0b147ef..6fe304999587 100644 --- a/docs/submitting-applications.md +++ b/docs/submitting-applications.md @@ -58,7 +58,8 @@ for applications that involve the REPL (e.g. Spark shell). Alternatively, if your application is submitted from a machine far from the worker machines (e.g. locally on your laptop), it is common to use `cluster` mode to minimize network latency between -the drivers and the executors. Currently only YARN supports cluster mode for Python applications. +the drivers and the executors. Currently, standalone mode does not support cluster mode for Python +applications. For Python applications, simply pass a `.py` file in the place of `` instead of a JAR, and add Python `.zip`, `.egg` or `.py` files to the search path with `--py-files`. diff --git a/docs/tuning.md b/docs/tuning.md index 1ed14091c054..cbf37213aa72 100644 --- a/docs/tuning.md +++ b/docs/tuning.md @@ -115,28 +115,15 @@ Although there are two relevant configurations, the typical user should not need as the default values are applicable to most workloads: * `spark.memory.fraction` expresses the size of `M` as a fraction of the (JVM heap space - 300MB) -(default 0.6). The rest of the space (25%) is reserved for user data structures, internal +(default 0.6). The rest of the space (40%) is reserved for user data structures, internal metadata in Spark, and safeguarding against OOM errors in the case of sparse and unusually large records. * `spark.memory.storageFraction` expresses the size of `R` as a fraction of `M` (default 0.5). `R` is the storage space within `M` where cached blocks immune to being evicted by execution. The value of `spark.memory.fraction` should be set in order to fit this amount of heap space -comfortably within the JVM's old or "tenured" generation. Otherwise, when much of this space is -used for caching and execution, the tenured generation will be full, which causes the JVM to -significantly increase time spent in garbage collection. See -Java GC sizing documentation -for more information. - -The tenured generation size is controlled by the JVM's `NewRatio` parameter, which defaults to 2, -meaning that the tenured generation is 2 times the size of the new generation (the rest of the heap). -So, by default, the tenured generation occupies 2/3 or about 0.66 of the heap. A value of -0.6 for `spark.memory.fraction` keeps storage and execution memory within the old generation with -room to spare. If `spark.memory.fraction` is increased to, say, 0.8, then `NewRatio` may have to -increase to 6 or more. - -`NewRatio` is set as a JVM flag for executors, which means adding -`spark.executor.extraJavaOptions=-XX:NewRatio=x` to a Spark job's configuration. +comfortably within the JVM's old or "tenured" generation. See the discussion of advanced GC +tuning below for details. ## Determining Memory Consumption @@ -217,14 +204,22 @@ temporary objects created during task execution. Some steps which may be useful * Check if there are too many garbage collections by collecting GC stats. If a full GC is invoked multiple times for before a task completes, it means that there isn't enough memory available for executing tasks. -* In the GC stats that are printed, if the OldGen is close to being full, reduce the amount of - memory used for caching by lowering `spark.memory.storageFraction`; it is better to cache fewer - objects than to slow down task execution! - * If there are too many minor collections but not many major GCs, allocating more memory for Eden would help. You can set the size of the Eden to be an over-estimate of how much memory each task will need. If the size of Eden is determined to be `E`, then you can set the size of the Young generation using the option `-Xmn=4/3*E`. (The scaling up by 4/3 is to account for space used by survivor regions as well.) + +* In the GC stats that are printed, if the OldGen is close to being full, reduce the amount of + memory used for caching by lowering `spark.memory.fraction`; it is better to cache fewer + objects than to slow down task execution. Alternatively, consider decreasing the size of + the Young generation. This means lowering `-Xmn` if you've set it as above. If not, try changing the + value of the JVM's `NewRatio` parameter. Many JVMs default this to 2, meaning that the Old generation + occupies 2/3 of the heap. It should be large enough such that this fraction exceeds `spark.memory.fraction`. + +* Try the G1GC garbage collector with `-XX:+UseG1GC`. It can improve performance in some situations where + garbage collection is a bottleneck. Note that with large executor heap sizes, it may be important to + increase the [G1 region size](https://blogs.oracle.com/g1gc/entry/g1_gc_tuning_a_case) + with `-XX:G1HeapRegionSize` * As an example, if your task is reading data from HDFS, the amount of memory used by the task can be estimated using the size of the data block read from HDFS. Note that the size of a decompressed block is often 2 or 3 times the @@ -237,6 +232,9 @@ Our experience suggests that the effect of GC tuning depends on your application There are [many more tuning options](http://www.oracle.com/technetwork/java/javase/gc-tuning-6-140523.html) described online, but at a high level, managing how frequently full GC takes place can help in reducing the overhead. +GC tuning flags for executors can be specified by setting `spark.executor.extraJavaOptions` in +a job's configuration. + # Other Considerations ## Level of Parallelism diff --git a/examples/build.gradle b/examples/build.gradle new file mode 100644 index 000000000000..eeeee87812fe --- /dev/null +++ b/examples/build.gradle @@ -0,0 +1,43 @@ +/* + * Copyright (c) 2016 SnappyData, Inc. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you + * may not use this file except in compliance with the License. You + * may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or + * implied. See the License for the specific language governing + * permissions and limitations under the License. See accompanying + * LICENSE file. + */ + +description = 'Spark Project Examples' + +dependencies { + compile project(subprojectBase + 'snappy-spark-core_' + scalaBinaryVersion) + compile project(subprojectBase + 'snappy-spark-streaming_' + scalaBinaryVersion) + compile project(subprojectBase + 'snappy-spark-mllib_' + scalaBinaryVersion) + compile project(subprojectBase + 'snappy-spark-hive_' + scalaBinaryVersion) + compile project(subprojectBase + 'snappy-spark-graphx_' + scalaBinaryVersion) + compile project(subprojectBase + 'snappy-spark-streaming-flume_' + scalaBinaryVersion) + compile project(subprojectBase + 'snappy-spark-streaming-kafka-0.8_' + scalaBinaryVersion) + + compile group: 'org.apache.commons', name: 'commons-math3', version: '3.4.1' + compile group: 'com.github.scopt', name: 'scopt_' + scalaBinaryVersion, version: '3.3.0' + compile group: 'com.twitter', name: 'parquet-hadoop-bundle', version: hiveParquetVersion + + runtimeJar group: 'com.github.scopt', name: 'scopt_' + scalaBinaryVersion, version: '3.3.0' +} + +jar.doLast { + copy { + from configurations.runtimeJar + from outputs + exclude 'scala-*' + into "${buildDir}/jars" + } +} diff --git a/examples/pom.xml b/examples/pom.xml index d2227944d92d..89e0c61a3d69 100644 --- a/examples/pom.xml +++ b/examples/pom.xml @@ -21,7 +21,7 @@ org.apache.spark spark-parent_2.11 - 2.0.1-SNAPSHOT + 2.0.1 ../pom.xml diff --git a/examples/src/main/java/org/apache/spark/examples/ml/JavaAFTSurvivalRegressionExample.java b/examples/src/main/java/org/apache/spark/examples/ml/JavaAFTSurvivalRegressionExample.java index b0115756cf45..3f034588c952 100644 --- a/examples/src/main/java/org/apache/spark/examples/ml/JavaAFTSurvivalRegressionExample.java +++ b/examples/src/main/java/org/apache/spark/examples/ml/JavaAFTSurvivalRegressionExample.java @@ -23,12 +23,16 @@ import org.apache.spark.ml.regression.AFTSurvivalRegression; import org.apache.spark.ml.regression.AFTSurvivalRegressionModel; -import org.apache.spark.mllib.linalg.*; +import org.apache.spark.ml.linalg.VectorUDT; +import org.apache.spark.ml.linalg.Vectors; import org.apache.spark.sql.Dataset; import org.apache.spark.sql.Row; import org.apache.spark.sql.RowFactory; import org.apache.spark.sql.SparkSession; -import org.apache.spark.sql.types.*; +import org.apache.spark.sql.types.DataTypes; +import org.apache.spark.sql.types.Metadata; +import org.apache.spark.sql.types.StructField; +import org.apache.spark.sql.types.StructType; // $example off$ /** diff --git a/examples/src/main/java/org/apache/spark/examples/ml/JavaBinarizerExample.java b/examples/src/main/java/org/apache/spark/examples/ml/JavaBinarizerExample.java index 5f964aca9209..a954dbd20c12 100644 --- a/examples/src/main/java/org/apache/spark/examples/ml/JavaBinarizerExample.java +++ b/examples/src/main/java/org/apache/spark/examples/ml/JavaBinarizerExample.java @@ -47,7 +47,7 @@ public static void main(String[] args) { RowFactory.create(2, 0.2) ); StructType schema = new StructType(new StructField[]{ - new StructField("label", DataTypes.DoubleType, false, Metadata.empty()), + new StructField("id", DataTypes.IntegerType, false, Metadata.empty()), new StructField("feature", DataTypes.DoubleType, false, Metadata.empty()) }); Dataset continuousDataFrame = spark.createDataFrame(data, schema); diff --git a/examples/src/main/java/org/apache/spark/examples/ml/JavaChiSqSelectorExample.java b/examples/src/main/java/org/apache/spark/examples/ml/JavaChiSqSelectorExample.java index f8f2fb14be1f..fcf90d8d1874 100644 --- a/examples/src/main/java/org/apache/spark/examples/ml/JavaChiSqSelectorExample.java +++ b/examples/src/main/java/org/apache/spark/examples/ml/JavaChiSqSelectorExample.java @@ -25,8 +25,8 @@ import java.util.List; import org.apache.spark.ml.feature.ChiSqSelector; -import org.apache.spark.mllib.linalg.VectorUDT; -import org.apache.spark.mllib.linalg.Vectors; +import org.apache.spark.ml.linalg.VectorUDT; +import org.apache.spark.ml.linalg.Vectors; import org.apache.spark.sql.Row; import org.apache.spark.sql.RowFactory; import org.apache.spark.sql.types.DataTypes; diff --git a/examples/src/main/java/org/apache/spark/examples/ml/JavaDCTExample.java b/examples/src/main/java/org/apache/spark/examples/ml/JavaDCTExample.java index eee92c77a8c5..66ce23b49d36 100644 --- a/examples/src/main/java/org/apache/spark/examples/ml/JavaDCTExample.java +++ b/examples/src/main/java/org/apache/spark/examples/ml/JavaDCTExample.java @@ -25,8 +25,8 @@ import java.util.List; import org.apache.spark.ml.feature.DCT; -import org.apache.spark.mllib.linalg.VectorUDT; -import org.apache.spark.mllib.linalg.Vectors; +import org.apache.spark.ml.linalg.VectorUDT; +import org.apache.spark.ml.linalg.Vectors; import org.apache.spark.sql.Row; import org.apache.spark.sql.RowFactory; import org.apache.spark.sql.types.Metadata; diff --git a/examples/src/main/java/org/apache/spark/examples/ml/JavaEstimatorTransformerParamExample.java b/examples/src/main/java/org/apache/spark/examples/ml/JavaEstimatorTransformerParamExample.java index 889f5785dfd8..9e07a0c2f899 100644 --- a/examples/src/main/java/org/apache/spark/examples/ml/JavaEstimatorTransformerParamExample.java +++ b/examples/src/main/java/org/apache/spark/examples/ml/JavaEstimatorTransformerParamExample.java @@ -19,16 +19,20 @@ // $example on$ import java.util.Arrays; -// $example off$ +import java.util.List; -// $example on$ import org.apache.spark.ml.classification.LogisticRegression; import org.apache.spark.ml.classification.LogisticRegressionModel; +import org.apache.spark.ml.linalg.VectorUDT; +import org.apache.spark.ml.linalg.Vectors; import org.apache.spark.ml.param.ParamMap; -import org.apache.spark.mllib.linalg.Vectors; -import org.apache.spark.mllib.regression.LabeledPoint; import org.apache.spark.sql.Dataset; import org.apache.spark.sql.Row; +import org.apache.spark.sql.RowFactory; +import org.apache.spark.sql.types.DataTypes; +import org.apache.spark.sql.types.Metadata; +import org.apache.spark.sql.types.StructField; +import org.apache.spark.sql.types.StructType; // $example off$ import org.apache.spark.sql.SparkSession; @@ -44,15 +48,17 @@ public static void main(String[] args) { // $example on$ // Prepare training data. - // We use LabeledPoint, which is a JavaBean. Spark SQL can convert RDDs of JavaBeans into - // DataFrames, where it uses the bean metadata to infer the schema. - Dataset training = spark.createDataFrame( - Arrays.asList( - new LabeledPoint(1.0, Vectors.dense(0.0, 1.1, 0.1)), - new LabeledPoint(0.0, Vectors.dense(2.0, 1.0, -1.0)), - new LabeledPoint(0.0, Vectors.dense(2.0, 1.3, 1.0)), - new LabeledPoint(1.0, Vectors.dense(0.0, 1.2, -0.5)) - ), LabeledPoint.class); + List dataTraining = Arrays.asList( + RowFactory.create(1.0, Vectors.dense(0.0, 1.1, 0.1)), + RowFactory.create(0.0, Vectors.dense(2.0, 1.0, -1.0)), + RowFactory.create(0.0, Vectors.dense(2.0, 1.3, 1.0)), + RowFactory.create(1.0, Vectors.dense(0.0, 1.2, -0.5)) + ); + StructType schema = new StructType(new StructField[]{ + new StructField("label", DataTypes.DoubleType, false, Metadata.empty()), + new StructField("features", new VectorUDT(), false, Metadata.empty()) + }); + Dataset training = spark.createDataFrame(dataTraining, schema); // Create a LogisticRegression instance. This instance is an Estimator. LogisticRegression lr = new LogisticRegression(); @@ -87,11 +93,12 @@ public static void main(String[] args) { System.out.println("Model 2 was fit using parameters: " + model2.parent().extractParamMap()); // Prepare test documents. - Dataset test = spark.createDataFrame(Arrays.asList( - new LabeledPoint(1.0, Vectors.dense(-1.0, 1.5, 1.3)), - new LabeledPoint(0.0, Vectors.dense(3.0, 2.0, -0.1)), - new LabeledPoint(1.0, Vectors.dense(0.0, 2.2, -1.5)) - ), LabeledPoint.class); + List dataTest = Arrays.asList( + RowFactory.create(1.0, Vectors.dense(-1.0, 1.5, 1.3)), + RowFactory.create(0.0, Vectors.dense(3.0, 2.0, -0.1)), + RowFactory.create(1.0, Vectors.dense(0.0, 2.2, -1.5)) + ); + Dataset test = spark.createDataFrame(dataTest, schema); // Make predictions on test documents using the Transformer.transform() method. // LogisticRegression.transform will only use the 'features' column. diff --git a/examples/src/main/java/org/apache/spark/examples/ml/JavaLinearRegressionWithElasticNetExample.java b/examples/src/main/java/org/apache/spark/examples/ml/JavaLinearRegressionWithElasticNetExample.java index dcd209e28e2b..a561b6d39ba8 100644 --- a/examples/src/main/java/org/apache/spark/examples/ml/JavaLinearRegressionWithElasticNetExample.java +++ b/examples/src/main/java/org/apache/spark/examples/ml/JavaLinearRegressionWithElasticNetExample.java @@ -21,7 +21,7 @@ import org.apache.spark.ml.regression.LinearRegression; import org.apache.spark.ml.regression.LinearRegressionModel; import org.apache.spark.ml.regression.LinearRegressionTrainingSummary; -import org.apache.spark.mllib.linalg.Vectors; +import org.apache.spark.ml.linalg.Vectors; import org.apache.spark.sql.Dataset; import org.apache.spark.sql.Row; import org.apache.spark.sql.SparkSession; diff --git a/examples/src/main/java/org/apache/spark/examples/ml/JavaOneHotEncoderExample.java b/examples/src/main/java/org/apache/spark/examples/ml/JavaOneHotEncoderExample.java index 5d29e5454921..a15e5f84a187 100644 --- a/examples/src/main/java/org/apache/spark/examples/ml/JavaOneHotEncoderExample.java +++ b/examples/src/main/java/org/apache/spark/examples/ml/JavaOneHotEncoderExample.java @@ -53,7 +53,7 @@ public static void main(String[] args) { ); StructType schema = new StructType(new StructField[]{ - new StructField("id", DataTypes.DoubleType, false, Metadata.empty()), + new StructField("id", DataTypes.IntegerType, false, Metadata.empty()), new StructField("category", DataTypes.StringType, false, Metadata.empty()) }); diff --git a/examples/src/main/java/org/apache/spark/examples/ml/JavaPCAExample.java b/examples/src/main/java/org/apache/spark/examples/ml/JavaPCAExample.java index ffa979ee013a..d597a9a2ed0b 100644 --- a/examples/src/main/java/org/apache/spark/examples/ml/JavaPCAExample.java +++ b/examples/src/main/java/org/apache/spark/examples/ml/JavaPCAExample.java @@ -25,8 +25,8 @@ import org.apache.spark.ml.feature.PCA; import org.apache.spark.ml.feature.PCAModel; -import org.apache.spark.mllib.linalg.VectorUDT; -import org.apache.spark.mllib.linalg.Vectors; +import org.apache.spark.ml.linalg.VectorUDT; +import org.apache.spark.ml.linalg.Vectors; import org.apache.spark.sql.Dataset; import org.apache.spark.sql.Row; import org.apache.spark.sql.RowFactory; diff --git a/examples/src/main/java/org/apache/spark/examples/ml/JavaPolynomialExpansionExample.java b/examples/src/main/java/org/apache/spark/examples/ml/JavaPolynomialExpansionExample.java index 7afcd0e50cd9..67180df65c72 100644 --- a/examples/src/main/java/org/apache/spark/examples/ml/JavaPolynomialExpansionExample.java +++ b/examples/src/main/java/org/apache/spark/examples/ml/JavaPolynomialExpansionExample.java @@ -24,8 +24,8 @@ import java.util.List; import org.apache.spark.ml.feature.PolynomialExpansion; -import org.apache.spark.mllib.linalg.VectorUDT; -import org.apache.spark.mllib.linalg.Vectors; +import org.apache.spark.ml.linalg.VectorUDT; +import org.apache.spark.ml.linalg.Vectors; import org.apache.spark.sql.Dataset; import org.apache.spark.sql.Row; import org.apache.spark.sql.RowFactory; diff --git a/examples/src/main/java/org/apache/spark/examples/ml/JavaTfIdfExample.java b/examples/src/main/java/org/apache/spark/examples/ml/JavaTfIdfExample.java index 6e0753959efd..800e42c949cb 100644 --- a/examples/src/main/java/org/apache/spark/examples/ml/JavaTfIdfExample.java +++ b/examples/src/main/java/org/apache/spark/examples/ml/JavaTfIdfExample.java @@ -25,7 +25,7 @@ import org.apache.spark.ml.feature.IDF; import org.apache.spark.ml.feature.IDFModel; import org.apache.spark.ml.feature.Tokenizer; -import org.apache.spark.mllib.linalg.Vector; +import org.apache.spark.ml.linalg.Vector; import org.apache.spark.sql.Dataset; import org.apache.spark.sql.Row; import org.apache.spark.sql.RowFactory; @@ -45,9 +45,9 @@ public static void main(String[] args) { // $example on$ List data = Arrays.asList( - RowFactory.create(0, "Hi I heard about Spark"), - RowFactory.create(0, "I wish Java could use case classes"), - RowFactory.create(1, "Logistic regression models are neat") + RowFactory.create(0.0, "Hi I heard about Spark"), + RowFactory.create(0.0, "I wish Java could use case classes"), + RowFactory.create(1.0, "Logistic regression models are neat") ); StructType schema = new StructType(new StructField[]{ new StructField("label", DataTypes.DoubleType, false, Metadata.empty()), diff --git a/examples/src/main/java/org/apache/spark/examples/ml/JavaVectorAssemblerExample.java b/examples/src/main/java/org/apache/spark/examples/ml/JavaVectorAssemblerExample.java index 41f1d8750ac4..9bb0f93d3a6a 100644 --- a/examples/src/main/java/org/apache/spark/examples/ml/JavaVectorAssemblerExample.java +++ b/examples/src/main/java/org/apache/spark/examples/ml/JavaVectorAssemblerExample.java @@ -23,8 +23,8 @@ import java.util.Arrays; import org.apache.spark.ml.feature.VectorAssembler; -import org.apache.spark.mllib.linalg.VectorUDT; -import org.apache.spark.mllib.linalg.Vectors; +import org.apache.spark.ml.linalg.VectorUDT; +import org.apache.spark.ml.linalg.Vectors; import org.apache.spark.sql.Dataset; import org.apache.spark.sql.Row; import org.apache.spark.sql.RowFactory; diff --git a/examples/src/main/java/org/apache/spark/examples/ml/JavaVectorSlicerExample.java b/examples/src/main/java/org/apache/spark/examples/ml/JavaVectorSlicerExample.java index 24959c0e10f2..19b8bc83be6e 100644 --- a/examples/src/main/java/org/apache/spark/examples/ml/JavaVectorSlicerExample.java +++ b/examples/src/main/java/org/apache/spark/examples/ml/JavaVectorSlicerExample.java @@ -28,7 +28,7 @@ import org.apache.spark.ml.attribute.AttributeGroup; import org.apache.spark.ml.attribute.NumericAttribute; import org.apache.spark.ml.feature.VectorSlicer; -import org.apache.spark.mllib.linalg.Vectors; +import org.apache.spark.ml.linalg.Vectors; import org.apache.spark.sql.Dataset; import org.apache.spark.sql.Row; import org.apache.spark.sql.RowFactory; diff --git a/examples/src/main/java/org/apache/spark/examples/sql/JavaSQLDataSourceExample.java b/examples/src/main/java/org/apache/spark/examples/sql/JavaSQLDataSourceExample.java index ec02c8bbb8ef..52e3b62b79dd 100644 --- a/examples/src/main/java/org/apache/spark/examples/sql/JavaSQLDataSourceExample.java +++ b/examples/src/main/java/org/apache/spark/examples/sql/JavaSQLDataSourceExample.java @@ -25,7 +25,6 @@ // $example on:basic_parquet_example$ import org.apache.spark.api.java.function.MapFunction; import org.apache.spark.sql.Encoders; -// import org.apache.spark.sql.Encoders; // $example on:schema_merging$ // $example on:json_dataset$ import org.apache.spark.sql.Dataset; @@ -92,7 +91,7 @@ public void setCube(int cube) { public static void main(String[] args) { SparkSession spark = SparkSession .builder() - .appName("Java Spark SQL Data Sources Example") + .appName("Java Spark SQL data sources example") .config("spark.some.config.option", "some-value") .getOrCreate(); @@ -100,6 +99,7 @@ public static void main(String[] args) { runBasicParquetExample(spark); runParquetSchemaMergingExample(spark); runJsonDatasetExample(spark); + runJdbcDatasetExample(spark); spark.stop(); } @@ -183,10 +183,10 @@ private static void runParquetSchemaMergingExample(SparkSession spark) { // The final schema consists of all 3 columns in the Parquet files together // with the partitioning column appeared in the partition directory paths // root - // |-- value: int (nullable = true) - // |-- square: int (nullable = true) - // |-- cube: int (nullable = true) - // |-- key : int (nullable = true) + // |-- value: int (nullable = true) + // |-- square: int (nullable = true) + // |-- cube: int (nullable = true) + // |-- key: int (nullable = true) // $example off:schema_merging$ } @@ -216,4 +216,15 @@ private static void runJsonDatasetExample(SparkSession spark) { // $example off:json_dataset$ } + private static void runJdbcDatasetExample(SparkSession spark) { + // $example on:jdbc_dataset$ + Dataset jdbcDF = spark.read() + .format("jdbc") + .option("url", "jdbc:postgresql:dbserver") + .option("dbtable", "schema.tablename") + .option("user", "username") + .option("password", "password") + .load(); + // $example off:jdbc_dataset$ + } } diff --git a/examples/src/main/java/org/apache/spark/examples/sql/JavaSparkSQLExample.java b/examples/src/main/java/org/apache/spark/examples/sql/JavaSparkSQLExample.java index afc18078d471..cff9032f52b5 100644 --- a/examples/src/main/java/org/apache/spark/examples/sql/JavaSparkSQLExample.java +++ b/examples/src/main/java/org/apache/spark/examples/sql/JavaSparkSQLExample.java @@ -88,7 +88,7 @@ public static void main(String[] args) { // $example on:init_session$ SparkSession spark = SparkSession .builder() - .appName("Java Spark SQL Example") + .appName("Java Spark SQL basic example") .config("spark.some.config.option", "some-value") .getOrCreate(); // $example off:init_session$ diff --git a/examples/src/main/python/sql/basic.py b/examples/src/main/python/sql/basic.py index 74f5009581e4..fdc017aed97c 100644 --- a/examples/src/main/python/sql/basic.py +++ b/examples/src/main/python/sql/basic.py @@ -182,7 +182,7 @@ def programmatic_schema_example(spark): # $example on:init_session$ spark = SparkSession \ .builder \ - .appName("PythonSQL") \ + .appName("Python Spark SQL basic example") \ .config("spark.some.config.option", "some-value") \ .getOrCreate() # $example off:init_session$ diff --git a/examples/src/main/python/sql/datasource.py b/examples/src/main/python/sql/datasource.py index 0bdc3d66ff98..b36c901d2b40 100644 --- a/examples/src/main/python/sql/datasource.py +++ b/examples/src/main/python/sql/datasource.py @@ -92,14 +92,14 @@ def parquet_schema_merging_example(spark): # The final schema consists of all 3 columns in the Parquet files together # with the partitioning column appeared in the partition directory paths. # root - # |-- double: long (nullable = true) - # |-- single: long (nullable = true) - # |-- triple: long (nullable = true) - # |-- key: integer (nullable = true) + # |-- double: long (nullable = true) + # |-- single: long (nullable = true) + # |-- triple: long (nullable = true) + # |-- key: integer (nullable = true) # $example off:schema_merging$ -def json_dataset_examplg(spark): +def json_dataset_example(spark): # $example on:json_dataset$ # spark is from the previous example. sc = spark.sparkContext @@ -112,8 +112,8 @@ def json_dataset_examplg(spark): # The inferred schema can be visualized using the printSchema() method peopleDF.printSchema() # root - # |-- age: long (nullable = true) - # |-- name: string (nullable = true) + # |-- age: long (nullable = true) + # |-- name: string (nullable = true) # Creates a temporary view using the DataFrame peopleDF.createOrReplaceTempView("people") @@ -140,15 +140,29 @@ def json_dataset_examplg(spark): # +---------------+----+ # $example off:json_dataset$ + +def jdbc_dataset_example(spark): + # $example on:jdbc_dataset$ + jdbcDF = spark.read \ + .format("jdbc") \ + .option("url", "jdbc:postgresql:dbserver") \ + .option("dbtable", "schema.tablename") \ + .option("user", "username") \ + .option("password", "password") \ + .load() + # $example off:jdbc_dataset$ + + if __name__ == "__main__": spark = SparkSession \ .builder \ - .appName("PythonSQL") \ + .appName("Python Spark SQL data source example") \ .getOrCreate() basic_datasource_example(spark) parquet_example(spark) parquet_schema_merging_example(spark) - json_dataset_examplg(spark) + json_dataset_example(spark) + jdbc_dataset_example(spark) spark.stop() diff --git a/examples/src/main/python/sql/hive.py b/examples/src/main/python/sql/hive.py index d9ce5cef1f2b..9b2a2c4e6a16 100644 --- a/examples/src/main/python/sql/hive.py +++ b/examples/src/main/python/sql/hive.py @@ -38,7 +38,7 @@ spark = SparkSession \ .builder \ - .appName("PythonSQL") \ + .appName("Python Spark SQL Hive integration example") \ .config("spark.sql.warehouse.dir", warehouse_location) \ .enableHiveSupport() \ .getOrCreate() diff --git a/examples/src/main/r/RSparkSQLExample.R b/examples/src/main/r/RSparkSQLExample.R index 33e88e15fd47..de489e1bda2c 100644 --- a/examples/src/main/r/RSparkSQLExample.R +++ b/examples/src/main/r/RSparkSQLExample.R @@ -18,31 +18,43 @@ library(SparkR) # $example on:init_session$ -sparkR.session(appName = "MyApp", sparkConfig = list(spark.executor.memory = "1g")) +sparkR.session(appName = "MyApp", sparkConfig = list(spark.some.config.option = "some-value")) # $example off:init_session$ -# $example on:create_DataFrames$ +# $example on:create_df$ df <- read.json("examples/src/main/resources/people.json") # Displays the content of the DataFrame head(df) +## age name +## 1 NA Michael +## 2 30 Andy +## 3 19 Justin # Another method to print the first few rows and optionally truncate the printing of long values showDF(df) -# $example off:create_DataFrames$ +## +----+-------+ +## | age| name| +## +----+-------+ +## |null|Michael| +## | 30| Andy| +## | 19| Justin| +## +----+-------+ +## $example off:create_df$ -# $example on:dataframe_operations$ +# $example on:untyped_ops$ # Create the DataFrame df <- read.json("examples/src/main/resources/people.json") # Show the content of the DataFrame head(df) -## age name -## null Michael -## 30 Andy -## 19 Justin +## age name +## 1 NA Michael +## 2 30 Andy +## 3 19 Justin + # Print the schema in a tree format printSchema(df) @@ -52,58 +64,58 @@ printSchema(df) # Select only the "name" column head(select(df, "name")) -## name -## Michael -## Andy -## Justin +## name +## 1 Michael +## 2 Andy +## 3 Justin # Select everybody, but increment the age by 1 head(select(df, df$name, df$age + 1)) -## name (age + 1) -## Michael null -## Andy 31 -## Justin 20 +## name (age + 1.0) +## 1 Michael NA +## 2 Andy 31 +## 3 Justin 20 # Select people older than 21 head(where(df, df$age > 21)) -## age name -## 30 Andy +## age name +## 1 30 Andy # Count people by age head(count(groupBy(df, "age"))) -## age count -## null 1 -## 19 1 -## 30 1 -# $example off:dataframe_operations$ +## age count +## 1 19 1 +## 2 NA 1 +## 3 30 1 +# $example off:untyped_ops$ # Register this DataFrame as a table. createOrReplaceTempView(df, "table") -# $example on:sql_query$ +# $example on:run_sql$ df <- sql("SELECT * FROM table") -# $example off:sql_query$ +# $example off:run_sql$ -# $example on:source_parquet$ +# $example on:generic_load_save_functions$ df <- read.df("examples/src/main/resources/users.parquet") write.df(select(df, "name", "favorite_color"), "namesAndFavColors.parquet") -# $example off:source_parquet$ +# $example off:generic_load_save_functions$ -# $example on:source_json$ +# $example on:manual_load_options$ df <- read.df("examples/src/main/resources/people.json", "json") namesAndAges <- select(df, "name", "age") write.df(namesAndAges, "namesAndAges.parquet", "parquet") -# $example off:source_json$ +# $example off:manual_load_options$ -# $example on:direct_query$ +# $example on:direct_sql$ df <- sql("SELECT * FROM parquet.`examples/src/main/resources/users.parquet`") -# $example off:direct_query$ +# $example off:direct_sql$ -# $example on:load_programmatically$ +# $example on:basic_parquet_example$ df <- read.df("examples/src/main/resources/people.json", "json") # SparkDataFrame can be saved as Parquet files, maintaining the schema information. @@ -117,7 +129,7 @@ parquetFile <- read.parquet("people.parquet") createOrReplaceTempView(parquetFile, "parquetFile") teenagers <- sql("SELECT name FROM parquetFile WHERE age >= 13 AND age <= 19") head(teenagers) -## name +## name ## 1 Justin # We can also run custom R-UDFs on Spark DataFrames. Here we prefix all the names with "Name:" @@ -129,7 +141,7 @@ for (teenName in collect(teenNames)$name) { ## Name: Michael ## Name: Andy ## Name: Justin -# $example off:load_programmatically$ +# $example off:basic_parquet_example$ # $example on:schema_merging$ @@ -146,18 +158,17 @@ write.df(df2, "data/test_table/key=2", "parquet", "overwrite") # Read the partitioned table df3 <- read.df("data/test_table", "parquet", mergeSchema = "true") printSchema(df3) - # The final schema consists of all 3 columns in the Parquet files together -# with the partitioning column appeared in the partition directory paths. -# root -# |-- single: double (nullable = true) -# |-- double: double (nullable = true) -# |-- triple: double (nullable = true) -# |-- key : int (nullable = true) +# with the partitioning column appeared in the partition directory paths +## root +## |-- single: double (nullable = true) +## |-- double: double (nullable = true) +## |-- triple: double (nullable = true) +## |-- key: integer (nullable = true) # $example off:schema_merging$ -# $example on:load_json_file$ +# $example on:json_dataset$ # A JSON dataset is pointed to by path. # The path can be either a single text file or a directory storing text files. path <- "examples/src/main/resources/people.json" @@ -166,9 +177,9 @@ people <- read.json(path) # The inferred schema can be visualized using the printSchema() method. printSchema(people) -# root -# |-- age: long (nullable = true) -# |-- name: string (nullable = true) +## root +## |-- age: long (nullable = true) +## |-- name: string (nullable = true) # Register this DataFrame as a table. createOrReplaceTempView(people, "people") @@ -176,12 +187,12 @@ createOrReplaceTempView(people, "people") # SQL statements can be run by using the sql methods. teenagers <- sql("SELECT name FROM people WHERE age >= 13 AND age <= 19") head(teenagers) -## name +## name ## 1 Justin -# $example off:load_json_file$ +# $example off:json_dataset$ -# $example on:hive_table$ +# $example on:spark_hive$ # enableHiveSupport defaults to TRUE sparkR.session(enableHiveSupport = TRUE) sql("CREATE TABLE IF NOT EXISTS src (key INT, value STRING)") @@ -189,12 +200,12 @@ sql("LOAD DATA LOCAL INPATH 'examples/src/main/resources/kv1.txt' INTO TABLE src # Queries can be expressed in HiveQL. results <- collect(sql("FROM src SELECT key, value")) -# $example off:hive_table$ +# $example off:spark_hive$ -# $example on:jdbc$ +# $example on:jdbc_dataset$ df <- read.jdbc("jdbc:postgresql:dbserver", "schema.tablename", user = "username", password = "password") -# $example off:jdbc$ +# $example off:jdbc_dataset$ # Stop the SparkSession now sparkR.session.stop() diff --git a/examples/src/main/scala/org/apache/spark/examples/mllib/LDAExample.scala b/examples/src/main/scala/org/apache/spark/examples/mllib/LDAExample.scala index 3fbf8e03339e..ef67841f0cbe 100644 --- a/examples/src/main/scala/org/apache/spark/examples/mllib/LDAExample.scala +++ b/examples/src/main/scala/org/apache/spark/examples/mllib/LDAExample.scala @@ -24,8 +24,9 @@ import scopt.OptionParser import org.apache.spark.{SparkConf, SparkContext} import org.apache.spark.ml.Pipeline import org.apache.spark.ml.feature.{CountVectorizer, CountVectorizerModel, RegexTokenizer, StopWordsRemover} +import org.apache.spark.ml.linalg.{Vector => MLVector} import org.apache.spark.mllib.clustering.{DistributedLDAModel, EMLDAOptimizer, LDA, OnlineLDAOptimizer} -import org.apache.spark.mllib.linalg.Vector +import org.apache.spark.mllib.linalg.{Vector, Vectors} import org.apache.spark.rdd.RDD import org.apache.spark.sql.{Row, SparkSession} @@ -225,7 +226,7 @@ object LDAExample { val documents = model.transform(df) .select("features") .rdd - .map { case Row(features: Vector) => features } + .map { case Row(features: MLVector) => Vectors.fromML(features) } .zipWithIndex() .map(_.swap) diff --git a/examples/src/main/scala/org/apache/spark/examples/sql/SQLDataSourceExample.scala b/examples/src/main/scala/org/apache/spark/examples/sql/SQLDataSourceExample.scala index 0caba12af0bd..dc3915a4882b 100644 --- a/examples/src/main/scala/org/apache/spark/examples/sql/SQLDataSourceExample.scala +++ b/examples/src/main/scala/org/apache/spark/examples/sql/SQLDataSourceExample.scala @@ -25,7 +25,7 @@ object SQLDataSourceExample { def main(args: Array[String]) { val spark = SparkSession .builder() - .appName("Spark SQL Data Soures Example") + .appName("Spark SQL data sources example") .config("spark.some.config.option", "some-value") .getOrCreate() @@ -33,6 +33,7 @@ object SQLDataSourceExample { runBasicParquetExample(spark) runParquetSchemaMergingExample(spark) runJsonDatasetExample(spark) + runJdbcDatasetExample(spark) spark.stop() } @@ -99,10 +100,10 @@ object SQLDataSourceExample { // The final schema consists of all 3 columns in the Parquet files together // with the partitioning column appeared in the partition directory paths // root - // |-- value: int (nullable = true) - // |-- square: int (nullable = true) - // |-- cube: int (nullable = true) - // |-- key : int (nullable = true) + // |-- value: int (nullable = true) + // |-- square: int (nullable = true) + // |-- cube: int (nullable = true) + // |-- key: int (nullable = true) // $example off:schema_merging$ } @@ -145,4 +146,15 @@ object SQLDataSourceExample { // $example off:json_dataset$ } + private def runJdbcDatasetExample(spark: SparkSession): Unit = { + // $example on:jdbc_dataset$ + val jdbcDF = spark.read + .format("jdbc") + .option("url", "jdbc:postgresql:dbserver") + .option("dbtable", "schema.tablename") + .option("user", "username") + .option("password", "password") + .load() + // $example off:jdbc_dataset$ + } } diff --git a/examples/src/main/scala/org/apache/spark/examples/sql/SparkSQLExample.scala b/examples/src/main/scala/org/apache/spark/examples/sql/SparkSQLExample.scala index 952c074d0345..129b81d5fbbf 100644 --- a/examples/src/main/scala/org/apache/spark/examples/sql/SparkSQLExample.scala +++ b/examples/src/main/scala/org/apache/spark/examples/sql/SparkSQLExample.scala @@ -42,7 +42,7 @@ object SparkSQLExample { // $example on:init_session$ val spark = SparkSession .builder() - .appName("Spark SQL Example") + .appName("Spark SQL basic example") .config("spark.some.config.option", "some-value") .getOrCreate() @@ -203,7 +203,7 @@ object SparkSQLExample { // No pre-defined encoders for Dataset[Map[K,V]], define explicitly implicit val mapEncoder = org.apache.spark.sql.Encoders.kryo[Map[String, Any]] // Primitive types and case classes can be also defined as - implicit val stringIntMapEncoder: Encoder[Map[String, Int]] = ExpressionEncoder() + // implicit val stringIntMapEncoder: Encoder[Map[String, Any]] = ExpressionEncoder() // row.getValuesMap[T] retrieves multiple columns at once into a Map[String, T] teenagersDF.map(teenager => teenager.getValuesMap[Any](List("name", "age"))).collect() diff --git a/external/docker-integration-tests/build.gradle b/external/docker-integration-tests/build.gradle new file mode 100644 index 000000000000..93ae7e08befa --- /dev/null +++ b/external/docker-integration-tests/build.gradle @@ -0,0 +1,52 @@ +/* + * Copyright (c) 2016 SnappyData, Inc. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you + * may not use this file except in compliance with the License. You + * may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or + * implied. See the License for the specific language governing + * permissions and limitations under the License. See accompanying + * LICENSE file. + */ + +description = 'Spark Project Docker Integration Tests' + +dependencies { + compile group: 'com.ibm.db2.jcc', name: 'db2jcc4', version: '10.5.0.5' + + testCompile project(subprojectBase + 'snappy-spark-core_' + scalaBinaryVersion) + testCompile project(subprojectBase + 'snappy-spark-sql_' + scalaBinaryVersion) + testCompile project(subprojectBase + 'snappy-spark-tags_' + scalaBinaryVersion) + + testCompile(group: 'com.spotify', name: 'docker-client', version: '3.6.6', classifier: 'shaded') { + exclude(group: 'com.google.guava', module: 'guava') + exclude(group: 'commons-logging', module: 'commons-logging') + exclude(group: 'com.fasterxml.jackson.jaxrs', module: 'jackson-jaxrs-json-provider') + exclude(group: 'com.fasterxml.jackson.datatype', module: 'jackson-datatype-guava') + exclude(group: 'com.fasterxml.jackson.core', module: 'jackson-databind') + exclude(group: 'org.glassfish.jersey.core', module: 'jersey-client') + exclude(group: 'org.glassfish.jersey.connectors', module: 'jersey-apache-connector') + exclude(group: 'org.glassfish.jersey.media', module: 'jersey-media-json-jackson') + } + testCompile group: 'org.apache.httpcomponents', name: 'httpclient', version: httpClientVersion + testCompile group: 'org.apache.httpcomponents', name: 'httpcore', version: httpCoreVersion + testCompile group: 'mysql', name: 'mysql-connector-java', version: '5.1.38' + testCompile group: 'org.postgresql', name: 'postgresql', version: '9.4.1207.jre7' + testCompile group: 'com.oracle', name: 'ojdbc6', version: '11.2.0.1.0' + testCompile group: 'com.sun.jersey', name: 'jersey-server', version: sunJerseyVersion + testCompile group: 'com.sun.jersey', name: 'jersey-core', version: sunJerseyVersion + testCompile group: 'com.sun.jersey', name: 'jersey-servlet', version: sunJerseyVersion + testCompile(group: 'com.sun.jersey', name: 'jersey-json', version: sunJerseyVersion) { + exclude(group: 'stax', module: 'stax-api') + } + testCompile group: 'com.google.guava', name: 'guava', version: '18.0' + + testCompile project(subprojectBase + 'snappy-spark-core_' + scalaBinaryVersion, configuration: 'testOutput') + testCompile project(subprojectBase + 'snappy-spark-sql_' + scalaBinaryVersion, configuration: 'testOutput') +} diff --git a/external/docker-integration-tests/pom.xml b/external/docker-integration-tests/pom.xml index 18e14c7981d8..8c6e22155c26 100644 --- a/external/docker-integration-tests/pom.xml +++ b/external/docker-integration-tests/pom.xml @@ -22,7 +22,7 @@ org.apache.spark spark-parent_2.11 - 2.0.1-SNAPSHOT + 2.0.1 ../../pom.xml @@ -49,38 +49,7 @@ com.spotify docker-client - shaded test - - - - com.fasterxml.jackson.jaxrs - jackson-jaxrs-json-provider - - - com.fasterxml.jackson.datatype - jackson-datatype-guava - - - com.fasterxml.jackson.core - jackson-databind - - - org.glassfish.jersey.core - jersey-client - - - org.glassfish.jersey.connectors - jersey-apache-connector - - - org.glassfish.jersey.media - jersey-media-json-jackson - - org.apache.httpcomponents @@ -152,43 +121,6 @@ test - - - com.sun.jersey - jersey-server - 1.19 - test - - - com.sun.jersey - jersey-core - 1.19 - test - - - com.sun.jersey - jersey-servlet - 1.19 - test - - - com.sun.jersey - jersey-json - 1.19 - test - - - stax - stax-api - - - - - + + org.apache.maven.plugins + maven-deploy-plugin + + true + + + + org.apache.maven.plugins + maven-install-plugin + + true + + org.apache.maven.plugins maven-shade-plugin diff --git a/external/kinesis-asl/pom.xml b/external/kinesis-asl/pom.xml index b524001d0471..f96db6588cda 100644 --- a/external/kinesis-asl/pom.xml +++ b/external/kinesis-asl/pom.xml @@ -20,7 +20,7 @@ org.apache.spark spark-parent_2.11 - 2.0.1-SNAPSHOT + 2.0.1 ../../pom.xml diff --git a/external/spark-ganglia-lgpl/build.gradle b/external/spark-ganglia-lgpl/build.gradle new file mode 100644 index 000000000000..39e0a747ce43 --- /dev/null +++ b/external/spark-ganglia-lgpl/build.gradle @@ -0,0 +1,24 @@ +/* + * Copyright (c) 2016 SnappyData, Inc. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you + * may not use this file except in compliance with the License. You + * may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or + * implied. See the License for the specific language governing + * permissions and limitations under the License. See accompanying + * LICENSE file. + */ + +description = 'Spark Ganglia Integration' + +dependencies { + compile project(subprojectBase + 'spark-core_' + scalaBinaryVersion) + + compile group: 'io.dropwizard.metrics', name: 'metrics-ganglia', version: metricsVersion +} diff --git a/external/spark-ganglia-lgpl/pom.xml b/external/spark-ganglia-lgpl/pom.xml index 93ac8b6e664d..40f2e38832fb 100644 --- a/external/spark-ganglia-lgpl/pom.xml +++ b/external/spark-ganglia-lgpl/pom.xml @@ -20,7 +20,7 @@ org.apache.spark spark-parent_2.11 - 2.0.1-SNAPSHOT + 2.0.1 ../../pom.xml diff --git a/gradle.properties b/gradle.properties new file mode 100644 index 000000000000..53c56bd3da6f --- /dev/null +++ b/gradle.properties @@ -0,0 +1,5 @@ +org.gradle.daemon = false +#org.gradle.parallel=true + +# added below options to gradlew* scripts +# org.gradle.jvmargs = -Xmx2g -XX:MaxPermSize=512m -XX:ReservedCodeCacheSize=512m diff --git a/gradle/wrapper/gradle-wrapper.jar b/gradle/wrapper/gradle-wrapper.jar new file mode 100644 index 000000000000..5ccda13e9cb9 Binary files /dev/null and b/gradle/wrapper/gradle-wrapper.jar differ diff --git a/gradle/wrapper/gradle-wrapper.properties b/gradle/wrapper/gradle-wrapper.properties new file mode 100644 index 000000000000..30455d487cc4 --- /dev/null +++ b/gradle/wrapper/gradle-wrapper.properties @@ -0,0 +1,6 @@ +#Sun Jul 31 00:16:02 IST 2016 +distributionBase=GRADLE_USER_HOME +distributionPath=wrapper/dists +zipStoreBase=GRADLE_USER_HOME +zipStorePath=wrapper/dists +distributionUrl=https\://services.gradle.org/distributions/gradle-2.14.1-all.zip diff --git a/gradlew b/gradlew new file mode 100755 index 000000000000..a357c0353981 --- /dev/null +++ b/gradlew @@ -0,0 +1,160 @@ +#!/usr/bin/env bash + +############################################################################## +## +## Gradle start up script for UN*X +## +############################################################################## + +# Add default JVM options here. You can also use JAVA_OPTS and GRADLE_OPTS to pass JVM options to this script. +DEFAULT_JVM_OPTS="-Xmx2g -XX:MaxPermSize=512m -XX:ReservedCodeCacheSize=512m" + +APP_NAME="Gradle" +APP_BASE_NAME=`basename "$0"` + +# Use the maximum available, or set MAX_FD != -1 to use that value. +MAX_FD="maximum" + +warn ( ) { + echo "$*" +} + +die ( ) { + echo + echo "$*" + echo + exit 1 +} + +# OS specific support (must be 'true' or 'false'). +cygwin=false +msys=false +darwin=false +case "`uname`" in + CYGWIN* ) + cygwin=true + ;; + Darwin* ) + darwin=true + ;; + MINGW* ) + msys=true + ;; +esac + +# Attempt to set APP_HOME +# Resolve links: $0 may be a link +PRG="$0" +# Need this for relative symlinks. +while [ -h "$PRG" ] ; do + ls=`ls -ld "$PRG"` + link=`expr "$ls" : '.*-> \(.*\)$'` + if expr "$link" : '/.*' > /dev/null; then + PRG="$link" + else + PRG=`dirname "$PRG"`"/$link" + fi +done +SAVED="`pwd`" +cd "`dirname \"$PRG\"`/" >/dev/null +APP_HOME="`pwd -P`" +cd "$SAVED" >/dev/null + +CLASSPATH=$APP_HOME/gradle/wrapper/gradle-wrapper.jar + +# Determine the Java command to use to start the JVM. +if [ -n "$JAVA_HOME" ] ; then + if [ -x "$JAVA_HOME/jre/sh/java" ] ; then + # IBM's JDK on AIX uses strange locations for the executables + JAVACMD="$JAVA_HOME/jre/sh/java" + else + JAVACMD="$JAVA_HOME/bin/java" + fi + if [ ! -x "$JAVACMD" ] ; then + die "ERROR: JAVA_HOME is set to an invalid directory: $JAVA_HOME + +Please set the JAVA_HOME variable in your environment to match the +location of your Java installation." + fi +else + JAVACMD="java" + which java >/dev/null 2>&1 || die "ERROR: JAVA_HOME is not set and no 'java' command could be found in your PATH. + +Please set the JAVA_HOME variable in your environment to match the +location of your Java installation." +fi + +# Increase the maximum file descriptors if we can. +if [ "$cygwin" = "false" -a "$darwin" = "false" ] ; then + MAX_FD_LIMIT=`ulimit -H -n` + if [ $? -eq 0 ] ; then + if [ "$MAX_FD" = "maximum" -o "$MAX_FD" = "max" ] ; then + MAX_FD="$MAX_FD_LIMIT" + fi + ulimit -n $MAX_FD + if [ $? -ne 0 ] ; then + warn "Could not set maximum file descriptor limit: $MAX_FD" + fi + else + warn "Could not query maximum file descriptor limit: $MAX_FD_LIMIT" + fi +fi + +# For Darwin, add options to specify how the application appears in the dock +if $darwin; then + GRADLE_OPTS="$GRADLE_OPTS \"-Xdock:name=$APP_NAME\" \"-Xdock:icon=$APP_HOME/media/gradle.icns\"" +fi + +# For Cygwin, switch paths to Windows format before running java +if $cygwin ; then + APP_HOME=`cygpath --path --mixed "$APP_HOME"` + CLASSPATH=`cygpath --path --mixed "$CLASSPATH"` + JAVACMD=`cygpath --unix "$JAVACMD"` + + # We build the pattern for arguments to be converted via cygpath + ROOTDIRSRAW=`find -L / -maxdepth 1 -mindepth 1 -type d 2>/dev/null` + SEP="" + for dir in $ROOTDIRSRAW ; do + ROOTDIRS="$ROOTDIRS$SEP$dir" + SEP="|" + done + OURCYGPATTERN="(^($ROOTDIRS))" + # Add a user-defined pattern to the cygpath arguments + if [ "$GRADLE_CYGPATTERN" != "" ] ; then + OURCYGPATTERN="$OURCYGPATTERN|($GRADLE_CYGPATTERN)" + fi + # Now convert the arguments - kludge to limit ourselves to /bin/sh + i=0 + for arg in "$@" ; do + CHECK=`echo "$arg"|egrep -c "$OURCYGPATTERN" -` + CHECK2=`echo "$arg"|egrep -c "^-"` ### Determine if an option + + if [ $CHECK -ne 0 ] && [ $CHECK2 -eq 0 ] ; then ### Added a condition + eval `echo args$i`=`cygpath --path --ignore --mixed "$arg"` + else + eval `echo args$i`="\"$arg\"" + fi + i=$((i+1)) + done + case $i in + (0) set -- ;; + (1) set -- "$args0" ;; + (2) set -- "$args0" "$args1" ;; + (3) set -- "$args0" "$args1" "$args2" ;; + (4) set -- "$args0" "$args1" "$args2" "$args3" ;; + (5) set -- "$args0" "$args1" "$args2" "$args3" "$args4" ;; + (6) set -- "$args0" "$args1" "$args2" "$args3" "$args4" "$args5" ;; + (7) set -- "$args0" "$args1" "$args2" "$args3" "$args4" "$args5" "$args6" ;; + (8) set -- "$args0" "$args1" "$args2" "$args3" "$args4" "$args5" "$args6" "$args7" ;; + (9) set -- "$args0" "$args1" "$args2" "$args3" "$args4" "$args5" "$args6" "$args7" "$args8" ;; + esac +fi + +# Split up the JVM_OPTS And GRADLE_OPTS values into an array, following the shell quoting and substitution rules +function splitJvmOpts() { + JVM_OPTS=("$@") +} +eval splitJvmOpts $DEFAULT_JVM_OPTS $JAVA_OPTS $GRADLE_OPTS +JVM_OPTS[${#JVM_OPTS[*]}]="-Dorg.gradle.appname=$APP_BASE_NAME" + +exec "$JAVACMD" "${JVM_OPTS[@]}" -classpath "$CLASSPATH" org.gradle.wrapper.GradleWrapperMain "$@" diff --git a/gradlew.bat b/gradlew.bat new file mode 100644 index 000000000000..b5adeb2fde6e --- /dev/null +++ b/gradlew.bat @@ -0,0 +1,90 @@ +@if "%DEBUG%" == "" @echo off +@rem ########################################################################## +@rem +@rem Gradle startup script for Windows +@rem +@rem ########################################################################## + +@rem Set local scope for the variables with windows NT shell +if "%OS%"=="Windows_NT" setlocal + +@rem Add default JVM options here. You can also use JAVA_OPTS and GRADLE_OPTS to pass JVM options to this script. +set DEFAULT_JVM_OPTS=-Xmx2g -XX:MaxPermSize=512m -XX:ReservedCodeCacheSize=512m + +set DIRNAME=%~dp0 +if "%DIRNAME%" == "" set DIRNAME=. +set APP_BASE_NAME=%~n0 +set APP_HOME=%DIRNAME% + +@rem Find java.exe +if defined JAVA_HOME goto findJavaFromJavaHome + +set JAVA_EXE=java.exe +%JAVA_EXE% -version >NUL 2>&1 +if "%ERRORLEVEL%" == "0" goto init + +echo. +echo ERROR: JAVA_HOME is not set and no 'java' command could be found in your PATH. +echo. +echo Please set the JAVA_HOME variable in your environment to match the +echo location of your Java installation. + +goto fail + +:findJavaFromJavaHome +set JAVA_HOME=%JAVA_HOME:"=% +set JAVA_EXE=%JAVA_HOME%/bin/java.exe + +if exist "%JAVA_EXE%" goto init + +echo. +echo ERROR: JAVA_HOME is set to an invalid directory: %JAVA_HOME% +echo. +echo Please set the JAVA_HOME variable in your environment to match the +echo location of your Java installation. + +goto fail + +:init +@rem Get command-line arguments, handling Windows variants + +if not "%OS%" == "Windows_NT" goto win9xME_args +if "%@eval[2+2]" == "4" goto 4NT_args + +:win9xME_args +@rem Slurp the command line arguments. +set CMD_LINE_ARGS= +set _SKIP=2 + +:win9xME_args_slurp +if "x%~1" == "x" goto execute + +set CMD_LINE_ARGS=%* +goto execute + +:4NT_args +@rem Get arguments from the 4NT Shell from JP Software +set CMD_LINE_ARGS=%$ + +:execute +@rem Setup the command line + +set CLASSPATH=%APP_HOME%\gradle\wrapper\gradle-wrapper.jar + +@rem Execute Gradle +"%JAVA_EXE%" %DEFAULT_JVM_OPTS% %JAVA_OPTS% %GRADLE_OPTS% "-Dorg.gradle.appname=%APP_BASE_NAME%" -classpath "%CLASSPATH%" org.gradle.wrapper.GradleWrapperMain %CMD_LINE_ARGS% + +:end +@rem End local scope for the variables with windows NT shell +if "%ERRORLEVEL%"=="0" goto mainEnd + +:fail +rem Set variable GRADLE_EXIT_CONSOLE if you need the _script_ return code instead of +rem the _cmd.exe /c_ return code! +if not "" == "%GRADLE_EXIT_CONSOLE%" exit 1 +exit /b 1 + +:mainEnd +if "%OS%"=="Windows_NT" endlocal + +:omega diff --git a/graphx/build.gradle b/graphx/build.gradle new file mode 100644 index 000000000000..64ee2e856d38 --- /dev/null +++ b/graphx/build.gradle @@ -0,0 +1,30 @@ +/* + * Copyright (c) 2016 SnappyData, Inc. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you + * may not use this file except in compliance with the License. You + * may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or + * implied. See the License for the specific language governing + * permissions and limitations under the License. See accompanying + * LICENSE file. + */ + +description = 'Spark Project GraphX' + +dependencies { + compile project(subprojectBase + 'snappy-spark-core_' + scalaBinaryVersion) + compile project(subprojectBase + 'snappy-spark-tags_' + scalaBinaryVersion) + + compile group: 'org.apache.xbean', name: 'xbean-asm5-shaded', version: '4.4' + compile group: 'com.google.guava', name: 'guava', version: guavaVersion + compile group: 'com.github.fommil.netlib', name: 'core', version: '1.1.2' + compile group: 'net.sourceforge.f2j', name: 'arpack_combined_all', version: '0.1' + + testCompile project(path: subprojectBase + 'snappy-spark-core_' + scalaBinaryVersion, configuration: 'testOutput') +} diff --git a/graphx/pom.xml b/graphx/pom.xml index 4f8af77792b0..979217e2ba8f 100644 --- a/graphx/pom.xml +++ b/graphx/pom.xml @@ -21,7 +21,7 @@ org.apache.spark spark-parent_2.11 - 2.0.1-SNAPSHOT + 2.0.1 ../pom.xml diff --git a/launcher/build.gradle b/launcher/build.gradle new file mode 100644 index 000000000000..22a32f5227a2 --- /dev/null +++ b/launcher/build.gradle @@ -0,0 +1,39 @@ +/* + * Copyright (c) 2016 SnappyData, Inc. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you + * may not use this file except in compliance with the License. You + * may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or + * implied. See the License for the specific language governing + * permissions and limitations under the License. See accompanying + * LICENSE file. + */ + +description = 'Spark Project Launcher' + +dependencies { + compile project(subprojectBase + 'snappy-spark-tags_' + scalaBinaryVersion) + + testCompile(group: 'org.apache.hadoop', name: 'hadoop-client', version: hadoopVersion) { + exclude(group: 'asm', module: 'asm') + exclude(group: 'org.codehaus.jackson', module: 'jackson-mapper-asl') + exclude(group: 'org.ow2.asm', module: 'asm') + exclude(group: 'org.jboss.netty', module: 'netty') + exclude(group: 'commons-logging', module: 'commons-logging') + exclude(group: 'org.mockito', module: 'mockito-all') + exclude(group: 'org.mortbay.jetty', module: 'servlet-api-2.5') + exclude(group: 'javax.servlet', module: 'servlet-api') + exclude(group: 'junit', module: 'junit') + exclude(group: 'com.google.guava', module: 'guava') + exclude(group: 'com.sun.jersey') + exclude(group: 'com.sun.jersey.jersey-test-framework') + exclude(group: 'com.sun.jersey.contribs') + } + testCompile group: 'org.slf4j', name: 'jul-to-slf4j', version: slf4jVersion +} diff --git a/launcher/pom.xml b/launcher/pom.xml index b6591598ee12..0c0dd0c6069d 100644 --- a/launcher/pom.xml +++ b/launcher/pom.xml @@ -22,7 +22,7 @@ org.apache.spark spark-parent_2.11 - 2.0.1-SNAPSHOT + 2.0.1 ../pom.xml diff --git a/launcher/src/main/java/org/apache/spark/launcher/SparkLauncher.java b/launcher/src/main/java/org/apache/spark/launcher/SparkLauncher.java index 08873f581123..7be2306b5ee6 100644 --- a/launcher/src/main/java/org/apache/spark/launcher/SparkLauncher.java +++ b/launcher/src/main/java/org/apache/spark/launcher/SparkLauncher.java @@ -14,6 +14,24 @@ * See the License for the specific language governing permissions and * limitations under the License. */ +/* + * Changes for SnappyData data platform. + * + * Portions Copyright (c) 2016 SnappyData, Inc. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you + * may not use this file except in compliance with the License. You + * may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or + * implied. See the License for the specific language governing + * permissions and limitations under the License. See accompanying + * LICENSE file. + */ package org.apache.spark.launcher; @@ -162,7 +180,8 @@ public SparkLauncher setPropertiesFile(String path) { public SparkLauncher setConf(String key, String value) { checkNotNull(key, "key"); checkNotNull(value, "value"); - checkArgument(key.startsWith("spark."), "'key' must start with 'spark.'"); + checkArgument(key.startsWith("spark.") || key.startsWith("snappydata."), + "'key' must start with 'spark.' or 'snappydata.'"); builder.conf.put(key, value); return this; } diff --git a/mllib-local/build.gradle b/mllib-local/build.gradle new file mode 100644 index 000000000000..c4183a09ba73 --- /dev/null +++ b/mllib-local/build.gradle @@ -0,0 +1,29 @@ +/* + * Copyright (c) 2016 SnappyData, Inc. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you + * may not use this file except in compliance with the License. You + * may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or + * implied. See the License for the specific language governing + * permissions and limitations under the License. See accompanying + * LICENSE file. + */ + +description = 'Spark Project ML Local Library' + +dependencies { + compile project(subprojectBase + 'snappy-spark-tags_' + scalaBinaryVersion) + compile(group: 'org.scalanlp', name: 'breeze_' + scalaBinaryVersion, version: '0.11.2') { + exclude(group: 'junit', module: 'junit') + exclude(group: 'org.apache.commons', module: 'commons-math3') + } + compile group: 'org.apache.commons', name: 'commons-math3', version: '3.4.1' + + testCompile group: 'org.mockito', name: 'mockito-core', version: '1.10.19' +} diff --git a/mllib-local/pom.xml b/mllib-local/pom.xml index 1d8f7f4d9bbe..5681b36ea9ab 100644 --- a/mllib-local/pom.xml +++ b/mllib-local/pom.xml @@ -21,7 +21,7 @@ org.apache.spark spark-parent_2.11 - 2.0.1-SNAPSHOT + 2.0.1 ../pom.xml diff --git a/mllib-local/src/main/scala/org/apache/spark/ml/linalg/Matrices.scala b/mllib-local/src/main/scala/org/apache/spark/ml/linalg/Matrices.scala index 0ea687bbccc5..f1ecc65af110 100644 --- a/mllib-local/src/main/scala/org/apache/spark/ml/linalg/Matrices.scala +++ b/mllib-local/src/main/scala/org/apache/spark/ml/linalg/Matrices.scala @@ -454,10 +454,13 @@ class SparseMatrix @Since("2.0.0") ( require(values.length == rowIndices.length, "The number of row indices and values don't match! " + s"values.length: ${values.length}, rowIndices.length: ${rowIndices.length}") - // The Or statement is for the case when the matrix is transposed - require(colPtrs.length == numCols + 1 || colPtrs.length == numRows + 1, "The length of the " + - "column indices should be the number of columns + 1. Currently, colPointers.length: " + - s"${colPtrs.length}, numCols: $numCols") + if (isTransposed) { + require(colPtrs.length == numRows + 1, + s"Expecting ${numRows + 1} colPtrs when numRows = $numRows but got ${colPtrs.length}") + } else { + require(colPtrs.length == numCols + 1, + s"Expecting ${numCols + 1} colPtrs when numCols = $numCols but got ${colPtrs.length}") + } require(values.length == colPtrs.last, "The last value of colPtrs must equal the number of " + s"elements. values.length: ${values.length}, colPtrs.last: ${colPtrs.last}") diff --git a/mllib/build.gradle b/mllib/build.gradle new file mode 100644 index 000000000000..0bcbd130afec --- /dev/null +++ b/mllib/build.gradle @@ -0,0 +1,46 @@ +/* + * Copyright (c) 2016 SnappyData, Inc. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you + * may not use this file except in compliance with the License. You + * may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or + * implied. See the License for the specific language governing + * permissions and limitations under the License. See accompanying + * LICENSE file. + */ + +description = 'Spark Project ML Library' + +dependencies { + compile project(subprojectBase + 'snappy-spark-mllib-local_' + scalaBinaryVersion) + compile project(subprojectBase + 'snappy-spark-core_' + scalaBinaryVersion) + compile project(subprojectBase + 'snappy-spark-streaming_' + scalaBinaryVersion) + compile project(subprojectBase + 'snappy-spark-sql_' + scalaBinaryVersion) + compile project(subprojectBase + 'snappy-spark-graphx_' + scalaBinaryVersion) + compile project(subprojectBase + 'snappy-spark-tags_' + scalaBinaryVersion) + + compile(group: 'org.scalanlp', name: 'breeze_' + scalaBinaryVersion, version: '0.11.2') { + exclude(group: 'junit', module: 'junit') + exclude(group: 'org.apache.commons', module: 'commons-math3') + } + compile group: 'org.apache.commons', name: 'commons-math3', version: '3.4.1' + compile(group: 'org.jpmml', name: 'pmml-model', version: '1.2.15') { + exclude(group: 'org.jpmml', module: 'pmml-agent') + } + + testCompile project(path: subprojectBase + 'snappy-spark-mllib-local_' + scalaBinaryVersion, configuration: 'testOutput') + testCompile project(path: subprojectBase + 'snappy-spark-core_' + scalaBinaryVersion, configuration: 'testOutput') + testCompile project(path: subprojectBase + 'snappy-spark-streaming_' + scalaBinaryVersion, configuration: 'testOutput') +} + +// TODO: netlib-lgpl profile + +// fix scala+java test ordering +sourceSets.test.scala.srcDir 'src/test/java' +sourceSets.test.java.srcDirs = [] diff --git a/mllib/pom.xml b/mllib/pom.xml index 40fde1bab7ad..80ab2e01a457 100644 --- a/mllib/pom.xml +++ b/mllib/pom.xml @@ -21,7 +21,7 @@ org.apache.spark spark-parent_2.11 - 2.0.1-SNAPSHOT + 2.0.1 ../pom.xml diff --git a/mllib/src/main/scala/org/apache/spark/ml/ann/Layer.scala b/mllib/src/main/scala/org/apache/spark/ml/ann/Layer.scala index 576584c62797..88909a9fb953 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/ann/Layer.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/ann/Layer.scala @@ -26,6 +26,7 @@ import org.apache.spark.mllib.linalg.{Vector => OldVector, Vectors => OldVectors import org.apache.spark.mllib.linalg.VectorImplicits._ import org.apache.spark.mllib.optimization._ import org.apache.spark.rdd.RDD +import org.apache.spark.storage.StorageLevel import org.apache.spark.util.random.XORShiftRandom /** @@ -810,9 +811,13 @@ private[ml] class FeedForwardTrainer( getWeights } // TODO: deprecate standard optimizer because it needs Vector - val newWeights = optimizer.optimize(dataStacker.stack(data).map { v => + val trainData = dataStacker.stack(data).map { v => (v._1, OldVectors.fromML(v._2)) - }, w) + } + val handlePersistence = trainData.getStorageLevel == StorageLevel.NONE + if (handlePersistence) trainData.persist(StorageLevel.MEMORY_AND_DISK) + val newWeights = optimizer.optimize(trainData, w) + if (handlePersistence) trainData.unpersist() topology.model(newWeights) } diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/DecisionTreeClassifier.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/DecisionTreeClassifier.scala index 71293017e052..bb192ab5f25a 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/classification/DecisionTreeClassifier.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/classification/DecisionTreeClassifier.scala @@ -84,6 +84,13 @@ class DecisionTreeClassifier @Since("1.4.0") ( val categoricalFeatures: Map[Int, Int] = MetadataUtils.getCategoricalFeatures(dataset.schema($(featuresCol))) val numClasses: Int = getNumClasses(dataset) + + if (isDefined(thresholds)) { + require($(thresholds).length == numClasses, this.getClass.getSimpleName + + ".train() called with non-matching numClasses and thresholds.length." + + s" numClasses=$numClasses, but thresholds has length ${$(thresholds).length}") + } + val oldDataset: RDD[LabeledPoint] = extractLabeledPoints(dataset, numClasses) val strategy = getOldStrategy(categoricalFeatures, numClasses) diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala index 91eee0e69d63..cca337487d6d 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala @@ -292,6 +292,12 @@ class LogisticRegression @Since("1.2.0") ( val numClasses = histogram.length val numFeatures = summarizer.mean.size + if (isDefined(thresholds)) { + require($(thresholds).length == numClasses, this.getClass.getSimpleName + + ".train() called with non-matching numClasses and thresholds.length." + + s" numClasses=$numClasses, but thresholds has length ${$(thresholds).length}") + } + instr.logNumClasses(numClasses) instr.logNumFeatures(numFeatures) diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/NaiveBayes.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/NaiveBayes.scala index ab977c8802e3..f939a1c6808e 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/classification/NaiveBayes.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/classification/NaiveBayes.scala @@ -101,6 +101,14 @@ class NaiveBayes @Since("1.5.0") ( setDefault(modelType -> OldNaiveBayes.Multinomial) override protected def train(dataset: Dataset[_]): NaiveBayesModel = { + val numClasses = getNumClasses(dataset) + + if (isDefined(thresholds)) { + require($(thresholds).length == numClasses, this.getClass.getSimpleName + + ".train() called with non-matching numClasses and thresholds.length." + + s" numClasses=$numClasses, but thresholds has length ${$(thresholds).length}") + } + val oldDataset: RDD[OldLabeledPoint] = extractLabeledPoints(dataset).map(OldLabeledPoint.fromML) val oldModel = OldNaiveBayes.train(oldDataset, $(smoothing), $(modelType)) diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/RandomForestClassifier.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/RandomForestClassifier.scala index 4ab132e5f294..52345b0626c4 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/classification/RandomForestClassifier.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/classification/RandomForestClassifier.scala @@ -100,6 +100,13 @@ class RandomForestClassifier @Since("1.4.0") ( val categoricalFeatures: Map[Int, Int] = MetadataUtils.getCategoricalFeatures(dataset.schema($(featuresCol))) val numClasses: Int = getNumClasses(dataset) + + if (isDefined(thresholds)) { + require($(thresholds).length == numClasses, this.getClass.getSimpleName + + ".train() called with non-matching numClasses and thresholds.length." + + s" numClasses=$numClasses, but thresholds has length ${$(thresholds).length}") + } + val oldDataset: RDD[LabeledPoint] = extractLabeledPoints(dataset, numClasses) val strategy = super.getOldStrategy(categoricalFeatures, numClasses, OldAlgo.Classification, getOldImpurity) diff --git a/mllib/src/main/scala/org/apache/spark/ml/clustering/BisectingKMeans.scala b/mllib/src/main/scala/org/apache/spark/ml/clustering/BisectingKMeans.scala index afb1080b9b7d..a97bd0fb16fd 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/clustering/BisectingKMeans.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/clustering/BisectingKMeans.scala @@ -99,6 +99,7 @@ class BisectingKMeansModel private[ml] ( @Since("2.0.0") override def transform(dataset: Dataset[_]): DataFrame = { + transformSchema(dataset.schema, logging = true) val predictUDF = udf((vector: Vector) => predict(vector)) dataset.withColumn($(predictionCol), predictUDF(col($(featuresCol)))) } @@ -222,6 +223,7 @@ class BisectingKMeans @Since("2.0.0") ( @Since("2.0.0") override def fit(dataset: Dataset[_]): BisectingKMeansModel = { + transformSchema(dataset.schema, logging = true) val rdd: RDD[OldVector] = dataset.select(col($(featuresCol))).rdd.map { case Row(point: Vector) => OldVectors.fromML(point) } diff --git a/mllib/src/main/scala/org/apache/spark/ml/clustering/GaussianMixture.scala b/mllib/src/main/scala/org/apache/spark/ml/clustering/GaussianMixture.scala index 81749055c761..69f060ad7711 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/clustering/GaussianMixture.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/clustering/GaussianMixture.scala @@ -30,7 +30,7 @@ import org.apache.spark.ml.stat.distribution.MultivariateGaussian import org.apache.spark.ml.util._ import org.apache.spark.mllib.clustering.{GaussianMixture => MLlibGM} import org.apache.spark.mllib.linalg.{Matrices => OldMatrices, Matrix => OldMatrix, - Vector => OldVector, Vectors => OldVectors, VectorUDT => OldVectorUDT} + Vector => OldVector, Vectors => OldVectors} import org.apache.spark.rdd.RDD import org.apache.spark.sql.{DataFrame, Dataset, Row, SparkSession} import org.apache.spark.sql.functions.{col, udf} @@ -61,9 +61,9 @@ private[clustering] trait GaussianMixtureParams extends Params with HasMaxIter w * @return output schema */ protected def validateAndTransformSchema(schema: StructType): StructType = { - SchemaUtils.checkColumnType(schema, $(featuresCol), new OldVectorUDT) + SchemaUtils.checkColumnType(schema, $(featuresCol), new VectorUDT) SchemaUtils.appendColumn(schema, $(predictionCol), IntegerType) - SchemaUtils.appendColumn(schema, $(probabilityCol), new OldVectorUDT) + SchemaUtils.appendColumn(schema, $(probabilityCol), new VectorUDT) } } @@ -95,6 +95,7 @@ class GaussianMixtureModel private[ml] ( @Since("2.0.0") override def transform(dataset: Dataset[_]): DataFrame = { + transformSchema(dataset.schema, logging = true) val predUDF = udf((vector: Vector) => predict(vector)) val probUDF = udf((vector: Vector) => predictProbability(vector)) dataset.withColumn($(predictionCol), predUDF(col($(featuresCol)))) @@ -317,6 +318,7 @@ class GaussianMixture @Since("2.0.0") ( @Since("2.0.0") override def fit(dataset: Dataset[_]): GaussianMixtureModel = { + transformSchema(dataset.schema, logging = true) val rdd: RDD[OldVector] = dataset.select(col($(featuresCol))).rdd.map { case Row(point: Vector) => OldVectors.fromML(point) } diff --git a/mllib/src/main/scala/org/apache/spark/ml/clustering/KMeans.scala b/mllib/src/main/scala/org/apache/spark/ml/clustering/KMeans.scala index 9fb7d6a9a21a..6c46be719674 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/clustering/KMeans.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/clustering/KMeans.scala @@ -120,6 +120,7 @@ class KMeansModel private[ml] ( @Since("2.0.0") override def transform(dataset: Dataset[_]): DataFrame = { + transformSchema(dataset.schema, logging = true) val predictUDF = udf((vector: Vector) => predict(vector)) dataset.withColumn($(predictionCol), predictUDF(col($(featuresCol)))) } @@ -304,6 +305,7 @@ class KMeans @Since("1.5.0") ( @Since("2.0.0") override def fit(dataset: Dataset[_]): KMeansModel = { + transformSchema(dataset.schema, logging = true) val rdd: RDD[OldVector] = dataset.select(col($(featuresCol))).rdd.map { case Row(point: Vector) => OldVectors.fromML(point) } diff --git a/mllib/src/main/scala/org/apache/spark/ml/clustering/LDA.scala b/mllib/src/main/scala/org/apache/spark/ml/clustering/LDA.scala index 034f2c3fa2fd..8e233255b4e2 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/clustering/LDA.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/clustering/LDA.scala @@ -18,6 +18,9 @@ package org.apache.spark.ml.clustering import org.apache.hadoop.fs.Path +import org.json4s.DefaultFormats +import org.json4s.JsonAST.JObject +import org.json4s.jackson.JsonMethods._ import org.apache.spark.annotation.{DeveloperApi, Experimental, Since} import org.apache.spark.internal.Logging @@ -26,19 +29,21 @@ import org.apache.spark.ml.linalg.{Matrix, Vector, Vectors, VectorUDT} import org.apache.spark.ml.param._ import org.apache.spark.ml.param.shared.{HasCheckpointInterval, HasFeaturesCol, HasMaxIter, HasSeed} import org.apache.spark.ml.util._ +import org.apache.spark.ml.util.DefaultParamsReader.Metadata import org.apache.spark.mllib.clustering.{DistributedLDAModel => OldDistributedLDAModel, EMLDAOptimizer => OldEMLDAOptimizer, LDA => OldLDA, LDAModel => OldLDAModel, LDAOptimizer => OldLDAOptimizer, LocalLDAModel => OldLocalLDAModel, OnlineLDAOptimizer => OldOnlineLDAOptimizer} import org.apache.spark.mllib.impl.PeriodicCheckpointer -import org.apache.spark.mllib.linalg.{Matrices => OldMatrices, Vector => OldVector, - Vectors => OldVectors} +import org.apache.spark.mllib.linalg.{Vector => OldVector, Vectors => OldVectors} import org.apache.spark.mllib.linalg.MatrixImplicits._ import org.apache.spark.mllib.linalg.VectorImplicits._ +import org.apache.spark.mllib.util.MLUtils import org.apache.spark.rdd.RDD import org.apache.spark.sql.{DataFrame, Dataset, Row, SparkSession} import org.apache.spark.sql.functions.{col, monotonically_increasing_id, udf} import org.apache.spark.sql.types.StructType +import org.apache.spark.util.VersionUtils private[clustering] trait LDAParams extends Params with HasFeaturesCol with HasMaxIter @@ -80,6 +85,7 @@ private[clustering] trait LDAParams extends Params with HasFeaturesCol with HasM * - Values should be >= 0 * - default = uniformly (1.0 / k), following the implementation from * [[https://github.com/Blei-Lab/onlineldavb]]. + * * @group param */ @Since("1.6.0") @@ -121,6 +127,7 @@ private[clustering] trait LDAParams extends Params with HasFeaturesCol with HasM * - Value should be >= 0 * - default = (1.0 / k), following the implementation from * [[https://github.com/Blei-Lab/onlineldavb]]. + * * @group param */ @Since("1.6.0") @@ -354,6 +361,39 @@ private[clustering] trait LDAParams extends Params with HasFeaturesCol with HasM } } +private object LDAParams { + + /** + * Equivalent to [[DefaultParamsReader.getAndSetParams()]], but handles [[LDA]] and [[LDAModel]] + * formats saved with Spark 1.6, which differ from the formats in Spark 2.0+. + * + * @param model [[LDA]] or [[LDAModel]] instance. This instance will be modified with + * [[Param]] values extracted from metadata. + * @param metadata Loaded model metadata + */ + def getAndSetParams(model: LDAParams, metadata: Metadata): Unit = { + VersionUtils.majorMinorVersion(metadata.sparkVersion) match { + case (1, 6) => + implicit val format = DefaultFormats + metadata.params match { + case JObject(pairs) => + pairs.foreach { case (paramName, jsonValue) => + val origParam = + if (paramName == "topicDistribution") "topicDistributionCol" else paramName + val param = model.getParam(origParam) + val value = param.jsonDecode(compact(render(jsonValue))) + model.set(param, value) + } + case _ => + throw new IllegalArgumentException( + s"Cannot recognize JSON metadata: ${metadata.metadataJson}.") + } + case _ => // 2.0+ + DefaultParamsReader.getAndSetParams(model, metadata) + } + } +} + /** * :: Experimental :: @@ -414,11 +454,11 @@ sealed abstract class LDAModel private[ml] ( val transformer = oldLocalModel.getTopicDistributionMethod(sparkSession.sparkContext) val t = udf { (v: Vector) => transformer(OldVectors.fromML(v)).asML } - dataset.withColumn($(topicDistributionCol), t(col($(featuresCol)))).toDF + dataset.withColumn($(topicDistributionCol), t(col($(featuresCol)))).toDF() } else { logWarning("LDAModel.transform was called without any output columns. Set an output column" + " such as topicDistributionCol to produce results.") - dataset.toDF + dataset.toDF() } } @@ -574,18 +614,16 @@ object LocalLDAModel extends MLReadable[LocalLDAModel] { val metadata = DefaultParamsReader.loadMetadata(path, sc, className) val dataPath = new Path(path, "data").toString val data = sparkSession.read.parquet(dataPath) - .select("vocabSize", "topicsMatrix", "docConcentration", "topicConcentration", - "gammaShape") - .head() - val vocabSize = data.getAs[Int](0) - val topicsMatrix = data.getAs[Matrix](1) - val docConcentration = data.getAs[Vector](2) - val topicConcentration = data.getAs[Double](3) - val gammaShape = data.getAs[Double](4) + val vectorConverted = MLUtils.convertVectorColumnsToML(data, "docConcentration") + val matrixConverted = MLUtils.convertMatrixColumnsToML(vectorConverted, "topicsMatrix") + val Row(vocabSize: Int, topicsMatrix: Matrix, docConcentration: Vector, + topicConcentration: Double, gammaShape: Double) = + matrixConverted.select("vocabSize", "topicsMatrix", "docConcentration", + "topicConcentration", "gammaShape").head() val oldModel = new OldLocalLDAModel(topicsMatrix, docConcentration, topicConcentration, gammaShape) val model = new LocalLDAModel(metadata.uid, vocabSize, oldModel, sparkSession) - DefaultParamsReader.getAndSetParams(model, metadata) + LDAParams.getAndSetParams(model, metadata) model } } @@ -731,9 +769,9 @@ object DistributedLDAModel extends MLReadable[DistributedLDAModel] { val metadata = DefaultParamsReader.loadMetadata(path, sc, className) val modelPath = new Path(path, "oldModel").toString val oldModel = OldDistributedLDAModel.load(sc, modelPath) - val model = new DistributedLDAModel( - metadata.uid, oldModel.vocabSize, oldModel, sparkSession, None) - DefaultParamsReader.getAndSetParams(model, metadata) + val model = new DistributedLDAModel(metadata.uid, oldModel.vocabSize, + oldModel, sparkSession, None) + LDAParams.getAndSetParams(model, metadata) model } } @@ -881,7 +919,7 @@ class LDA @Since("1.6.0") ( } @Since("2.0.0") -object LDA extends DefaultParamsReadable[LDA] { +object LDA extends MLReadable[LDA] { /** Get dataset for spark.mllib LDA */ private[clustering] def getOldDataset( @@ -896,6 +934,20 @@ object LDA extends DefaultParamsReadable[LDA] { } } + private class LDAReader extends MLReader[LDA] { + + private val className = classOf[LDA].getName + + override def load(path: String): LDA = { + val metadata = DefaultParamsReader.loadMetadata(path, sc, className) + val model = new LDA(metadata.uid) + LDAParams.getAndSetParams(model, metadata) + model + } + } + + override def read: MLReader[LDA] = new LDAReader + @Since("2.0.0") override def load(path: String): LDA = super.load(path) } diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/Interaction.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/Interaction.scala index 7b11f86279b9..96d0bdee9e2b 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/Interaction.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/Interaction.scala @@ -68,6 +68,7 @@ class Interaction @Since("1.6.0") (@Since("1.6.0") override val uid: String) ext @Since("2.0.0") override def transform(dataset: Dataset[_]): DataFrame = { + transformSchema(dataset.schema, logging = true) val inputFeatures = $(inputCols).map(c => dataset.schema(c)) val featureEncoders = getFeatureEncoders(inputFeatures) val featureAttrs = getFeatureAttrs(inputFeatures) diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/MinMaxScaler.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/MinMaxScaler.scala index 9ed8d83324cf..068f11a2a573 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/MinMaxScaler.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/MinMaxScaler.scala @@ -170,6 +170,7 @@ class MinMaxScalerModel private[ml] ( @Since("2.0.0") override def transform(dataset: Dataset[_]): DataFrame = { + transformSchema(dataset.schema, logging = true) val originalRange = (originalMax.asBreeze - originalMin.asBreeze).toArray val minArray = originalMin.toArray diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/PolynomialExpansion.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/PolynomialExpansion.scala index 72fb35bd79ad..6e872c1f2cad 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/PolynomialExpansion.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/PolynomialExpansion.scala @@ -19,6 +19,8 @@ package org.apache.spark.ml.feature import scala.collection.mutable +import org.apache.commons.math3.util.CombinatoricsUtils + import org.apache.spark.annotation.Since import org.apache.spark.ml.UnaryTransformer import org.apache.spark.ml.linalg._ @@ -84,12 +86,12 @@ class PolynomialExpansion @Since("1.4.0") (@Since("1.4.0") override val uid: Str @Since("1.6.0") object PolynomialExpansion extends DefaultParamsReadable[PolynomialExpansion] { - private def choose(n: Int, k: Int): Int = { - Range(n, n - k, -1).product / Range(k, 1, -1).product + private def getPolySize(numFeatures: Int, degree: Int): Int = { + val n = CombinatoricsUtils.binomialCoefficient(numFeatures + degree, degree) + require(n <= Integer.MAX_VALUE) + n.toInt } - private def getPolySize(numFeatures: Int, degree: Int): Int = choose(numFeatures + degree, degree) - private def expandDense( values: Array[Double], lastIdx: Int, diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/QuantileDiscretizer.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/QuantileDiscretizer.scala index 9a636bd8a5e4..e09800877c69 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/QuantileDiscretizer.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/QuantileDiscretizer.scala @@ -97,7 +97,7 @@ final class QuantileDiscretizer @Since("1.6.0") (@Since("1.6.0") override val ui @Since("1.6.0") override def transformSchema(schema: StructType): StructType = { - SchemaUtils.checkColumnType(schema, $(inputCol), DoubleType) + SchemaUtils.checkNumericType(schema, $(inputCol)) val inputFields = schema.fields require(inputFields.forall(_.name != $(outputCol)), s"Output column ${$(outputCol)} already exists.") @@ -108,12 +108,18 @@ final class QuantileDiscretizer @Since("1.6.0") (@Since("1.6.0") override val ui @Since("2.0.0") override def fit(dataset: Dataset[_]): Bucketizer = { + transformSchema(dataset.schema, logging = true) val splits = dataset.stat.approxQuantile($(inputCol), (0.0 to 1.0 by 1.0/$(numBuckets)).toArray, $(relativeError)) splits(0) = Double.NegativeInfinity splits(splits.length - 1) = Double.PositiveInfinity - val bucketizer = new Bucketizer(uid).setSplits(splits) + val distinctSplits = splits.distinct + if (splits.length != distinctSplits.length) { + log.warn(s"Some quantiles were identical. Bucketing to ${distinctSplits.length - 1}" + + s" buckets as a result.") + } + val bucketizer = new Bucketizer(uid).setSplits(distinctSplits.sorted) copyValues(bucketizer.setParent(this)) } diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/RFormula.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/RFormula.scala index c95dacfce8cf..2ee899bcca56 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/RFormula.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/RFormula.scala @@ -112,6 +112,7 @@ class RFormula @Since("1.5.0") (@Since("1.5.0") override val uid: String) @Since("2.0.0") override def fit(dataset: Dataset[_]): RFormulaModel = { + transformSchema(dataset.schema, logging = true) require(isDefined(formula), "Formula must be defined first.") val parsedFormula = RFormulaParser.parse($(formula)) val resolvedFormula = parsedFormula.resolve(dataset.schema) diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/SQLTransformer.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/SQLTransformer.scala index 289037640fd4..259be2679ce1 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/SQLTransformer.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/SQLTransformer.scala @@ -63,6 +63,7 @@ class SQLTransformer @Since("1.6.0") (@Since("1.6.0") override val uid: String) @Since("2.0.0") override def transform(dataset: Dataset[_]): DataFrame = { + transformSchema(dataset.schema, logging = true) val tableName = Identifiable.randomUID(uid) dataset.createOrReplaceTempView(tableName) val realStatement = $(statement).replace(tableIdentifier, tableName) diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/StringIndexer.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/StringIndexer.scala index fe79e2ec808a..80fe46796f80 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/StringIndexer.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/StringIndexer.scala @@ -85,6 +85,7 @@ class StringIndexer @Since("1.4.0") ( @Since("2.0.0") override def fit(dataset: Dataset[_]): StringIndexerModel = { + transformSchema(dataset.schema, logging = true) val counts = dataset.select(col($(inputCol)).cast(StringType)) .rdd .map(_.getString(0)) @@ -160,7 +161,7 @@ class StringIndexerModel ( "Skip StringIndexerModel.") return dataset.toDF } - validateAndTransformSchema(dataset.schema) + transformSchema(dataset.schema, logging = true) val indexer = udf { label: String => if (labelToIndex.contains(label)) { @@ -305,6 +306,7 @@ class IndexToString private[ml] (@Since("1.5.0") override val uid: String) @Since("2.0.0") override def transform(dataset: Dataset[_]): DataFrame = { + transformSchema(dataset.schema, logging = true) val inputColSchema = dataset.schema($(inputCol)) // If the labels array is empty use column metadata val values = if (!isDefined(labels) || $(labels).isEmpty) { diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/VectorAssembler.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/VectorAssembler.scala index 142a2ae44c69..ca900536bc7b 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/VectorAssembler.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/VectorAssembler.scala @@ -51,6 +51,7 @@ class VectorAssembler @Since("1.4.0") (@Since("1.4.0") override val uid: String) @Since("2.0.0") override def transform(dataset: Dataset[_]): DataFrame = { + transformSchema(dataset.schema, logging = true) // Schema transformation. val schema = dataset.schema lazy val first = dataset.toDF.first() diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/Word2Vec.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/Word2Vec.scala index c2b434c3d5cb..d53f3df514df 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/Word2Vec.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/Word2Vec.scala @@ -108,7 +108,8 @@ private[feature] trait Word2VecBase extends Params * Validate and transform the input schema. */ protected def validateAndTransformSchema(schema: StructType): StructType = { - SchemaUtils.checkColumnType(schema, $(inputCol), new ArrayType(StringType, true)) + val typeCandidates = List(new ArrayType(StringType, true), new ArrayType(StringType, false)) + SchemaUtils.checkColumnTypes(schema, $(inputCol), typeCandidates) SchemaUtils.appendColumn(schema, $(outputCol), new VectorUDT) } } @@ -221,24 +222,26 @@ class Word2VecModel private[ml] ( } /** - * Find "num" number of words closest in similarity to the given word. - * Returns a dataframe with the words and the cosine similarities between the - * synonyms and the given word. + * Find "num" number of words closest in similarity to the given word, not + * including the word itself. Returns a dataframe with the words and the + * cosine similarities between the synonyms and the given word. */ @Since("1.5.0") def findSynonyms(word: String, num: Int): DataFrame = { - findSynonyms(wordVectors.transform(word), num) + val spark = SparkSession.builder().getOrCreate() + spark.createDataFrame(wordVectors.findSynonyms(word, num)).toDF("word", "similarity") } /** - * Find "num" number of words closest to similarity to the given vector representation - * of the word. Returns a dataframe with the words and the cosine similarities between the - * synonyms and the given word vector. + * Find "num" number of words whose vector representation most similar to the supplied vector. + * If the supplied vector is the vector representation of a word in the model's vocabulary, + * that word will be in the results. Returns a dataframe with the words and the cosine + * similarities between the synonyms and the given word vector. */ @Since("2.0.0") - def findSynonyms(word: Vector, num: Int): DataFrame = { + def findSynonyms(vec: Vector, num: Int): DataFrame = { val spark = SparkSession.builder().getOrCreate() - spark.createDataFrame(wordVectors.findSynonyms(word, num)).toDF("word", "similarity") + spark.createDataFrame(wordVectors.findSynonyms(vec, num)).toDF("word", "similarity") } /** @group setParam */ diff --git a/mllib/src/main/scala/org/apache/spark/ml/regression/AFTSurvivalRegression.scala b/mllib/src/main/scala/org/apache/spark/ml/regression/AFTSurvivalRegression.scala index 700a92cc261b..3ebaba636883 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/regression/AFTSurvivalRegression.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/regression/AFTSurvivalRegression.scala @@ -196,7 +196,7 @@ class AFTSurvivalRegression @Since("1.6.0") (@Since("1.6.0") override val uid: S @Since("2.0.0") override def fit(dataset: Dataset[_]): AFTSurvivalRegressionModel = { - validateAndTransformSchema(dataset.schema, fitting = true) + transformSchema(dataset.schema, logging = true) val instances = extractAFTPoints(dataset) val handlePersistence = dataset.rdd.getStorageLevel == StorageLevel.NONE if (handlePersistence) instances.persist(StorageLevel.MEMORY_AND_DISK) @@ -332,7 +332,7 @@ class AFTSurvivalRegressionModel private[ml] ( @Since("2.0.0") override def transform(dataset: Dataset[_]): DataFrame = { - transformSchema(dataset.schema) + transformSchema(dataset.schema, logging = true) val predictUDF = udf { features: Vector => predict(features) } val predictQuantilesUDF = udf { features: Vector => predictQuantiles(features)} if (hasQuantilesCol) { diff --git a/mllib/src/main/scala/org/apache/spark/ml/regression/IsotonicRegression.scala b/mllib/src/main/scala/org/apache/spark/ml/regression/IsotonicRegression.scala index 35396446edc1..cd7b4f2a9c56 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/regression/IsotonicRegression.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/regression/IsotonicRegression.scala @@ -164,7 +164,7 @@ class IsotonicRegression @Since("1.5.0") (@Since("1.5.0") override val uid: Stri @Since("2.0.0") override def fit(dataset: Dataset[_]): IsotonicRegressionModel = { - validateAndTransformSchema(dataset.schema, fitting = true) + transformSchema(dataset.schema, logging = true) // Extract columns from data. If dataset is persisted, do not persist oldDataset. val instances = extractWeightedLabeledPoints(dataset) val handlePersistence = dataset.rdd.getStorageLevel == StorageLevel.NONE @@ -234,6 +234,7 @@ class IsotonicRegressionModel private[ml] ( @Since("2.0.0") override def transform(dataset: Dataset[_]): DataFrame = { + transformSchema(dataset.schema, logging = true) val predict = dataset.schema($(featuresCol)).dataType match { case DoubleType => udf { feature: Double => oldModel.predict(feature) } diff --git a/mllib/src/main/scala/org/apache/spark/ml/source/libsvm/LibSVMRelation.scala b/mllib/src/main/scala/org/apache/spark/ml/source/libsvm/LibSVMRelation.scala index 034223e11538..ac95b9272f1d 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/source/libsvm/LibSVMRelation.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/source/libsvm/LibSVMRelation.scala @@ -25,6 +25,7 @@ import org.apache.hadoop.io.{NullWritable, Text} import org.apache.hadoop.mapreduce.{Job, RecordWriter, TaskAttemptContext} import org.apache.hadoop.mapreduce.lib.output.TextOutputFormat +import org.apache.spark.TaskContext import org.apache.spark.ml.feature.LabeledPoint import org.apache.spark.ml.linalg.{Vector, Vectors, VectorUDT} import org.apache.spark.mllib.util.MLUtils @@ -160,8 +161,10 @@ private[libsvm] class LibSVMFileFormat extends TextBasedFileFormat with DataSour sparkSession.sparkContext.broadcast(new SerializableConfiguration(hadoopConf)) (file: PartitionedFile) => { - val points = - new HadoopFileLinesReader(file, broadcastedHadoopConf.value.value) + val linesReader = new HadoopFileLinesReader(file, broadcastedHadoopConf.value.value) + Option(TaskContext.get()).foreach(_.addTaskCompletionListener(_ => linesReader.close())) + + val points = linesReader .map(_.toString.trim) .filterNot(line => line.isEmpty || line.startsWith("#")) .map { line => diff --git a/mllib/src/main/scala/org/apache/spark/mllib/api/python/Word2VecModelWrapper.scala b/mllib/src/main/scala/org/apache/spark/mllib/api/python/Word2VecModelWrapper.scala index 4b4ed2291d13..5cbfbff3e4a6 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/api/python/Word2VecModelWrapper.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/api/python/Word2VecModelWrapper.scala @@ -43,18 +43,34 @@ private[python] class Word2VecModelWrapper(model: Word2VecModel) { rdd.rdd.map(model.transform) } + /** + * Finds synonyms of a word; do not include the word itself in results. + * @param word a word + * @param num number of synonyms to find + * @return a list consisting of a list of words and a vector of cosine similarities + */ def findSynonyms(word: String, num: Int): JList[Object] = { - val vec = transform(word) - findSynonyms(vec, num) + prepareResult(model.findSynonyms(word, num)) } + /** + * Finds words similar to the the vector representation of a word without + * filtering results. + * @param vector a vector + * @param num number of synonyms to find + * @return a list consisting of a list of words and a vector of cosine similarities + */ def findSynonyms(vector: Vector, num: Int): JList[Object] = { - val result = model.findSynonyms(vector, num) + prepareResult(model.findSynonyms(vector, num)) + } + + private def prepareResult(result: Array[(String, Double)]) = { val similarity = Vectors.dense(result.map(_._2)) val words = result.map(_._1) List(words, similarity).map(_.asInstanceOf[Object]).asJava } + def getVectors: JMap[String, JList[Float]] = { model.getVectors.map { case (k, v) => (k, v.toList.asJava) diff --git a/mllib/src/main/scala/org/apache/spark/mllib/feature/Word2Vec.scala b/mllib/src/main/scala/org/apache/spark/mllib/feature/Word2Vec.scala index bc75646d532d..761996f44739 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/feature/Word2Vec.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/feature/Word2Vec.scala @@ -518,7 +518,7 @@ class Word2VecModel private[spark] ( } /** - * Find synonyms of a word + * Find synonyms of a word; do not include the word itself in results. * @param word a word * @param num number of synonyms to find * @return array of (word, cosineSimilarity) @@ -526,17 +526,34 @@ class Word2VecModel private[spark] ( @Since("1.1.0") def findSynonyms(word: String, num: Int): Array[(String, Double)] = { val vector = transform(word) - findSynonyms(vector, num) + findSynonyms(vector, num, Some(word)) } /** - * Find synonyms of the vector representation of a word + * Find synonyms of the vector representation of a word, possibly + * including any words in the model vocabulary whose vector respresentation + * is the supplied vector. * @param vector vector representation of a word * @param num number of synonyms to find * @return array of (word, cosineSimilarity) */ @Since("1.1.0") def findSynonyms(vector: Vector, num: Int): Array[(String, Double)] = { + findSynonyms(vector, num, None) + } + + /** + * Find synonyms of the vector representation of a word, rejecting + * words identical to the value of wordOpt, if one is supplied. + * @param vector vector representation of a word + * @param num number of synonyms to find + * @param wordOpt optionally, a word to reject from the results list + * @return array of (word, cosineSimilarity) + */ + private def findSynonyms( + vector: Vector, + num: Int, + wordOpt: Option[String]): Array[(String, Double)] = { require(num > 0, "Number of similar words should > 0") // TODO: optimize top-k val fVector = vector.toArray.map(_.toFloat) @@ -563,12 +580,14 @@ class Word2VecModel private[spark] ( ind += 1 } - wordList.zip(cosVec) - .toSeq - .sortBy(-_._2) - .take(num + 1) - .tail - .toArray + val scored = wordList.zip(cosVec).toSeq.sortBy(-_._2) + + val filtered = wordOpt match { + case Some(w) => scored.take(num + 1).filter(tup => w != tup._1) + case None => scored + } + + filtered.take(num).toArray } /** diff --git a/mllib/src/main/scala/org/apache/spark/mllib/linalg/Matrices.scala b/mllib/src/main/scala/org/apache/spark/mllib/linalg/Matrices.scala index e8f34388cd9f..4c39cf17f427 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/linalg/Matrices.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/linalg/Matrices.scala @@ -572,10 +572,13 @@ class SparseMatrix @Since("1.3.0") ( require(values.length == rowIndices.length, "The number of row indices and values don't match! " + s"values.length: ${values.length}, rowIndices.length: ${rowIndices.length}") - // The Or statement is for the case when the matrix is transposed - require(colPtrs.length == numCols + 1 || colPtrs.length == numRows + 1, "The length of the " + - "column indices should be the number of columns + 1. Currently, colPointers.length: " + - s"${colPtrs.length}, numCols: $numCols") + if (isTransposed) { + require(colPtrs.length == numRows + 1, + s"Expecting ${numRows + 1} colPtrs when numRows = $numRows but got ${colPtrs.length}") + } else { + require(colPtrs.length == numCols + 1, + s"Expecting ${numCols + 1} colPtrs when numCols = $numCols but got ${colPtrs.length}") + } require(values.length == colPtrs.last, "The last value of colPtrs must equal the number of " + s"elements. values.length: ${values.length}, colPtrs.last: ${colPtrs.last}") diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/PolynomialExpansionSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/PolynomialExpansionSuite.scala index 8e1f9ddb36cb..9ecd321b128f 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/feature/PolynomialExpansionSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/feature/PolynomialExpansionSuite.scala @@ -116,5 +116,29 @@ class PolynomialExpansionSuite .setDegree(3) testDefaultReadWrite(t) } + + test("SPARK-17027. Integer overflow in PolynomialExpansion.getPolySize") { + val data: Array[(Vector, Int, Int)] = Array( + (Vectors.dense(1.0, 2.0, 3.0, 4.0, 5.0), 3002, 4367), + (Vectors.sparse(5, Seq((0, 1.0), (4, 5.0))), 3002, 4367), + (Vectors.dense(1.0, 2.0, 3.0, 4.0, 5.0, 6.0), 8007, 12375) + ) + + val df = spark.createDataFrame(data) + .toDF("features", "expectedPoly10size", "expectedPoly11size") + + val t = new PolynomialExpansion() + .setInputCol("features") + .setOutputCol("polyFeatures") + + for (i <- Seq(10, 11)) { + val transformed = t.setDegree(i) + .transform(df) + .select(s"expectedPoly${i}size", "polyFeatures") + .rdd.map { case Row(expected: Int, v: Vector) => expected == v.size } + + assert(transformed.collect.forall(identity)) + } + } } diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/QuantileDiscretizerSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/QuantileDiscretizerSuite.scala index b73dbd62328c..18f1e89ee814 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/feature/QuantileDiscretizerSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/feature/QuantileDiscretizerSuite.scala @@ -52,6 +52,25 @@ class QuantileDiscretizerSuite "Bucket sizes are not within expected relative error tolerance.") } + test("Test Bucketizer on duplicated splits") { + val spark = this.spark + import spark.implicits._ + + val datasetSize = 12 + val numBuckets = 5 + val df = sc.parallelize(Array(1.0, 3.0, 2.0, 1.0, 1.0, 2.0, 3.0, 2.0, 2.0, 2.0, 1.0, 3.0)) + .map(Tuple1.apply).toDF("input") + val discretizer = new QuantileDiscretizer() + .setInputCol("input") + .setOutputCol("result") + .setNumBuckets(numBuckets) + val result = discretizer.fit(df).transform(df) + + val observedNumBuckets = result.select("result").distinct.count + assert(2 <= observedNumBuckets && observedNumBuckets <= numBuckets, + "Observed number of buckets are not within expected range.") + } + test("Test transform method on unseen data") { val spark = this.spark import spark.implicits._ diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/StringIndexerSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/StringIndexerSuite.scala index c221d4aa558a..b478fea5e74e 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/feature/StringIndexerSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/feature/StringIndexerSuite.scala @@ -120,12 +120,20 @@ class StringIndexerSuite test("StringIndexerModel can't overwrite output column") { val df = spark.createDataFrame(Seq((1, 2), (3, 4))).toDF("input", "output") + intercept[IllegalArgumentException] { + new StringIndexer() + .setInputCol("input") + .setOutputCol("output") + .fit(df) + } + val indexer = new StringIndexer() .setInputCol("input") - .setOutputCol("output") + .setOutputCol("indexedInput") .fit(df) + intercept[IllegalArgumentException] { - indexer.transform(df) + indexer.setOutputCol("output").transform(df) } } diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/VectorAssemblerSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/VectorAssemblerSuite.scala index 14973e79bf34..561493fbafd6 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/feature/VectorAssemblerSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/feature/VectorAssemblerSuite.scala @@ -74,10 +74,10 @@ class VectorAssemblerSuite val assembler = new VectorAssembler() .setInputCols(Array("a", "b", "c")) .setOutputCol("features") - val thrown = intercept[SparkException] { + val thrown = intercept[IllegalArgumentException] { assembler.transform(df) } - assert(thrown.getMessage contains "VectorAssembler does not support the StringType type") + assert(thrown.getMessage contains "Data type StringType is not supported") } test("ML attributes") { diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/Word2VecSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/Word2VecSuite.scala index 16c74f678587..c8f131153895 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/feature/Word2VecSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/feature/Word2VecSuite.scala @@ -207,5 +207,26 @@ class Word2VecSuite extends SparkFunSuite with MLlibTestSparkContext with Defaul val newInstance = testDefaultReadWrite(instance) assert(newInstance.getVectors.collect() === instance.getVectors.collect()) } + + test("Word2Vec works with input that is non-nullable (NGram)") { + val spark = this.spark + import spark.implicits._ + + val sentence = "a q s t q s t b b b s t m s t m q " + val docDF = sc.parallelize(Seq(sentence, sentence)).map(_.split(" ")).toDF("text") + + val ngram = new NGram().setN(2).setInputCol("text").setOutputCol("ngrams") + val ngramDF = ngram.transform(docDF) + + val model = new Word2Vec() + .setVectorSize(2) + .setInputCol("ngrams") + .setOutputCol("result") + .fit(ngramDF) + + // Just test that this transformation succeeds + model.transform(ngramDF).collect() + } + } diff --git a/mllib/src/test/scala/org/apache/spark/ml/recommendation/ALSSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/recommendation/ALSSuite.scala index e8ed50acf877..d0aa2cdfe0fd 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/recommendation/ALSSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/recommendation/ALSSuite.scala @@ -510,18 +510,18 @@ class ALSSuite (1, 1L, 1d, 0, 0L, 0d, 5.0) ).toDF("user", "user_big", "user_small", "item", "item_big", "item_small", "rating") withClue("fit should fail when ids exceed integer range. ") { - assert(intercept[IllegalArgumentException] { + assert(intercept[SparkException] { als.fit(df.select(df("user_big").as("user"), df("item"), df("rating"))) - }.getMessage.contains("was out of Integer range")) - assert(intercept[IllegalArgumentException] { + }.getCause.getMessage.contains("was out of Integer range")) + assert(intercept[SparkException] { als.fit(df.select(df("user_small").as("user"), df("item"), df("rating"))) - }.getMessage.contains("was out of Integer range")) - assert(intercept[IllegalArgumentException] { + }.getCause.getMessage.contains("was out of Integer range")) + assert(intercept[SparkException] { als.fit(df.select(df("item_big").as("item"), df("user"), df("rating"))) - }.getMessage.contains("was out of Integer range")) - assert(intercept[IllegalArgumentException] { + }.getCause.getMessage.contains("was out of Integer range")) + assert(intercept[SparkException] { als.fit(df.select(df("item_small").as("item"), df("user"), df("rating"))) - }.getMessage.contains("was out of Integer range")) + }.getCause.getMessage.contains("was out of Integer range")) } withClue("transform should fail when ids exceed integer range. ") { val model = als.fit(df) diff --git a/mllib/src/test/scala/org/apache/spark/mllib/feature/Word2VecSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/feature/Word2VecSuite.scala index 22de4c4ac40e..f4fa216b8eba 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/feature/Word2VecSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/feature/Word2VecSuite.scala @@ -18,6 +18,7 @@ package org.apache.spark.mllib.feature import org.apache.spark.SparkFunSuite +import org.apache.spark.mllib.linalg.Vectors import org.apache.spark.mllib.util.MLlibTestSparkContext import org.apache.spark.util.Utils @@ -68,6 +69,21 @@ class Word2VecSuite extends SparkFunSuite with MLlibTestSparkContext { assert(syms(1)._1 == "japan") } + test("findSynonyms doesn't reject similar word vectors when called with a vector") { + val num = 2 + val word2VecMap = Map( + ("china", Array(0.50f, 0.50f, 0.50f, 0.50f)), + ("japan", Array(0.40f, 0.50f, 0.50f, 0.50f)), + ("taiwan", Array(0.60f, 0.50f, 0.50f, 0.50f)), + ("korea", Array(0.45f, 0.60f, 0.60f, 0.60f)) + ) + val model = new Word2VecModel(word2VecMap) + val syms = model.findSynonyms(Vectors.dense(Array(0.52, 0.5, 0.5, 0.5)), num) + assert(syms.length == num) + assert(syms(0)._1 == "china") + assert(syms(1)._1 == "taiwan") + } + test("model load / save") { val word2VecMap = Map( diff --git a/pom.xml b/pom.xml index 9f3d7f003584..79255f9a1f82 100644 --- a/pom.xml +++ b/pom.xml @@ -26,7 +26,7 @@ org.apache.spark spark-parent_2.11 - 2.0.1-SNAPSHOT + 2.0.1 pom Spark Project Parent POM http://spark.apache.org/ @@ -134,7 +134,7 @@ 1.2.1.spark2 1.2.1 - 10.11.1.1 + 10.12.1.1 1.7.0 1.6.0 9.2.16.v20160414 @@ -159,11 +159,9 @@ 3.2.2 2.11.8 2.11 - ${scala.version} - org.scala-lang 1.9.13 2.6.5 - 1.1.2.4 + 1.1.2.6 1.1.2 1.2.0-incubating 1.10 @@ -746,7 +744,6 @@ com.spotify docker-client - shaded 3.6.6 test @@ -1428,6 +1425,10 @@ org.codehaus.groovy groovy-all + + jline + jline + @@ -1832,6 +1833,11 @@ antlr4-runtime ${antlr4.version} + + ${jline.groupid} + jline + ${jline.version} + @@ -2504,7 +2510,7 @@ hadoop-2.7 - 2.7.2 + 2.7.3 0.9.3 3.4.6 2.6.0 @@ -2537,15 +2543,6 @@ ${scala.version} org.scala-lang - - - - ${jline.groupid} - jline - ${jline.version} - - - @@ -2644,6 +2641,8 @@ 2.11.8 2.11 + 2.12.1 + jline diff --git a/project/MimaBuild.scala b/project/MimaBuild.scala index 2a989dd4f7a1..77397eab81ed 100644 --- a/project/MimaBuild.scala +++ b/project/MimaBuild.scala @@ -88,15 +88,8 @@ object MimaBuild { def mimaSettings(sparkHome: File, projectRef: ProjectRef) = { val organization = "org.apache.spark" - val previousSparkVersion = "1.6.0" - // This check can be removed post-2.0 - val project = if (previousSparkVersion == "1.6.0" && - projectRef.project == "streaming-kafka-0-8" - ) { - "streaming-kafka" - } else { - projectRef.project - } + val previousSparkVersion = "2.0.0" + val project = projectRef.project val fullId = "spark-" + project + "_2.11" mimaDefaultSettings ++ Seq(previousArtifact := Some(organization % fullId % previousSparkVersion), diff --git a/project/MimaExcludes.scala b/project/MimaExcludes.scala index 4bd615628859..423cbd465ee9 100644 --- a/project/MimaExcludes.scala +++ b/project/MimaExcludes.scala @@ -42,12 +42,15 @@ object MimaExcludes { Seq( excludePackage("org.apache.spark.rpc"), excludePackage("org.spark-project.jetty"), + excludePackage("org.spark_project.jetty"), + excludePackage("org.apache.spark.internal"), excludePackage("org.apache.spark.unused"), excludePackage("org.apache.spark.unsafe"), excludePackage("org.apache.spark.memory"), excludePackage("org.apache.spark.util.collection.unsafe"), excludePackage("org.apache.spark.sql.catalyst"), excludePackage("org.apache.spark.sql.execution"), + excludePackage("org.apache.spark.sql.internal"), ProblemFilters.exclude[MissingMethodProblem]("org.apache.spark.mllib.feature.PCAModel.this"), ProblemFilters.exclude[MissingMethodProblem]("org.apache.spark.status.api.v1.StageData.this"), ProblemFilters.exclude[MissingMethodProblem]( @@ -777,6 +780,13 @@ object MimaExcludes { ProblemFilters.exclude[IncompatibleResultTypeProblem]("org.apache.spark.sql.SQLContext.jdbc"), ProblemFilters.exclude[IncompatibleResultTypeProblem]("org.apache.spark.sql.SQLContext.parquetFile"), ProblemFilters.exclude[IncompatibleResultTypeProblem]("org.apache.spark.sql.SQLContext.applySchema") + ) ++ Seq( + // SPARK-17096: Improve exception string reported through the StreamingQueryListener + ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.sql.streaming.StreamingQueryListener#QueryTerminated.stackTrace"), + ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.sql.streaming.StreamingQueryListener#QueryTerminated.this") + ) ++ Seq( + // SPARK-16240: ML persistence backward compatibility for LDA + ProblemFilters.exclude[MissingTypesProblem]("org.apache.spark.ml.clustering.LDA$") ) } diff --git a/project/SparkBuild.scala b/project/SparkBuild.scala index b1a9f393423b..133d3b390e91 100644 --- a/project/SparkBuild.scala +++ b/project/SparkBuild.scala @@ -352,7 +352,7 @@ object SparkBuild extends PomBuild { val mimaProjects = allProjects.filterNot { x => Seq( spark, hive, hiveThriftServer, catalyst, repl, networkCommon, networkShuffle, networkYarn, - unsafe, tags, sketch, mllibLocal, streamingKafka010 + unsafe, tags ).contains(x) } diff --git a/python/docs/Makefile b/python/docs/Makefile index 12e397e4507c..de86e97d862f 100644 --- a/python/docs/Makefile +++ b/python/docs/Makefile @@ -7,7 +7,7 @@ SPHINXBUILD ?= sphinx-build PAPER ?= BUILDDIR ?= _build -export PYTHONPATH=$(realpath ..):$(realpath ../lib/py4j-0.10.1-src.zip) +export PYTHONPATH=$(realpath ..):$(realpath ../lib/py4j-0.10.3-src.zip) # User-friendly check for sphinx-build ifeq ($(shell which $(SPHINXBUILD) >/dev/null 2>&1; echo $$?), 1) diff --git a/python/docs/index.rst b/python/docs/index.rst index 306ffdb0e0f1..421c8de86a3c 100644 --- a/python/docs/index.rst +++ b/python/docs/index.rst @@ -50,4 +50,3 @@ Indices and tables ================== * :ref:`search` - diff --git a/python/docs/pyspark.sql.rst b/python/docs/pyspark.sql.rst index 3be9533c126d..09848b880194 100644 --- a/python/docs/pyspark.sql.rst +++ b/python/docs/pyspark.sql.rst @@ -8,14 +8,12 @@ Module Context :members: :undoc-members: - pyspark.sql.types module ------------------------ .. automodule:: pyspark.sql.types :members: :undoc-members: - pyspark.sql.functions module ---------------------------- .. automodule:: pyspark.sql.functions diff --git a/python/lib/py4j-0.10.1-src.zip b/python/lib/py4j-0.10.1-src.zip deleted file mode 100644 index a54bcae03afb..000000000000 Binary files a/python/lib/py4j-0.10.1-src.zip and /dev/null differ diff --git a/python/lib/py4j-0.10.3-src.zip b/python/lib/py4j-0.10.3-src.zip new file mode 100644 index 000000000000..bc54f33af151 Binary files /dev/null and b/python/lib/py4j-0.10.3-src.zip differ diff --git a/python/pyspark/context.py b/python/pyspark/context.py index 6e9f24ef1026..2744bb9ec04e 100644 --- a/python/pyspark/context.py +++ b/python/pyspark/context.py @@ -787,14 +787,6 @@ def addFile(self, path): """ self._jsc.sc().addFile(path) - def clearFiles(self): - """ - Clear the job's list of files added by L{addFile} or L{addPyFile} so - that they do not get downloaded to any new nodes. - """ - # TODO: remove added .py or .zip files from the PYTHONPATH? - self._jsc.sc().clearFiles() - def addPyFile(self, path): """ Add a .py or .zip dependency for all tasks to be executed on this diff --git a/python/pyspark/ml/clustering.py b/python/pyspark/ml/clustering.py index 75d9a0e8cac1..4dab83362a0a 100644 --- a/python/pyspark/ml/clustering.py +++ b/python/pyspark/ml/clustering.py @@ -99,9 +99,9 @@ class GaussianMixture(JavaEstimator, HasFeaturesCol, HasPredictionCol, HasMaxIte +--------------------+--------------------+ | mean| cov| +--------------------+--------------------+ - |[-0.0550000000000...|0.002025000000000...| - |[0.82499999999999...|0.005625000000000...| - |[-0.87,-0.7200000...|0.001600000000000...| + |[0.82500000140229...|0.005625000000006...| + |[-0.4777098016092...|0.167969502720916...| + |[-0.4472625243352...|0.167304119758233...| +--------------------+--------------------+ ... >>> transformed = model.transform(df).select("features", "prediction") @@ -124,9 +124,9 @@ class GaussianMixture(JavaEstimator, HasFeaturesCol, HasPredictionCol, HasMaxIte +--------------------+--------------------+ | mean| cov| +--------------------+--------------------+ - |[-0.0550000000000...|0.002025000000000...| - |[0.82499999999999...|0.005625000000000...| - |[-0.87,-0.7200000...|0.001600000000000...| + |[0.82500000140229...|0.005625000000006...| + |[-0.4777098016092...|0.167969502720916...| + |[-0.4472625243352...|0.167304119758233...| +--------------------+--------------------+ ... diff --git a/python/pyspark/ml/tuning.py b/python/pyspark/ml/tuning.py index 298314d46caf..e17d13d4d2bd 100644 --- a/python/pyspark/ml/tuning.py +++ b/python/pyspark/ml/tuning.py @@ -160,6 +160,8 @@ class CrossValidator(Estimator, ValidatorParams): >>> evaluator = BinaryClassificationEvaluator() >>> cv = CrossValidator(estimator=lr, estimatorParamMaps=grid, evaluator=evaluator) >>> cvModel = cv.fit(dataset) + >>> cvModel.avgMetrics[0] + 0.5 >>> evaluator.evaluate(cvModel.transform(dataset)) 0.8333... @@ -228,7 +230,7 @@ def _fit(self, dataset): model = est.fit(train, epm[j]) # TODO: duplicate evaluator to take extra params from input metric = eva.evaluate(model.transform(validation, epm[j])) - metrics[j] += metric + metrics[j] += metric/nFolds if eva.isLargerBetter(): bestIndex = np.argmax(metrics) diff --git a/python/pyspark/mllib/clustering.py b/python/pyspark/mllib/clustering.py index c8c3c42774f2..29aa61512577 100644 --- a/python/pyspark/mllib/clustering.py +++ b/python/pyspark/mllib/clustering.py @@ -416,7 +416,7 @@ class GaussianMixtureModel(JavaModelWrapper, JavaSaveable, JavaLoader): ... 4.5605, 5.2043, 6.2734]) >>> clusterdata_2 = sc.parallelize(data.reshape(5,3)) >>> model = GaussianMixture.train(clusterdata_2, 2, convergenceTol=0.0001, - ... maxIterations=150, seed=10) + ... maxIterations=150, seed=4) >>> labels = model.predict(clusterdata_2).collect() >>> labels[0]==labels[1] True diff --git a/python/pyspark/mllib/feature.py b/python/pyspark/mllib/feature.py index c8a6e33f4d9a..929531862d18 100644 --- a/python/pyspark/mllib/feature.py +++ b/python/pyspark/mllib/feature.py @@ -545,8 +545,7 @@ def load(cls, sc, path): @ignore_unicode_prefix class Word2Vec(object): - """ - Word2Vec creates vector representation of words in a text corpus. + """Word2Vec creates vector representation of words in a text corpus. The algorithm first constructs a vocabulary from the corpus and then learns vector representation of words in the vocabulary. The vector representation can be used as features in @@ -568,13 +567,19 @@ class Word2Vec(object): >>> doc = sc.parallelize(localDoc).map(lambda line: line.split(" ")) >>> model = Word2Vec().setVectorSize(10).setSeed(42).fit(doc) + Querying for synonyms of a word will not return that word: + >>> syms = model.findSynonyms("a", 2) >>> [s[0] for s in syms] [u'b', u'c'] + + But querying for synonyms of a vector may return the word whose + representation is that vector: + >>> vec = model.transform("a") >>> syms = model.findSynonyms(vec, 2) >>> [s[0] for s in syms] - [u'b', u'c'] + [u'a', u'b'] >>> import os, tempfile >>> path = tempfile.mkdtemp() @@ -592,6 +597,7 @@ class Word2Vec(object): ... pass .. versionadded:: 1.2.0 + """ def __init__(self): """ diff --git a/python/pyspark/mllib/tests.py b/python/pyspark/mllib/tests.py index 99bf50b5a164..3f3dfd186c10 100644 --- a/python/pyspark/mllib/tests.py +++ b/python/pyspark/mllib/tests.py @@ -550,7 +550,7 @@ def test_gmm(self): [-6, -7], ]) clusters = GaussianMixture.train(data, 2, convergenceTol=0.001, - maxIterations=10, seed=56) + maxIterations=10, seed=1) labels = clusters.predict(data).collect() self.assertEqual(labels[0], labels[1]) self.assertEqual(labels[2], labels[3]) diff --git a/python/pyspark/mllib/util.py b/python/pyspark/mllib/util.py index 48867a08dbfa..ed6fd4bca4c5 100644 --- a/python/pyspark/mllib/util.py +++ b/python/pyspark/mllib/util.py @@ -140,8 +140,8 @@ def saveAsLibSVMFile(data, dir): >>> from pyspark.mllib.regression import LabeledPoint >>> from glob import glob >>> from pyspark.mllib.util import MLUtils - >>> examples = [LabeledPoint(1.1, Vectors.sparse(3, [(0, 1.23), (2, 4.56)])), \ - LabeledPoint(0.0, Vectors.dense([1.01, 2.02, 3.03]))] + >>> examples = [LabeledPoint(1.1, Vectors.sparse(3, [(0, 1.23), (2, 4.56)])), + ... LabeledPoint(0.0, Vectors.dense([1.01, 2.02, 3.03]))] >>> tempFile = NamedTemporaryFile(delete=True) >>> tempFile.close() >>> MLUtils.saveAsLibSVMFile(sc.parallelize(examples), tempFile.name) @@ -166,8 +166,8 @@ def loadLabeledPoints(sc, path, minPartitions=None): >>> from tempfile import NamedTemporaryFile >>> from pyspark.mllib.util import MLUtils >>> from pyspark.mllib.regression import LabeledPoint - >>> examples = [LabeledPoint(1.1, Vectors.sparse(3, [(0, -1.23), (2, 4.56e-7)])), \ - LabeledPoint(0.0, Vectors.dense([1.01, 2.02, 3.03]))] + >>> examples = [LabeledPoint(1.1, Vectors.sparse(3, [(0, -1.23), (2, 4.56e-7)])), + ... LabeledPoint(0.0, Vectors.dense([1.01, 2.02, 3.03]))] >>> tempFile = NamedTemporaryFile(delete=True) >>> tempFile.close() >>> sc.parallelize(examples, 1).saveAsTextFile(tempFile.name) diff --git a/python/pyspark/rdd.py b/python/pyspark/rdd.py index 0508235c1c9e..5fb10f86f469 100644 --- a/python/pyspark/rdd.py +++ b/python/pyspark/rdd.py @@ -754,8 +754,8 @@ def foreachPartition(self, f): Applies a function to each partition of this RDD. >>> def f(iterator): - ... for x in iterator: - ... print(x) + ... for x in iterator: + ... print(x) >>> sc.parallelize([1, 2, 3, 4, 5]).foreachPartition(f) """ def func(it): diff --git a/python/pyspark/shell.py b/python/pyspark/shell.py index ac5ce87a3f0f..f54892e1ad5a 100644 --- a/python/pyspark/shell.py +++ b/python/pyspark/shell.py @@ -15,6 +15,25 @@ # limitations under the License. # +# +# Changes for SnappyData data platform. +# +# Portions Copyright (c) 2016 SnappyData, Inc. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); you +# may not use this file except in compliance with the License. You +# may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or +# implied. See the License for the specific language governing +# permissions and limitations under the License. See accompanying +# LICENSE file. +# + """ An interactive shell. @@ -30,6 +49,7 @@ import pyspark from pyspark.context import SparkContext from pyspark.sql import SparkSession, SQLContext +from pyspark.sql.snappy import SnappyContext from pyspark.storagelevel import StorageLevel if os.environ.get("SPARK_EXECUTOR_URI"): @@ -38,6 +58,8 @@ SparkContext._ensure_initialized() try: + sqlContext = SnappyContext(sc) +except py4j.protocol.Py4JError: # Try to access HiveConf, it will raise exception if Hive is not added SparkContext._jvm.org.apache.hadoop.hive.conf.HiveConf() spark = SparkSession.builder\ diff --git a/python/pyspark/sql/__init__.py b/python/pyspark/sql/__init__.py index cff73ff192e5..22ec416f6c58 100644 --- a/python/pyspark/sql/__init__.py +++ b/python/pyspark/sql/__init__.py @@ -18,7 +18,7 @@ """ Important classes of Spark SQL and DataFrames: - - :class:`pyspark.sql.SQLContext` + - :class:`pyspark.sql.SparkSession` Main entry point for :class:`DataFrame` and SQL functionality. - :class:`pyspark.sql.DataFrame` A distributed collection of data grouped into named columns. @@ -26,8 +26,6 @@ A column expression in a :class:`DataFrame`. - :class:`pyspark.sql.Row` A row of data in a :class:`DataFrame`. - - :class:`pyspark.sql.HiveContext` - Main entry point for accessing data stored in Apache Hive. - :class:`pyspark.sql.GroupedData` Aggregation methods, returned by :func:`DataFrame.groupBy`. - :class:`pyspark.sql.DataFrameNaFunctions` @@ -45,7 +43,7 @@ from pyspark.sql.types import Row -from pyspark.sql.context import SQLContext, HiveContext +from pyspark.sql.context import SQLContext, HiveContext, UDFRegistration from pyspark.sql.session import SparkSession from pyspark.sql.column import Column from pyspark.sql.dataframe import DataFrame, DataFrameNaFunctions, DataFrameStatFunctions @@ -55,7 +53,8 @@ __all__ = [ - 'SparkSession', 'SQLContext', 'HiveContext', 'DataFrame', 'GroupedData', 'Column', - 'Row', 'DataFrameNaFunctions', 'DataFrameStatFunctions', 'Window', 'WindowSpec', + 'SparkSession', 'SQLContext', 'HiveContext', 'UDFRegistration', + 'DataFrame', 'GroupedData', 'Column', 'Row', + 'DataFrameNaFunctions', 'DataFrameStatFunctions', 'Window', 'WindowSpec', 'DataFrameReader', 'DataFrameWriter' ] diff --git a/python/pyspark/sql/catalog.py b/python/pyspark/sql/catalog.py index 4af930a3cd56..3c5030722f30 100644 --- a/python/pyspark/sql/catalog.py +++ b/python/pyspark/sql/catalog.py @@ -193,7 +193,7 @@ def registerFunction(self, name, f, returnType=StringType()): :param name: name of the UDF :param f: python function - :param returnType: a :class:`DataType` object + :param returnType: a :class:`pyspark.sql.types.DataType` object >>> spark.catalog.registerFunction("stringLengthString", lambda x: len(x)) >>> spark.sql("SELECT stringLengthString('test')").collect() diff --git a/python/pyspark/sql/context.py b/python/pyspark/sql/context.py index 4cfdf799f6f4..8cdf37188e66 100644 --- a/python/pyspark/sql/context.py +++ b/python/pyspark/sql/context.py @@ -152,9 +152,9 @@ def udf(self): @since(1.4) def range(self, start, end=None, step=1, numPartitions=None): """ - Create a :class:`DataFrame` with single LongType column named `id`, - containing elements in a range from `start` to `end` (exclusive) with - step value `step`. + Create a :class:`DataFrame` with single :class:`pyspark.sql.types.LongType` column named + ``id``, containing elements in a range from ``start`` to ``end`` (exclusive) with + step value ``step``. :param start: the start value :param end: the end value (exclusive) @@ -184,7 +184,7 @@ def registerFunction(self, name, f, returnType=StringType()): :param name: name of the UDF :param f: python function - :param returnType: a :class:`DataType` object + :param returnType: a :class:`pyspark.sql.types.DataType` object >>> sqlContext.registerFunction("stringLengthString", lambda x: len(x)) >>> sqlContext.sql("SELECT stringLengthString('test')").collect() @@ -209,13 +209,13 @@ def _inferSchema(self, rdd, samplingRatio=None): :param rdd: an RDD of Row or tuple :param samplingRatio: sampling ratio, or no sampling (default) - :return: StructType + :return: :class:`pyspark.sql.types.StructType` """ return self.sparkSession._inferSchema(rdd, samplingRatio) @since(1.3) @ignore_unicode_prefix - def createDataFrame(self, data, schema=None, samplingRatio=None): + def createDataFrame(self, data, schema=None, samplingRatio=None, verifySchema=True): """ Creates a :class:`DataFrame` from an :class:`RDD`, a list or a :class:`pandas.DataFrame`. @@ -226,28 +226,38 @@ def createDataFrame(self, data, schema=None, samplingRatio=None): from ``data``, which should be an RDD of :class:`Row`, or :class:`namedtuple`, or :class:`dict`. - When ``schema`` is :class:`DataType` or datatype string, it must match the real data, or - exception will be thrown at runtime. If the given schema is not StructType, it will be - wrapped into a StructType as its only field, and the field name will be "value", each record - will also be wrapped into a tuple, which can be converted to row later. + When ``schema`` is :class:`pyspark.sql.types.DataType` or + :class:`pyspark.sql.types.StringType`, it must match the + real data, or an exception will be thrown at runtime. If the given schema is not + :class:`pyspark.sql.types.StructType`, it will be wrapped into a + :class:`pyspark.sql.types.StructType` as its only field, and the field name will be "value", + each record will also be wrapped into a tuple, which can be converted to row later. If schema inference is needed, ``samplingRatio`` is used to determined the ratio of rows used for schema inference. The first row will be used if ``samplingRatio`` is ``None``. - :param data: an RDD of any kind of SQL data representation(e.g. row, tuple, int, boolean, - etc.), or :class:`list`, or :class:`pandas.DataFrame`. - :param schema: a :class:`DataType` or a datatype string or a list of column names, default - is None. The data type string format equals to `DataType.simpleString`, except that - top level struct type can omit the `struct<>` and atomic types use `typeName()` as - their format, e.g. use `byte` instead of `tinyint` for ByteType. We can also use `int` - as a short name for IntegerType. + :param data: an RDD of any kind of SQL data representation(e.g. :class:`Row`, + :class:`tuple`, ``int``, ``boolean``, etc.), or :class:`list`, or + :class:`pandas.DataFrame`. + :param schema: a :class:`pyspark.sql.types.DataType` or a + :class:`pyspark.sql.types.StringType` or a list of + column names, default is None. The data type string format equals to + :class:`pyspark.sql.types.DataType.simpleString`, except that top level struct type can + omit the ``struct<>`` and atomic types use ``typeName()`` as their format, e.g. use + ``byte`` instead of ``tinyint`` for :class:`pyspark.sql.types.ByteType`. + We can also use ``int`` as a short name for :class:`pyspark.sql.types.IntegerType`. :param samplingRatio: the sample ratio of rows used for inferring + :param verifySchema: verify data types of every row against schema. :return: :class:`DataFrame` .. versionchanged:: 2.0 - The schema parameter can be a DataType or a datatype string after 2.0. If it's not a - StructType, it will be wrapped into a StructType and each record will also be wrapped - into a tuple. + The ``schema`` parameter can be a :class:`pyspark.sql.types.DataType` or a + :class:`pyspark.sql.types.StringType` after 2.0. + If it's not a :class:`pyspark.sql.types.StructType`, it will be wrapped into a + :class:`pyspark.sql.types.StructType` and each record will also be wrapped into a tuple. + + .. versionchanged:: 2.0.1 + Added verifySchema. >>> l = [('Alice', 1)] >>> sqlContext.createDataFrame(l).collect() @@ -296,7 +306,7 @@ def createDataFrame(self, data, schema=None, samplingRatio=None): ... Py4JJavaError: ... """ - return self.sparkSession.createDataFrame(data, schema, samplingRatio) + return self.sparkSession.createDataFrame(data, schema, samplingRatio, verifySchema) @since(1.3) def registerDataFrameAsTable(self, df, tableName): diff --git a/python/pyspark/sql/dataframe.py b/python/pyspark/sql/dataframe.py index 6aff93835495..64d7a2075743 100644 --- a/python/pyspark/sql/dataframe.py +++ b/python/pyspark/sql/dataframe.py @@ -61,7 +61,7 @@ class DataFrame(object): people = sqlContext.read.parquet("...") department = sqlContext.read.parquet("...") - people.filter(people.age > 30).join(department, people.deptId == department.id)\ + people.filter(people.age > 30).join(department, people.deptId == department.id) \\ .groupBy(department.name, "gender").agg({"salary": "avg", "age": "max"}) .. versionadded:: 1.3 @@ -196,7 +196,7 @@ def writeStream(self): @property @since(1.3) def schema(self): - """Returns the schema of this :class:`DataFrame` as a :class:`types.StructType`. + """Returns the schema of this :class:`DataFrame` as a :class:`pyspark.sql.types.StructType`. >>> df.schema StructType(List(StructField(age,IntegerType,true),StructField(name,StringType,true))) @@ -345,10 +345,7 @@ def take(self, num): >>> df.take(2) [Row(age=2, name=u'Alice'), Row(age=5, name=u'Bob')] """ - with SCCallSiteSync(self._sc) as css: - port = self._sc._jvm.org.apache.spark.sql.execution.python.EvaluatePython.takeAndServe( - self._jdf, num) - return list(_load_from_socket(port, BatchedSerializer(PickleSerializer()))) + return self.limit(num).collect() @since(1.3) def foreach(self, f): diff --git a/python/pyspark/sql/functions.py b/python/pyspark/sql/functions.py index 92d709ee40e1..4ea83e24bbc9 100644 --- a/python/pyspark/sql/functions.py +++ b/python/pyspark/sql/functions.py @@ -142,7 +142,7 @@ def _(): _binary_mathfunctions = { 'atan2': 'Returns the angle theta from the conversion of rectangular coordinates (x, y) to' + 'polar coordinates (r, theta).', - 'hypot': 'Computes `sqrt(a^2 + b^2)` without intermediate overflow or underflow.', + 'hypot': 'Computes ``sqrt(a^2 + b^2)`` without intermediate overflow or underflow.', 'pow': 'Returns the value of the first argument raised to the power of the second argument.', } @@ -958,7 +958,8 @@ def months_between(date1, date2): @since(1.5) def to_date(col): """ - Converts the column of StringType or TimestampType into DateType. + Converts the column of :class:`pyspark.sql.types.StringType` or + :class:`pyspark.sql.types.TimestampType` into :class:`pyspark.sql.types.DateType`. >>> df = spark.createDataFrame([('1997-02-28 10:30:00',)], ['t']) >>> df.select(to_date(df.t).alias('date')).collect() @@ -1074,18 +1075,18 @@ def window(timeColumn, windowDuration, slideDuration=None, startTime=None): [12:05,12:10) but not in [12:00,12:05). Windows can support microsecond precision. Windows in the order of months are not supported. - The time column must be of TimestampType. + The time column must be of :class:`pyspark.sql.types.TimestampType`. Durations are provided as strings, e.g. '1 second', '1 day 12 hours', '2 minutes'. Valid interval strings are 'week', 'day', 'hour', 'minute', 'second', 'millisecond', 'microsecond'. - If the `slideDuration` is not provided, the windows will be tumbling windows. + If the ``slideDuration`` is not provided, the windows will be tumbling windows. The startTime is the offset with respect to 1970-01-01 00:00:00 UTC with which to start window intervals. For example, in order to have hourly tumbling windows that start 15 minutes past the hour, e.g. 12:15-13:15, 13:15-14:15... provide `startTime` as `15 minutes`. The output column will be a struct called 'window' by default with the nested columns 'start' - and 'end', where 'start' and 'end' will be of `TimestampType`. + and 'end', where 'start' and 'end' will be of :class:`pyspark.sql.types.TimestampType`. >>> df = spark.createDataFrame([("2016-03-11 09:00:07", 1)]).toDF("date", "val") >>> w = df.groupBy(window("date", "5 seconds")).agg(sum("val").alias("sum")) @@ -1367,7 +1368,7 @@ def locate(substr, str, pos=1): could not be found in str. :param substr: a string - :param str: a Column of StringType + :param str: a Column of :class:`pyspark.sql.types.StringType` :param pos: start position (zero based) >>> df = spark.createDataFrame([('abcd',)], ['s',]) @@ -1439,11 +1440,18 @@ def split(str, pattern): @ignore_unicode_prefix @since(1.5) def regexp_extract(str, pattern, idx): - """Extract a specific(idx) group identified by a java regex, from the specified string column. + """Extract a specific group matched by a Java regex, from the specified string column. + If the regex did not match, or the specified group did not match, an empty string is returned. >>> df = spark.createDataFrame([('100-200',)], ['str']) >>> df.select(regexp_extract('str', '(\d+)-(\d+)', 1).alias('d')).collect() [Row(d=u'100')] + >>> df = spark.createDataFrame([('foo',)], ['str']) + >>> df.select(regexp_extract('str', '(\d+)', 1).alias('d')).collect() + [Row(d=u'')] + >>> df = spark.createDataFrame([('aaaac',)], ['str']) + >>> df.select(regexp_extract('str', '(a+)(b)?(c)', 2).alias('d')).collect() + [Row(d=u'')] """ sc = SparkContext._active_spark_context jc = sc._jvm.functions.regexp_extract(_to_java_column(str), pattern, idx) @@ -1506,8 +1514,9 @@ def bin(col): @ignore_unicode_prefix @since(1.5) def hex(col): - """Computes hex value of the given column, which could be StringType, - BinaryType, IntegerType or LongType. + """Computes hex value of the given column, which could be :class:`pyspark.sql.types.StringType`, + :class:`pyspark.sql.types.BinaryType`, :class:`pyspark.sql.types.IntegerType` or + :class:`pyspark.sql.types.LongType`. >>> spark.createDataFrame([('ABC', 3)], ['a', 'b']).select(hex('a'), hex('b')).collect() [Row(hex(a)=u'414243', hex(b)=u'3')] @@ -1781,6 +1790,9 @@ def udf(f, returnType=StringType()): duplicate invocations may be eliminated or the function may even be invoked more times than it is present in the query. + :param f: python function + :param returnType: a :class:`pyspark.sql.types.DataType` object + >>> from pyspark.sql.types import IntegerType >>> slen = udf(lambda s: len(s), IntegerType()) >>> df.select(slen(df.name).alias('slen')).collect() diff --git a/python/pyspark/sql/readwriter.py b/python/pyspark/sql/readwriter.py index f7c354f51330..dc13a818fcbf 100644 --- a/python/pyspark/sql/readwriter.py +++ b/python/pyspark/sql/readwriter.py @@ -96,7 +96,7 @@ def schema(self, schema): By specifying the schema here, the underlying data source can skip the schema inference step, and thus speed up data loading. - :param schema: a StructType object + :param schema: a :class:`pyspark.sql.types.StructType` object """ if not isinstance(schema, StructType): raise TypeError("schema should be StructType") @@ -125,7 +125,7 @@ def load(self, path=None, format=None, schema=None, **options): :param path: optional string or a list of string for file-system backed data sources. :param format: optional string for format of the data source. Default to 'parquet'. - :param schema: optional :class:`StructType` for the input schema. + :param schema: optional :class:`pyspark.sql.types.StructType` for the input schema. :param options: all other string options >>> df = spark.read.load('python/test_support/sql/parquet_partitioned', opt1=True, @@ -156,7 +156,7 @@ def load(self, path=None, format=None, schema=None, **options): def json(self, path, schema=None, primitivesAsString=None, prefersDecimal=None, allowComments=None, allowUnquotedFieldNames=None, allowSingleQuotes=None, allowNumericLeadingZero=None, allowBackslashEscapingAnyCharacter=None, - mode=None, columnNameOfCorruptRecord=None): + mode=None, columnNameOfCorruptRecord=None, dateFormat=None, timestampFormat=None): """ Loads a JSON file (one object per line) or an RDD of Strings storing JSON objects (one object per record) and returns the result as a :class`DataFrame`. @@ -166,7 +166,7 @@ def json(self, path, schema=None, primitivesAsString=None, prefersDecimal=None, :param path: string represents path to the JSON dataset, or RDD of Strings storing JSON objects. - :param schema: an optional :class:`StructType` for the input schema. + :param schema: an optional :class:`pyspark.sql.types.StructType` for the input schema. :param primitivesAsString: infers all primitive values as a string type. If None is set, it uses the default value, ``false``. :param prefersDecimal: infers all floating-point values as a decimal type. If the values @@ -198,6 +198,14 @@ def json(self, path, schema=None, primitivesAsString=None, prefersDecimal=None, ``spark.sql.columnNameOfCorruptRecord``. If None is set, it uses the value specified in ``spark.sql.columnNameOfCorruptRecord``. + :param dateFormat: sets the string that indicates a date format. Custom date formats + follow the formats at ``java.text.SimpleDateFormat``. This + applies to date type. If None is set, it uses the + default value value, ``yyyy-MM-dd``. + :param timestampFormat: sets the string that indicates a timestamp format. Custom date + formats follow the formats at ``java.text.SimpleDateFormat``. + This applies to timestamp type. If None is set, it uses the + default value value, ``yyyy-MM-dd'T'HH:mm:ss.SSSZZ``. >>> df1 = spark.read.json('python/test_support/sql/people.json') >>> df1.dtypes @@ -213,7 +221,8 @@ def json(self, path, schema=None, primitivesAsString=None, prefersDecimal=None, allowComments=allowComments, allowUnquotedFieldNames=allowUnquotedFieldNames, allowSingleQuotes=allowSingleQuotes, allowNumericLeadingZero=allowNumericLeadingZero, allowBackslashEscapingAnyCharacter=allowBackslashEscapingAnyCharacter, - mode=mode, columnNameOfCorruptRecord=columnNameOfCorruptRecord) + mode=mode, columnNameOfCorruptRecord=columnNameOfCorruptRecord, dateFormat=dateFormat, + timestampFormat=timestampFormat) if isinstance(path, basestring): path = [path] if type(path) == list: @@ -285,8 +294,8 @@ def text(self, paths): def csv(self, path, schema=None, sep=None, encoding=None, quote=None, escape=None, comment=None, header=None, inferSchema=None, ignoreLeadingWhiteSpace=None, ignoreTrailingWhiteSpace=None, nullValue=None, nanValue=None, positiveInf=None, - negativeInf=None, dateFormat=None, maxColumns=None, maxCharsPerColumn=None, - maxMalformedLogPerPartition=None, mode=None): + negativeInf=None, dateFormat=None, timestampFormat=None, maxColumns=None, + maxCharsPerColumn=None, maxMalformedLogPerPartition=None, mode=None): """Loads a CSV file and returns the result as a :class:`DataFrame`. This function will go through the input once to determine the input schema if @@ -294,7 +303,7 @@ def csv(self, path, schema=None, sep=None, encoding=None, quote=None, escape=Non ``inferSchema`` option or specify the schema explicitly using ``schema``. :param path: string, or list of strings, for input path(s). - :param schema: an optional :class:`StructType` for the input schema. + :param schema: an optional :class:`pyspark.sql.types.StructType` for the input schema. :param sep: sets the single character as a separator for each field and value. If None is set, it uses the default value, ``,``. :param encoding: decodes the CSV files by the given encoding type. If None is set, @@ -318,7 +327,8 @@ def csv(self, path, schema=None, sep=None, encoding=None, quote=None, escape=Non being read should be skipped. If None is set, it uses the default value, ``false``. :param nullValue: sets the string representation of a null value. If None is set, it uses - the default value, empty string. + the default value, empty string. Since 2.0.1, this ``nullValue`` param + applies to all supported types including the string type. :param nanValue: sets the string representation of a non-number value. If None is set, it uses the default value, ``NaN``. :param positiveInf: sets the string representation of a positive infinity value. If None @@ -327,9 +337,12 @@ def csv(self, path, schema=None, sep=None, encoding=None, quote=None, escape=Non is set, it uses the default value, ``Inf``. :param dateFormat: sets the string that indicates a date format. Custom date formats follow the formats at ``java.text.SimpleDateFormat``. This - applies to both date type and timestamp type. By default, it is None - which means trying to parse times and date by - ``java.sql.Timestamp.valueOf()`` and ``java.sql.Date.valueOf()``. + applies to date type. If None is set, it uses the + default value value, ``yyyy-MM-dd``. + :param timestampFormat: sets the string that indicates a timestamp format. Custom date + formats follow the formats at ``java.text.SimpleDateFormat``. + This applies to timestamp type. If None is set, it uses the + default value value, ``yyyy-MM-dd'T'HH:mm:ss.SSSZZ``. :param maxColumns: defines a hard limit of how many columns a record can have. If None is set, it uses the default value, ``20480``. :param maxCharsPerColumn: defines the maximum number of characters allowed for any given @@ -356,7 +369,8 @@ def csv(self, path, schema=None, sep=None, encoding=None, quote=None, escape=Non header=header, inferSchema=inferSchema, ignoreLeadingWhiteSpace=ignoreLeadingWhiteSpace, ignoreTrailingWhiteSpace=ignoreTrailingWhiteSpace, nullValue=nullValue, nanValue=nanValue, positiveInf=positiveInf, negativeInf=negativeInf, - dateFormat=dateFormat, maxColumns=maxColumns, maxCharsPerColumn=maxCharsPerColumn, + dateFormat=dateFormat, timestampFormat=timestampFormat, maxColumns=maxColumns, + maxCharsPerColumn=maxCharsPerColumn, maxMalformedLogPerPartition=maxMalformedLogPerPartition, mode=mode) if isinstance(path, basestring): path = [path] @@ -401,8 +415,9 @@ def jdbc(self, url, table, column=None, lowerBound=None, upperBound=None, numPar :param numPartitions: the number of partitions :param predicates: a list of expressions suitable for inclusion in WHERE clauses; each one defines one partition of the :class:`DataFrame` - :param properties: a dictionary of JDBC database connection arguments; normally, - at least a "user" and "password" property should be included + :param properties: a dictionary of JDBC database connection arguments. Normally at + least properties "user" and "password" with their corresponding values. + For example { 'user' : 'SYSTEM', 'password' : 'mypassword' } :return: a DataFrame """ if properties is None: @@ -570,7 +585,7 @@ def saveAsTable(self, name, format=None, mode=None, partitionBy=None, **options) self._jwrite.saveAsTable(name) @since(1.4) - def json(self, path, mode=None, compression=None): + def json(self, path, mode=None, compression=None, dateFormat=None, timestampFormat=None): """Saves the content of the :class:`DataFrame` in JSON format at the specified path. :param path: the path in any Hadoop supported file system @@ -583,11 +598,20 @@ def json(self, path, mode=None, compression=None): :param compression: compression codec to use when saving to file. This can be one of the known case-insensitive shorten names (none, bzip2, gzip, lz4, snappy and deflate). + :param dateFormat: sets the string that indicates a date format. Custom date formats + follow the formats at ``java.text.SimpleDateFormat``. This + applies to date type. If None is set, it uses the + default value value, ``yyyy-MM-dd``. + :param timestampFormat: sets the string that indicates a timestamp format. Custom date + formats follow the formats at ``java.text.SimpleDateFormat``. + This applies to timestamp type. If None is set, it uses the + default value value, ``yyyy-MM-dd'T'HH:mm:ss.SSSZZ``. >>> df.write.json(os.path.join(tempfile.mkdtemp(), 'data')) """ self.mode(mode) - self._set_opts(compression=compression) + self._set_opts( + compression=compression, dateFormat=dateFormat, timestampFormat=timestampFormat) self._jwrite.json(path) @since(1.4) @@ -633,7 +657,8 @@ def text(self, path, compression=None): @since(2.0) def csv(self, path, mode=None, compression=None, sep=None, quote=None, escape=None, - header=None, nullValue=None, escapeQuotes=None, quoteAll=None): + header=None, nullValue=None, escapeQuotes=None, quoteAll=None, dateFormat=None, + timestampFormat=None): """Saves the content of the :class:`DataFrame` in CSV format at the specified path. :param path: the path in any Hadoop supported file system @@ -665,12 +690,21 @@ def csv(self, path, mode=None, compression=None, sep=None, quote=None, escape=No the default value, ``false``. :param nullValue: sets the string representation of a null value. If None is set, it uses the default value, empty string. + :param dateFormat: sets the string that indicates a date format. Custom date formats + follow the formats at ``java.text.SimpleDateFormat``. This + applies to date type. If None is set, it uses the + default value value, ``yyyy-MM-dd``. + :param timestampFormat: sets the string that indicates a timestamp format. Custom date + formats follow the formats at ``java.text.SimpleDateFormat``. + This applies to timestamp type. If None is set, it uses the + default value value, ``yyyy-MM-dd'T'HH:mm:ss.SSSZZ``. >>> df.write.csv(os.path.join(tempfile.mkdtemp(), 'data')) """ self.mode(mode) self._set_opts(compression=compression, sep=sep, quote=quote, escape=escape, header=header, - nullValue=nullValue, escapeQuotes=escapeQuotes, quoteAll=quoteAll) + nullValue=nullValue, escapeQuotes=escapeQuotes, quoteAll=quoteAll, + dateFormat=dateFormat, timestampFormat=timestampFormat) self._jwrite.csv(path) @since(1.5) @@ -716,9 +750,9 @@ def jdbc(self, url, table, mode=None, properties=None): * ``overwrite``: Overwrite existing data. * ``ignore``: Silently ignore this operation if data already exists. * ``error`` (default case): Throw an exception if data already exists. - :param properties: JDBC database connection arguments, a list of - arbitrary string tag/value. Normally at least a - "user" and "password" property should be included. + :param properties: a dictionary of JDBC database connection arguments. Normally at + least properties "user" and "password" with their corresponding values. + For example { 'user' : 'SYSTEM', 'password' : 'mypassword' } """ if properties is None: properties = dict() diff --git a/python/pyspark/sql/session.py b/python/pyspark/sql/session.py index 594f9375f767..d25823dfcacd 100644 --- a/python/pyspark/sql/session.py +++ b/python/pyspark/sql/session.py @@ -47,7 +47,7 @@ def toDF(self, schema=None, sampleRatio=None): This is a shorthand for ``spark.createDataFrame(rdd, schema, sampleRatio)`` - :param schema: a StructType or list of names of columns + :param schema: a :class:`pyspark.sql.types.StructType` or list of names of columns :param samplingRatio: the sample ratio of rows used for inferring :return: a DataFrame @@ -274,9 +274,9 @@ def udf(self): @since(2.0) def range(self, start, end=None, step=1, numPartitions=None): """ - Create a :class:`DataFrame` with single LongType column named `id`, - containing elements in a range from `start` to `end` (exclusive) with - step value `step`. + Create a :class:`DataFrame` with single :class:`pyspark.sql.types.LongType` column named + ``id``, containing elements in a range from ``start`` to ``end`` (exclusive) with + step value ``step``. :param start: the start value :param end: the end value (exclusive) @@ -307,7 +307,7 @@ def _inferSchemaFromList(self, data): Infer schema from list of Row or tuple. :param data: list of Row or tuple - :return: StructType + :return: :class:`pyspark.sql.types.StructType` """ if not data: raise ValueError("can not infer schema from empty dataset") @@ -326,7 +326,7 @@ def _inferSchema(self, rdd, samplingRatio=None): :param rdd: an RDD of Row or tuple :param samplingRatio: sampling ratio, or no sampling (default) - :return: StructType + :return: :class:`pyspark.sql.types.StructType` """ first = rdd.first() if not first: @@ -384,17 +384,15 @@ def _createFromLocal(self, data, schema): if schema is None or isinstance(schema, (list, tuple)): struct = self._inferSchemaFromList(data) + converter = _create_converter(struct) + data = map(converter, data) if isinstance(schema, (list, tuple)): for i, name in enumerate(schema): struct.fields[i].name = name struct.names[i] = name schema = struct - elif isinstance(schema, StructType): - for row in data: - _verify_type(row, schema) - - else: + elif not isinstance(schema, StructType): raise TypeError("schema should be StructType or list or None, but got: %s" % schema) # convert python objects to sql data @@ -403,7 +401,7 @@ def _createFromLocal(self, data, schema): @since(2.0) @ignore_unicode_prefix - def createDataFrame(self, data, schema=None, samplingRatio=None): + def createDataFrame(self, data, schema=None, samplingRatio=None, verifySchema=True): """ Creates a :class:`DataFrame` from an :class:`RDD`, a list or a :class:`pandas.DataFrame`. @@ -414,28 +412,31 @@ def createDataFrame(self, data, schema=None, samplingRatio=None): from ``data``, which should be an RDD of :class:`Row`, or :class:`namedtuple`, or :class:`dict`. - When ``schema`` is :class:`DataType` or datatype string, it must match the real data, or - exception will be thrown at runtime. If the given schema is not StructType, it will be - wrapped into a StructType as its only field, and the field name will be "value", each record - will also be wrapped into a tuple, which can be converted to row later. + When ``schema`` is :class:`pyspark.sql.types.DataType` or + :class:`pyspark.sql.types.StringType`, it must match the + real data, or an exception will be thrown at runtime. If the given schema is not + :class:`pyspark.sql.types.StructType`, it will be wrapped into a + :class:`pyspark.sql.types.StructType` as its only field, and the field name will be "value", + each record will also be wrapped into a tuple, which can be converted to row later. If schema inference is needed, ``samplingRatio`` is used to determined the ratio of rows used for schema inference. The first row will be used if ``samplingRatio`` is ``None``. :param data: an RDD of any kind of SQL data representation(e.g. row, tuple, int, boolean, etc.), or :class:`list`, or :class:`pandas.DataFrame`. - :param schema: a :class:`DataType` or a datatype string or a list of column names, default - is None. The data type string format equals to `DataType.simpleString`, except that - top level struct type can omit the `struct<>` and atomic types use `typeName()` as - their format, e.g. use `byte` instead of `tinyint` for ByteType. We can also use `int` - as a short name for IntegerType. + :param schema: a :class:`pyspark.sql.types.DataType` or a + :class:`pyspark.sql.types.StringType` or a list of + column names, default is ``None``. The data type string format equals to + :class:`pyspark.sql.types.DataType.simpleString`, except that top level struct type can + omit the ``struct<>`` and atomic types use ``typeName()`` as their format, e.g. use + ``byte`` instead of ``tinyint`` for :class:`pyspark.sql.types.ByteType`. We can also use + ``int`` as a short name for ``IntegerType``. :param samplingRatio: the sample ratio of rows used for inferring + :param verifySchema: verify data types of every row against schema. :return: :class:`DataFrame` - .. versionchanged:: 2.0 - The schema parameter can be a DataType or a datatype string after 2.0. If it's not a - StructType, it will be wrapped into a StructType and each record will also be wrapped - into a tuple. + .. versionchanged:: 2.0.1 + Added verifySchema. >>> l = [('Alice', 1)] >>> spark.createDataFrame(l).collect() @@ -500,17 +501,18 @@ def createDataFrame(self, data, schema=None, samplingRatio=None): schema = [str(x) for x in data.columns] data = [r.tolist() for r in data.to_records(index=False)] + verify_func = _verify_type if verifySchema else lambda _, t: True if isinstance(schema, StructType): def prepare(obj): - _verify_type(obj, schema) + verify_func(obj, schema) return obj elif isinstance(schema, DataType): - datatype = schema + dataType = schema + schema = StructType().add("value", schema) def prepare(obj): - _verify_type(obj, datatype) - return (obj, ) - schema = StructType().add("value", datatype) + verify_func(obj, dataType) + return obj, else: if isinstance(schema, list): schema = [x.encode('utf-8') if not isinstance(x, str) else x for x in schema] @@ -595,6 +597,7 @@ def stop(self): """Stop the underlying :class:`SparkContext`. """ self._sc.stop() + SparkSession._instantiatedContext = None @since(2.0) def __enter__(self): diff --git a/python/pyspark/sql/streaming.py b/python/pyspark/sql/streaming.py index 8bac347e1308..118a02b6786b 100644 --- a/python/pyspark/sql/streaming.py +++ b/python/pyspark/sql/streaming.py @@ -269,7 +269,7 @@ def schema(self, schema): .. note:: Experimental. - :param schema: a StructType object + :param schema: a :class:`pyspark.sql.types.StructType` object >>> s = spark.readStream.schema(sdf_schema) """ @@ -310,12 +310,12 @@ def load(self, path=None, format=None, schema=None, **options): :param path: optional string for file-system backed data sources. :param format: optional string for format of the data source. Default to 'parquet'. - :param schema: optional :class:`StructType` for the input schema. + :param schema: optional :class:`pyspark.sql.types.StructType` for the input schema. :param options: all other string options - >>> json_sdf = spark.readStream.format("json")\ - .schema(sdf_schema)\ - .load(tempfile.mkdtemp()) + >>> json_sdf = spark.readStream.format("json") \\ + ... .schema(sdf_schema) \\ + ... .load(tempfile.mkdtemp()) >>> json_sdf.isStreaming True >>> json_sdf.schema == sdf_schema @@ -338,7 +338,8 @@ def load(self, path=None, format=None, schema=None, **options): def json(self, path, schema=None, primitivesAsString=None, prefersDecimal=None, allowComments=None, allowUnquotedFieldNames=None, allowSingleQuotes=None, allowNumericLeadingZero=None, allowBackslashEscapingAnyCharacter=None, - mode=None, columnNameOfCorruptRecord=None): + mode=None, columnNameOfCorruptRecord=None, dateFormat=None, + timestampFormat=None): """ Loads a JSON file stream (one object per line) and returns a :class`DataFrame`. @@ -349,7 +350,7 @@ def json(self, path, schema=None, primitivesAsString=None, prefersDecimal=None, :param path: string represents path to the JSON dataset, or RDD of Strings storing JSON objects. - :param schema: an optional :class:`StructType` for the input schema. + :param schema: an optional :class:`pyspark.sql.types.StructType` for the input schema. :param primitivesAsString: infers all primitive values as a string type. If None is set, it uses the default value, ``false``. :param prefersDecimal: infers all floating-point values as a decimal type. If the values @@ -381,6 +382,14 @@ def json(self, path, schema=None, primitivesAsString=None, prefersDecimal=None, ``spark.sql.columnNameOfCorruptRecord``. If None is set, it uses the value specified in ``spark.sql.columnNameOfCorruptRecord``. + :param dateFormat: sets the string that indicates a date format. Custom date formats + follow the formats at ``java.text.SimpleDateFormat``. This + applies to date type. If None is set, it uses the + default value value, ``yyyy-MM-dd``. + :param timestampFormat: sets the string that indicates a timestamp format. Custom date + formats follow the formats at ``java.text.SimpleDateFormat``. + This applies to timestamp type. If None is set, it uses the + default value value, ``yyyy-MM-dd'T'HH:mm:ss.SSSZZ``. >>> json_sdf = spark.readStream.json(tempfile.mkdtemp(), schema = sdf_schema) >>> json_sdf.isStreaming @@ -393,7 +402,8 @@ def json(self, path, schema=None, primitivesAsString=None, prefersDecimal=None, allowComments=allowComments, allowUnquotedFieldNames=allowUnquotedFieldNames, allowSingleQuotes=allowSingleQuotes, allowNumericLeadingZero=allowNumericLeadingZero, allowBackslashEscapingAnyCharacter=allowBackslashEscapingAnyCharacter, - mode=mode, columnNameOfCorruptRecord=columnNameOfCorruptRecord) + mode=mode, columnNameOfCorruptRecord=columnNameOfCorruptRecord, dateFormat=dateFormat, + timestampFormat=timestampFormat) if isinstance(path, basestring): return self._df(self._jreader.json(path)) else: @@ -450,8 +460,8 @@ def text(self, path): def csv(self, path, schema=None, sep=None, encoding=None, quote=None, escape=None, comment=None, header=None, inferSchema=None, ignoreLeadingWhiteSpace=None, ignoreTrailingWhiteSpace=None, nullValue=None, nanValue=None, positiveInf=None, - negativeInf=None, dateFormat=None, maxColumns=None, maxCharsPerColumn=None, - maxMalformedLogPerPartition=None, mode=None): + negativeInf=None, dateFormat=None, timestampFormat=None, maxColumns=None, + maxCharsPerColumn=None, maxMalformedLogPerPartition=None, mode=None): """Loads a CSV file stream and returns the result as a :class:`DataFrame`. This function will go through the input once to determine the input schema if @@ -461,7 +471,7 @@ def csv(self, path, schema=None, sep=None, encoding=None, quote=None, escape=Non .. note:: Experimental. :param path: string, or list of strings, for input path(s). - :param schema: an optional :class:`StructType` for the input schema. + :param schema: an optional :class:`pyspark.sql.types.StructType` for the input schema. :param sep: sets the single character as a separator for each field and value. If None is set, it uses the default value, ``,``. :param encoding: decodes the CSV files by the given encoding type. If None is set, @@ -485,7 +495,8 @@ def csv(self, path, schema=None, sep=None, encoding=None, quote=None, escape=Non being read should be skipped. If None is set, it uses the default value, ``false``. :param nullValue: sets the string representation of a null value. If None is set, it uses - the default value, empty string. + the default value, empty string. Since 2.0.1, this ``nullValue`` param + applies to all supported types including the string type. :param nanValue: sets the string representation of a non-number value. If None is set, it uses the default value, ``NaN``. :param positiveInf: sets the string representation of a positive infinity value. If None @@ -494,9 +505,12 @@ def csv(self, path, schema=None, sep=None, encoding=None, quote=None, escape=Non is set, it uses the default value, ``Inf``. :param dateFormat: sets the string that indicates a date format. Custom date formats follow the formats at ``java.text.SimpleDateFormat``. This - applies to both date type and timestamp type. By default, it is None - which means trying to parse times and date by - ``java.sql.Timestamp.valueOf()`` and ``java.sql.Date.valueOf()``. + applies to date type. If None is set, it uses the + default value value, ``yyyy-MM-dd``. + :param timestampFormat: sets the string that indicates a timestamp format. Custom date + formats follow the formats at ``java.text.SimpleDateFormat``. + This applies to timestamp type. If None is set, it uses the + default value value, ``yyyy-MM-dd'T'HH:mm:ss.SSSZZ``. :param maxColumns: defines a hard limit of how many columns a record can have. If None is set, it uses the default value, ``20480``. :param maxCharsPerColumn: defines the maximum number of characters allowed for any given @@ -521,7 +535,8 @@ def csv(self, path, schema=None, sep=None, encoding=None, quote=None, escape=Non header=header, inferSchema=inferSchema, ignoreLeadingWhiteSpace=ignoreLeadingWhiteSpace, ignoreTrailingWhiteSpace=ignoreTrailingWhiteSpace, nullValue=nullValue, nanValue=nanValue, positiveInf=positiveInf, negativeInf=negativeInf, - dateFormat=dateFormat, maxColumns=maxColumns, maxCharsPerColumn=maxCharsPerColumn, + dateFormat=dateFormat, timestampFormat=timestampFormat, maxColumns=maxColumns, + maxCharsPerColumn=maxCharsPerColumn, maxMalformedLogPerPartition=maxMalformedLogPerPartition, mode=mode) if isinstance(path, basestring): return self._df(self._jreader.csv(path)) @@ -575,7 +590,7 @@ def format(self, source): .. note:: Experimental. - :param source: string, name of the data source, e.g. 'json', 'parquet'. + :param source: string, name of the data source, which for now can be 'parquet'. >>> writer = sdf.writeStream.format('json') """ diff --git a/python/pyspark/sql/tests.py b/python/pyspark/sql/tests.py index a8ca386e1ce3..1ec40cecf438 100644 --- a/python/pyspark/sql/tests.py +++ b/python/pyspark/sql/tests.py @@ -323,6 +323,14 @@ def test_multiple_udfs(self): [row] = self.spark.sql("SELECT double(add(1, 2)), add(double(2), 1)").collect() self.assertEqual(tuple(row), (6, 5)) + def test_udf_in_filter_on_top_of_outer_join(self): + from pyspark.sql.functions import udf + left = self.spark.createDataFrame([Row(a=1)]) + right = self.spark.createDataFrame([Row(a=1)]) + df = left.join(right, on='a', how='left_outer') + df = df.withColumn('b', udf(lambda x: 'x')(df.a)) + self.assertEqual(df.filter('b = "x"').collect(), [Row(a=1, b='x')]) + def test_udf_without_arguments(self): self.spark.catalog.registerFunction("foo", lambda: "bar") [row] = self.spark.sql("SELECT foo()").collect() @@ -371,6 +379,14 @@ def test_udf_in_generate(self): row = df.select(explode(f(*df))).groupBy().sum().first() self.assertEqual(row[0], 10) + def test_udf_with_order_by_and_limit(self): + from pyspark.sql.functions import udf + my_copy = udf(lambda x: x, IntegerType()) + df = self.spark.range(10).orderBy("id") + res = df.select(df.id, my_copy(df.id).alias("copy")).limit(1) + res.explain(True) + self.assertEqual(res.collect(), [Row(id=0, copy=0)]) + def test_basic_functions(self): rdd = self.sc.parallelize(['{"foo":"bar"}', '{"foo":"baz"}']) df = self.spark.read.json(rdd) @@ -411,6 +427,22 @@ def test_infer_schema_to_local(self): df3 = self.spark.createDataFrame(rdd, df.schema) self.assertEqual(10, df3.count()) + def test_apply_schema_to_dict_and_rows(self): + schema = StructType().add("b", StringType()).add("a", IntegerType()) + input = [{"a": 1}, {"b": "coffee"}] + rdd = self.sc.parallelize(input) + for verify in [False, True]: + df = self.spark.createDataFrame(input, schema, verifySchema=verify) + df2 = self.spark.createDataFrame(rdd, schema, verifySchema=verify) + self.assertEqual(df.schema, df2.schema) + + rdd = self.sc.parallelize(range(10)).map(lambda x: Row(a=x, b=None)) + df3 = self.spark.createDataFrame(rdd, schema, verifySchema=verify) + self.assertEqual(10, df3.count()) + input = [Row(a=x, b=str(x)) for x in range(10)] + df4 = self.spark.createDataFrame(input, schema, verifySchema=verify) + self.assertEqual(10, df4.count()) + def test_create_dataframe_schema_mismatch(self): input = [Row(a=1)] rdd = self.sc.parallelize(range(3)).map(lambda i: Row(a=i)) @@ -575,6 +607,41 @@ def check_datatype(datatype): _verify_type(PythonOnlyPoint(1.0, 2.0), PythonOnlyUDT()) self.assertRaises(ValueError, lambda: _verify_type([1.0, 2.0], PythonOnlyUDT())) + def test_simple_udt_in_df(self): + schema = StructType().add("key", LongType()).add("val", PythonOnlyUDT()) + df = self.spark.createDataFrame( + [(i % 3, PythonOnlyPoint(float(i), float(i))) for i in range(10)], + schema=schema) + df.show() + + def test_nested_udt_in_df(self): + schema = StructType().add("key", LongType()).add("val", ArrayType(PythonOnlyUDT())) + df = self.spark.createDataFrame( + [(i % 3, [PythonOnlyPoint(float(i), float(i))]) for i in range(10)], + schema=schema) + df.collect() + + schema = StructType().add("key", LongType()).add("val", + MapType(LongType(), PythonOnlyUDT())) + df = self.spark.createDataFrame( + [(i % 3, {i % 3: PythonOnlyPoint(float(i + 1), float(i + 1))}) for i in range(10)], + schema=schema) + df.collect() + + def test_complex_nested_udt_in_df(self): + from pyspark.sql.functions import udf + + schema = StructType().add("key", LongType()).add("val", PythonOnlyUDT()) + df = self.spark.createDataFrame( + [(i % 3, PythonOnlyPoint(float(i), float(i))) for i in range(10)], + schema=schema) + df.collect() + + gd = df.groupby("key").agg({"val": "collect_list"}) + gd.collect() + udf = udf(lambda k, v: [(k, v[0])], ArrayType(df.schema)) + gd.select(udf(*gd)).collect() + def test_udt_with_none(self): df = self.spark.range(0, 10, 1, 1) @@ -1798,6 +1865,24 @@ def test_collect_functions(self): sorted(df.select(functions.collect_list(df.value).alias('r')).collect()[0].r), ["1", "2", "2", "2"]) + def test_limit_and_take(self): + df = self.spark.range(1, 1000, numPartitions=10) + + def assert_runs_only_one_job_stage_and_task(job_group_name, f): + tracker = self.sc.statusTracker() + self.sc.setJobGroup(job_group_name, description="") + f() + jobs = tracker.getJobIdsForGroup(job_group_name) + self.assertEqual(1, len(jobs)) + stages = tracker.getJobInfo(jobs[0]).stageIds + self.assertEqual(1, len(stages)) + self.assertEqual(1, tracker.getStageInfo(stages[0]).numTasks) + + # Regression test for SPARK-10731: take should delegate to Scala implementation + assert_runs_only_one_job_stage_and_task("take", lambda: df.take(1)) + # Regression test for SPARK-17514: limit(n).collect() should the perform same as take(n) + assert_runs_only_one_job_stage_and_task("collect_limit", lambda: df.limit(1).collect()) + if __name__ == "__main__": from pyspark.sql.tests import * diff --git a/python/pyspark/sql/types.py b/python/pyspark/sql/types.py index eea80684e2df..b765472d6edb 100644 --- a/python/pyspark/sql/types.py +++ b/python/pyspark/sql/types.py @@ -582,6 +582,8 @@ def toInternal(self, obj): else: if isinstance(obj, dict): return tuple(obj.get(n) for n in self.names) + elif isinstance(obj, Row) and getattr(obj, "__from_dict__", False): + return tuple(obj[n] for n in self.names) elif isinstance(obj, (list, tuple)): return tuple(obj) elif hasattr(obj, "__dict__"): @@ -786,9 +788,10 @@ def _parse_struct_fields_string(s): def _parse_datatype_string(s): """ Parses the given data type string to a :class:`DataType`. The data type string format equals - to `DataType.simpleString`, except that top level struct type can omit the `struct<>` and - atomic types use `typeName()` as their format, e.g. use `byte` instead of `tinyint` for - ByteType. We can also use `int` as a short name for IntegerType. + to :class:`DataType.simpleString`, except that top level struct type can omit + the ``struct<>`` and atomic types use ``typeName()`` as their format, e.g. use ``byte`` instead + of ``tinyint`` for :class:`ByteType`. We can also use ``int`` as a short name + for :class:`IntegerType`. >>> _parse_datatype_string("int ") IntegerType @@ -1242,7 +1245,7 @@ def _infer_schema_type(obj, dataType): TimestampType: (datetime.datetime,), ArrayType: (list, tuple, array), MapType: (dict,), - StructType: (tuple, list), + StructType: (tuple, list, dict), } @@ -1313,10 +1316,10 @@ def _verify_type(obj, dataType, nullable=True): assert _type in _acceptable_types, "unknown datatype: %s for object %r" % (dataType, obj) if _type is StructType: - if not isinstance(obj, (tuple, list)): - raise TypeError("StructType can not accept object %r in type %s" % (obj, type(obj))) + # check the type and fields later + pass else: - # subclass of them can not be fromInternald in JVM + # subclass of them can not be fromInternal in JVM if type(obj) not in _acceptable_types[_type]: raise TypeError("%s can not accept object %r in type %s" % (dataType, obj, type(obj))) @@ -1342,11 +1345,25 @@ def _verify_type(obj, dataType, nullable=True): _verify_type(v, dataType.valueType, dataType.valueContainsNull) elif isinstance(dataType, StructType): - if len(obj) != len(dataType.fields): - raise ValueError("Length of object (%d) does not match with " - "length of fields (%d)" % (len(obj), len(dataType.fields))) - for v, f in zip(obj, dataType.fields): - _verify_type(v, f.dataType, f.nullable) + if isinstance(obj, dict): + for f in dataType.fields: + _verify_type(obj.get(f.name), f.dataType, f.nullable) + elif isinstance(obj, Row) and getattr(obj, "__from_dict__", False): + # the order in obj could be different than dataType.fields + for f in dataType.fields: + _verify_type(obj[f.name], f.dataType, f.nullable) + elif isinstance(obj, (tuple, list)): + if len(obj) != len(dataType.fields): + raise ValueError("Length of object (%d) does not match with " + "length of fields (%d)" % (len(obj), len(dataType.fields))) + for v, f in zip(obj, dataType.fields): + _verify_type(v, f.dataType, f.nullable) + elif hasattr(obj, "__dict__"): + d = obj.__dict__ + for f in dataType.fields: + _verify_type(d.get(f.name), f.dataType, f.nullable) + else: + raise TypeError("StructType can not accept object %r in type %s" % (obj, type(obj))) # This is used to unpickle a Row from JVM @@ -1409,6 +1426,7 @@ def __new__(self, *args, **kwargs): names = sorted(kwargs.keys()) row = tuple.__new__(self, [kwargs[n] for n in names]) row.__fields__ = names + row.__from_dict__ = True return row else: @@ -1484,7 +1502,7 @@ def __getattr__(self, item): raise AttributeError(item) def __setattr__(self, key, value): - if key != '__fields__': + if key != '__fields__' and key != "__from_dict__": raise Exception("Row is read-only") self.__dict__[key] = value diff --git a/python/pyspark/streaming/kafka.py b/python/pyspark/streaming/kafka.py index 2c1a667fc80c..bf27d8047a75 100644 --- a/python/pyspark/streaming/kafka.py +++ b/python/pyspark/streaming/kafka.py @@ -287,6 +287,9 @@ def __eq__(self, other): def __ne__(self, other): return not self.__eq__(other) + def __hash__(self): + return (self._topic, self._partition).__hash__() + class Broker(object): """ diff --git a/python/pyspark/streaming/tests.py b/python/pyspark/streaming/tests.py index 360ba1e7167c..5ac007cd598b 100644 --- a/python/pyspark/streaming/tests.py +++ b/python/pyspark/streaming/tests.py @@ -41,6 +41,9 @@ else: import unittest +if sys.version >= "3": + long = int + from pyspark.context import SparkConf, SparkContext, RDD from pyspark.storagelevel import StorageLevel from pyspark.streaming.context import StreamingContext @@ -1058,7 +1061,6 @@ def test_kafka_direct_stream(self): stream = KafkaUtils.createDirectStream(self.ssc, [topic], kafkaParams) self._validateStreamResult(sendData, stream) - @unittest.skipIf(sys.version >= "3", "long type not support") def test_kafka_direct_stream_from_offset(self): """Test the Python direct Kafka stream API with start offset specified.""" topic = self._randomTopic() @@ -1072,7 +1074,6 @@ def test_kafka_direct_stream_from_offset(self): stream = KafkaUtils.createDirectStream(self.ssc, [topic], kafkaParams, fromOffsets) self._validateStreamResult(sendData, stream) - @unittest.skipIf(sys.version >= "3", "long type not support") def test_kafka_rdd(self): """Test the Python direct Kafka RDD API.""" topic = self._randomTopic() @@ -1085,7 +1086,6 @@ def test_kafka_rdd(self): rdd = KafkaUtils.createRDD(self.sc, kafkaParams, offsetRanges) self._validateRddResult(sendData, rdd) - @unittest.skipIf(sys.version >= "3", "long type not support") def test_kafka_rdd_with_leaders(self): """Test the Python direct Kafka RDD API with leaders.""" topic = self._randomTopic() @@ -1100,7 +1100,6 @@ def test_kafka_rdd_with_leaders(self): rdd = KafkaUtils.createRDD(self.sc, kafkaParams, offsetRanges, leaders) self._validateRddResult(sendData, rdd) - @unittest.skipIf(sys.version >= "3", "long type not support") def test_kafka_rdd_get_offsetRanges(self): """Test Python direct Kafka RDD get OffsetRanges.""" topic = self._randomTopic() @@ -1113,7 +1112,6 @@ def test_kafka_rdd_get_offsetRanges(self): rdd = KafkaUtils.createRDD(self.sc, kafkaParams, offsetRanges) self.assertEqual(offsetRanges, rdd.offsetRanges()) - @unittest.skipIf(sys.version >= "3", "long type not support") def test_kafka_direct_stream_foreach_get_offsetRanges(self): """Test the Python direct Kafka stream foreachRDD get offsetRanges.""" topic = self._randomTopic() @@ -1138,7 +1136,6 @@ def getOffsetRanges(_, rdd): self.assertEqual(offsetRanges, [OffsetRange(topic, 0, long(0), long(6))]) - @unittest.skipIf(sys.version >= "3", "long type not support") def test_kafka_direct_stream_transform_get_offsetRanges(self): """Test the Python direct Kafka stream transform get offsetRanges.""" topic = self._randomTopic() @@ -1176,7 +1173,6 @@ def test_topic_and_partition_equality(self): self.assertNotEqual(topic_and_partition_a, topic_and_partition_c) self.assertNotEqual(topic_and_partition_a, topic_and_partition_d) - @unittest.skipIf(sys.version >= "3", "long type not support") def test_kafka_direct_stream_transform_with_checkpoint(self): """Test the Python direct Kafka stream transform with checkpoint correctly recovered.""" topic = self._randomTopic() @@ -1225,7 +1221,6 @@ def setup(): finally: shutil.rmtree(tmpdir) - @unittest.skipIf(sys.version >= "3", "long type not support") def test_kafka_rdd_message_handler(self): """Test Python direct Kafka RDD MessageHandler.""" topic = self._randomTopic() @@ -1242,7 +1237,6 @@ def getKeyAndDoubleMessage(m): messageHandler=getKeyAndDoubleMessage) self._validateRddResult({"aa": 1, "bb": 1, "cc": 2}, rdd) - @unittest.skipIf(sys.version >= "3", "long type not support") def test_kafka_direct_stream_message_handler(self): """Test the Python direct Kafka stream MessageHandler.""" topic = self._randomTopic() diff --git a/repl/build.gradle b/repl/build.gradle new file mode 100644 index 000000000000..018602803e4f --- /dev/null +++ b/repl/build.gradle @@ -0,0 +1,41 @@ +/* + * Copyright (c) 2016 SnappyData, Inc. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you + * may not use this file except in compliance with the License. You + * may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or + * implied. See the License for the specific language governing + * permissions and limitations under the License. See accompanying + * LICENSE file. + */ + +description = 'Spark Project REPL' + +dependencies { + compile project(subprojectBase + 'snappy-spark-core_' + scalaBinaryVersion) + compile project(subprojectBase + 'snappy-spark-sql_' + scalaBinaryVersion) + compile project(subprojectBase + 'snappy-spark-tags_' + scalaBinaryVersion) + + compile group: 'org.apache.xbean', name: 'xbean-asm5-shaded', version: '4.4' + compile group: 'org.scala-lang', name: 'scala-compiler', version: scalaVersion + compile group: 'org.slf4j', name: 'jul-to-slf4j', version: slf4jVersion + compile group: 'jline', name: 'jline', version: jlineVersion + + runtime project(subprojectBase + 'snappy-spark-mllib_' + scalaBinaryVersion) + + testCompile project(path: subprojectBase + 'snappy-spark-core_' + scalaBinaryVersion, configuration: 'testOutput') +} + +if (scalaBinaryVersion == '2.11') { + sourceSets.main.scala.srcDir 'scala-2.11/src/main/scala' + sourceSets.test.scala.srcDir 'scala-2.11/src/test/scala' +} else { + sourceSets.main.scala.srcDir 'scala-2.10/src/main/scala' + sourceSets.test.scala.srcDir 'scala-2.10/src/test/scala' +} diff --git a/repl/pom.xml b/repl/pom.xml index 0b5ec1a08c82..4b70d647d59e 100644 --- a/repl/pom.xml +++ b/repl/pom.xml @@ -21,7 +21,7 @@ org.apache.spark spark-parent_2.11 - 2.0.1-SNAPSHOT + 2.0.1 ../pom.xml @@ -72,6 +72,10 @@ ${scala.version} + ${jline.groupid} + jline + + org.slf4j jul-to-slf4j @@ -161,13 +165,6 @@ scala-2.10 - - - ${jline.groupid} - jline - ${jline.version} - - diff --git a/repl/scala-2.10/src/main/scala/org/apache/spark/repl/SparkILoop.scala b/repl/scala-2.10/src/main/scala/org/apache/spark/repl/SparkILoop.scala index 16f330a320a4..e017aa42a4c1 100644 --- a/repl/scala-2.10/src/main/scala/org/apache/spark/repl/SparkILoop.scala +++ b/repl/scala-2.10/src/main/scala/org/apache/spark/repl/SparkILoop.scala @@ -1059,7 +1059,8 @@ class SparkILoop( @deprecated("Use `process` instead", "2.9.0") private def main(settings: Settings): Unit = process(settings) - private[repl] def getAddedJars(): Array[String] = { + @DeveloperApi + def getAddedJars(): Array[String] = { val conf = new SparkConf().setMaster(getMaster()) val envJars = sys.env.get("ADD_JARS") if (envJars.isDefined) { diff --git a/repl/scala-2.11/src/test/scala/org/apache/spark/repl/ReplSuite.scala b/repl/scala-2.11/src/test/scala/org/apache/spark/repl/ReplSuite.scala index c10db947bcb4..f7d7a4f04131 100644 --- a/repl/scala-2.11/src/test/scala/org/apache/spark/repl/ReplSuite.scala +++ b/repl/scala-2.11/src/test/scala/org/apache/spark/repl/ReplSuite.scala @@ -396,6 +396,29 @@ class ReplSuite extends SparkFunSuite { assertContains("ret: Array[(Int, Iterable[Foo])] = Array((1,", output) } + test("replicating blocks of object with class defined in repl") { + val output = runInterpreter("local-cluster[2,1,1024]", + """ + |val timeout = 60000 // 60 seconds + |val start = System.currentTimeMillis + |while(sc.getExecutorStorageStatus.size != 3 && + | (System.currentTimeMillis - start) < timeout) { + | Thread.sleep(10) + |} + |if (System.currentTimeMillis - start >= timeout) { + | throw new java.util.concurrent.TimeoutException("Executors were not up in 60 seconds") + |} + |import org.apache.spark.storage.StorageLevel._ + |case class Foo(i: Int) + |val ret = sc.parallelize((1 to 100).map(Foo), 10).persist(MEMORY_AND_DISK_2) + |ret.count() + |sc.getExecutorStorageStatus.map(s => s.rddBlocksById(ret.id).size).sum + """.stripMargin) + assertDoesNotContain("error:", output) + assertDoesNotContain("Exception", output) + assertContains(": Int = 20", output) + } + test("line wrapper only initialized once when used as encoder outer scope") { val output = runInterpreter("local", """ diff --git a/sbin/spark-config.sh b/sbin/spark-config.sh index 5f7bf41caf9b..b7284487c511 100755 --- a/sbin/spark-config.sh +++ b/sbin/spark-config.sh @@ -26,5 +26,8 @@ fi export SPARK_CONF_DIR="${SPARK_CONF_DIR:-"${SPARK_HOME}/conf"}" # Add the PySpark classes to the PYTHONPATH: -export PYTHONPATH="${SPARK_HOME}/python:${PYTHONPATH}" -export PYTHONPATH="${SPARK_HOME}/python/lib/py4j-0.10.1-src.zip:${PYTHONPATH}" +if [ -z "${PYSPARK_PYTHONPATH_SET}" ]; then + export PYTHONPATH="${SPARK_HOME}/python:${PYTHONPATH}" + export PYTHONPATH="${SPARK_HOME}/python/lib/py4j-0.10.3-src.zip:${PYTHONPATH}" + export PYSPARK_PYTHONPATH_SET=1 +fi diff --git a/sbin/start-master.sh b/sbin/start-master.sh index 981cb15bc000..d970fcc45e2c 100755 --- a/sbin/start-master.sh +++ b/sbin/start-master.sh @@ -48,7 +48,7 @@ if [ "$SPARK_MASTER_PORT" = "" ]; then fi if [ "$SPARK_MASTER_HOST" = "" ]; then - SPARK_MASTER_HOST=`hostname` + SPARK_MASTER_HOST=`hostname -f` fi if [ "$SPARK_MASTER_WEBUI_PORT" = "" ]; then diff --git a/sbin/start-mesos-dispatcher.sh b/sbin/start-mesos-dispatcher.sh index 06a966d1c20b..ef65fb953914 100755 --- a/sbin/start-mesos-dispatcher.sh +++ b/sbin/start-mesos-dispatcher.sh @@ -34,7 +34,7 @@ if [ "$SPARK_MESOS_DISPATCHER_PORT" = "" ]; then fi if [ "$SPARK_MESOS_DISPATCHER_HOST" = "" ]; then - SPARK_MESOS_DISPATCHER_HOST=`hostname` + SPARK_MESOS_DISPATCHER_HOST=`hostname -f` fi if [ "$SPARK_MESOS_DISPATCHER_NUM" = "" ]; then diff --git a/sbin/start-slaves.sh b/sbin/start-slaves.sh index 0fa160548970..7d8871251f81 100755 --- a/sbin/start-slaves.sh +++ b/sbin/start-slaves.sh @@ -32,7 +32,7 @@ if [ "$SPARK_MASTER_PORT" = "" ]; then fi if [ "$SPARK_MASTER_HOST" = "" ]; then - SPARK_MASTER_HOST="`hostname`" + SPARK_MASTER_HOST="`hostname -f`" fi # Launch the slaves diff --git a/scalastyle-config.xml b/scalastyle-config.xml index 9a35183c6373..7fe0697202cd 100644 --- a/scalastyle-config.xml +++ b/scalastyle-config.xml @@ -250,6 +250,14 @@ This file is divided into 3 sections: Omit braces in case clauses. + + + ^Override$ + override modifier should be used instead of @java.lang.Override. + + + + diff --git a/settings.gradle b/settings.gradle new file mode 100644 index 000000000000..ca33d18d94bf --- /dev/null +++ b/settings.gradle @@ -0,0 +1,82 @@ +/* + * Copyright (c) 2016 SnappyData, Inc. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you + * may not use this file except in compliance with the License. You + * may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or + * implied. See the License for the specific language governing + * permissions and limitations under the License. See accompanying + * LICENSE file. + */ + +def scalaBinaryVersion = '2.11' +rootProject.name = 'snappy-spark' + +include ':snappy-spark-tags_' + scalaBinaryVersion +include ':snappy-spark-core_' + scalaBinaryVersion +include ':snappy-spark-graphx_' + scalaBinaryVersion +include ':snappy-spark-mllib_' + scalaBinaryVersion +include ':snappy-spark-mllib-local_' + scalaBinaryVersion +include ':snappy-spark-tools_' + scalaBinaryVersion +include ':snappy-spark-network-common_' + scalaBinaryVersion +include ':snappy-spark-network-shuffle_' + scalaBinaryVersion +include ':snappy-spark-network-yarn_' + scalaBinaryVersion +include ':snappy-spark-sketch_' + scalaBinaryVersion +include ':snappy-spark-yarn_' + scalaBinaryVersion +include ':snappy-spark-streaming_' + scalaBinaryVersion +include ':snappy-spark-catalyst_' + scalaBinaryVersion +include ':snappy-spark-sql_' + scalaBinaryVersion +include ':snappy-spark-hive_' + scalaBinaryVersion +include ':snappy-spark-hive-thriftserver_' + scalaBinaryVersion +include ':snappy-spark-unsafe_' + scalaBinaryVersion +include ':snappy-spark-assembly_' + scalaBinaryVersion +include ':snappy-spark-streaming-flume_' + scalaBinaryVersion +include ':snappy-spark-streaming-flume-sink_' + scalaBinaryVersion +include ':snappy-spark-streaming-kafka-0.8_' + scalaBinaryVersion +include ':snappy-spark-streaming-kafka-0.10_' + scalaBinaryVersion +include ':snappy-spark-examples_' + scalaBinaryVersion +include ':snappy-spark-repl_' + scalaBinaryVersion +include ':snappy-spark-launcher_' + scalaBinaryVersion +include ':snappy-spark-assembly_' + scalaBinaryVersion + +project(':snappy-spark-tags_' + scalaBinaryVersion).projectDir = "$rootDir/common/tags" as File +project(':snappy-spark-core_' + scalaBinaryVersion).projectDir = "$rootDir/core" as File +project(':snappy-spark-graphx_' + scalaBinaryVersion).projectDir = "$rootDir/graphx" as File +project(':snappy-spark-mllib_' + scalaBinaryVersion).projectDir = "$rootDir/mllib" as File +project(':snappy-spark-mllib-local_' + scalaBinaryVersion).projectDir = "$rootDir/mllib-local" as File +project(':snappy-spark-tools_' + scalaBinaryVersion).projectDir = "$rootDir/tools" as File +project(':snappy-spark-network-common_' + scalaBinaryVersion).projectDir = "$rootDir/common/network-common" as File +project(':snappy-spark-network-shuffle_' + scalaBinaryVersion).projectDir = "$rootDir/common/network-shuffle" as File +project(':snappy-spark-network-yarn_' + scalaBinaryVersion).projectDir = "$rootDir/common/network-yarn" as File +project(':snappy-spark-sketch_' + scalaBinaryVersion).projectDir = "$rootDir/common/sketch" as File +project(':snappy-spark-yarn_' + scalaBinaryVersion).projectDir = "$rootDir/yarn" as File +project(':snappy-spark-streaming_' + scalaBinaryVersion).projectDir = "$rootDir/streaming" as File +project(':snappy-spark-catalyst_' + scalaBinaryVersion).projectDir = "$rootDir/sql/catalyst" as File +project(':snappy-spark-sql_' + scalaBinaryVersion).projectDir = "$rootDir/sql/core" as File +project(':snappy-spark-hive_' + scalaBinaryVersion).projectDir = "$rootDir/sql/hive" as File +project(':snappy-spark-hive-thriftserver_' + scalaBinaryVersion).projectDir = "$rootDir/sql/hive-thriftserver" as File +project(':snappy-spark-unsafe_' + scalaBinaryVersion).projectDir = "$rootDir/common/unsafe" as File +project(':snappy-spark-assembly_' + scalaBinaryVersion).projectDir = "$rootDir/assembly" as File +project(':snappy-spark-streaming-flume_' + scalaBinaryVersion).projectDir = "$rootDir/external/flume" as File +project(':snappy-spark-streaming-flume-sink_' + scalaBinaryVersion).projectDir = "$rootDir/external/flume-sink" as File +project(':snappy-spark-streaming-kafka-0.8_' + scalaBinaryVersion).projectDir = "$rootDir/external/kafka-0-8" as File +project(':snappy-spark-streaming-kafka-0.10_' + scalaBinaryVersion).projectDir = "$rootDir/external/kafka-0-10" as File +project(':snappy-spark-examples_' + scalaBinaryVersion).projectDir = "$rootDir/examples" as File +project(':snappy-spark-repl_' + scalaBinaryVersion).projectDir = "$rootDir/repl" as File +project(':snappy-spark-launcher_' + scalaBinaryVersion).projectDir = "$rootDir/launcher" as File +project(':snappy-spark-assembly_' + scalaBinaryVersion).projectDir = "$rootDir/assembly" as File + +if (rootProject.hasProperty('docker')) { + include ':snappy-spark-docker-integration-tests_' + scalaBinaryVersion + project(':snappy-spark-docker-integration-tests_' + scalaBinaryVersion).projectDir = "$rootDir/external/docker-integration-tests" as File +} +if (rootProject.hasProperty('ganglia')) { + include ':snappy-spark-ganglia-lgpl_' + scalaBinaryVersion + project(':snappy-spark-ganglia-lgpl_' + scalaBinaryVersion).projectDir = "$rootDir/external/spark-ganglia-lgpl" as File +} diff --git a/sql/catalyst/.gitignore b/sql/catalyst/.gitignore new file mode 100644 index 000000000000..42b6ce41f8a6 --- /dev/null +++ b/sql/catalyst/.gitignore @@ -0,0 +1 @@ +src/generated/antlr4 diff --git a/sql/catalyst/build.gradle b/sql/catalyst/build.gradle new file mode 100644 index 000000000000..cc0e9bbf2822 --- /dev/null +++ b/sql/catalyst/build.gradle @@ -0,0 +1,50 @@ +/* + * Copyright (c) 2016 SnappyData, Inc. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you + * may not use this file except in compliance with the License. You + * may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or + * implied. See the License for the specific language governing + * permissions and limitations under the License. See accompanying + * LICENSE file. + */ + +description = 'Spark Project Catalyst' + +apply plugin: 'antlr' + +dependencies { + compile project(subprojectBase + 'snappy-spark-core_' + scalaBinaryVersion) + compile project(subprojectBase + 'snappy-spark-unsafe_' + scalaBinaryVersion) + compile project(subprojectBase + 'snappy-spark-tags_' + scalaBinaryVersion) + + compile group: 'org.scala-lang', name: 'scala-compiler', version: scalaVersion + compile group: 'org.scala-lang.modules', name: 'scala-parser-combinators_' + scalaBinaryVersion, version: '1.0.4' + compile group: 'org.codehaus.janino', name: 'janino', version: '2.7.8' + compile group: 'org.antlr', name: 'antlr4-runtime', version: antlrVersion + compile group: 'commons-codec', name: 'commons-codec', version: commonsCodecVersion + antlr group: 'org.antlr', name: 'antlr4', version: antlrVersion + + testCompile project(path: subprojectBase + 'snappy-spark-core_' + scalaBinaryVersion, configuration: 'testOutput') +} + +compileScala.dependsOn generateGrammarSource + +sourceSets.main.antlr.srcDirs = [ 'src/main/antlr4' ] + +// use an output directory that IDEA can easily find +String antlrOut = 'src/generated/antlr4' +// add generated sources to scala compiler path (plugin adds it to java path) +sourceSets.main.scala.srcDir antlrOut +sourceSets.main.java.srcDirs = [] + +generateGrammarSource { + arguments += [ '-package', 'org.apache.spark.sql.catalyst.parser', '-visitor' ] + outputDirectory = file(antlrOut) +} diff --git a/sql/catalyst/pom.xml b/sql/catalyst/pom.xml index 0bfdb13cec94..efa327cbf21c 100644 --- a/sql/catalyst/pom.xml +++ b/sql/catalyst/pom.xml @@ -22,7 +22,7 @@ org.apache.spark spark-parent_2.11 - 2.0.1-SNAPSHOT + 2.0.1 ../../pom.xml diff --git a/sql/catalyst/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBase.g4 b/sql/catalyst/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBase.g4 index 4c15f9cec657..8b721407eb17 100644 --- a/sql/catalyst/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBase.g4 +++ b/sql/catalyst/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBase.g4 @@ -16,6 +16,30 @@ grammar SqlBase; +@members { + /** + * Verify whether current token is a valid decimal token (which contains dot). + * Returns true if the character that follows the token is not a digit or letter or underscore. + * + * For example: + * For char stream "2.3", "2." is not a valid decimal token, because it is followed by digit '3'. + * For char stream "2.3_", "2.3" is not a valid decimal token, because it is followed by '_'. + * For char stream "2.3W", "2.3" is not a valid decimal token, because it is followed by 'W'. + * For char stream "12.0D 34.E2+0.12 " 12.0D is a valid decimal token because it is folllowed + * by a space. 34.E2 is a valid decimal token because it is followed by symbol '+' + * which is not a digit or letter or underscore. + */ + public boolean isValidDecimal() { + int nextChar = _input.LA(1); + if (nextChar >= 'A' && nextChar <= 'Z' || nextChar >= '0' && nextChar <= '9' || + nextChar == '_') { + return false; + } else { + return true; + } + } +} + tokens { DELIMITER } @@ -84,6 +108,7 @@ statement | ALTER VIEW tableIdentifier DROP (IF EXISTS)? partitionSpec (',' partitionSpec)* #dropTablePartitions | ALTER TABLE tableIdentifier partitionSpec? SET locationSpec #setTableLocation + | ALTER TABLE tableIdentifier RECOVER PARTITIONS #recoverPartitions | DROP TABLE (IF EXISTS)? tableIdentifier PURGE? #dropTable | DROP VIEW (IF EXISTS)? tableIdentifier #dropTable | CREATE (OR REPLACE)? TEMPORARY? VIEW (IF NOT EXISTS)? tableIdentifier @@ -121,6 +146,7 @@ statement | LOAD DATA LOCAL? INPATH path=STRING OVERWRITE? INTO TABLE tableIdentifier partitionSpec? #loadData | TRUNCATE TABLE tableIdentifier partitionSpec? #truncateTable + | MSCK REPAIR TABLE tableIdentifier #repairTable | op=(ADD | LIST) identifier .*? #manageResource | SET ROLE .*? #failNativeCommand | SET .*? #setConfiguration @@ -154,7 +180,6 @@ unsupportedHiveNativeCommands | kw1=UNLOCK kw2=DATABASE | kw1=CREATE kw2=TEMPORARY kw3=MACRO | kw1=DROP kw2=TEMPORARY kw3=MACRO - | kw1=MSCK kw2=REPAIR kw3=TABLE | kw1=ALTER kw2=TABLE tableIdentifier kw3=NOT kw4=CLUSTERED | kw1=ALTER kw2=TABLE tableIdentifier kw3=CLUSTERED kw4=BY | kw1=ALTER kw2=TABLE tableIdentifier kw3=NOT kw4=SORTED @@ -366,11 +391,12 @@ setQuantifier ; relation - : left=relation - ((CROSS | joinType) JOIN right=relation joinCriteria? - | NATURAL joinType JOIN right=relation - ) #joinRelation - | relationPrimary #relationDefault + : relationPrimary joinRelation* + ; + +joinRelation + : (CROSS | joinType) JOIN right=relationPrimary joinCriteria? + | NATURAL joinType JOIN right=relationPrimary ; joinType @@ -425,6 +451,7 @@ relationPrimary | '(' queryNoWith ')' sample? (AS? strictIdentifier)? #aliasedQuery | '(' relation ')' sample? (AS? strictIdentifier)? #aliasedRelation | inlineTable #inlineTableDefault2 + | identifier '(' (expression (',' expression)*)? ')' #tableValuedFunction ; inlineTable @@ -493,6 +520,7 @@ valueExpression primaryExpression : constant #constantDefault + | name=(CURRENT_DATE | CURRENT_TIMESTAMP) #timeFunctionCall | ASTERISK #star | qualifiedName '.' ASTERISK #star | '(' expression (',' expression)+ ')' #rowConstructor @@ -616,13 +644,14 @@ quotedIdentifier ; number - : DECIMAL_VALUE #decimalLiteral - | SCIENTIFIC_DECIMAL_VALUE #scientificDecimalLiteral - | INTEGER_VALUE #integerLiteral - | BIGINT_LITERAL #bigIntLiteral - | SMALLINT_LITERAL #smallIntLiteral - | TINYINT_LITERAL #tinyIntLiteral - | DOUBLE_LITERAL #doubleLiteral + : MINUS? DECIMAL_VALUE #decimalLiteral + | MINUS? SCIENTIFIC_DECIMAL_VALUE #scientificDecimalLiteral + | MINUS? INTEGER_VALUE #integerLiteral + | MINUS? BIGINT_LITERAL #bigIntLiteral + | MINUS? SMALLINT_LITERAL #smallIntLiteral + | MINUS? TINYINT_LITERAL #tinyIntLiteral + | MINUS? DOUBLE_LITERAL #doubleLiteral + | MINUS? BIGDECIMAL_LITERAL #bigDecimalLiteral ; nonReserved @@ -645,7 +674,7 @@ nonReserved | CASCADE | RESTRICT | BUCKETS | CLUSTERED | SORTED | PURGE | INPUTFORMAT | OUTPUTFORMAT | DBPROPERTIES | DFS | TRUNCATE | COMPUTE | LIST | STATISTICS | ANALYZE | PARTITIONED | EXTERNAL | DEFINED | RECORDWRITER - | REVOKE | GRANT | LOCK | UNLOCK | MSCK | REPAIR | EXPORT | IMPORT | LOAD | VALUES | COMMENT | ROLE + | REVOKE | GRANT | LOCK | UNLOCK | MSCK | REPAIR | RECOVER | EXPORT | IMPORT | LOAD | VALUES | COMMENT | ROLE | ROLES | COMPACTIONS | PRINCIPALS | TRANSACTIONS | INDEX | INDEXES | LOCKS | OPTION | LOCAL | INPATH | ASC | DESC | LIMIT | RENAME | SETS | AT | NULLS | OVERWRITE | ALL | ALTER | AS | BETWEEN | BY | CREATE | DELETE @@ -653,7 +682,7 @@ nonReserved | NULL | ORDER | OUTER | TABLE | TRUE | WITH | RLIKE | AND | CASE | CAST | DISTINCT | DIV | ELSE | END | FUNCTION | INTERVAL | MACRO | OR | STRATIFY | THEN | UNBOUNDED | WHEN - | DATABASE | SELECT | FROM | WHERE | HAVING | TO | TABLE | WITH | NOT + | DATABASE | SELECT | FROM | WHERE | HAVING | TO | TABLE | WITH | NOT | CURRENT_DATE | CURRENT_TIMESTAMP ; SELECT: 'SELECT'; @@ -858,6 +887,7 @@ LOCK: 'LOCK'; UNLOCK: 'UNLOCK'; MSCK: 'MSCK'; REPAIR: 'REPAIR'; +RECOVER: 'RECOVER'; EXPORT: 'EXPORT'; IMPORT: 'IMPORT'; LOAD: 'LOAD'; @@ -873,6 +903,8 @@ OPTION: 'OPTION'; ANTI: 'ANTI'; LOCAL: 'LOCAL'; INPATH: 'INPATH'; +CURRENT_DATE: 'CURRENT_DATE'; +CURRENT_TIMESTAMP: 'CURRENT_TIMESTAMP'; STRING : '\'' ( ~('\''|'\\') | ('\\' .) )* '\'' @@ -900,18 +932,22 @@ INTEGER_VALUE ; DECIMAL_VALUE - : DIGIT+ '.' DIGIT* - | '.' DIGIT+ + : DECIMAL_DIGITS {isValidDecimal()}? ; SCIENTIFIC_DECIMAL_VALUE - : DIGIT+ ('.' DIGIT*)? EXPONENT - | '.' DIGIT+ EXPONENT + : DIGIT+ EXPONENT + | DECIMAL_DIGITS EXPONENT {isValidDecimal()}? ; DOUBLE_LITERAL - : - (INTEGER_VALUE | DECIMAL_VALUE | SCIENTIFIC_DECIMAL_VALUE) 'D' + : DIGIT+ EXPONENT? 'D' + | DECIMAL_DIGITS EXPONENT? 'D' {isValidDecimal()}? + ; + +BIGDECIMAL_LITERAL + : DIGIT+ EXPONENT? 'BD' + | DECIMAL_DIGITS EXPONENT? 'BD' {isValidDecimal()}? ; IDENTIFIER @@ -922,6 +958,11 @@ BACKQUOTED_IDENTIFIER : '`' ( ~'`' | '``' )* '`' ; +fragment DECIMAL_DIGITS + : DIGIT+ '.' DIGIT* + | '.' DIGIT+ + ; + fragment EXPONENT : 'E' [+-]? DIGIT+ ; diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeRow.java b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeRow.java index dd2f39eb816f..9027652d57f1 100644 --- a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeRow.java +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeRow.java @@ -31,6 +31,7 @@ import com.esotericsoftware.kryo.io.Input; import com.esotericsoftware.kryo.io.Output; +import org.apache.spark.sql.catalyst.InternalRow; import org.apache.spark.sql.types.*; import org.apache.spark.unsafe.Platform; import org.apache.spark.unsafe.array.ByteArrayMethods; @@ -577,8 +578,12 @@ public boolean equals(Object other) { return (sizeInBytes == o.sizeInBytes) && ByteArrayMethods.arrayEquals(baseObject, baseOffset, o.baseObject, o.baseOffset, sizeInBytes); + } else if (!(other instanceof InternalRow)) { + return false; + } else { + throw new IllegalArgumentException( + "Cannot compare UnsafeRow to " + other.getClass().getName()); } - return false; } /** diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/codegen/UnsafeArrayWriter.java b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/codegen/UnsafeArrayWriter.java index 7dd932d1981b..5a80928bded1 100644 --- a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/codegen/UnsafeArrayWriter.java +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/codegen/UnsafeArrayWriter.java @@ -14,6 +14,24 @@ * See the License for the specific language governing permissions and * limitations under the License. */ +/* + * Changes for SnappyData data platform. + * + * Portions Copyright (c) 2016 SnappyData, Inc. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you + * may not use this file except in compliance with the License. You + * may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or + * implied. See the License for the specific language governing + * permissions and limitations under the License. See accompanying + * LICENSE file. + */ package org.apache.spark.sql.catalyst.expressions.codegen; @@ -120,7 +138,7 @@ public void write(int ordinal, Decimal input, int precision, int scale) { holder.cursor += 8; } else { final byte[] bytes = input.toJavaBigDecimal().unscaledValue().toByteArray(); - assert bytes.length <= 16; + // assert bytes.length <= 16; holder.grow(bytes.length); // Write the bytes to the variable length portion. diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/Row.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/Row.scala index d83eef7a4162..e16850efbea5 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/Row.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/Row.scala @@ -463,6 +463,6 @@ trait Row extends Serializable { * @throws NullPointerException when value is null. */ private def getAnyValAs[T <: AnyVal](i: Int): T = - if (isNullAt(i)) throw new NullPointerException(s"Value at index $i in null") + if (isNullAt(i)) throw new NullPointerException(s"Value at index $i is null") else getAs[T](i) } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/CatalystTypeConverters.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/CatalystTypeConverters.scala index 9cc7b2ac7920..00fd22bd430b 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/CatalystTypeConverters.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/CatalystTypeConverters.scala @@ -14,6 +14,24 @@ * See the License for the specific language governing permissions and * limitations under the License. */ +/* + * Changes for SnappyData data platform. + * + * Portions Copyright (c) 2016 SnappyData, Inc. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you + * may not use this file except in compliance with the License. You + * may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or + * implied. See the License for the specific language governing + * permissions and limitations under the License. See accompanying + * LICENSE file. + */ package org.apache.spark.sql.catalyst @@ -170,6 +188,8 @@ object CatalystTypeConverters { convertedIterable += elementConverter.toCatalyst(item) } new GenericArrayData(convertedIterable.toArray) + + case a: ArrayData => a } } @@ -227,6 +247,8 @@ object CatalystTypeConverters { i += 1 } ArrayBasedMapData(convertedKeys, convertedValues) + + case m: MapData => m } override def toScala(catalystValue: MapData): Map[Any, Any] = { @@ -272,6 +294,8 @@ object CatalystTypeConverters { idx += 1 } new GenericInternalRow(ar) + + case row: InternalRow => row } override def toScala(row: InternalRow): Row = { @@ -382,7 +406,7 @@ object CatalystTypeConverters { * Typical use case would be converting a collection of rows that have the same schema. You will * call this function once to get a converter, and apply it to every row. */ - private[sql] def createToCatalystConverter(dataType: DataType): Any => Any = { + def createToCatalystConverter(dataType: DataType): Any => Any = { if (isPrimitive(dataType)) { // Although the `else` branch here is capable of handling inbound conversion of primitives, // we add some special-case handling for those types here. The motivation for this relates to @@ -409,7 +433,7 @@ object CatalystTypeConverters { * Typical use case would be converting a collection of rows that have the same schema. You will * call this function once to get a converter, and apply it to every row. */ - private[sql] def createToScalaConverter(dataType: DataType): Any => Any = { + def createToScalaConverter(dataType: DataType): Any => Any = { if (isPrimitive(dataType)) { identity } else { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala index 8affb033d828..dd36468583b1 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala @@ -720,7 +720,7 @@ object ScalaReflection extends ScalaReflection { /** * Whether the fields of the given type is defined entirely by its constructor parameters. */ - private[sql] def definedByConstructorParams(tpe: Type): Boolean = { + def definedByConstructorParams(tpe: Type): Boolean = { tpe <:< localTypeOf[Product] || tpe <:< localTypeOf[DefinedByConstructorParams] } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala index 2efa997ff22d..3e4c76921726 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala @@ -86,6 +86,7 @@ class Analyzer( WindowsSubstitution, EliminateUnions), Batch("Resolution", fixedPoint, + ResolveTableValuedFunctions :: ResolveRelations :: ResolveReferences :: ResolveDeserializer :: @@ -107,6 +108,7 @@ class Analyzer( GlobalAggregates :: ResolveAggregateFunctions :: TimeWindowing :: + ResolveInlineTables :: TypeCoercion.typeCoercionRules ++ extendedResolutionRules : _*), Batch("Nondeterministic", Once, @@ -246,7 +248,7 @@ class Analyzer( }.isDefined } - private[sql] def hasGroupingFunction(e: Expression): Boolean = { + private[analysis] def hasGroupingFunction(e: Expression): Boolean = { e.collectFirst { case g: Grouping => g case g: GroupingID => g @@ -547,8 +549,7 @@ class Analyzer( case a: Aggregate if containsStar(a.aggregateExpressions) => if (conf.groupByOrdinal && a.groupingExpressions.exists(IntegerIndex.unapply(_).nonEmpty)) { failAnalysis( - "Group by position: star is not allowed to use in the select list " + - "when using ordinals in group by") + "Star (*) is not allowed in select list when GROUP BY ordinal position is used") } else { a.copy(aggregateExpressions = buildExpandedProjectList(a.aggregateExpressions, a.child)) } @@ -723,9 +724,9 @@ class Analyzer( if (index > 0 && index <= child.output.size) { SortOrder(child.output(index - 1), direction) } else { - throw new UnresolvedException(s, - s"Order/sort By position: $index does not exist " + - s"The Select List is indexed from 1 to ${child.output.size}") + s.failAnalysis( + s"ORDER BY position $index is not in select list " + + s"(valid range is [1, ${child.output.size}])") } case o => o } @@ -737,17 +738,18 @@ class Analyzer( if conf.groupByOrdinal && aggs.forall(_.resolved) && groups.exists(IntegerIndex.unapply(_).nonEmpty) => val newGroups = groups.map { - case IntegerIndex(index) if index > 0 && index <= aggs.size => + case ordinal @ IntegerIndex(index) if index > 0 && index <= aggs.size => aggs(index - 1) match { case e if ResolveAggregateFunctions.containsAggregate(e) => - throw new UnresolvedException(a, - s"Group by position: the '$index'th column in the select contains an " + - s"aggregate function: ${e.sql}. Aggregate functions are not allowed in GROUP BY") + ordinal.failAnalysis( + s"GROUP BY position $index is an aggregate function, and " + + "aggregate functions are not allowed in GROUP BY") case o => o } - case IntegerIndex(index) => - throw new UnresolvedException(a, - s"Group by position: '$index' exceeds the size of the select list '${aggs.size}'.") + case ordinal @ IntegerIndex(index) => + ordinal.failAnalysis( + s"GROUP BY position $index is not in select list " + + s"(valid range is [1, ${aggs.size}])") case o => o } Aggregate(newGroups, aggs, child) @@ -1412,7 +1414,7 @@ class Analyzer( * Construct the output attributes for a [[Generator]], given a list of names. If the list of * names is empty names are assigned from field names in generator. */ - private[sql] def makeGeneratorOutput( + private[analysis] def makeGeneratorOutput( generator: Generator, names: Seq[String]): Seq[Attribute] = { val elementAttrs = generator.elementSchema.toAttributes @@ -1647,27 +1649,17 @@ class Analyzer( } }.toSeq - // Third, for every Window Spec, we add a Window operator and set currentChild as the - // child of it. - var currentChild = child - var i = 0 - while (i < groupedWindowExpressions.size) { - val ((partitionSpec, orderSpec), windowExpressions) = groupedWindowExpressions(i) - // Set currentChild to the newly created Window operator. - currentChild = - Window( - windowExpressions, - partitionSpec, - orderSpec, - currentChild) - - // Move to next Window Spec. - i += 1 - } + // Third, we aggregate them by adding each Window operator for each Window Spec and then + // setting this to the child of the next Window operator. + val windowOps = + groupedWindowExpressions.foldLeft(child) { + case (last, ((partitionSpec, orderSpec), windowExpressions)) => + Window(windowExpressions, partitionSpec, orderSpec, last) + } - // Finally, we create a Project to output currentChild's output + // Finally, we create a Project to output windowOps's output // newExpressionsWithWindowFunctions. - Project(currentChild.output ++ newExpressionsWithWindowFunctions, currentChild) + Project(windowOps.output ++ newExpressionsWithWindowFunctions, windowOps) } // end of addWindow // We have to use transformDown at here to make sure the rule of diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala index 8b87a4e41c23..790566c7659c 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala @@ -342,6 +342,7 @@ trait CheckAnalysis extends PredicateHelper { case InsertIntoTable(t, _, _, _, _) if !t.isInstanceOf[LeafNode] || + t.isInstanceOf[Range] || t == OneRowRelation || t.isInstanceOf[LocalRelation] => failAnalysis(s"Inserting into an RDD-based table is not allowed.") diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala index c5f91c159054..35fd800df4a4 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala @@ -161,7 +161,6 @@ object FunctionRegistry { val expressions: Map[String, (ExpressionInfo, FunctionBuilder)] = Map( // misc non-aggregate functions expression[Abs]("abs"), - expression[CreateArray]("array"), expression[Coalesce]("coalesce"), expression[Explode]("explode"), expression[Greatest]("greatest"), @@ -172,10 +171,6 @@ object FunctionRegistry { expression[IsNull]("isnull"), expression[IsNotNull]("isnotnull"), expression[Least]("least"), - expression[CreateMap]("map"), - expression[MapKeys]("map_keys"), - expression[MapValues]("map_values"), - expression[CreateNamedStruct]("named_struct"), expression[NaNvl]("nanvl"), expression[NullIf]("nullif"), expression[Nvl]("nvl"), @@ -184,7 +179,6 @@ object FunctionRegistry { expression[Rand]("rand"), expression[Randn]("randn"), expression[Stack]("stack"), - expression[CreateStruct]("struct"), expression[CaseWhen]("when"), // math functions @@ -354,9 +348,15 @@ object FunctionRegistry { expression[TimeWindow]("window"), // collection functions + expression[CreateArray]("array"), expression[ArrayContains]("array_contains"), + expression[CreateMap]("map"), + expression[CreateNamedStruct]("named_struct"), + expression[MapKeys]("map_keys"), + expression[MapValues]("map_values"), expression[Size]("size"), expression[SortArray]("sort_array"), + expression[CreateStruct]("struct"), // misc functions expression[AssertTrue]("assert_true"), diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveInlineTables.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveInlineTables.scala new file mode 100644 index 000000000000..7323197b10f6 --- /dev/null +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveInlineTables.scala @@ -0,0 +1,112 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.analysis + +import scala.util.control.NonFatal + +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.expressions.Cast +import org.apache.spark.sql.catalyst.plans.logical.{LocalRelation, LogicalPlan} +import org.apache.spark.sql.catalyst.rules.Rule +import org.apache.spark.sql.types.{StructField, StructType} + +/** + * An analyzer rule that replaces [[UnresolvedInlineTable]] with [[LocalRelation]]. + */ +object ResolveInlineTables extends Rule[LogicalPlan] { + override def apply(plan: LogicalPlan): LogicalPlan = plan transformUp { + case table: UnresolvedInlineTable if table.expressionsResolved => + validateInputDimension(table) + validateInputEvaluable(table) + convert(table) + } + + /** + * Validates the input data dimension: + * 1. All rows have the same cardinality. + * 2. The number of column aliases defined is consistent with the number of columns in data. + * + * This is package visible for unit testing. + */ + private[analysis] def validateInputDimension(table: UnresolvedInlineTable): Unit = { + if (table.rows.nonEmpty) { + val numCols = table.names.size + table.rows.zipWithIndex.foreach { case (row, ri) => + if (row.size != numCols) { + table.failAnalysis(s"expected $numCols columns but found ${row.size} columns in row $ri") + } + } + } + } + + /** + * Validates that all inline table data are valid expressions that can be evaluated + * (in this they must be foldable). + * + * This is package visible for unit testing. + */ + private[analysis] def validateInputEvaluable(table: UnresolvedInlineTable): Unit = { + table.rows.foreach { row => + row.foreach { e => + // Note that nondeterministic expressions are not supported since they are not foldable. + if (!e.resolved || !e.foldable) { + e.failAnalysis(s"cannot evaluate expression ${e.sql} in inline table definition") + } + } + } + } + + /** + * Convert a valid (with right shape and foldable inputs) [[UnresolvedInlineTable]] + * into a [[LocalRelation]]. + * + * This function attempts to coerce inputs into consistent types. + * + * This is package visible for unit testing. + */ + private[analysis] def convert(table: UnresolvedInlineTable): LocalRelation = { + // For each column, traverse all the values and find a common data type and nullability. + val fields = table.rows.transpose.zip(table.names).map { case (column, name) => + val inputTypes = column.map(_.dataType) + val tpe = TypeCoercion.findWiderTypeWithoutStringPromotion(inputTypes).getOrElse { + table.failAnalysis(s"incompatible types found in column $name for inline table") + } + StructField(name, tpe, nullable = column.exists(_.nullable)) + } + val attributes = StructType(fields).toAttributes + assert(fields.size == table.names.size) + + val newRows: Seq[InternalRow] = table.rows.map { row => + InternalRow.fromSeq(row.zipWithIndex.map { case (e, ci) => + val targetType = fields(ci).dataType + try { + if (e.dataType.sameType(targetType)) { + e.eval() + } else { + Cast(e, targetType).eval() + } + } catch { + case NonFatal(ex) => + table.failAnalysis(s"failed to evaluate expression ${e.sql}: ${ex.getMessage}") + } + }) + } + + LocalRelation(attributes, newRows) + } +} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveTableValuedFunctions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveTableValuedFunctions.scala new file mode 100644 index 000000000000..6b3bb68538dd --- /dev/null +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveTableValuedFunctions.scala @@ -0,0 +1,129 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.analysis + +import org.apache.spark.{SparkConf, SparkContext} +import org.apache.spark.sql.catalyst.expressions.Expression +import org.apache.spark.sql.catalyst.plans._ +import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, Range} +import org.apache.spark.sql.catalyst.rules._ +import org.apache.spark.sql.types.{DataType, IntegerType, LongType} + +/** + * Rule that resolves table-valued function references. + */ +object ResolveTableValuedFunctions extends Rule[LogicalPlan] { + /** + * List of argument names and their types, used to declare a function. + */ + private case class ArgumentList(args: (String, DataType)*) { + /** + * Try to cast the expressions to satisfy the expected types of this argument list. If there + * are any types that cannot be casted, then None is returned. + */ + def implicitCast(values: Seq[Expression]): Option[Seq[Expression]] = { + if (args.length == values.length) { + val casted = values.zip(args).map { case (value, (_, expectedType)) => + TypeCoercion.ImplicitTypeCasts.implicitCast(value, expectedType) + } + if (casted.forall(_.isDefined)) { + return Some(casted.map(_.get)) + } + } + None + } + + override def toString: String = { + args.map { a => + s"${a._1}: ${a._2.typeName}" + }.mkString(", ") + } + } + + /** + * A TVF maps argument lists to resolver functions that accept those arguments. Using a map + * here allows for function overloading. + */ + private type TVF = Map[ArgumentList, Seq[Any] => LogicalPlan] + + /** + * TVF builder. + */ + private def tvf(args: (String, DataType)*)(pf: PartialFunction[Seq[Any], LogicalPlan]) + : (ArgumentList, Seq[Any] => LogicalPlan) = { + (ArgumentList(args: _*), + pf orElse { + case args => + throw new IllegalArgumentException( + "Invalid arguments for resolved function: " + args.mkString(", ")) + }) + } + + /** + * Internal registry of table-valued functions. + */ + private val builtinFunctions: Map[String, TVF] = Map( + "range" -> Map( + /* range(end) */ + tvf("end" -> LongType) { case Seq(end: Long) => + Range(0, end, 1, None) + }, + + /* range(start, end) */ + tvf("start" -> LongType, "end" -> LongType) { case Seq(start: Long, end: Long) => + Range(start, end, 1, None) + }, + + /* range(start, end, step) */ + tvf("start" -> LongType, "end" -> LongType, "step" -> LongType) { + case Seq(start: Long, end: Long, step: Long) => + Range(start, end, step, None) + }, + + /* range(start, end, step, numPartitions) */ + tvf("start" -> LongType, "end" -> LongType, "step" -> LongType, + "numPartitions" -> IntegerType) { + case Seq(start: Long, end: Long, step: Long, numPartitions: Int) => + Range(start, end, step, Some(numPartitions)) + }) + ) + + override def apply(plan: LogicalPlan): LogicalPlan = plan resolveOperators { + case u: UnresolvedTableValuedFunction if u.functionArgs.forall(_.resolved) => + builtinFunctions.get(u.functionName) match { + case Some(tvf) => + val resolved = tvf.flatMap { case (argList, resolver) => + argList.implicitCast(u.functionArgs) match { + case Some(casted) => + Some(resolver(casted.map(_.eval()))) + case _ => + None + } + } + resolved.headOption.getOrElse { + val argTypes = u.functionArgs.map(_.dataType.typeName).mkString(", ") + u.failAnalysis( + s"""error: table-valued function ${u.functionName} with alternatives: + |${tvf.keys.map(_.toString).toSeq.sorted.map(x => s" ($x)").mkString("\n")} + |cannot be applied to: (${argTypes})""".stripMargin) + } + case _ => + u.failAnalysis(s"could not resolve `${u.functionName}` to a table-valued function") + } + } +} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala index 9a040f8644fb..e773f3dbbcf0 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala @@ -63,7 +63,7 @@ object TypeCoercion { // See https://cwiki.apache.org/confluence/display/Hive/LanguageManual+Types. // The conversion for integral and floating point types have a linear widening hierarchy: - private[sql] val numericPrecedence = + val numericPrecedence = IndexedSeq( ByteType, ShortType, @@ -108,18 +108,6 @@ object TypeCoercion { }) } - /** - * Similar to [[findTightestCommonType]], if can not find the TightestCommonType, try to use - * [[findTightestCommonTypeToString]] to find the TightestCommonType. - */ - private def findTightestCommonTypeAndPromoteToString(types: Seq[DataType]): Option[DataType] = { - types.foldLeft[Option[DataType]](Some(NullType))((r, c) => r match { - case None => None - case Some(d) => - findTightestCommonTypeToString(d, c) - }) - } - /** * Find the tightest common type of a set of types by continuously applying * `findTightestCommonTypeOfTwo` on these types. @@ -157,6 +145,28 @@ object TypeCoercion { }) } + /** + * Similar to [[findWiderCommonType]], but can't promote to string. This is also similar to + * [[findTightestCommonType]], but can handle decimal types. If the wider decimal type exceeds + * system limitation, this rule will truncate the decimal type before return it. + */ + def findWiderTypeWithoutStringPromotion(types: Seq[DataType]): Option[DataType] = { + types.foldLeft[Option[DataType]](Some(NullType))((r, c) => r match { + case Some(d) => findTightestCommonTypeOfTwo(d, c).orElse((d, c) match { + case (t1: DecimalType, t2: DecimalType) => + Some(DecimalPrecision.widerDecimalType(t1, t2)) + case (t: IntegralType, d: DecimalType) => + Some(DecimalPrecision.widerDecimalType(DecimalType.forType(t), d)) + case (d: DecimalType, t: IntegralType) => + Some(DecimalPrecision.widerDecimalType(DecimalType.forType(t), d)) + case (_: FractionalType, _: DecimalType) | (_: DecimalType, _: FractionalType) => + Some(DoubleType) + case _ => None + }) + case None => None + }) + } + private def haveSameType(exprs: Seq[Expression]): Boolean = exprs.map(_.dataType).distinct.length == 1 @@ -302,23 +312,22 @@ object TypeCoercion { case p @ Equality(left @ TimestampType(), right @ StringType()) => p.makeCopy(Array(left, Cast(right, TimestampType))) - // We should cast all relative timestamp/date/string comparison into string comparisons - // This behaves as a user would expect because timestamp strings sort lexicographically. - // i.e. TimeStamp(2013-01-01 00:00 ...) < "2014" = true + // Parsing of partial dates/timestamps has been added for SPARK-8995 hence + // converting strings to dates/timestamps. case p @ BinaryComparison(left @ StringType(), right @ DateType()) => - p.makeCopy(Array(left, Cast(right, StringType))) + p.makeCopy(Array(Cast(left, DateType), right)) case p @ BinaryComparison(left @ DateType(), right @ StringType()) => - p.makeCopy(Array(Cast(left, StringType), right)) + p.makeCopy(Array(left, Cast(right, DateType))) case p @ BinaryComparison(left @ StringType(), right @ TimestampType()) => - p.makeCopy(Array(left, Cast(right, StringType))) + p.makeCopy(Array(Cast(left, TimestampType), right)) case p @ BinaryComparison(left @ TimestampType(), right @ StringType()) => - p.makeCopy(Array(Cast(left, StringType), right)) + p.makeCopy(Array(left, Cast(right, TimestampType))) // Comparisons between dates and timestamps. case p @ BinaryComparison(left @ TimestampType(), right @ DateType()) => - p.makeCopy(Array(Cast(left, StringType), Cast(right, StringType))) + p.makeCopy(Array(left, Cast(right, TimestampType))) case p @ BinaryComparison(left @ DateType(), right @ TimestampType()) => - p.makeCopy(Array(Cast(left, StringType), Cast(right, StringType))) + p.makeCopy(Array(Cast(left, TimestampType), right)) // Checking NullType case p @ BinaryComparison(left @ StringType(), right @ NullType()) => @@ -332,13 +341,13 @@ object TypeCoercion { p.makeCopy(Array(left, Cast(right, DoubleType))) case i @ In(a @ DateType(), b) if b.forall(_.dataType == StringType) => - i.makeCopy(Array(Cast(a, StringType), b)) + i.makeCopy(Array(a, b.map(Cast(_, DateType)))) case i @ In(a @ TimestampType(), b) if b.forall(_.dataType == StringType) => i.makeCopy(Array(a, b.map(Cast(_, TimestampType)))) case i @ In(a @ DateType(), b) if b.forall(_.dataType == TimestampType) => - i.makeCopy(Array(Cast(a, StringType), b.map(Cast(_, StringType)))) + i.makeCopy(Array(Cast(a, TimestampType), b)) case i @ In(a @ TimestampType(), b) if b.forall(_.dataType == DateType) => - i.makeCopy(Array(Cast(a, StringType), b.map(Cast(_, StringType)))) + i.makeCopy(Array(a, b.map(Cast(_, TimestampType)))) case Sum(e @ StringType()) => Sum(Cast(e, DoubleType)) case Average(e @ StringType()) => Average(Cast(e, DoubleType)) @@ -440,7 +449,7 @@ object TypeCoercion { case a @ CreateArray(children) if !haveSameType(children) => val types = children.map(_.dataType) - findTightestCommonTypeAndPromoteToString(types) match { + findWiderCommonType(types) match { case Some(finalDataType) => CreateArray(children.map(Cast(_, finalDataType))) case None => a } @@ -451,7 +460,7 @@ object TypeCoercion { m.keys } else { val types = m.keys.map(_.dataType) - findTightestCommonTypeAndPromoteToString(types) match { + findWiderCommonType(types) match { case Some(finalDataType) => m.keys.map(Cast(_, finalDataType)) case None => m.keys } @@ -461,7 +470,7 @@ object TypeCoercion { m.values } else { val types = m.values.map(_.dataType) - findTightestCommonTypeAndPromoteToString(types) match { + findWiderCommonType(types) match { case Some(finalDataType) => m.values.map(Cast(_, finalDataType)) case None => m.values } @@ -494,16 +503,19 @@ object TypeCoercion { case None => c } + // When finding wider type for `Greatest` and `Least`, we should handle decimal types even if + // we need to truncate, but we should not promote one side to string if the other side is + // string.g case g @ Greatest(children) if !haveSameType(children) => val types = children.map(_.dataType) - findTightestCommonType(types) match { + findWiderTypeWithoutStringPromotion(types) match { case Some(finalDataType) => Greatest(children.map(Cast(_, finalDataType))) case None => g } case l @ Least(children) if !haveSameType(children) => val types = children.map(_.dataType) - findTightestCommonType(types) match { + findWiderTypeWithoutStringPromotion(types) match { case Some(finalDataType) => Least(children.map(Cast(_, finalDataType))) case None => l } @@ -530,11 +542,14 @@ object TypeCoercion { // Decimal and Double remain the same case d: Divide if d.dataType == DoubleType => d case d: Divide if d.dataType.isInstanceOf[DecimalType] => d - case Divide(left, right) if isNumeric(left) && isNumeric(right) => + case Divide(left, right) if isNumericOrNull(left) && isNumericOrNull(right) => Divide(Cast(left, DoubleType), Cast(right, DoubleType)) } - private def isNumeric(ex: Expression): Boolean = ex.dataType.isInstanceOf[NumericType] + private def isNumericOrNull(ex: Expression): Boolean = { + // We need to handle null types in case a query contains null literals. + ex.dataType.isInstanceOf[NumericType] || ex.dataType == NullType + } } /** diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/unresolved.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/unresolved.scala index 609089a302c8..15239b99c946 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/unresolved.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/unresolved.scala @@ -49,6 +49,37 @@ case class UnresolvedRelation( override lazy val resolved = false } +/** + * An inline table that has not been resolved yet. Once resolved, it is turned by the analyzer into + * a [[org.apache.spark.sql.catalyst.plans.logical.LocalRelation]]. + * + * @param names list of column names + * @param rows expressions for the data + */ +case class UnresolvedInlineTable( + names: Seq[String], + rows: Seq[Seq[Expression]]) + extends LeafNode { + + lazy val expressionsResolved: Boolean = rows.forall(_.forall(_.resolved)) + override lazy val resolved = false + override def output: Seq[Attribute] = Nil +} + +/** + * A table-valued function, e.g. + * {{{ + * select * from range(10); + * }}} + */ +case class UnresolvedTableValuedFunction(functionName: String, functionArgs: Seq[Expression]) + extends LeafNode { + + override def output: Seq[Attribute] = Nil + + override lazy val resolved = false +} + /** * Holds the name of an attribute that has yet to be resolved. */ diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/ExternalCatalog.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/ExternalCatalog.scala index 6714846e8cbd..4371ff3d58a3 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/ExternalCatalog.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/ExternalCatalog.scala @@ -17,7 +17,7 @@ package org.apache.spark.sql.catalyst.catalog -import org.apache.spark.sql.catalyst.analysis.NoSuchDatabaseException +import org.apache.spark.sql.catalyst.analysis.{FunctionAlreadyExistsException, NoSuchDatabaseException, NoSuchFunctionException} /** @@ -38,6 +38,18 @@ abstract class ExternalCatalog { } } + protected def requireFunctionExists(db: String, funcName: String): Unit = { + if (!functionExists(db, funcName)) { + throw new NoSuchFunctionException(db = db, func = funcName) + } + } + + protected def requireFunctionNotExists(db: String, funcName: String): Unit = { + if (functionExists(db, funcName)) { + throw new FunctionAlreadyExistsException(db = db, func = funcName) + } + } + // -------------------------------------------------------------------------- // Databases // -------------------------------------------------------------------------- diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/InMemoryCatalog.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/InMemoryCatalog.scala index fb3e1b3637f2..ef5a19687ce8 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/InMemoryCatalog.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/InMemoryCatalog.scala @@ -59,18 +59,6 @@ class InMemoryCatalog(hadoopConfig: Configuration = new Configuration) extends E catalog(db).tables(table).partitions.contains(spec) } - private def requireFunctionExists(db: String, funcName: String): Unit = { - if (!functionExists(db, funcName)) { - throw new NoSuchFunctionException(db = db, func = funcName) - } - } - - private def requireFunctionNotExists(db: String, funcName: String): Unit = { - if (functionExists(db, funcName)) { - throw new FunctionAlreadyExistsException(db = db, func = funcName) - } - } - private def requireTableExists(db: String, table: String): Unit = { if (!tableExists(db, table)) { throw new NoSuchTableException(db = db, table = table) @@ -465,11 +453,8 @@ class InMemoryCatalog(hadoopConfig: Configuration = new Configuration) extends E override def createFunction(db: String, func: CatalogFunction): Unit = synchronized { requireDbExists(db) - if (functionExists(db, func.identifier.funcName)) { - throw new FunctionAlreadyExistsException(db = db, func = func.identifier.funcName) - } else { - catalog(db).functions.put(func.identifier.funcName, func) - } + requireFunctionNotExists(db, func.identifier.funcName) + catalog(db).functions.put(func.identifier.funcName, func) } override def dropFunction(db: String, funcName: String): Unit = synchronized { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/SessionCatalog.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/SessionCatalog.scala index 3a2e574a1d9a..f455cc90963f 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/SessionCatalog.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/SessionCatalog.scala @@ -246,34 +246,26 @@ class SessionCatalog( } /** - * Retrieve the metadata of an existing metastore table. - * If no database is specified, assume the table is in the current database. - * If the specified table is not found in the database then a [[NoSuchTableException]] is thrown. + * Return whether a table/view with the specified name exists. If no database is specified, check + * with current database. + */ + def tableExists(name: TableIdentifier): Boolean = synchronized { + val db = formatDatabaseName(name.database.getOrElse(currentDb)) + val table = formatTableName(name.table) + externalCatalog.tableExists(db, table) + } + + /** + * Retrieve the metadata of an existing permanent table/view. If no database is specified, + * assume the table/view is in the current database. If the specified table/view is not found + * in the database then a [[NoSuchTableException]] is thrown. */ def getTableMetadata(name: TableIdentifier): CatalogTable = { val db = formatDatabaseName(name.database.getOrElse(getCurrentDatabase)) val table = formatTableName(name.table) - val tid = TableIdentifier(table) - if (isTemporaryTable(name)) { - CatalogTable( - identifier = tid, - tableType = CatalogTableType.VIEW, - storage = CatalogStorageFormat.empty, - schema = tempTables(table).output.map { c => - CatalogColumn( - name = c.name, - dataType = c.dataType.catalogString, - nullable = c.nullable, - comment = Option(c.name) - ) - }, - properties = Map(), - viewText = None) - } else { - requireDbExists(db) - requireTableExists(TableIdentifier(table, Some(db))) - externalCatalog.getTable(db, table) - } + requireDbExists(db) + requireTableExists(TableIdentifier(table, Some(db))) + externalCatalog.getTable(db, table) } /** @@ -333,9 +325,9 @@ class SessionCatalog( new Path(new Path(dbLocation), formatTableName(tableIdent.table)).toString } - // ------------------------------------------------------------- - // | Methods that interact with temporary and metastore tables | - // ------------------------------------------------------------- + // ---------------------------------------------- + // | Methods that interact with temp views only | + // ---------------------------------------------- /** * Create a temporary table. @@ -351,6 +343,56 @@ class SessionCatalog( tempTables.put(table, tableDefinition) } + /** + * Return a temporary view exactly as it was stored. + */ + def getTempView(name: String): Option[LogicalPlan] = synchronized { + tempTables.get(formatTableName(name)) + } + + /** + * Drop a temporary view. + */ + def dropTempView(name: String): Unit = synchronized { + tempTables.remove(formatTableName(name)) + } + + // ------------------------------------------------------------- + // | Methods that interact with temporary and metastore tables | + // ------------------------------------------------------------- + + /** + * Retrieve the metadata of an existing temporary view or permanent table/view. + * + * If a database is specified in `name`, this will return the metadata of table/view in that + * database. + * If no database is specified, this will first attempt to get the metadata of a temporary view + * with the same name, then, if that does not exist, return the metadata of table/view in the + * current database. + */ + def getTempViewOrPermanentTableMetadata(name: TableIdentifier): CatalogTable = synchronized { + val table = formatTableName(name.table) + if (name.database.isDefined) { + getTableMetadata(name) + } else { + getTempView(table).map { plan => + CatalogTable( + identifier = TableIdentifier(table), + tableType = CatalogTableType.VIEW, + storage = CatalogStorageFormat.empty, + schema = plan.output.map { c => + CatalogColumn( + name = c.name, + dataType = c.dataType.catalogString, + nullable = c.nullable + ) + }, + properties = Map(), + viewText = None) + }.getOrElse(getTableMetadata(name)) + } + } + /** * Rename a table. * @@ -439,24 +481,6 @@ class SessionCatalog( } } - /** - * Return whether a table with the specified name exists. - * - * Note: If a database is explicitly specified, then this will return whether the table - * exists in that particular database instead. In that case, even if there is a temporary - * table with the same name, we will return false if the specified database does not - * contain the table. - */ - def tableExists(name: TableIdentifier): Boolean = synchronized { - val db = formatDatabaseName(name.database.getOrElse(currentDb)) - val table = formatTableName(name.table) - if (isTemporaryTable(name)) { - true - } else { - externalCatalog.tableExists(db, table) - } - } - /** * Return whether a table with the specified name is a temporary table. * @@ -495,7 +519,7 @@ class SessionCatalog( // If the database is defined, this is definitely not a temp table. // If the database is not defined, there is a good chance this is a temp table. if (name.database.isEmpty) { - tempTables.get(name.table).foreach(_.refresh()) + tempTables.get(formatTableName(name.table)).foreach(_.refresh()) } } @@ -507,14 +531,6 @@ class SessionCatalog( tempTables.clear() } - /** - * Return a temporary table exactly as it was stored. - * For testing only. - */ - private[catalog] def getTempTable(name: String): Option[LogicalPlan] = synchronized { - tempTables.get(name) - } - // ---------------------------------------------------------------------------- // Partitions // ---------------------------------------------------------------------------- @@ -535,11 +551,11 @@ class SessionCatalog( tableName: TableIdentifier, parts: Seq[CatalogTablePartition], ignoreIfExists: Boolean): Unit = { - requireExactMatchedPartitionSpec(parts.map(_.spec), getTableMetadata(tableName)) val db = formatDatabaseName(tableName.database.getOrElse(getCurrentDatabase)) val table = formatTableName(tableName.table) requireDbExists(db) requireTableExists(TableIdentifier(table, Option(db))) + requireExactMatchedPartitionSpec(parts.map(_.spec), getTableMetadata(tableName)) externalCatalog.createPartitions(db, table, parts, ignoreIfExists) } @@ -551,11 +567,11 @@ class SessionCatalog( tableName: TableIdentifier, specs: Seq[TablePartitionSpec], ignoreIfNotExists: Boolean): Unit = { - requirePartialMatchedPartitionSpec(specs, getTableMetadata(tableName)) val db = formatDatabaseName(tableName.database.getOrElse(getCurrentDatabase)) val table = formatTableName(tableName.table) requireDbExists(db) requireTableExists(TableIdentifier(table, Option(db))) + requirePartialMatchedPartitionSpec(specs, getTableMetadata(tableName)) externalCatalog.dropPartitions(db, table, specs, ignoreIfNotExists) } @@ -570,12 +586,12 @@ class SessionCatalog( specs: Seq[TablePartitionSpec], newSpecs: Seq[TablePartitionSpec]): Unit = { val tableMetadata = getTableMetadata(tableName) - requireExactMatchedPartitionSpec(specs, tableMetadata) - requireExactMatchedPartitionSpec(newSpecs, tableMetadata) val db = formatDatabaseName(tableName.database.getOrElse(getCurrentDatabase)) val table = formatTableName(tableName.table) requireDbExists(db) requireTableExists(TableIdentifier(table, Option(db))) + requireExactMatchedPartitionSpec(specs, tableMetadata) + requireExactMatchedPartitionSpec(newSpecs, tableMetadata) externalCatalog.renamePartitions(db, table, specs, newSpecs) } @@ -589,11 +605,11 @@ class SessionCatalog( * this becomes a no-op. */ def alterPartitions(tableName: TableIdentifier, parts: Seq[CatalogTablePartition]): Unit = { - requireExactMatchedPartitionSpec(parts.map(_.spec), getTableMetadata(tableName)) val db = formatDatabaseName(tableName.database.getOrElse(getCurrentDatabase)) val table = formatTableName(tableName.table) requireDbExists(db) requireTableExists(TableIdentifier(table, Option(db))) + requireExactMatchedPartitionSpec(parts.map(_.spec), getTableMetadata(tableName)) externalCatalog.alterPartitions(db, table, parts) } @@ -602,11 +618,11 @@ class SessionCatalog( * If no database is specified, assume the table is in the current database. */ def getPartition(tableName: TableIdentifier, spec: TablePartitionSpec): CatalogTablePartition = { - requireExactMatchedPartitionSpec(Seq(spec), getTableMetadata(tableName)) val db = formatDatabaseName(tableName.database.getOrElse(getCurrentDatabase)) val table = formatTableName(tableName.table) requireDbExists(db) requireTableExists(TableIdentifier(table, Option(db))) + requireExactMatchedPartitionSpec(Seq(spec), getTableMetadata(tableName)) externalCatalog.getPartition(db, table, spec) } @@ -746,7 +762,7 @@ class SessionCatalog( * * This performs reflection to decide what type of [[Expression]] to return in the builder. */ - private[sql] def makeFunctionBuilder(name: String, functionClassName: String): FunctionBuilder = { + def makeFunctionBuilder(name: String, functionClassName: String): FunctionBuilder = { // TODO: at least support UDAFs here throw new UnsupportedOperationException("Use sqlContext.udf.register(...) instead.") } @@ -790,7 +806,7 @@ class SessionCatalog( /** * Look up the [[ExpressionInfo]] associated with the specified function, assuming it exists. */ - private[spark] def lookupFunctionInfo(name: FunctionIdentifier): ExpressionInfo = synchronized { + def lookupFunctionInfo(name: FunctionIdentifier): ExpressionInfo = synchronized { // TODO: just make function registry take in FunctionIdentifier instead of duplicating this val database = name.database.orElse(Some(currentDb)).map(formatDatabaseName) val qualifiedName = name.copy(database = database) @@ -902,7 +918,7 @@ class SessionCatalog( * * This is mainly used for tests. */ - private[sql] def reset(): Unit = synchronized { + def reset(): Unit = synchronized { setCurrentDatabase(DEFAULT_DATABASE) listDatabases().filter(_ != DEFAULT_DATABASE).foreach { db => dropDatabase(db, ignoreIfNotExists = false, cascade = true) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/interface.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/interface.scala index 6197acab3378..e7430b030901 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/interface.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/interface.scala @@ -103,10 +103,12 @@ case class CatalogColumn( * * @param spec partition spec values indexed by column name * @param storage storage format of the partition + * @param parameters some parameters for the partition, for example, stats. */ case class CatalogTablePartition( spec: CatalogTypes.TablePartitionSpec, - storage: CatalogStorageFormat) + storage: CatalogStorageFormat, + parameters: Map[String, String] = Map.empty) /** @@ -203,7 +205,6 @@ case class CatalogTableType private(name: String) object CatalogTableType { val EXTERNAL = new CatalogTableType("EXTERNAL") val MANAGED = new CatalogTableType("MANAGED") - val INDEX = new CatalogTableType("INDEX") val VIEW = new CatalogTableType("VIEW") } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/RowEncoder.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/RowEncoder.scala index 67fca153b551..2a6fcd03a26b 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/RowEncoder.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/RowEncoder.scala @@ -206,6 +206,7 @@ object RowEncoder { case _: ArrayType => ObjectType(classOf[scala.collection.Seq[_]]) case _: MapType => ObjectType(classOf[scala.collection.Map[_, _]]) case _: StructType => ObjectType(classOf[Row]) + case p: PythonUserDefinedType => externalDataTypeFor(p.sqlType) case udt: UserDefinedType[_] => ObjectType(udt.userClass) } @@ -220,9 +221,15 @@ object RowEncoder { CreateExternalRow(fields, schema) } - private def deserializerFor(input: Expression): Expression = input.dataType match { + private def deserializerFor(input: Expression): Expression = { + deserializerFor(input, input.dataType) + } + + private def deserializerFor(input: Expression, dataType: DataType): Expression = dataType match { case dt if ScalaReflection.isNativeType(dt) => input + case p: PythonUserDefinedType => deserializerFor(input, p.sqlType) + case udt: UserDefinedType[_] => val annotation = udt.userClass.getAnnotation(classOf[SQLUserDefinedType]) val udtClass: Class[_] = if (annotation != null) { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/package.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/package.scala index 03708fb7afd4..59f7969e5614 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/package.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/package.scala @@ -26,7 +26,7 @@ package object encoders { * references from a specific schema.) This requirement allows us to preserve whether a given * object type is being bound by name or by ordinal when doing resolution. */ - private[sql] def encoderFor[A : Encoder]: ExpressionEncoder[A] = implicitly[Encoder[A]] match { + def encoderFor[A : Encoder]: ExpressionEncoder[A] = implicitly[Encoder[A]] match { case e: ExpressionEncoder[A] => e.assertUnresolved() e diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/errors/package.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/errors/package.scala index 0420b4b5387c..0d45f371fa0c 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/errors/package.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/errors/package.scala @@ -17,7 +17,10 @@ package org.apache.spark.sql.catalyst +import scala.util.control.NonFatal + import org.apache.spark.sql.catalyst.trees.TreeNode +import org.apache.spark.SparkException /** * Functions for attaching and retrieving trees that are associated with errors. @@ -47,7 +50,10 @@ package object errors { */ def attachTree[TreeType <: TreeNode[_], A](tree: TreeType, msg: String = "")(f: => A): A = { try f catch { - case e: Exception => throw new TreeNodeException(tree, msg, e) + // SPARK-16748: We do not want SparkExceptions from job failures in the planning phase + // to create TreeNodeException. Hence, wrap exception only if it is not SparkException. + case NonFatal(e) if !e.isInstanceOf[SparkException] => + throw new TreeNodeException(tree, msg, e) } } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala index c452765af2dd..70fff5195625 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala @@ -416,7 +416,7 @@ case class Cast(child: Expression, dataType: DataType) extends UnaryExpression w } private[this] def cast(from: DataType, to: DataType): Any => Any = to match { - case dt if dt == child.dataType => identity[Any] + case dt if dt == from => identity[Any] case StringType => castToString(from) case BinaryType => castToBinary(from) case DateType => castToDate(from) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala index 1f37b68846ae..7abbbe257d83 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala @@ -526,7 +526,7 @@ abstract class BinaryOperator extends BinaryExpression with ExpectsInputTypes { } -private[sql] object BinaryOperator { +object BinaryOperator { def unapply(e: BinaryOperator): Option[(Expression, Expression)] = Some((e.left, e.right)) } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/MonotonicallyIncreasingID.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/MonotonicallyIncreasingID.scala index 75c6bb2d84df..5b4922e0cf2b 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/MonotonicallyIncreasingID.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/MonotonicallyIncreasingID.scala @@ -40,7 +40,7 @@ import org.apache.spark.sql.types.{DataType, LongType} represent the record number within each partition. The assumption is that the data frame has less than 1 billion partitions, and each partition has less than 8 billion records.""", extended = "> SELECT _FUNC_();\n 0") -private[sql] case class MonotonicallyIncreasingID() extends LeafExpression with Nondeterministic { +case class MonotonicallyIncreasingID() extends LeafExpression with Nondeterministic { /** * Record ID within each partition. By being transient, count's value is reset to 0 every time diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Projection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Projection.scala index c8d18667f7c4..10ae4021d47a 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Projection.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Projection.scala @@ -14,12 +14,30 @@ * See the License for the specific language governing permissions and * limitations under the License. */ +/* + * Changes for SnappyData data platform. + * + * Portions Copyright (c) 2016 SnappyData, Inc. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you + * may not use this file except in compliance with the License. You + * may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or + * implied. See the License for the specific language governing + * permissions and limitations under the License. See accompanying + * LICENSE file. + */ package org.apache.spark.sql.catalyst.expressions import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions.codegen.{GenerateSafeProjection, GenerateUnsafeProjection} -import org.apache.spark.sql.types.{DataType, StructType} +import org.apache.spark.sql.types._ /** * A [[Projection]] that is calculated by calling the `eval` of each of the specified expressions. @@ -67,13 +85,45 @@ case class InterpretedMutableProjection(expressions: Seq[Expression]) extends Mu case n: Nondeterministic => n.setInitialValues() case _ => }) - + private var targetUnsafe = false + type UnsafeSetter = (UnsafeRow, Any) => Unit + private var setters: Array[UnsafeSetter] = _ private[this] val exprArray = expressions.toArray private[this] var mutableRow: MutableRow = new GenericMutableRow(exprArray.length) def currentValue: InternalRow = mutableRow + override def target(row: MutableRow): MutableProjection = { mutableRow = row + targetUnsafe = row match { + case _: UnsafeRow => + if (setters == null) { + setters = Array.ofDim[UnsafeSetter](exprArray.length) + for (i <- exprArray.indices) { + setters(i) = exprArray(i).dataType match { + case IntegerType => (target: UnsafeRow, value: Any) => + target.setInt(i, value.asInstanceOf[Int]) + case LongType => (target: UnsafeRow, value: Any) => + target.setLong(i, value.asInstanceOf[Long]) + case DoubleType => (target: UnsafeRow, value: Any) => + target.setDouble(i, value.asInstanceOf[Double]) + case FloatType => (target: UnsafeRow, value: Any) => + target.setFloat(i, value.asInstanceOf[Float]) + case NullType => (target: UnsafeRow, value: Any) => + target.setNullAt(i) + case BooleanType => (target: UnsafeRow, value: Any) => + target.setBoolean(i, value.asInstanceOf[Boolean]) + case ByteType => (target: UnsafeRow, value: Any) => + target.setByte(i, value.asInstanceOf[Byte]) + case ShortType => (target: UnsafeRow, value: Any) => + target.setShort(i, value.asInstanceOf[Short]) + } + } + } + true + case _ => false + } + this } @@ -86,7 +136,11 @@ case class InterpretedMutableProjection(expressions: Seq[Expression]) extends Mu } i = 0 while (i < exprArray.length) { - mutableRow(i) = buffer(i) + if (targetUnsafe) { + setters(i)(mutableRow.asInstanceOf[UnsafeRow], buffer(i)) + } else { + mutableRow(i) = buffer(i) + } i += 1 } mutableRow diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ScalaUDF.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ScalaUDF.scala index 21390644bc0b..6cfdea9fdf9c 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ScalaUDF.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ScalaUDF.scala @@ -17,6 +17,7 @@ package org.apache.spark.sql.catalyst.expressions +import org.apache.spark.SparkException import org.apache.spark.sql.catalyst.{CatalystTypeConverters, InternalRow} import org.apache.spark.sql.catalyst.expressions.codegen._ import org.apache.spark.sql.types.DataType @@ -994,20 +995,15 @@ case class ScalaUDF( ctx: CodegenContext, ev: ExprCode): ExprCode = { - ctx.references += this - - val scalaUDFClassName = classOf[ScalaUDF].getName + val scalaUDF = ctx.addReferenceObj("scalaUDF", this) val converterClassName = classOf[Any => Any].getName val typeConvertersClassName = CatalystTypeConverters.getClass.getName + ".MODULE$" - val expressionClassName = classOf[Expression].getName // Generate codes used to convert the returned value of user-defined functions to Catalyst type val catalystConverterTerm = ctx.freshName("catalystConverter") - val catalystConverterTermIdx = ctx.references.size - 1 ctx.addMutableState(converterClassName, catalystConverterTerm, s"this.$catalystConverterTerm = ($converterClassName)$typeConvertersClassName" + - s".createToCatalystConverter((($scalaUDFClassName)references" + - s"[$catalystConverterTermIdx]).dataType());") + s".createToCatalystConverter($scalaUDF.dataType());") val resultTerm = ctx.freshName("result") @@ -1019,10 +1015,8 @@ case class ScalaUDF( val funcClassName = s"scala.Function${children.size}" val funcTerm = ctx.freshName("udf") - val funcExpressionIdx = ctx.references.size - 1 ctx.addMutableState(funcClassName, funcTerm, - s"this.$funcTerm = ($funcClassName)((($scalaUDFClassName)references" + - s"[$funcExpressionIdx]).userDefinedFunc());") + s"this.$funcTerm = ($funcClassName)$scalaUDF.userDefinedFunc();") // codegen for children expressions val evals = children.map(_.genCode(ctx)) @@ -1039,9 +1033,16 @@ case class ScalaUDF( (convert, argTerm) }.unzip - val callFunc = s"${ctx.boxedType(dataType)} $resultTerm = " + - s"(${ctx.boxedType(dataType)})${catalystConverterTerm}" + - s".apply($funcTerm.apply(${funcArguments.mkString(", ")}));" + val getFuncResult = s"$funcTerm.apply(${funcArguments.mkString(", ")})" + val callFunc = + s""" + ${ctx.boxedType(dataType)} $resultTerm = null; + try { + $resultTerm = (${ctx.boxedType(dataType)})$catalystConverterTerm.apply($getFuncResult); + } catch (Exception e) { + throw new org.apache.spark.SparkException($scalaUDF.udfErrorMessage(), e); + } + """ ev.copy(code = s""" $evalCode @@ -1057,5 +1058,20 @@ case class ScalaUDF( private[this] val converter = CatalystTypeConverters.createToCatalystConverter(dataType) - override def eval(input: InternalRow): Any = converter(f(input)) + lazy val udfErrorMessage = { + val funcCls = function.getClass.getSimpleName + val inputTypes = children.map(_.dataType.simpleString).mkString(", ") + s"Failed to execute user defined function($funcCls: ($inputTypes) => ${dataType.simpleString})" + } + + override def eval(input: InternalRow): Any = { + val result = try { + f(input) + } catch { + case e: Exception => + throw new SparkException(udfErrorMessage, e) + } + + converter(result) + } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SparkPartitionID.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SparkPartitionID.scala index 71af59a7a852..1f675d5b0727 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SparkPartitionID.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SparkPartitionID.scala @@ -28,7 +28,7 @@ import org.apache.spark.sql.types.{DataType, IntegerType} @ExpressionDescription( usage = "_FUNC_() - Returns the current partition id of the Spark task", extended = "> SELECT _FUNC_();\n 0") -private[sql] case class SparkPartitionID() extends LeafExpression with Nondeterministic { +case class SparkPartitionID() extends LeafExpression with Nondeterministic { override def nullable: Boolean = false diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/TimeWindow.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/TimeWindow.scala index 66c4bf29ea4b..7ff61ee47945 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/TimeWindow.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/TimeWindow.scala @@ -45,12 +45,12 @@ case class TimeWindow( slideDuration: Expression, startTime: Expression) = { this(timeColumn, TimeWindow.parseExpression(windowDuration), - TimeWindow.parseExpression(windowDuration), TimeWindow.parseExpression(startTime)) + TimeWindow.parseExpression(slideDuration), TimeWindow.parseExpression(startTime)) } def this(timeColumn: Expression, windowDuration: Expression, slideDuration: Expression) = { this(timeColumn, TimeWindow.parseExpression(windowDuration), - TimeWindow.parseExpression(windowDuration), 0) + TimeWindow.parseExpression(slideDuration), 0) } def this(timeColumn: Expression, windowDuration: Expression) = { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Average.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Average.scala index ff7077484783..70d59f9dc7d7 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Average.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Average.scala @@ -53,10 +53,16 @@ case class Average(child: Expression) extends DeclarativeAggregate { } private lazy val sum = AttributeReference("sum", sumDataType)() - private lazy val count = AttributeReference("count", LongType)() + private lazy val count = AttributeReference("count", LongType, nullable = false)() override lazy val aggBufferAttributes = sum :: count :: Nil + override lazy val aggBufferAttributesForGroup: Seq[AttributeReference] = { + if (child.nullable) aggBufferAttributes + else sum.copy(nullable = false)(sum.exprId, sum.qualifier, + sum.isGenerated) :: count :: Nil + } + override lazy val initialValues = Seq( /* sum = */ Cast(Literal(0), sumDataType), /* count = */ Literal(0L) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Sum.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Sum.scala index ad217f25b5a2..aecd58c1fb54 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Sum.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Sum.scala @@ -53,10 +53,22 @@ case class Sum(child: Expression) extends DeclarativeAggregate { override lazy val aggBufferAttributes = sum :: Nil + override lazy val aggBufferAttributesForGroup: Seq[AttributeReference] = { + if (child.nullable) aggBufferAttributes + else sum.copy(nullable = false)(sum.exprId, sum.qualifier, + sum.isGenerated) :: Nil + } + override lazy val initialValues: Seq[Expression] = Seq( /* sum = */ Literal.create(null, sumDataType) ) + override lazy val initialValuesForGroup: Seq[Expression] = Seq( + /* sum = */ + if (child.nullable) Literal.create(null, sumDataType) + else Cast(Literal(0), sumDataType) + ) + override lazy val updateExpressions: Seq[Expression] = { if (child.nullable) { Seq( diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/collect.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/collect.scala index ac2cefaddcf5..78a388d20630 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/collect.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/collect.scala @@ -54,6 +54,10 @@ abstract class Collect extends ImperativeAggregate { override def inputAggBufferAttributes: Seq[AttributeReference] = Nil + // Both `CollectList` and `CollectSet` are non-deterministic since their results depend on the + // actual order of input rows. + override def deterministic: Boolean = false + protected[this] val buffer: Growable[Any] with Iterable[Any] override def initialize(b: MutableRow): Unit = { @@ -61,7 +65,12 @@ abstract class Collect extends ImperativeAggregate { } override def update(b: MutableRow, input: InternalRow): Unit = { - buffer += child.eval(input) + // Do not allow null values. We follow the semantics of Hive's collect_list/collect_set here. + // See: org.apache.hadoop.hive.ql.udf.generic.GenericUDAFMkCollectionEvaluator + val value = child.eval(input) + if (value != null) { + buffer += value + } } override def merge(buffer: MutableRow, input: InternalRow): Unit = { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/interfaces.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/interfaces.scala index 504cea52797d..5fd17465b803 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/interfaces.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/interfaces.scala @@ -14,6 +14,24 @@ * See the License for the specific language governing permissions and * limitations under the License. */ +/* + * Changes for SnappyData data platform. + * + * Portions Copyright (c) 2016 SnappyData, Inc. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you + * may not use this file except in compliance with the License. You + * may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or + * implied. See the License for the specific language governing + * permissions and limitations under the License. See accompanying + * LICENSE file. + */ package org.apache.spark.sql.catalyst.expressions.aggregate @@ -24,14 +42,14 @@ import org.apache.spark.sql.catalyst.expressions.codegen.CodegenFallback import org.apache.spark.sql.types._ /** The mode of an [[AggregateFunction]]. */ -private[sql] sealed trait AggregateMode +sealed trait AggregateMode /** * An [[AggregateFunction]] with [[Partial]] mode is used for partial aggregation. * This function updates the given aggregation buffer with the original input of this * function. When it has processed all input rows, the aggregation buffer is returned. */ -private[sql] case object Partial extends AggregateMode +case object Partial extends AggregateMode /** * An [[AggregateFunction]] with [[PartialMerge]] mode is used to merge aggregation buffers @@ -39,7 +57,7 @@ private[sql] case object Partial extends AggregateMode * This function updates the given aggregation buffer by merging multiple aggregation buffers. * When it has processed all input rows, the aggregation buffer is returned. */ -private[sql] case object PartialMerge extends AggregateMode +case object PartialMerge extends AggregateMode /** * An [[AggregateFunction]] with [[Final]] mode is used to merge aggregation buffers @@ -47,7 +65,7 @@ private[sql] case object PartialMerge extends AggregateMode * This function updates the given aggregation buffer by merging multiple aggregation buffers. * When it has processed all input rows, the final result of this function is returned. */ -private[sql] case object Final extends AggregateMode +case object Final extends AggregateMode /** * An [[AggregateFunction]] with [[Complete]] mode is used to evaluate this function directly @@ -55,13 +73,13 @@ private[sql] case object Final extends AggregateMode * This function updates the given aggregation buffer with the original input of this * function. When it has processed all input rows, the final result of this function is returned. */ -private[sql] case object Complete extends AggregateMode +case object Complete extends AggregateMode /** * A place holder expressions used in code-gen, it does not change the corresponding value * in the row. */ -private[sql] case object NoOp extends Expression with Unevaluable { +case object NoOp extends Expression with Unevaluable { override def nullable: Boolean = true override def dataType: DataType = NullType override def children: Seq[Expression] = Nil @@ -84,7 +102,7 @@ object AggregateExpression { * A container for an [[AggregateFunction]] with its [[AggregateMode]] and a field * (`isDistinct`) indicating if DISTINCT keyword is specified for this function. */ -private[sql] case class AggregateExpression( +case class AggregateExpression( aggregateFunction: AggregateFunction, mode: AggregateMode, isDistinct: Boolean, @@ -166,6 +184,9 @@ sealed abstract class AggregateFunction extends Expression with ImplicitCastInpu /** Attributes of fields in aggBufferSchema. */ def aggBufferAttributes: Seq[AttributeReference] + /** Attributes of fields in aggBufferSchema used for group by. */ + def aggBufferAttributesForGroup: Seq[AttributeReference] = aggBufferAttributes + /** * Attributes of fields in input aggregation buffers (immutable aggregation buffers that are * merged with mutable aggregation buffers in the merge() function or merge expressions). @@ -349,6 +370,11 @@ abstract class DeclarativeAggregate */ val initialValues: Seq[Expression] + /** + * Expressions for initializing empty aggregation buffers for group by. + */ + def initialValuesForGroup: Seq[Expression] = initialValues + /** * Expressions for updating the mutable aggregation buffer based on an input row. */ @@ -371,8 +397,16 @@ abstract class DeclarativeAggregate /** An expression-based aggregate's bufferSchema is derived from bufferAttributes. */ final override def aggBufferSchema: StructType = StructType.fromAttributes(aggBufferAttributes) - final lazy val inputAggBufferAttributes: Seq[AttributeReference] = - aggBufferAttributes.map(_.newInstance()) + lazy val inputAggBufferbaseExprID = NamedExpression.allocateExprID(aggBufferAttributes.length) + + /* final lazy val inputAggBufferAttributes: Seq[AttributeReference] = + aggBufferAttributes.map(_.newInstance()) */ + + @transient final lazy val inputAggBufferAttributes: Seq[AttributeReference] = + aggBufferAttributes.zipWithIndex.map { + case ( attr, i) => attr.withExprId( ExprId( inputAggBufferbaseExprID.id + i, + inputAggBufferbaseExprID.jvmId)) + } /** * A helper class for representing an attribute used in merging two diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala index 91ffac0ba2a6..01c5d8217074 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala @@ -57,7 +57,7 @@ case class UnaryMinus(child: Expression) extends UnaryExpression } } - override def sql: String = s"(-${child.sql})" + override def sql: String = s"(- ${child.sql})" } @ExpressionDescription( @@ -75,7 +75,7 @@ case class UnaryPositive(child: Expression) protected override def nullSafeEval(input: Any): Any = input - override def sql: String = s"(+${child.sql})" + override def sql: String = s"(+ ${child.sql})" } /** @@ -125,7 +125,7 @@ abstract class BinaryArithmetic extends BinaryOperator { } } -private[sql] object BinaryArithmetic { +object BinaryArithmetic { def unapply(e: BinaryArithmetic): Option[(Expression, Expression)] = Some((e.left, e.right)) } @@ -309,7 +309,11 @@ case class Remainder(left: Expression, right: Expression) if (input1 == null) { null } else { - integral.rem(input1, input2) + input1 match { + case d: Double => d % input2.asInstanceOf[java.lang.Double] + case f: Float => f % input2.asInstanceOf[java.lang.Float] + case _ => integral.rem(input1, input2) + } } } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala index 16fb1f683710..929f2da07531 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala @@ -23,6 +23,7 @@ import java.util.{Map => JavaMap} import scala.collection.JavaConverters._ import scala.collection.mutable import scala.collection.mutable.ArrayBuffer +import scala.util.control.NonFatal import com.google.common.cache.{CacheBuilder, CacheLoader} import org.codehaus.janino.{ByteArrayClassLoader, ClassBodyEvaluator, SimpleCompiler} @@ -584,15 +585,18 @@ class CodegenContext { * @param expressions the codes to evaluate expressions. */ def splitExpressions(row: String, expressions: Seq[String]): String = { - if (row == null) { + if (row == null || currentVars != null) { // Cannot split these expressions because they are not created from a row object. return expressions.mkString("\n") } val blocks = new ArrayBuffer[String]() val blockBuilder = new StringBuilder() for (code <- expressions) { - // We can't know how many byte code will be generated, so use the number of bytes as limit - if (blockBuilder.length > 64 * 1000) { + // We can't know how many bytecode will be generated, so use the length of source code + // as metric. A method should not go beyond 8K, otherwise it will not be JITted, should + // also not be too small, or it will have many function calls (for wide table), see the + // results in BenchmarkWideTable. + if (blockBuilder.length > 1024) { blocks.append(blockBuilder.toString()) blockBuilder.clear() } @@ -911,14 +915,19 @@ object CodeGenerator extends Logging { codeAttrField.setAccessible(true) classes.foreach { case (_, classBytes) => CodegenMetrics.METRIC_GENERATED_CLASS_BYTECODE_SIZE.update(classBytes.length) - val cf = new ClassFile(new ByteArrayInputStream(classBytes)) - cf.methodInfos.asScala.foreach { method => - method.getAttributes().foreach { a => - if (a.getClass.getName == codeAttr.getName) { - CodegenMetrics.METRIC_GENERATED_METHOD_BYTECODE_SIZE.update( - codeAttrField.get(a).asInstanceOf[Array[Byte]].length) + try { + val cf = new ClassFile(new ByteArrayInputStream(classBytes)) + cf.methodInfos.asScala.foreach { method => + method.getAttributes().foreach { a => + if (a.getClass.getName == codeAttr.getName) { + CodegenMetrics.METRIC_GENERATED_METHOD_BYTECODE_SIZE.update( + codeAttrField.get(a).asInstanceOf[Array[Byte]].length) + } } } + } catch { + case NonFatal(e) => + logWarning("Error calculating stats of compiled class.", e) } } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeCreator.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeCreator.scala index 0ca715f42472..09e22aaf3e3d 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeCreator.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeCreator.scala @@ -84,8 +84,8 @@ case class CreateArray(children: Seq[Expression]) extends Expression { @ExpressionDescription( usage = "_FUNC_(key0, value0, key1, value1...) - Creates a map with the given key/value pairs.") case class CreateMap(children: Seq[Expression]) extends Expression { - private[sql] lazy val keys = children.indices.filter(_ % 2 == 0).map(children) - private[sql] lazy val values = children.indices.filter(_ % 2 != 0).map(children) + lazy val keys = children.indices.filter(_ % 2 == 0).map(children) + lazy val values = children.indices.filter(_ % 2 != 0).map(children) override def foldable: Boolean = children.forall(_.foldable) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeExtractors.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeExtractors.scala index 3b4468f55ca7..abb5594bfa7f 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeExtractors.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeExtractors.scala @@ -106,7 +106,7 @@ trait ExtractValue extends Expression case class GetStructField(child: Expression, ordinal: Int, name: Option[String] = None) extends UnaryExpression with ExtractValue { - private[sql] lazy val childSchema = child.dataType.asInstanceOf[StructType] + lazy val childSchema = child.dataType.asInstanceOf[StructType] override def dataType: DataType = childSchema(ordinal).dataType override def nullable: Boolean = child.nullable || childSchema(ordinal).nullable diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/conditionalExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/conditionalExpressions.scala index e97e08947a50..f9499cf78569 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/conditionalExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/conditionalExpressions.scala @@ -126,7 +126,8 @@ abstract class CaseWhenBase( override def eval(input: InternalRow): Any = { var i = 0 - while (i < branches.size) { + val size = branches.size + while (i < size) { if (java.lang.Boolean.TRUE.equals(branches(i)._1.eval(input))) { return branches(i)._2.eval(input) } @@ -299,7 +300,7 @@ case class Least(children: Seq[Expression]) extends Expression { } else if (children.map(_.dataType).distinct.count(_ != NullType) > 1) { TypeCheckResult.TypeCheckFailure( s"The expressions should all have the same type," + - s" got LEAST (${children.map(_.dataType)}).") + s" got LEAST(${children.map(_.dataType.simpleString).mkString(", ")}).") } else { TypeUtils.checkForOrderingExpr(dataType, "function " + prettyName) } @@ -359,7 +360,7 @@ case class Greatest(children: Seq[Expression]) extends Expression { } else if (children.map(_.dataType).distinct.count(_ != NullType) > 1) { TypeCheckResult.TypeCheckFailure( s"The expressions should all have the same type," + - s" got GREATEST (${children.map(_.dataType)}).") + s" got GREATEST(${children.map(_.dataType.simpleString).mkString(", ")}).") } else { TypeUtils.checkForOrderingExpr(dataType, "function " + prettyName) } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/literals.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/literals.scala index 95ed68fbb052..41e3952f0e25 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/literals.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/literals.scala @@ -163,8 +163,7 @@ object DecimalLiteral { /** * In order to do type checking, use Literal.create() instead of constructor */ -case class Literal protected (value: Any, dataType: DataType) - extends LeafExpression with CodegenFallback { +case class Literal (value: Any, dataType: DataType) extends LeafExpression with CodegenFallback { override def foldable: Boolean = true override def nullable: Boolean = value == null @@ -246,15 +245,28 @@ case class Literal protected (value: Any, dataType: DataType) case (_, NullType | _: ArrayType | _: MapType | _: StructType) if value == null => "NULL" case _ if value == null => s"CAST(NULL AS ${dataType.sql})" case (v: UTF8String, StringType) => - // Escapes all backslashes and double quotes. - "\"" + v.toString.replace("\\", "\\\\").replace("\"", "\\\"") + "\"" + // Escapes all backslashes and single quotes. + "'" + v.toString.replace("\\", "\\\\").replace("'", "\\'") + "'" case (v: Byte, ByteType) => v + "Y" case (v: Short, ShortType) => v + "S" case (v: Long, LongType) => v + "L" // Float type doesn't have a suffix - case (v: Float, FloatType) => s"CAST($v AS ${FloatType.sql})" - case (v: Double, DoubleType) => v + "D" - case (v: Decimal, t: DecimalType) => s"CAST($v AS ${t.sql})" + case (v: Float, FloatType) => + val castedValue = v match { + case _ if v.isNaN => "'NaN'" + case Float.PositiveInfinity => "'Infinity'" + case Float.NegativeInfinity => "'-Infinity'" + case _ => v + } + s"CAST($castedValue AS ${FloatType.sql})" + case (v: Double, DoubleType) => + v match { + case _ if v.isNaN => s"CAST('NaN' AS ${DoubleType.sql})" + case Double.PositiveInfinity => s"CAST('Infinity' AS ${DoubleType.sql})" + case Double.NegativeInfinity => s"CAST('-Infinity' AS ${DoubleType.sql})" + case _ => v + "D" + } + case (v: Decimal, t: DecimalType) => v + "BD" case (v: Int, DateType) => s"DATE '${DateTimeUtils.toJavaDate(v)}'" case (v: Long, TimestampType) => s"TIMESTAMP('${DateTimeUtils.toJavaTimestamp(v)}')" case _ => value.toString diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/misc.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/misc.scala index d2c94ec1df4d..92f8fb85fc0e 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/misc.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/misc.scala @@ -477,10 +477,13 @@ case class PrintToStderr(child: Expression) extends UnaryExpression { protected override def nullSafeEval(input: Any): Any = input + private val outputPrefix = s"Result of ${child.simpleString} is " + override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { + val outputPrefixField = ctx.addReferenceObj("outputPrefix", outputPrefix) nullSafeCodeGen(ctx, ev, c => s""" - | System.err.println("Result of ${child.simpleString} is " + $c); + | System.err.println($outputPrefixField + $c); | ${ev.value} = $c; """.stripMargin) } @@ -501,10 +504,12 @@ case class AssertTrue(child: Expression) extends UnaryExpression with ImplicitCa override def prettyName: String = "assert_true" + private val errMsg = s"'${child.simpleString}' is not true!" + override def eval(input: InternalRow) : Any = { val v = child.eval(input) if (v == null || java.lang.Boolean.FALSE.equals(v)) { - throw new RuntimeException(s"'${child.simpleString}' is not true!") + throw new RuntimeException(errMsg) } else { null } @@ -512,9 +517,10 @@ case class AssertTrue(child: Expression) extends UnaryExpression with ImplicitCa override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { val eval = child.genCode(ctx) + val errMsgField = ctx.addReferenceObj("errMsg", errMsg) ExprCode(code = s"""${eval.code} |if (${eval.isNull} || !${eval.value}) { - | throw new RuntimeException("'${child.simpleString}' is not true."); + | throw new RuntimeException($errMsgField); |}""".stripMargin, isNull = "true", value = "null") } @@ -554,7 +560,7 @@ object XxHash64Function extends InterpretedHashFunction { @ExpressionDescription( usage = "_FUNC_() - Returns the current database.", extended = "> SELECT _FUNC_()") -private[sql] case class CurrentDatabase() extends LeafExpression with Unevaluable { +case class CurrentDatabase() extends LeafExpression with Unevaluable { override def dataType: DataType = StringType override def foldable: Boolean = true override def nullable: Boolean = false diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/namedExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/namedExpressions.scala index 306a99d5a37b..aeb7a2b0a9eb 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/namedExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/namedExpressions.scala @@ -14,6 +14,24 @@ * See the License for the specific language governing permissions and * limitations under the License. */ +/* + * Changes for SnappyData data platform. + * + * Portions Copyright (c) 2016 SnappyData, Inc. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you + * may not use this file except in compliance with the License. You + * may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or + * implied. See the License for the specific language governing + * permissions and limitations under the License. See accompanying + * LICENSE file. + */ package org.apache.spark.sql.catalyst.expressions @@ -30,6 +48,8 @@ object NamedExpression { private[expressions] val jvmId = UUID.randomUUID() def newExprId: ExprId = ExprId(curId.getAndIncrement(), jvmId) def unapply(expr: NamedExpression): Option[(String, DataType)] = Some(expr.name, expr.dataType) + def allocateExprID(quota: Int): ExprId = ExprId(curId.getAndAdd(quota), jvmId) + } /** diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala index ea4dee174e74..691edd5c2e7b 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala @@ -232,27 +232,47 @@ case class NewInstance( override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { val javaType = ctx.javaType(dataType) - val argGen = arguments.map(_.genCode(ctx)) - val argString = argGen.map(_.value).mkString(", ") + val argIsNulls = ctx.freshName("argIsNulls") + ctx.addMutableState("boolean[]", argIsNulls, + s"$argIsNulls = new boolean[${arguments.size}];") + val argValues = arguments.zipWithIndex.map { case (e, i) => + val argValue = ctx.freshName("argValue") + ctx.addMutableState(ctx.javaType(e.dataType), argValue, "") + argValue + } + + val argCodes = arguments.zipWithIndex.map { case (e, i) => + val expr = e.genCode(ctx) + expr.code + s""" + $argIsNulls[$i] = ${expr.isNull}; + ${argValues(i)} = ${expr.value}; + """ + } + val argCode = ctx.splitExpressions(ctx.INPUT_ROW, argCodes) val outer = outerPointer.map(func => Literal.fromObject(func()).genCode(ctx)) var isNull = ev.isNull val setIsNull = if (propagateNull && arguments.nonEmpty) { - s"final boolean $isNull = ${argGen.map(_.isNull).mkString(" || ")};" + s""" + boolean $isNull = false; + for (int idx = 0; idx < ${arguments.length}; idx++) { + if ($argIsNulls[idx]) { $isNull = true; break; } + } + """ } else { isNull = "false" "" } val constructorCall = outer.map { gen => - s"""${gen.value}.new ${cls.getSimpleName}($argString)""" + s"""${gen.value}.new ${cls.getSimpleName}(${argValues.mkString(", ")})""" }.getOrElse { - s"new $className($argString)" + s"new $className(${argValues.mkString(", ")})" } val code = s""" - ${argGen.map(_.code).mkString("\n")} + $argCode ${outer.map(_.code).getOrElse("")} $setIsNull final $javaType ${ev.value} = $isNull ? ${ctx.defaultValue(javaType)} : $constructorCall; @@ -346,6 +366,13 @@ case class LambdaVariable(value: String, isNull: String, dataType: DataType) ext object MapObjects { private val curId = new java.util.concurrent.atomic.AtomicInteger() + /** + * Construct an instance of MapObjects case class. + * + * @param function The function applied on the collection elements. + * @param inputData An expression that when evaluated returns a collection object. + * @param elementType The data type of elements in the collection. + */ def apply( function: Expression => Expression, inputData: Expression, @@ -433,8 +460,14 @@ case class MapObjects private( case _ => "" } + // The data with PythonUserDefinedType are actually stored with the data type of its sqlType. + // When we want to apply MapObjects on it, we have to use it. + val inputDataType = inputData.dataType match { + case p: PythonUserDefinedType => p.sqlType + case _ => inputData.dataType + } - val (getLength, getLoopVar) = inputData.dataType match { + val (getLength, getLoopVar) = inputDataType match { case ObjectType(cls) if classOf[Seq[_]].isAssignableFrom(cls) => s"${genInputData.value}.size()" -> s"${genInputData.value}.apply($loopIndex)" case ObjectType(cls) if cls.isArray => @@ -448,7 +481,17 @@ case class MapObjects private( s"$seq == null ? $array[$loopIndex] : $seq.apply($loopIndex)" } - val loopNullCheck = inputData.dataType match { + // Make a copy of the data if it's unsafe-backed + def makeCopyIfInstanceOf(clazz: Class[_ <: Any], value: String) = + s"$value instanceof ${clazz.getSimpleName}? ${value}.copy() : $value" + val genFunctionValue = lambdaFunction.dataType match { + case StructType(_) => makeCopyIfInstanceOf(classOf[UnsafeRow], genFunction.value) + case ArrayType(_, _) => makeCopyIfInstanceOf(classOf[UnsafeArrayData], genFunction.value) + case MapType(_, _, _) => makeCopyIfInstanceOf(classOf[UnsafeMapData], genFunction.value) + case _ => genFunction.value + } + + val loopNullCheck = inputDataType match { case _: ArrayType => s"$loopIsNull = ${genInputData.value}.isNullAt($loopIndex);" // The element of primitive array will never be null. case ObjectType(cls) if cls.isArray && cls.getComponentType.isPrimitive => @@ -475,7 +518,7 @@ case class MapObjects private( if (${genFunction.isNull}) { $convertedArray[$loopIndex] = null; } else { - $convertedArray[$loopIndex] = ${genFunction.value}; + $convertedArray[$loopIndex] = $genFunctionValue; } $loopIndex += 1; @@ -720,7 +763,10 @@ case class GetExternalRowField( override def eval(input: InternalRow): Any = throw new UnsupportedOperationException("Only code-generated evaluation is supported") + private val errMsg = s"The ${index}th field '$fieldName' of input row cannot be null." + override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { + val errMsgField = ctx.addReferenceObj("errMsg", errMsg) val row = child.genCode(ctx) val code = s""" ${row.code} @@ -730,8 +776,7 @@ case class GetExternalRowField( } if (${row.value}.isNullAt($index)) { - throw new RuntimeException("The ${index}th field '$fieldName' of input row " + - "cannot be null."); + throw new RuntimeException($errMsgField); } final Object ${ev.value} = ${row.value}.get($index); @@ -756,7 +801,10 @@ case class ValidateExternalType(child: Expression, expected: DataType) override def eval(input: InternalRow): Any = throw new UnsupportedOperationException("Only code-generated evaluation is supported") + private val errMsg = s" is not a valid external type for schema of ${expected.simpleString}" + override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { + val errMsgField = ctx.addReferenceObj("errMsg", errMsg) val input = child.genCode(ctx) val obj = input.value @@ -777,8 +825,7 @@ case class ValidateExternalType(child: Expression, expected: DataType) if ($typeCheck) { ${ev.value} = (${ctx.boxedType(dataType)}) $obj; } else { - throw new RuntimeException($obj.getClass().getName() + " is not a valid " + - "external type for schema of ${expected.simpleString}"); + throw new RuntimeException($obj.getClass().getName() + $errMsgField); } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ordering.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ordering.scala index 6112259fed61..9a892905f518 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ordering.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ordering.scala @@ -31,7 +31,8 @@ class InterpretedOrdering(ordering: Seq[SortOrder]) extends Ordering[InternalRow def compare(a: InternalRow, b: InternalRow): Int = { var i = 0 - while (i < ordering.size) { + val size = ordering.size + while (i < size) { val order = ordering(i) val left = order.child.eval(a) val right = order.child.eval(b) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala index a3b098afe572..100087ed5891 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala @@ -393,13 +393,13 @@ abstract class BinaryComparison extends BinaryOperator with Predicate { } -private[sql] object BinaryComparison { +object BinaryComparison { def unapply(e: BinaryComparison): Option[(Expression, Expression)] = Some((e.left, e.right)) } /** An extractor that matches both standard 3VL equality and null-safe equality. */ -private[sql] object Equality { +object Equality { def unapply(e: BinaryComparison): Option[(Expression, Expression)] = e match { case EqualTo(l, r) => Some((l, r)) case EqualNullSafe(l, r) => Some((l, r)) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/regexpExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/regexpExpressions.scala index be82b3b8f45f..d25da3fd587b 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/regexpExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/regexpExpressions.scala @@ -329,7 +329,12 @@ case class RegExpExtract(subject: Expression, regexp: Expression, idx: Expressio val m = pattern.matcher(s.toString) if (m.find) { val mr: MatchResult = m.toMatchResult - UTF8String.fromString(mr.group(r.asInstanceOf[Int])) + val group = mr.group(r.asInstanceOf[Int]) + if (group == null) { // Pattern matched, but not optional group + UTF8String.EMPTY_UTF8 + } else { + UTF8String.fromString(group) + } } else { UTF8String.EMPTY_UTF8 } @@ -367,7 +372,11 @@ case class RegExpExtract(subject: Expression, regexp: Expression, idx: Expressio ${termPattern}.matcher($subject.toString()); if (${matcher}.find()) { java.util.regex.MatchResult ${matchResult} = ${matcher}.toMatchResult(); - ${ev.value} = UTF8String.fromString(${matchResult}.group($idx)); + if (${matchResult}.group($idx) == null) { + ${ev.value} = UTF8String.EMPTY_UTF8; + } else { + ${ev.value} = UTF8String.fromString(${matchResult}.group($idx)); + } $setEvNotNull } else { ${ev.value} = UTF8String.EMPTY_UTF8; diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/rows.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/rows.scala index e036982e70f9..73dceb35ac50 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/rows.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/rows.scala @@ -218,7 +218,7 @@ class GenericRowWithSchema(values: Array[Any], override val schema: StructType) * Note that, while the array is not copied, and thus could technically be mutated after creation, * this is not allowed. */ -class GenericInternalRow(private[sql] val values: Array[Any]) extends BaseGenericInternalRow { +class GenericInternalRow(val values: Array[Any]) extends BaseGenericInternalRow { /** No-arg constructor for serialization. */ protected def this() = this(null) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala index f24f8b78d476..d824c2e26d71 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala @@ -75,7 +75,6 @@ abstract class Optimizer(sessionCatalog: SessionCatalog, conf: CatalystConf) Batch("Operator Optimizations", fixedPoint, // Operator push down PushThroughSetOperations, - PushProjectThroughSample, ReorderJoin, EliminateOuterJoin, PushPredicateThroughJoin, @@ -146,17 +145,6 @@ class SimpleTestOptimizer extends Optimizer( new SimpleCatalystConf(caseSensitiveAnalysis = true)), new SimpleCatalystConf(caseSensitiveAnalysis = true)) -/** - * Pushes projects down beneath Sample to enable column pruning with sampling. - */ -object PushProjectThroughSample extends Rule[LogicalPlan] { - def apply(plan: LogicalPlan): LogicalPlan = plan transform { - // Push down projection into sample - case Project(projectList, Sample(lb, up, replace, seed, child)) => - Sample(lb, up, replace, seed, Project(projectList, child))() - } -} - /** * Removes the Project only conducting Alias of its child node. * It is created mainly for removing extra Project added in EliminateSerialization rule, @@ -198,25 +186,6 @@ object RemoveAliasOnlyProject extends Rule[LogicalPlan] { } } -/** - * Removes cases where we are unnecessarily going between the object and serialized (InternalRow) - * representation of data item. For example back to back map operations. - */ -object EliminateSerialization extends Rule[LogicalPlan] { - def apply(plan: LogicalPlan): LogicalPlan = plan transform { - case d @ DeserializeToObject(_, _, s: SerializeFromObject) - if d.outputObjectType == s.inputObjectType => - // Adds an extra Project here, to preserve the output expr id of `DeserializeToObject`. - // We will remove it later in RemoveAliasOnlyProject rule. - val objAttr = - Alias(s.child.output.head, s.child.output.head.name)(exprId = d.output.head.exprId) - Project(objAttr :: Nil, s.child) - case a @ AppendColumns(_, _, _, s: SerializeFromObject) - if a.deserializer.dataType == s.inputObjectType => - AppendColumnsWithObject(a.func, s.serializer, a.serializer, s.child) - } -} - /** * Pushes down [[LocalLimit]] beneath UNION ALL and beneath the streamed inputs of outer joins. */ @@ -594,6 +563,8 @@ object NullPropagation extends Rule[LogicalPlan] { def apply(plan: LogicalPlan): LogicalPlan = plan transform { case q: LogicalPlan => q transformExpressionsUp { + case e @ WindowExpression(Cast(Literal(0L, _), _), _) => + Cast(Literal(0L), e.dataType) case e @ AggregateExpression(Count(exprs), _, _, _) if !exprs.exists(nonNullLiteral) => Cast(Literal(0L), e.dataType) case e @ IsNull(c) if !c.nullable => Literal.create(false, BooleanType) @@ -695,6 +666,19 @@ object FoldablePropagation extends Rule[LogicalPlan] { case j @ Join(_, _, LeftOuter | RightOuter | FullOuter, _) => stop = true j + + // These 3 operators take attributes as constructor parameters, and these attributes + // can't be replaced by alias. + case m: MapGroups => + stop = true + m + case f: FlatMapGroupsInR => + stop = true + f + case c: CoGroup => + stop = true + c + case p: LogicalPlan if !stop => p.transformExpressions { case a: AttributeReference if foldableMap.contains(a) => foldableMap(a) @@ -1120,17 +1104,28 @@ object PushDownPredicate extends Rule[LogicalPlan] with PredicateHelper { filter } - // two filters should be combine together by other rules - case filter @ Filter(_, _: Filter) => filter - // should not push predicates through sample, or will generate different results. - case filter @ Filter(_, _: Sample) => filter - - case filter @ Filter(condition, u: UnaryNode) if u.expressions.forall(_.deterministic) => + case filter @ Filter(condition, u: UnaryNode) + if canPushThrough(u) && u.expressions.forall(_.deterministic) => pushDownPredicate(filter, u.child) { predicate => u.withNewChildren(Seq(Filter(predicate, u.child))) } } + private def canPushThrough(p: UnaryNode): Boolean = p match { + // Note that some operators (e.g. project, aggregate, union) are being handled separately + // (earlier in this rule). + case _: AppendColumns => true + case _: BroadcastHint => true + case _: Distinct => true + case _: Generate => true + case _: Pivot => true + case _: RedistributeData => true + case _: Repartition => true + case _: ScriptTransformation => true + case _: Sort => true + case _ => false + } + private def pushDownPredicate( filter: Filter, grandchild: LogicalPlan)(insertFilter: Expression => LogicalPlan): LogicalPlan = { @@ -1154,118 +1149,6 @@ object PushDownPredicate extends Rule[LogicalPlan] with PredicateHelper { } } -/** - * Reorder the joins and push all the conditions into join, so that the bottom ones have at least - * one condition. - * - * The order of joins will not be changed if all of them already have at least one condition. - */ -object ReorderJoin extends Rule[LogicalPlan] with PredicateHelper { - - /** - * Join a list of plans together and push down the conditions into them. - * - * The joined plan are picked from left to right, prefer those has at least one join condition. - * - * @param input a list of LogicalPlans to join. - * @param conditions a list of condition for join. - */ - @tailrec - def createOrderedJoin(input: Seq[LogicalPlan], conditions: Seq[Expression]): LogicalPlan = { - assert(input.size >= 2) - if (input.size == 2) { - val (joinConditions, others) = conditions.partition( - e => !SubqueryExpression.hasCorrelatedSubquery(e)) - val join = Join(input(0), input(1), Inner, joinConditions.reduceLeftOption(And)) - if (others.nonEmpty) { - Filter(others.reduceLeft(And), join) - } else { - join - } - } else { - val left :: rest = input.toList - // find out the first join that have at least one join condition - val conditionalJoin = rest.find { plan => - val refs = left.outputSet ++ plan.outputSet - conditions.filterNot(canEvaluate(_, left)).filterNot(canEvaluate(_, plan)) - .exists(_.references.subsetOf(refs)) - } - // pick the next one if no condition left - val right = conditionalJoin.getOrElse(rest.head) - - val joinedRefs = left.outputSet ++ right.outputSet - val (joinConditions, others) = conditions.partition( - e => e.references.subsetOf(joinedRefs) && !SubqueryExpression.hasCorrelatedSubquery(e)) - val joined = Join(left, right, Inner, joinConditions.reduceLeftOption(And)) - - // should not have reference to same logical plan - createOrderedJoin(Seq(joined) ++ rest.filterNot(_ eq right), others) - } - } - - def apply(plan: LogicalPlan): LogicalPlan = plan transform { - case j @ ExtractFiltersAndInnerJoins(input, conditions) - if input.size > 2 && conditions.nonEmpty => - createOrderedJoin(input, conditions) - } -} - -/** - * Elimination of outer joins, if the predicates can restrict the result sets so that - * all null-supplying rows are eliminated - * - * - full outer -> inner if both sides have such predicates - * - left outer -> inner if the right side has such predicates - * - right outer -> inner if the left side has such predicates - * - full outer -> left outer if only the left side has such predicates - * - full outer -> right outer if only the right side has such predicates - * - * This rule should be executed before pushing down the Filter - */ -object EliminateOuterJoin extends Rule[LogicalPlan] with PredicateHelper { - - /** - * Returns whether the expression returns null or false when all inputs are nulls. - */ - private def canFilterOutNull(e: Expression): Boolean = { - if (!e.deterministic || SubqueryExpression.hasCorrelatedSubquery(e)) return false - val attributes = e.references.toSeq - val emptyRow = new GenericInternalRow(attributes.length) - val v = BindReferences.bindReference(e, attributes).eval(emptyRow) - v == null || v == false - } - - private def buildNewJoinType(filter: Filter, join: Join): JoinType = { - val splitConjunctiveConditions: Seq[Expression] = splitConjunctivePredicates(filter.condition) - val leftConditions = splitConjunctiveConditions - .filter(_.references.subsetOf(join.left.outputSet)) - val rightConditions = splitConjunctiveConditions - .filter(_.references.subsetOf(join.right.outputSet)) - - val leftHasNonNullPredicate = leftConditions.exists(canFilterOutNull) || - filter.constraints.filter(_.isInstanceOf[IsNotNull]) - .exists(expr => join.left.outputSet.intersect(expr.references).nonEmpty) - val rightHasNonNullPredicate = rightConditions.exists(canFilterOutNull) || - filter.constraints.filter(_.isInstanceOf[IsNotNull]) - .exists(expr => join.right.outputSet.intersect(expr.references).nonEmpty) - - join.joinType match { - case RightOuter if leftHasNonNullPredicate => Inner - case LeftOuter if rightHasNonNullPredicate => Inner - case FullOuter if leftHasNonNullPredicate && rightHasNonNullPredicate => Inner - case FullOuter if leftHasNonNullPredicate => LeftOuter - case FullOuter if rightHasNonNullPredicate => RightOuter - case o => o - } - } - - def apply(plan: LogicalPlan): LogicalPlan = plan transform { - case f @ Filter(condition, j @ Join(_, _, RightOuter | LeftOuter | FullOuter, _)) => - val newJoinType = buildNewJoinType(f, j) - if (j.joinType == newJoinType) f else Filter(condition, j.copy(joinType = newJoinType)) - } -} - /** * Pushes down [[Filter]] operators where the `condition` can be * evaluated using only the attributes of the left or right side of a join. Other @@ -1278,18 +1161,25 @@ object EliminateOuterJoin extends Rule[LogicalPlan] with PredicateHelper { */ object PushPredicateThroughJoin extends Rule[LogicalPlan] with PredicateHelper { /** - * Splits join condition expressions into three categories based on the attributes required - * to evaluate them. + * Splits join condition expressions or filter predicates (on a given join's output) into three + * categories based on the attributes required to evaluate them. Note that we explicitly exclude + * on-deterministic (i.e., stateful) condition expressions in canEvaluateInLeft or + * canEvaluateInRight to prevent pushing these predicates on either side of the join. * * @return (canEvaluateInLeft, canEvaluateInRight, haveToEvaluateInBoth) */ private def split(condition: Seq[Expression], left: LogicalPlan, right: LogicalPlan) = { + // Note: In order to ensure correctness, it's important to not change the relative ordering of + // any deterministic expression that follows a non-deterministic expression. To achieve this, + // we only consider pushing down those expressions that precede the first non-deterministic + // expression in the condition. + val (pushDownCandidates, containingNonDeterministic) = condition.span(_.deterministic) val (leftEvaluateCondition, rest) = - condition.partition(_.references subsetOf left.outputSet) + pushDownCandidates.partition(_.references.subsetOf(left.outputSet)) val (rightEvaluateCondition, commonCondition) = - rest.partition(_.references subsetOf right.outputSet) + rest.partition(expr => expr.references.subsetOf(right.outputSet)) - (leftEvaluateCondition, rightEvaluateCondition, commonCondition) + (leftEvaluateCondition, rightEvaluateCondition, commonCondition ++ containingNonDeterministic) } def apply(plan: LogicalPlan): LogicalPlan = plan transform { @@ -1340,7 +1230,7 @@ object PushPredicateThroughJoin extends Rule[LogicalPlan] with PredicateHelper { } // push down the join filter into sub query scanning if applicable - case f @ Join(left, right, joinType, joinCondition) => + case j @ Join(left, right, joinType, joinCondition) => val (leftJoinConditions, rightJoinConditions, commonJoinCondition) = split(joinCondition.map(splitConjunctivePredicates).getOrElse(Nil), left, right) @@ -1370,7 +1260,7 @@ object PushPredicateThroughJoin extends Rule[LogicalPlan] with PredicateHelper { val newJoinCond = (leftJoinConditions ++ commonJoinCondition).reduceLeftOption(And) Join(newLeft, newRight, LeftOuter, newJoinCond) - case FullOuter => f + case FullOuter => j case NaturalJoin(_) => sys.error("Untransformed NaturalJoin node") case UsingJoin(_, _) => sys.error("Untransformed Using join node") } @@ -1549,9 +1439,16 @@ object ReplaceExceptWithAntiJoin extends Rule[LogicalPlan] { */ object RemoveLiteralFromGroupExpressions extends Rule[LogicalPlan] { def apply(plan: LogicalPlan): LogicalPlan = plan transform { - case a @ Aggregate(grouping, _, _) => + case a @ Aggregate(grouping, _, _) if grouping.nonEmpty => val newGrouping = grouping.filter(!_.foldable) - a.copy(groupingExpressions = newGrouping) + if (newGrouping.nonEmpty) { + a.copy(groupingExpressions = newGrouping) + } else { + // All grouping expressions are literals. We should not drop them all, because this can + // change the return semantics when the input of the Aggregate is empty (SPARK-17114). We + // instead replace this by single, easy to hash/sort, literal expression. + a.copy(groupingExpressions = Seq(Literal(0, IntegerType))) + } } } @@ -1567,97 +1464,6 @@ object RemoveRepetitionFromGroupExpressions extends Rule[LogicalPlan] { } } -/** - * Finds all [[RuntimeReplaceable]] expressions and replace them with the expressions that can - * be evaluated. This is mainly used to provide compatibility with other databases. - * For example, we use this to support "nvl" by replacing it with "coalesce". - */ -object ReplaceExpressions extends Rule[LogicalPlan] { - def apply(plan: LogicalPlan): LogicalPlan = plan transformAllExpressions { - case e: RuntimeReplaceable => e.replaced - } -} - -/** - * Computes the current date and time to make sure we return the same result in a single query. - */ -object ComputeCurrentTime extends Rule[LogicalPlan] { - def apply(plan: LogicalPlan): LogicalPlan = { - val dateExpr = CurrentDate() - val timeExpr = CurrentTimestamp() - val currentDate = Literal.create(dateExpr.eval(EmptyRow), dateExpr.dataType) - val currentTime = Literal.create(timeExpr.eval(EmptyRow), timeExpr.dataType) - - plan transformAllExpressions { - case CurrentDate() => currentDate - case CurrentTimestamp() => currentTime - } - } -} - -/** Replaces the expression of CurrentDatabase with the current database name. */ -case class GetCurrentDatabase(sessionCatalog: SessionCatalog) extends Rule[LogicalPlan] { - def apply(plan: LogicalPlan): LogicalPlan = { - plan transformAllExpressions { - case CurrentDatabase() => - Literal.create(sessionCatalog.getCurrentDatabase, StringType) - } - } -} - -/** - * Typed [[Filter]] is by default surrounded by a [[DeserializeToObject]] beneath it and a - * [[SerializeFromObject]] above it. If these serializations can't be eliminated, we should embed - * the deserializer in filter condition to save the extra serialization at last. - */ -object EmbedSerializerInFilter extends Rule[LogicalPlan] { - def apply(plan: LogicalPlan): LogicalPlan = plan transform { - case s @ SerializeFromObject(_, Filter(condition, d: DeserializeToObject)) - // SPARK-15632: Conceptually, filter operator should never introduce schema change. This - // optimization rule also relies on this assumption. However, Dataset typed filter operator - // does introduce schema changes in some cases. Thus, we only enable this optimization when - // - // 1. either input and output schemata are exactly the same, or - // 2. both input and output schemata are single-field schema and share the same type. - // - // The 2nd case is included because encoders for primitive types always have only a single - // field with hard-coded field name "value". - // TODO Cleans this up after fixing SPARK-15632. - if s.schema == d.child.schema || samePrimitiveType(s.schema, d.child.schema) => - - val numObjects = condition.collect { - case a: Attribute if a == d.output.head => a - }.length - - if (numObjects > 1) { - // If the filter condition references the object more than one times, we should not embed - // deserializer in it as the deserialization will happen many times and slow down the - // execution. - // TODO: we can still embed it if we can make sure subexpression elimination works here. - s - } else { - val newCondition = condition transform { - case a: Attribute if a == d.output.head => d.deserializer - } - val filter = Filter(newCondition, d.child) - - // Adds an extra Project here, to preserve the output expr id of `SerializeFromObject`. - // We will remove it later in RemoveAliasOnlyProject rule. - val objAttrs = filter.output.zip(s.output).map { case (fout, sout) => - Alias(fout, fout.name)(exprId = sout.exprId) - } - Project(objAttrs, filter) - } - } - - def samePrimitiveType(lhs: StructType, rhs: StructType): Boolean = { - (lhs, rhs) match { - case (StructType(Array(f1)), StructType(Array(f2))) => f1.dataType == f2.dataType - case _ => false - } - } -} - /** * This rule rewrites predicate sub-queries into left semi/anti joins. The following predicates * are supported: diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/RewriteDistinctAggregates.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/RewriteDistinctAggregates.scala similarity index 93% rename from sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/RewriteDistinctAggregates.scala rename to sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/RewriteDistinctAggregates.scala index 8afd28dbba5c..d6a39ecf53b8 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/RewriteDistinctAggregates.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/RewriteDistinctAggregates.scala @@ -15,7 +15,7 @@ * limitations under the License. */ -package org.apache.spark.sql.catalyst.analysis +package org.apache.spark.sql.catalyst.optimizer import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.aggregate.{AggregateExpression, AggregateFunction, Complete} @@ -119,14 +119,16 @@ object RewriteDistinctAggregates extends Rule[LogicalPlan] { .filter(_.isDistinct) .groupBy(_.aggregateFunction.children.toSet) - // Aggregation strategy can handle the query with single distinct - if (distinctAggGroups.size > 1) { + // Check if the aggregates contains functions that do not support partial aggregation. + val existsNonPartial = aggExpressions.exists(!_.aggregateFunction.supportsPartial) + + // Aggregation strategy can handle queries with a single distinct group and partial aggregates. + if (distinctAggGroups.size > 1 || (distinctAggGroups.size == 1 && existsNonPartial)) { // Create the attributes for the grouping id and the group by clause. - val gid = - new AttributeReference("gid", IntegerType, false)(isGenerated = true) + val gid = AttributeReference("gid", IntegerType, nullable = false)(isGenerated = true) val groupByMap = a.groupingExpressions.collect { case ne: NamedExpression => ne -> ne.toAttribute - case e => e -> new AttributeReference(e.sql, e.dataType, e.nullable)() + case e => e -> AttributeReference(e.sql, e.dataType, e.nullable)() } val groupByAttrs = groupByMap.map(_._2) @@ -135,9 +137,7 @@ object RewriteDistinctAggregates extends Rule[LogicalPlan] { def patchAggregateFunctionChildren( af: AggregateFunction)( attrs: Expression => Expression): AggregateFunction = { - af.withNewChildren(af.children.map { - case afc => attrs(afc) - }).asInstanceOf[AggregateFunction] + af.withNewChildren(af.children.map(attrs)).asInstanceOf[AggregateFunction] } // Setup unique distinct aggregate children. @@ -265,5 +265,5 @@ object RewriteDistinctAggregates extends Rule[LogicalPlan] { // NamedExpression. This is done to prevent collisions between distinct and regular aggregate // children, in this case attribute reuse causes the input of the regular aggregate to bound to // the (nulled out) input of the distinct aggregate. - e -> new AttributeReference(e.sql, e.dataType, true)() + e -> AttributeReference(e.sql, e.dataType, nullable = true)() } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/finishAnalysis.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/finishAnalysis.scala new file mode 100644 index 000000000000..7c667315870f --- /dev/null +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/finishAnalysis.scala @@ -0,0 +1,65 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.optimizer + +import org.apache.spark.sql.catalyst.catalog.SessionCatalog +import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.catalyst.plans.logical._ +import org.apache.spark.sql.catalyst.rules._ +import org.apache.spark.sql.types._ + + +/** + * Finds all [[RuntimeReplaceable]] expressions and replace them with the expressions that can + * be evaluated. This is mainly used to provide compatibility with other databases. + * For example, we use this to support "nvl" by replacing it with "coalesce". + */ +object ReplaceExpressions extends Rule[LogicalPlan] { + def apply(plan: LogicalPlan): LogicalPlan = plan transformAllExpressions { + case e: RuntimeReplaceable => e.replaced + } +} + + +/** + * Computes the current date and time to make sure we return the same result in a single query. + */ +object ComputeCurrentTime extends Rule[LogicalPlan] { + def apply(plan: LogicalPlan): LogicalPlan = { + val dateExpr = CurrentDate() + val timeExpr = CurrentTimestamp() + val currentDate = Literal.create(dateExpr.eval(EmptyRow), dateExpr.dataType) + val currentTime = Literal.create(timeExpr.eval(EmptyRow), timeExpr.dataType) + + plan transformAllExpressions { + case CurrentDate() => currentDate + case CurrentTimestamp() => currentTime + } + } +} + + +/** Replaces the expression of CurrentDatabase with the current database name. */ +case class GetCurrentDatabase(sessionCatalog: SessionCatalog) extends Rule[LogicalPlan] { + def apply(plan: LogicalPlan): LogicalPlan = { + plan transformAllExpressions { + case CurrentDatabase() => + Literal.create(sessionCatalog.getCurrentDatabase, StringType) + } + } +} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/joins.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/joins.scala new file mode 100644 index 000000000000..ae4cd8e8709a --- /dev/null +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/joins.scala @@ -0,0 +1,136 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.optimizer + +import scala.annotation.tailrec + +import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.catalyst.planning.ExtractFiltersAndInnerJoins +import org.apache.spark.sql.catalyst.plans._ +import org.apache.spark.sql.catalyst.plans.logical._ +import org.apache.spark.sql.catalyst.rules._ + + +/** + * Reorder the joins and push all the conditions into join, so that the bottom ones have at least + * one condition. + * + * The order of joins will not be changed if all of them already have at least one condition. + */ +object ReorderJoin extends Rule[LogicalPlan] with PredicateHelper { + + /** + * Join a list of plans together and push down the conditions into them. + * + * The joined plan are picked from left to right, prefer those has at least one join condition. + * + * @param input a list of LogicalPlans to join. + * @param conditions a list of condition for join. + */ + @tailrec + def createOrderedJoin(input: Seq[LogicalPlan], conditions: Seq[Expression]): LogicalPlan = { + assert(input.size >= 2) + if (input.size == 2) { + val (joinConditions, others) = conditions.partition( + e => !SubqueryExpression.hasCorrelatedSubquery(e)) + val join = Join(input(0), input(1), Inner, joinConditions.reduceLeftOption(And)) + if (others.nonEmpty) { + Filter(others.reduceLeft(And), join) + } else { + join + } + } else { + val left :: rest = input.toList + // find out the first join that have at least one join condition + val conditionalJoin = rest.find { plan => + val refs = left.outputSet ++ plan.outputSet + conditions.filterNot(canEvaluate(_, left)).filterNot(canEvaluate(_, plan)) + .exists(_.references.subsetOf(refs)) + } + // pick the next one if no condition left + val right = conditionalJoin.getOrElse(rest.head) + + val joinedRefs = left.outputSet ++ right.outputSet + val (joinConditions, others) = conditions.partition( + e => e.references.subsetOf(joinedRefs) && !SubqueryExpression.hasCorrelatedSubquery(e)) + val joined = Join(left, right, Inner, joinConditions.reduceLeftOption(And)) + + // should not have reference to same logical plan + createOrderedJoin(Seq(joined) ++ rest.filterNot(_ eq right), others) + } + } + + def apply(plan: LogicalPlan): LogicalPlan = plan transform { + case j @ ExtractFiltersAndInnerJoins(input, conditions) + if input.size > 2 && conditions.nonEmpty => + createOrderedJoin(input, conditions) + } +} + + +/** + * Elimination of outer joins, if the predicates can restrict the result sets so that + * all null-supplying rows are eliminated + * + * - full outer -> inner if both sides have such predicates + * - left outer -> inner if the right side has such predicates + * - right outer -> inner if the left side has such predicates + * - full outer -> left outer if only the left side has such predicates + * - full outer -> right outer if only the right side has such predicates + * + * This rule should be executed before pushing down the Filter + */ +object EliminateOuterJoin extends Rule[LogicalPlan] with PredicateHelper { + + /** + * Returns whether the expression returns null or false when all inputs are nulls. + */ + private def canFilterOutNull(e: Expression): Boolean = { + if (!e.deterministic || SubqueryExpression.hasCorrelatedSubquery(e)) return false + val attributes = e.references.toSeq + val emptyRow = new GenericInternalRow(attributes.length) + val boundE = BindReferences.bindReference(e, attributes) + if (boundE.find(_.isInstanceOf[Unevaluable]).isDefined) return false + val v = boundE.eval(emptyRow) + v == null || v == false + } + + private def buildNewJoinType(filter: Filter, join: Join): JoinType = { + val conditions = splitConjunctivePredicates(filter.condition) ++ filter.constraints + val leftConditions = conditions.filter(_.references.subsetOf(join.left.outputSet)) + val rightConditions = conditions.filter(_.references.subsetOf(join.right.outputSet)) + + val leftHasNonNullPredicate = leftConditions.exists(canFilterOutNull) + val rightHasNonNullPredicate = rightConditions.exists(canFilterOutNull) + + join.joinType match { + case RightOuter if leftHasNonNullPredicate => Inner + case LeftOuter if rightHasNonNullPredicate => Inner + case FullOuter if leftHasNonNullPredicate && rightHasNonNullPredicate => Inner + case FullOuter if leftHasNonNullPredicate => LeftOuter + case FullOuter if rightHasNonNullPredicate => RightOuter + case o => o + } + } + + def apply(plan: LogicalPlan): LogicalPlan = plan transform { + case f @ Filter(condition, j @ Join(_, _, RightOuter | LeftOuter | FullOuter, _)) => + val newJoinType = buildNewJoinType(f, j) + if (j.joinType == newJoinType) f else Filter(condition, j.copy(joinType = newJoinType)) + } +} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/objects.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/objects.scala new file mode 100644 index 000000000000..8a25cee614c6 --- /dev/null +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/objects.scala @@ -0,0 +1,101 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.optimizer + +import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.catalyst.plans.logical._ +import org.apache.spark.sql.catalyst.rules._ +import org.apache.spark.sql.types.StructType + +/* + * This file defines optimization rules related to object manipulation (for the Dataset API). + */ + + +/** + * Removes cases where we are unnecessarily going between the object and serialized (InternalRow) + * representation of data item. For example back to back map operations. + */ +object EliminateSerialization extends Rule[LogicalPlan] { + def apply(plan: LogicalPlan): LogicalPlan = plan transform { + case d @ DeserializeToObject(_, _, s: SerializeFromObject) + if d.outputObjectType == s.inputObjectType => + // Adds an extra Project here, to preserve the output expr id of `DeserializeToObject`. + // We will remove it later in RemoveAliasOnlyProject rule. + val objAttr = + Alias(s.child.output.head, s.child.output.head.name)(exprId = d.output.head.exprId) + Project(objAttr :: Nil, s.child) + case a @ AppendColumns(_, _, _, s: SerializeFromObject) + if a.deserializer.dataType == s.inputObjectType => + AppendColumnsWithObject(a.func, s.serializer, a.serializer, s.child) + } +} + + +/** + * Typed [[Filter]] is by default surrounded by a [[DeserializeToObject]] beneath it and a + * [[SerializeFromObject]] above it. If these serializations can't be eliminated, we should embed + * the deserializer in filter condition to save the extra serialization at last. + */ +object EmbedSerializerInFilter extends Rule[LogicalPlan] { + def apply(plan: LogicalPlan): LogicalPlan = plan transform { + case s @ SerializeFromObject(_, Filter(condition, d: DeserializeToObject)) + // SPARK-15632: Conceptually, filter operator should never introduce schema change. This + // optimization rule also relies on this assumption. However, Dataset typed filter operator + // does introduce schema changes in some cases. Thus, we only enable this optimization when + // + // 1. either input and output schemata are exactly the same, or + // 2. both input and output schemata are single-field schema and share the same type. + // + // The 2nd case is included because encoders for primitive types always have only a single + // field with hard-coded field name "value". + // TODO Cleans this up after fixing SPARK-15632. + if s.schema == d.child.schema || samePrimitiveType(s.schema, d.child.schema) => + + val numObjects = condition.collect { + case a: Attribute if a == d.output.head => a + }.length + + if (numObjects > 1) { + // If the filter condition references the object more than one times, we should not embed + // deserializer in it as the deserialization will happen many times and slow down the + // execution. + // TODO: we can still embed it if we can make sure subexpression elimination works here. + s + } else { + val newCondition = condition transform { + case a: Attribute if a == d.output.head => d.deserializer + } + val filter = Filter(newCondition, d.child) + + // Adds an extra Project here, to preserve the output expr id of `SerializeFromObject`. + // We will remove it later in RemoveAliasOnlyProject rule. + val objAttrs = filter.output.zip(s.output).map { case (fout, sout) => + Alias(fout, fout.name)(exprId = sout.exprId) + } + Project(objAttrs, filter) + } + } + + def samePrimitiveType(lhs: StructType, rhs: StructType): Boolean = { + (lhs, rhs) match { + case (StructType(Array(f1)), StructType(Array(f2))) => f1.dataType == f2.dataType + case _ => false + } + } +} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala index c7420a1c5965..d1ce90770d3a 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala @@ -26,7 +26,8 @@ import org.antlr.v4.runtime.{ParserRuleContext, Token} import org.antlr.v4.runtime.tree.{ParseTree, RuleNode, TerminalNode} import org.apache.spark.internal.Logging -import org.apache.spark.sql.catalyst.{FunctionIdentifier, InternalRow, TableIdentifier} +import org.apache.spark.sql.AnalysisException +import org.apache.spark.sql.catalyst.{FunctionIdentifier, TableIdentifier} import org.apache.spark.sql.catalyst.analysis._ import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.parser.SqlBaseParser._ @@ -90,10 +91,9 @@ class AstBuilder extends SqlBaseBaseVisitor[AnyRef] with Logging { // Apply CTEs query.optional(ctx.ctes) { - val ctes = ctx.ctes.namedQuery.asScala.map { - case nCtx => - val namedQuery = visitNamedQuery(nCtx) - (namedQuery.alias, namedQuery) + val ctes = ctx.ctes.namedQuery.asScala.map { nCtx => + val namedQuery = visitNamedQuery(nCtx) + (namedQuery.alias, namedQuery) } // Check for duplicate names. checkDuplicateKeys(ctes, ctx) @@ -132,7 +132,7 @@ class AstBuilder extends SqlBaseBaseVisitor[AnyRef] with Logging { // Build the insert clauses. val inserts = ctx.multiInsertQueryBody.asScala.map { body => - assert(body.querySpecification.fromClause == null, + validate(body.querySpecification.fromClause == null, "Multi-Insert queries cannot have a FROM clause in their individual SELECT statements", body) @@ -399,7 +399,11 @@ class AstBuilder extends SqlBaseBaseVisitor[AnyRef] with Logging { * separated) relations here, these get converted into a single plan by condition-less inner join. */ override def visitFromClause(ctx: FromClauseContext): LogicalPlan = withOrigin(ctx) { - val from = ctx.relation.asScala.map(plan).reduceLeft(Join(_, _, Inner, None)) + val from = ctx.relation.asScala.foldLeft(null: LogicalPlan) { (left, relation) => + val right = plan(relation.relationPrimary) + val join = right.optionalMap(left)(Join(_, _, Inner, None)) + withJoinRelations(join, relation) + } ctx.lateralView.asScala.foldLeft(from)(withGenerate) } @@ -525,54 +529,49 @@ class AstBuilder extends SqlBaseBaseVisitor[AnyRef] with Logging { } /** - * Create a joins between two or more logical plans. + * Create a single relation referenced in a FROM claused. This method is used when a part of the + * join condition is nested, for example: + * {{{ + * select * from t1 join (t2 cross join t3) on col1 = col2 + * }}} */ - override def visitJoinRelation(ctx: JoinRelationContext): LogicalPlan = withOrigin(ctx) { - /** Build a join between two plans. */ - def join(ctx: JoinRelationContext, left: LogicalPlan, right: LogicalPlan): Join = { - val baseJoinType = ctx.joinType match { - case null => Inner - case jt if jt.FULL != null => FullOuter - case jt if jt.SEMI != null => LeftSemi - case jt if jt.ANTI != null => LeftAnti - case jt if jt.LEFT != null => LeftOuter - case jt if jt.RIGHT != null => RightOuter - case _ => Inner - } + override def visitRelation(ctx: RelationContext): LogicalPlan = withOrigin(ctx) { + withJoinRelations(plan(ctx.relationPrimary), ctx) + } - // Resolve the join type and join condition - val (joinType, condition) = Option(ctx.joinCriteria) match { - case Some(c) if c.USING != null => - val columns = c.identifier.asScala.map { column => - UnresolvedAttribute.quoted(column.getText) - } - (UsingJoin(baseJoinType, columns), None) - case Some(c) if c.booleanExpression != null => - (baseJoinType, Option(expression(c.booleanExpression))) - case None if ctx.NATURAL != null => - (NaturalJoin(baseJoinType), None) - case None => - (baseJoinType, None) - } - Join(left, right, joinType, condition) - } + /** + * Join one more [[LogicalPlan]]s to the current logical plan. + */ + private def withJoinRelations(base: LogicalPlan, ctx: RelationContext): LogicalPlan = { + ctx.joinRelation.asScala.foldLeft(base) { (left, join) => + withOrigin(join) { + val baseJoinType = join.joinType match { + case null => Inner + case jt if jt.FULL != null => FullOuter + case jt if jt.SEMI != null => LeftSemi + case jt if jt.ANTI != null => LeftAnti + case jt if jt.LEFT != null => LeftOuter + case jt if jt.RIGHT != null => RightOuter + case _ => Inner + } - // Handle all consecutive join clauses. ANTLR produces a right nested tree in which the the - // first join clause is at the top. However fields of previously referenced tables can be used - // in following join clauses. The tree needs to be reversed in order to make this work. - var result = plan(ctx.left) - var current = ctx - while (current != null) { - current.right match { - case right: JoinRelationContext => - result = join(current, result, plan(right.left)) - current = right - case right => - result = join(current, result, plan(right)) - current = null + // Resolve the join type and join condition + val (joinType, condition) = Option(join.joinCriteria) match { + case Some(c) if c.USING != null => + val columns = c.identifier.asScala.map { column => + UnresolvedAttribute.quoted(column.getText) + } + (UsingJoin(baseJoinType, columns), None) + case Some(c) if c.booleanExpression != null => + (baseJoinType, Option(expression(c.booleanExpression))) + case None if join.NATURAL != null => + (NaturalJoin(baseJoinType), None) + case None => + (baseJoinType, None) + } + Join(left, plan(join.right), joinType, condition) } } - result } /** @@ -591,7 +590,7 @@ class AstBuilder extends SqlBaseBaseVisitor[AnyRef] with Logging { // function takes X PERCENT as the input and the range of X is [0, 100], we need to // adjust the fraction. val eps = RandomSampler.roundingEpsilon - assert(fraction >= 0.0 - eps && fraction <= 1.0 + eps, + validate(fraction >= 0.0 - eps && fraction <= 1.0 + eps, s"Sampling fraction ($fraction) must be on interval [0, 1]", ctx) Sample(0.0, fraction, withReplacement = false, (math.random * 1000).toInt, query)(true) @@ -652,44 +651,37 @@ class AstBuilder extends SqlBaseBaseVisitor[AnyRef] with Logging { table.optionalMap(ctx.sample)(withSample) } + /** + * Create a table-valued function call with arguments, e.g. range(1000) + */ + override def visitTableValuedFunction(ctx: TableValuedFunctionContext) + : LogicalPlan = withOrigin(ctx) { + UnresolvedTableValuedFunction(ctx.identifier.getText, ctx.expression.asScala.map(expression)) + } + /** * Create an inline table (a virtual table in Hive parlance). */ override def visitInlineTable(ctx: InlineTableContext): LogicalPlan = withOrigin(ctx) { // Get the backing expressions. - val expressions = ctx.expression.asScala.map { eCtx => - val e = expression(eCtx) - assert(e.foldable, "All expressions in an inline table must be constants.", eCtx) - e - } - - // Validate and evaluate the rows. - val (structType, structConstructor) = expressions.head.dataType match { - case st: StructType => - (st, (e: Expression) => e) - case dt => - val st = CreateStruct(Seq(expressions.head)).dataType - (st, (e: Expression) => CreateStruct(Seq(e))) - } - val rows = expressions.map { - case expression => - val safe = Cast(structConstructor(expression), structType) - safe.eval().asInstanceOf[InternalRow] + val rows = ctx.expression.asScala.map { e => + expression(e) match { + // inline table comes in two styles: + // style 1: values (1), (2), (3) -- multiple columns are supported + // style 2: values 1, 2, 3 -- only a single column is supported here + case CreateStruct(children) => children // style 1 + case child => Seq(child) // style 2 + } } - // Construct attributes. - val baseAttributes = structType.toAttributes.map(_.withNullability(true)) - val attributes = if (ctx.identifierList != null) { - val aliases = visitIdentifierList(ctx.identifierList) - assert(aliases.size == baseAttributes.size, - "Number of aliases must match the number of fields in an inline table.", ctx) - baseAttributes.zip(aliases).map(p => p._1.withName(p._2)) + val aliases = if (ctx.identifierList != null) { + visitIdentifierList(ctx.identifierList) } else { - baseAttributes + Seq.tabulate(rows.head.size)(i => s"col${i + 1}") } - // Create plan and add an alias if a name has been defined. - LocalRelation(attributes, rows).optionalMap(ctx.identifier)(aliasPlan) + val table = UnresolvedInlineTable(aliases, rows) + table.optionalMap(ctx.identifier)(aliasPlan) } /** @@ -1022,6 +1014,19 @@ class AstBuilder extends SqlBaseBaseVisitor[AnyRef] with Logging { } } + /** + * Create a current timestamp/date expression. These are different from regular function because + * they do not require the user to specify braces when calling them. + */ + override def visitTimeFunctionCall(ctx: TimeFunctionCallContext): Expression = withOrigin(ctx) { + ctx.name.getType match { + case SqlBaseParser.CURRENT_DATE => + CurrentDate() + case SqlBaseParser.CURRENT_TIMESTAMP => + CurrentTimestamp() + } + } + /** * Create a function database (optional) and name pair. */ @@ -1076,7 +1081,7 @@ class AstBuilder extends SqlBaseBaseVisitor[AnyRef] with Logging { // We currently only allow foldable integers. def value: Int = { val e = expression(ctx.expression) - assert(e.resolved && e.foldable && e.dataType == IntegerType, + validate(e.resolved && e.foldable && e.dataType == IntegerType, "Frame bound value must be a constant integer.", ctx) e.eval().asInstanceOf[Int] @@ -1267,10 +1272,17 @@ class AstBuilder extends SqlBaseBaseVisitor[AnyRef] with Logging { } /** Create a numeric literal expression. */ - private def numericLiteral(ctx: NumberContext)(f: String => Any): Literal = withOrigin(ctx) { - val raw = ctx.getText + private def numericLiteral + (ctx: NumberContext, minValue: BigDecimal, maxValue: BigDecimal, typeName: String) + (converter: String => Any): Literal = withOrigin(ctx) { + val rawStrippedQualifier = ctx.getText.substring(0, ctx.getText.length - 1) try { - Literal(f(raw.substring(0, raw.length - 1))) + val rawBigDecimal = BigDecimal(rawStrippedQualifier) + if (rawBigDecimal < minValue || rawBigDecimal > maxValue) { + throw new ParseException(s"Numeric literal ${rawStrippedQualifier} does not " + + s"fit in range [${minValue}, ${maxValue}] for type ${typeName}", ctx) + } + Literal(converter(rawStrippedQualifier)) } catch { case e: NumberFormatException => throw new ParseException(e.getMessage, ctx) @@ -1280,29 +1292,42 @@ class AstBuilder extends SqlBaseBaseVisitor[AnyRef] with Logging { /** * Create a Byte Literal expression. */ - override def visitTinyIntLiteral(ctx: TinyIntLiteralContext): Literal = numericLiteral(ctx) { - _.toByte + override def visitTinyIntLiteral(ctx: TinyIntLiteralContext): Literal = { + numericLiteral(ctx, Byte.MinValue, Byte.MaxValue, ByteType.simpleString)(_.toByte) } /** * Create a Short Literal expression. */ - override def visitSmallIntLiteral(ctx: SmallIntLiteralContext): Literal = numericLiteral(ctx) { - _.toShort + override def visitSmallIntLiteral(ctx: SmallIntLiteralContext): Literal = { + numericLiteral(ctx, Short.MinValue, Short.MaxValue, ShortType.simpleString)(_.toShort) } /** * Create a Long Literal expression. */ - override def visitBigIntLiteral(ctx: BigIntLiteralContext): Literal = numericLiteral(ctx) { - _.toLong + override def visitBigIntLiteral(ctx: BigIntLiteralContext): Literal = { + numericLiteral(ctx, Long.MinValue, Long.MaxValue, LongType.simpleString)(_.toLong) } /** * Create a Double Literal expression. */ - override def visitDoubleLiteral(ctx: DoubleLiteralContext): Literal = numericLiteral(ctx) { - _.toDouble + override def visitDoubleLiteral(ctx: DoubleLiteralContext): Literal = { + numericLiteral(ctx, Double.MinValue, Double.MaxValue, DoubleType.simpleString)(_.toDouble) + } + + /** + * Create a BigDecimal Literal expression. + */ + override def visitBigDecimalLiteral(ctx: BigDecimalLiteralContext): Literal = { + val raw = ctx.getText.substring(0, ctx.getText.length - 2) + try { + Literal(BigDecimal(raw).underlying()) + } catch { + case e: AnalysisException => + throw new ParseException(e.message, ctx) + } } /** @@ -1329,7 +1354,7 @@ class AstBuilder extends SqlBaseBaseVisitor[AnyRef] with Logging { */ override def visitInterval(ctx: IntervalContext): Literal = withOrigin(ctx) { val intervals = ctx.intervalField.asScala.map(visitIntervalField) - assert(intervals.nonEmpty, "at least one time unit should be given for interval literal", ctx) + validate(intervals.nonEmpty, "at least one time unit should be given for interval literal", ctx) Literal(intervals.reduce(_.add(_))) } @@ -1356,7 +1381,7 @@ class AstBuilder extends SqlBaseBaseVisitor[AnyRef] with Logging { case (from, Some(t)) => throw new ParseException(s"Intervals FROM $from TO $t are not supported.", ctx) } - assert(interval != null, "No interval can be constructed", ctx) + validate(interval != null, "No interval can be constructed", ctx) interval } catch { // Handle Exceptions thrown by CalendarInterval diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/ParserUtils.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/ParserUtils.scala index b04ce58e233a..9506de229b87 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/ParserUtils.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/ParserUtils.scala @@ -77,8 +77,8 @@ object ParserUtils { Origin(Option(token.getLine), Option(token.getCharPositionInLine)) } - /** Assert if a condition holds. If it doesn't throw a parse exception. */ - def assert(f: => Boolean, message: String, ctx: ParserRuleContext): Unit = { + /** Validate the condition. If it doesn't throw a parse exception. */ + def validate(f: => Boolean, message: String, ctx: ParserRuleContext): Unit = { if (!f) { throw new ParseException(message, ctx) } @@ -192,9 +192,7 @@ object ParserUtils { * Map a [[LogicalPlan]] to another [[LogicalPlan]] if the passed context exists using the * passed function. The original plan is returned when the context does not exist. */ - def optionalMap[C <: ParserRuleContext]( - ctx: C)( - f: (C, LogicalPlan) => LogicalPlan): LogicalPlan = { + def optionalMap[C](ctx: C)(f: (C, LogicalPlan) => LogicalPlan): LogicalPlan = { if (ctx != null) { f(ctx, plan) } else { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/QueryPlan.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/QueryPlan.scala index cf34f4b30d8d..a6f94aa062b2 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/QueryPlan.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/QueryPlan.scala @@ -14,6 +14,24 @@ * See the License for the specific language governing permissions and * limitations under the License. */ +/* + * Changes for SnappyData data platform. + * + * Portions Copyright (c) 2016 SnappyData, Inc. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you + * may not use this file except in compliance with the License. You + * may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or + * implied. See the License for the specific language governing + * permissions and limitations under the License. See accompanying + * LICENSE file. + */ package org.apache.spark.sql.catalyst.plans @@ -35,7 +53,8 @@ abstract class QueryPlan[PlanType <: QueryPlan[PlanType]] extends TreeNode[PlanT .union(inferAdditionalConstraints(constraints)) .union(constructIsNotNullConstraints(constraints)) .filter(constraint => - constraint.references.nonEmpty && constraint.references.subsetOf(outputSet)) + constraint.references.nonEmpty && constraint.references.subsetOf(outputSet) && + constraint.deterministic) } /** @@ -199,6 +218,7 @@ abstract class QueryPlan[PlanType <: QueryPlan[PlanType]] extends TreeNode[PlanT def recursiveTransform(arg: Any): AnyRef = arg match { case e: Expression => transformExpressionUp(e) case Some(e: Expression) => Some(transformExpressionUp(e)) + case Some(seq: Traversable[_]) => Some(seq.map(recursiveTransform)) case m: Map[_, _] => m case d: DataType => d // Avoid unpacking Structs case seq: Traversable[_] => seq.map(recursiveTransform) @@ -233,6 +253,7 @@ abstract class QueryPlan[PlanType <: QueryPlan[PlanType]] extends TreeNode[PlanT productIterator.flatMap { case e: Expression => e :: Nil case Some(e: Expression) => e :: Nil + case Some(seq: Traversable[_] ) => seqToExpressions(seq) case seq: Traversable[_] => seqToExpressions(seq) case other => Nil }.toSeq diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LocalRelation.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LocalRelation.scala index 9d64f35efcc6..890865d17784 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LocalRelation.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LocalRelation.scala @@ -18,8 +18,9 @@ package org.apache.spark.sql.catalyst.plans.logical import org.apache.spark.sql.Row -import org.apache.spark.sql.catalyst.{analysis, CatalystTypeConverters, InternalRow} -import org.apache.spark.sql.catalyst.expressions.Attribute +import org.apache.spark.sql.catalyst.{CatalystTypeConverters, InternalRow} +import org.apache.spark.sql.catalyst.analysis +import org.apache.spark.sql.catalyst.expressions.{Attribute, Literal} import org.apache.spark.sql.types.{StructField, StructType} object LocalRelation { @@ -75,4 +76,16 @@ case class LocalRelation(output: Seq[Attribute], data: Seq[InternalRow] = Nil) override lazy val statistics = Statistics(sizeInBytes = output.map(_.dataType.defaultSize).sum * data.length) + + def toSQL(inlineTableName: String): String = { + require(data.nonEmpty) + val types = output.map(_.dataType) + val rows = data.map { row => + val cells = row.toSeq(types).zip(types).map { case (v, tpe) => Literal(v, tpe).sql } + cells.mkString("(", ", ", ")") + } + "VALUES " + rows.mkString(", ") + + " AS " + inlineTableName + + output.map(_.name).mkString("(", ", ", ")") + } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala index b31f5aa11c22..07e39b029894 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala @@ -127,7 +127,7 @@ abstract class SetOperation(left: LogicalPlan, right: LogicalPlan) extends Binar } } -private[sql] object SetOperation { +object SetOperation { def unapply(p: SetOperation): Option[(LogicalPlan, LogicalPlan)] = Some((p.left, p.right)) } @@ -365,7 +365,7 @@ case class InsertIntoTable( override def children: Seq[LogicalPlan] = child :: Nil override def output: Seq[Attribute] = Seq.empty - private[spark] lazy val expectedColumns = { + lazy val expectedColumns = { if (table.output.isEmpty) { None } else { @@ -422,17 +422,20 @@ case class Sort( /** Factory for constructing new `Range` nodes. */ object Range { - def apply(start: Long, end: Long, step: Long, numSlices: Int): Range = { + def apply(start: Long, end: Long, step: Long, numSlices: Option[Int]): Range = { val output = StructType(StructField("id", LongType, nullable = false) :: Nil).toAttributes new Range(start, end, step, numSlices, output) } + def apply(start: Long, end: Long, step: Long, numSlices: Int): Range = { + Range(start, end, step, Some(numSlices)) + } } case class Range( start: Long, end: Long, step: Long, - numSlices: Int, + numSlices: Option[Int], output: Seq[Attribute]) extends LeafNode with MultiInstanceRelation { @@ -449,6 +452,14 @@ case class Range( } } + def toSQL(): String = { + if (numSlices.isDefined) { + s"SELECT id AS `${output.head.name}` FROM range($start, $end, $step, ${numSlices.get})" + } else { + s"SELECT id AS `${output.head.name}` FROM range($start, $end, $step)" + } + } + override def newInstance(): Range = copy(output = output.map(_.newInstance())) override lazy val statistics: Statistics = { @@ -457,11 +468,7 @@ case class Range( } override def simpleString: String = { - if (step == 1) { - s"Range ($start, $end, splits=$numSlices)" - } else { - s"Range ($start, $end, step=$step, splits=$numSlices)" - } + s"Range ($start, $end, step=$step, splits=$numSlices)" } } @@ -509,7 +516,7 @@ case class Window( def windowOutputSet: AttributeSet = AttributeSet(windowExpressions.map(_.toAttribute)) } -private[sql] object Expand { +object Expand { /** * Extract attribute set according to the grouping id. * diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/physical/partitioning.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/physical/partitioning.scala index 51d78dd1233f..6bc140aa9aef 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/physical/partitioning.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/physical/partitioning.scala @@ -14,6 +14,24 @@ * See the License for the specific language governing permissions and * limitations under the License. */ +/* + * Changes for SnappyData data platform. + * + * Portions Copyright (c) 2016 SnappyData, Inc. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you + * may not use this file except in compliance with the License. You + * may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or + * implied. See the License for the specific language governing + * permissions and limitations under the License. See accompanying + * LICENSE file. + */ package org.apache.spark.sql.catalyst.plans.physical @@ -229,13 +247,57 @@ case object SinglePartition extends Partitioning { override def guarantees(other: Partitioning): Boolean = other.numPartitions == 1 } +/** + * Represents a partitioning where rows are split up across partitions based on the hash + * of `expressions`. All rows where `expressions` evaluate to the same values are guaranteed to be + * in the same partition. Moreover while evaluating expressions if they are given in different order + * than this partitioning then also it is considered equal. + */ +case class OrderlessHashPartitioning(expressions: Seq[Expression], + numPartitions: Int, numBuckets: Int) + extends Expression with Partitioning with Unevaluable { + + override def children: Seq[Expression] = expressions + override def nullable: Boolean = false + override def dataType: DataType = IntegerType + + private def matchExpressions(otherExpression: Seq[Expression]): Boolean = { + expressions.length == otherExpression.length && expressions.forall(a => + otherExpression.exists(e => e.semanticEquals(a))) + } + + override def satisfies(required: Distribution): Boolean = required match { + case UnspecifiedDistribution => true + case ClusteredDistribution(requiredClustering) => + matchExpressions(requiredClustering) + case _ => false + } + + private def anyOrderEquals(other: HashPartitioning) : Boolean = { + other.numBuckets == this.numBuckets && + other.numPartitions == this.numPartitions && + matchExpressions(other.expressions) + } + + override def compatibleWith(other: Partitioning): Boolean = other match { + case p: HashPartitioning => anyOrderEquals(p) + case _ => false + } + + override def guarantees(other: Partitioning): Boolean = other match { + case p: HashPartitioning => anyOrderEquals(p) + case _ => false + } + +} + /** * Represents a partitioning where rows are split up across partitions based on the hash * of `expressions`. All rows where `expressions` evaluate to the same values are guaranteed to be * in the same partition. */ -case class HashPartitioning(expressions: Seq[Expression], numPartitions: Int) - extends Expression with Partitioning with Unevaluable { +case class HashPartitioning(expressions: Seq[Expression], numPartitions: Int, + numBuckets: Int = 0) extends Expression with Partitioning with Unevaluable { override def children: Seq[Expression] = expressions override def nullable: Boolean = false @@ -249,12 +311,14 @@ case class HashPartitioning(expressions: Seq[Expression], numPartitions: Int) } override def compatibleWith(other: Partitioning): Boolean = other match { - case o: HashPartitioning => this.semanticEquals(o) + case o: HashPartitioning => + this.numBuckets == o.numBuckets && this.semanticEquals(o) case _ => false } override def guarantees(other: Partitioning): Boolean = other match { - case o: HashPartitioning => this.semanticEquals(o) + case o: HashPartitioning => + this.numBuckets == o.numBuckets && this.semanticEquals(o) case _ => false } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/TreeNode.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/TreeNode.scala index 072445af4f41..eeccba79e42a 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/TreeNode.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/TreeNode.scala @@ -613,7 +613,9 @@ abstract class TreeNode[BaseType <: TreeNode[BaseType]] extends Product { case s: String => JString(s) case u: UUID => JString(u.toString) case dt: DataType => dt.jsonValue - case m: Metadata => m.jsonValue + // SPARK-17356: In usage of mllib, Metadata may store a huge vector of data, transforming + // it to JSON may trigger OutOfMemoryError. + case m: Metadata => Metadata.empty.jsonValue case s: StorageLevel => ("useDisk" -> s.useDisk) ~ ("useMemory" -> s.useMemory) ~ ("useOffHeap" -> s.useOffHeap) ~ ("deserialized" -> s.deserialized) ~ ("replication" -> s.replication) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/AbstractScalaRowIterator.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/AbstractScalaRowIterator.scala index 6d35f140cf23..0c7205b3c665 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/AbstractScalaRowIterator.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/AbstractScalaRowIterator.scala @@ -23,7 +23,7 @@ package org.apache.spark.sql.catalyst.util * `Row` in order to work around a spurious IntelliJ compiler error. This cannot be an abstract * class because that leads to compilation errors under Scala 2.11. */ -private[spark] class AbstractScalaRowIterator[T] extends Iterator[T] { +class AbstractScalaRowIterator[T] extends Iterator[T] { override def hasNext: Boolean = throw new NotImplementedError override def next(): T = throw new NotImplementedError diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/DateTimeUtils.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/DateTimeUtils.scala index df480a1d65bc..0b643a5b8426 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/DateTimeUtils.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/DateTimeUtils.scala @@ -852,8 +852,10 @@ object DateTimeUtils { /** * Lookup the offset for given millis seconds since 1970-01-01 00:00:00 in given timezone. + * TODO: Improve handling of normalization differences. + * TODO: Replace with JSR-310 or similar system - see SPARK-16788 */ - private def getOffsetFromLocalMillis(millisLocal: Long, tz: TimeZone): Long = { + private[sql] def getOffsetFromLocalMillis(millisLocal: Long, tz: TimeZone): Long = { var guess = tz.getRawOffset // the actual offset should be calculated based on milliseconds in UTC val offset = tz.getOffset(millisLocal - guess) @@ -875,11 +877,11 @@ object DateTimeUtils { val hh = seconds / 3600 val mm = seconds / 60 % 60 val ss = seconds % 60 - val nano = millisOfDay % 1000 * 1000000 - - // create a Timestamp to get the unix timestamp (in UTC) - val timestamp = new Timestamp(year - 1900, month - 1, day, hh, mm, ss, nano) - guess = (millisLocal - timestamp.getTime).toInt + val ms = millisOfDay % 1000 + val calendar = Calendar.getInstance(tz) + calendar.set(year, month - 1, day, hh, mm, ss) + calendar.set(Calendar.MILLISECOND, ms) + guess = (millisLocal - calendar.getTimeInMillis()).toInt } } guess diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/AbstractDataType.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/AbstractDataType.scala index 65eae869d40d..00c92fc7ef95 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/AbstractDataType.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/AbstractDataType.scala @@ -14,6 +14,24 @@ * See the License for the specific language governing permissions and * limitations under the License. */ +/* + * Changes for SnappyData data platform. + * + * Portions Copyright (c) 2016 SnappyData, Inc. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you + * may not use this file except in compliance with the License. You + * may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or + * implied. See the License for the specific language governing + * permissions and limitations under the License. See accompanying + * LICENSE file. + */ package org.apache.spark.sql.types @@ -131,7 +149,7 @@ protected[sql] abstract class AtomicType extends DataType { private[sql] val tag: TypeTag[InternalType] private[sql] val ordering: Ordering[InternalType] - @transient private[sql] val classTag = ScalaReflectionLock.synchronized { + @transient private[sql] lazy val classTag = ScalaReflectionLock.synchronized { val mirror = runtimeMirror(Utils.getSparkClassLoader) ClassTag[InternalType](mirror.runtimeClass(tag.tpe)) } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/ArrayType.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/ArrayType.scala index 520e34436162..82a03b0afc00 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/ArrayType.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/ArrayType.scala @@ -77,6 +77,8 @@ case class ArrayType(elementType: DataType, containsNull: Boolean) extends DataT override def simpleString: String = s"array<${elementType.simpleString}>" + override def catalogString: String = s"array<${elementType.catalogString}>" + override def sql: String = s"ARRAY<${elementType.sql}>" override private[spark] def asNullable: ArrayType = diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/Decimal.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/Decimal.scala index cc8175c0a366..70859052872d 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/Decimal.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/Decimal.scala @@ -242,10 +242,30 @@ final class Decimal extends Ordered[Decimal] with Serializable { if (scale < _scale) { // Easier case: we just need to divide our scale down val diff = _scale - scale - val droppedDigits = longVal % POW_10(diff) - longVal /= POW_10(diff) - if (math.abs(droppedDigits) * 2 >= POW_10(diff)) { - longVal += (if (longVal < 0) -1L else 1L) + val pow10diff = POW_10(diff) + // % and / always round to 0 + val droppedDigits = longVal % pow10diff + longVal /= pow10diff + roundMode match { + case ROUND_FLOOR => + if (droppedDigits < 0) { + longVal += -1L + } + case ROUND_CEILING => + if (droppedDigits > 0) { + longVal += 1L + } + case ROUND_HALF_UP => + if (math.abs(droppedDigits) * 2 >= pow10diff) { + longVal += (if (droppedDigits < 0) -1L else 1L) + } + case ROUND_HALF_EVEN => + val doubled = math.abs(droppedDigits) * 2 + if (doubled > pow10diff || doubled == pow10diff && longVal % 2 != 0) { + longVal += (if (droppedDigits < 0) -1L else 1L) + } + case _ => + sys.error(s"Not supported rounding mode: $roundMode") } } else if (scale > _scale) { // We might be able to multiply longVal by a power of 10 and not overflow, but if not, diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DecimalType.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DecimalType.scala index 6500875f95e5..c57df0ae1776 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DecimalType.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DecimalType.scala @@ -14,6 +14,24 @@ * See the License for the specific language governing permissions and * limitations under the License. */ +/* + * Changes for SnappyData data platform. + * + * Portions Copyright (c) 2016 SnappyData, Inc. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you + * may not use this file except in compliance with the License. You + * may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or + * implied. See the License for the specific language governing + * permissions and limitations under the License. See accompanying + * LICENSE file. + */ package org.apache.spark.sql.types @@ -31,7 +49,7 @@ import org.apache.spark.sql.catalyst.expressions.Expression * A Decimal that must have fixed precision (the maximum number of digits) and scale (the number * of digits on right side of dot). * - * The precision can be up to 38, scale can also be up to 38 (less or equal to precision). + * The precision can be up to 127, scale can also be up to 127 (less or equal to precision). * * The default precision and scale is (10, 0). * @@ -46,7 +64,8 @@ case class DecimalType(precision: Int, scale: Int) extends FractionalType { } if (precision > DecimalType.MAX_PRECISION) { - throw new AnalysisException(s"DecimalType can only support precision up to 38") + throw new AnalysisException( + s"DecimalType can only support precision up to ${DecimalType.MAX_PRECISION}") } // default constructor for Java @@ -105,10 +124,10 @@ case class DecimalType(precision: Int, scale: Int) extends FractionalType { object DecimalType extends AbstractDataType { import scala.math.min - val MAX_PRECISION = 38 - val MAX_SCALE = 38 - val SYSTEM_DEFAULT: DecimalType = DecimalType(MAX_PRECISION, 18) - val USER_DEFAULT: DecimalType = DecimalType(10, 0) + val MAX_PRECISION = 127 + val MAX_SCALE = 63 + val SYSTEM_DEFAULT: DecimalType = DecimalType(38, 18) + val USER_DEFAULT: DecimalType = DecimalType(38, 18) // The decimal types compatible with other numeric types private[sql] val ByteDecimal = DecimalType(3, 0) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/MapType.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/MapType.scala index 454ea403bac2..178960929bd8 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/MapType.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/MapType.scala @@ -64,6 +64,8 @@ case class MapType( override def simpleString: String = s"map<${keyType.simpleString},${valueType.simpleString}>" + override def catalogString: String = s"map<${keyType.catalogString},${valueType.catalogString}>" + override def sql: String = s"MAP<${keyType.sql}, ${valueType.sql}>" override private[spark] def asNullable: MapType = diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/ExpressionTypeCheckingSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/ExpressionTypeCheckingSuite.scala index 76e42d9afa4c..3aefb3cfc333 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/ExpressionTypeCheckingSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/ExpressionTypeCheckingSuite.scala @@ -216,7 +216,6 @@ class ExpressionTypeCheckingSuite extends SparkFunSuite { for (operator <- Seq[(Seq[Expression] => Expression)](Greatest, Least)) { assertError(operator(Seq('booleanField)), "requires at least 2 arguments") assertError(operator(Seq('intField, 'stringField)), "should all have the same type") - assertError(operator(Seq('intField, 'decimalField)), "should all have the same type") assertError(operator(Seq('mapField, 'mapField)), "does not support ordering") } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/ResolveInlineTablesSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/ResolveInlineTablesSuite.scala new file mode 100644 index 000000000000..920c6ea50f4b --- /dev/null +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/ResolveInlineTablesSuite.scala @@ -0,0 +1,101 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.analysis + +import org.scalatest.BeforeAndAfter + +import org.apache.spark.sql.AnalysisException +import org.apache.spark.sql.catalyst.expressions.{Literal, Rand} +import org.apache.spark.sql.catalyst.expressions.aggregate.Count +import org.apache.spark.sql.catalyst.plans.PlanTest +import org.apache.spark.sql.types.{LongType, NullType} + +/** + * Unit tests for [[ResolveInlineTables]]. Note that there are also test cases defined in + * end-to-end tests (in sql/core module) for verifying the correct error messages are shown + * in negative cases. + */ +class ResolveInlineTablesSuite extends PlanTest with BeforeAndAfter { + + private def lit(v: Any): Literal = Literal(v) + + test("validate inputs are foldable") { + ResolveInlineTables.validateInputEvaluable( + UnresolvedInlineTable(Seq("c1", "c2"), Seq(Seq(lit(1))))) + + // nondeterministic (rand) should not work + intercept[AnalysisException] { + ResolveInlineTables.validateInputEvaluable( + UnresolvedInlineTable(Seq("c1"), Seq(Seq(Rand(1))))) + } + + // aggregate should not work + intercept[AnalysisException] { + ResolveInlineTables.validateInputEvaluable( + UnresolvedInlineTable(Seq("c1"), Seq(Seq(Count(lit(1)))))) + } + + // unresolved attribute should not work + intercept[AnalysisException] { + ResolveInlineTables.validateInputEvaluable( + UnresolvedInlineTable(Seq("c1"), Seq(Seq(UnresolvedAttribute("A"))))) + } + } + + test("validate input dimensions") { + ResolveInlineTables.validateInputDimension( + UnresolvedInlineTable(Seq("c1"), Seq(Seq(lit(1)), Seq(lit(2))))) + + // num alias != data dimension + intercept[AnalysisException] { + ResolveInlineTables.validateInputDimension( + UnresolvedInlineTable(Seq("c1", "c2"), Seq(Seq(lit(1)), Seq(lit(2))))) + } + + // num alias == data dimension, but data themselves are inconsistent + intercept[AnalysisException] { + ResolveInlineTables.validateInputDimension( + UnresolvedInlineTable(Seq("c1"), Seq(Seq(lit(1)), Seq(lit(21), lit(22))))) + } + } + + test("do not fire the rule if not all expressions are resolved") { + val table = UnresolvedInlineTable(Seq("c1", "c2"), Seq(Seq(UnresolvedAttribute("A")))) + assert(ResolveInlineTables(table) == table) + } + + test("convert") { + val table = UnresolvedInlineTable(Seq("c1"), Seq(Seq(lit(1)), Seq(lit(2L)))) + val converted = ResolveInlineTables.convert(table) + + assert(converted.output.map(_.dataType) == Seq(LongType)) + assert(converted.data.size == 2) + assert(converted.data(0).getLong(0) == 1L) + assert(converted.data(1).getLong(0) == 2L) + } + + test("nullability inference in convert") { + val table1 = UnresolvedInlineTable(Seq("c1"), Seq(Seq(lit(1)), Seq(lit(2L)))) + val converted1 = ResolveInlineTables.convert(table1) + assert(!converted1.schema.fields(0).nullable) + + val table2 = UnresolvedInlineTable(Seq("c1"), Seq(Seq(lit(1)), Seq(Literal(null, NullType)))) + val converted2 = ResolveInlineTables.convert(table2) + assert(converted2.schema.fields(0).nullable) + } +} diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercionSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercionSuite.scala index 971c99b67167..9560563a8ca5 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercionSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercionSuite.scala @@ -19,7 +19,7 @@ package org.apache.spark.sql.catalyst.analysis import java.sql.Timestamp -import org.apache.spark.sql.catalyst.analysis.TypeCoercion.{Division, FunctionArgumentConversion} +import org.apache.spark.sql.catalyst.analysis.TypeCoercion._ import org.apache.spark.sql.catalyst.dsl.expressions._ import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.plans.PlanTest @@ -283,6 +283,24 @@ class TypeCoercionSuite extends PlanTest { :: Cast(Literal(1), StringType) :: Cast(Literal("a"), StringType) :: Nil)) + + ruleTest(TypeCoercion.FunctionArgumentConversion, + CreateArray(Literal.create(null, DecimalType(5, 3)) + :: Literal(1) + :: Nil), + CreateArray(Literal.create(null, DecimalType(5, 3)).cast(DecimalType(13, 3)) + :: Literal(1).cast(DecimalType(13, 3)) + :: Nil)) + + ruleTest(TypeCoercion.FunctionArgumentConversion, + CreateArray(Literal.create(null, DecimalType(5, 3)) + :: Literal.create(null, DecimalType(22, 10)) + :: Literal.create(null, DecimalType(38, 38)) + :: Nil), + CreateArray(Literal.create(null, DecimalType(5, 3)).cast(DecimalType(38, 38)) + :: Literal.create(null, DecimalType(22, 10)).cast(DecimalType(38, 38)) + :: Literal.create(null, DecimalType(38, 38)).cast(DecimalType(38, 38)) + :: Nil)) } test("CreateMap casts") { @@ -298,6 +316,17 @@ class TypeCoercionSuite extends PlanTest { :: Cast(Literal.create(2.0, FloatType), FloatType) :: Literal("b") :: Nil)) + ruleTest(TypeCoercion.FunctionArgumentConversion, + CreateMap(Literal.create(null, DecimalType(5, 3)) + :: Literal("a") + :: Literal.create(2.0, FloatType) + :: Literal("b") + :: Nil), + CreateMap(Literal.create(null, DecimalType(5, 3)).cast(DoubleType) + :: Literal("a") + :: Literal.create(2.0, FloatType).cast(DoubleType) + :: Literal("b") + :: Nil)) // type coercion for map values ruleTest(TypeCoercion.FunctionArgumentConversion, CreateMap(Literal(1) @@ -310,6 +339,17 @@ class TypeCoercionSuite extends PlanTest { :: Literal(2) :: Cast(Literal(3.0), StringType) :: Nil)) + ruleTest(TypeCoercion.FunctionArgumentConversion, + CreateMap(Literal(1) + :: Literal.create(null, DecimalType(38, 0)) + :: Literal(2) + :: Literal.create(null, DecimalType(38, 38)) + :: Nil), + CreateMap(Literal(1) + :: Literal.create(null, DecimalType(38, 0)).cast(DecimalType(38, 38)) + :: Literal(2) + :: Literal.create(null, DecimalType(38, 38)).cast(DecimalType(38, 38)) + :: Nil)) // type coercion for both map keys and values ruleTest(TypeCoercion.FunctionArgumentConversion, CreateMap(Literal(1) @@ -344,6 +384,33 @@ class TypeCoercionSuite extends PlanTest { :: Cast(Literal(1), DecimalType(22, 0)) :: Cast(Literal(new java.math.BigDecimal("1000000000000000000000")), DecimalType(22, 0)) :: Nil)) + ruleTest(TypeCoercion.FunctionArgumentConversion, + operator(Literal(1.0) + :: Literal.create(null, DecimalType(10, 5)) + :: Literal(1) + :: Nil), + operator(Literal(1.0).cast(DoubleType) + :: Literal.create(null, DecimalType(10, 5)).cast(DoubleType) + :: Literal(1).cast(DoubleType) + :: Nil)) + ruleTest(TypeCoercion.FunctionArgumentConversion, + operator(Literal.create(null, DecimalType(15, 0)) + :: Literal.create(null, DecimalType(10, 5)) + :: Literal(1) + :: Nil), + operator(Literal.create(null, DecimalType(15, 0)).cast(DecimalType(20, 5)) + :: Literal.create(null, DecimalType(10, 5)).cast(DecimalType(20, 5)) + :: Literal(1).cast(DecimalType(20, 5)) + :: Nil)) + ruleTest(TypeCoercion.FunctionArgumentConversion, + operator(Literal.create(2L, LongType) + :: Literal(1) + :: Literal.create(null, DecimalType(10, 5)) + :: Nil), + operator(Literal.create(2L, LongType).cast(DecimalType(25, 5)) + :: Literal(1).cast(DecimalType(25, 5)) + :: Literal.create(null, DecimalType(10, 5)).cast(DecimalType(25, 5)) + :: Nil)) } } @@ -663,6 +730,13 @@ class TypeCoercionSuite extends PlanTest { // the right expression to Decimal. ruleTest(rules, sum(Divide(Decimal(4.0), 3)), sum(Divide(Decimal(4.0), 3))) } + + test("SPARK-17117 null type coercion in divide") { + val rules = Seq(FunctionArgumentConversion, Division, ImplicitTypeCasts) + val nullLit = Literal.create(null, NullType) + ruleTest(rules, Divide(1L, nullLit), Divide(Cast(1L, DoubleType), Cast(nullLit, DoubleType))) + ruleTest(rules, Divide(nullLit, 1L), Divide(Cast(nullLit, DoubleType), Cast(1L, DoubleType))) + } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/catalog/ExternalCatalogSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/catalog/ExternalCatalogSuite.scala index 0c4d36336502..31e422e91aea 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/catalog/ExternalCatalogSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/catalog/ExternalCatalogSuite.scala @@ -25,6 +25,7 @@ import org.scalatest.BeforeAndAfterEach import org.apache.spark.SparkFunSuite import org.apache.spark.sql.AnalysisException import org.apache.spark.sql.catalyst.{FunctionIdentifier, TableIdentifier} +import org.apache.spark.sql.catalyst.analysis.{FunctionAlreadyExistsException, NoSuchDatabaseException, NoSuchFunctionException} import org.apache.spark.util.Utils @@ -439,14 +440,14 @@ abstract class ExternalCatalogSuite extends SparkFunSuite with BeforeAndAfterEac test("create function when database does not exist") { val catalog = newBasicCatalog() - intercept[AnalysisException] { + intercept[NoSuchDatabaseException] { catalog.createFunction("does_not_exist", newFunc()) } } test("create function that already exists") { val catalog = newBasicCatalog() - intercept[AnalysisException] { + intercept[FunctionAlreadyExistsException] { catalog.createFunction("db2", newFunc("func1")) } } @@ -460,14 +461,14 @@ abstract class ExternalCatalogSuite extends SparkFunSuite with BeforeAndAfterEac test("drop function when database does not exist") { val catalog = newBasicCatalog() - intercept[AnalysisException] { + intercept[NoSuchDatabaseException] { catalog.dropFunction("does_not_exist", "something") } } test("drop function that does not exist") { val catalog = newBasicCatalog() - intercept[AnalysisException] { + intercept[NoSuchFunctionException] { catalog.dropFunction("db2", "does_not_exist") } } @@ -477,14 +478,14 @@ abstract class ExternalCatalogSuite extends SparkFunSuite with BeforeAndAfterEac assert(catalog.getFunction("db2", "func1") == CatalogFunction(FunctionIdentifier("func1", Some("db2")), funcClass, Seq.empty[FunctionResource])) - intercept[AnalysisException] { + intercept[NoSuchFunctionException] { catalog.getFunction("db2", "does_not_exist") } } test("get function when database does not exist") { val catalog = newBasicCatalog() - intercept[AnalysisException] { + intercept[NoSuchDatabaseException] { catalog.getFunction("does_not_exist", "func1") } } @@ -494,15 +495,15 @@ abstract class ExternalCatalogSuite extends SparkFunSuite with BeforeAndAfterEac val newName = "funcky" assert(catalog.getFunction("db2", "func1").className == funcClass) catalog.renameFunction("db2", "func1", newName) - intercept[AnalysisException] { catalog.getFunction("db2", "func1") } + intercept[NoSuchFunctionException] { catalog.getFunction("db2", "func1") } assert(catalog.getFunction("db2", newName).identifier.funcName == newName) assert(catalog.getFunction("db2", newName).className == funcClass) - intercept[AnalysisException] { catalog.renameFunction("db2", "does_not_exist", "me") } + intercept[NoSuchFunctionException] { catalog.renameFunction("db2", "does_not_exist", "me") } } test("rename function when database does not exist") { val catalog = newBasicCatalog() - intercept[AnalysisException] { + intercept[NoSuchDatabaseException] { catalog.renameFunction("does_not_exist", "func1", "func5") } } @@ -510,7 +511,7 @@ abstract class ExternalCatalogSuite extends SparkFunSuite with BeforeAndAfterEac test("rename function when new function already exists") { val catalog = newBasicCatalog() catalog.createFunction("db2", newFunc("func2", Some("db2"))) - intercept[AnalysisException] { + intercept[FunctionAlreadyExistsException] { catalog.renameFunction("db2", "func1", "func2") } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/catalog/SessionCatalogSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/catalog/SessionCatalogSuite.scala index 67ca0aadcc13..399b7067f4a0 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/catalog/SessionCatalogSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/catalog/SessionCatalogSuite.scala @@ -201,16 +201,16 @@ class SessionCatalogSuite extends SparkFunSuite { val tempTable2 = Range(1, 20, 2, 10) catalog.createTempView("tbl1", tempTable1, overrideIfExists = false) catalog.createTempView("tbl2", tempTable2, overrideIfExists = false) - assert(catalog.getTempTable("tbl1") == Option(tempTable1)) - assert(catalog.getTempTable("tbl2") == Option(tempTable2)) - assert(catalog.getTempTable("tbl3").isEmpty) + assert(catalog.getTempView("tbl1") == Option(tempTable1)) + assert(catalog.getTempView("tbl2") == Option(tempTable2)) + assert(catalog.getTempView("tbl3").isEmpty) // Temporary table already exists intercept[TempTableAlreadyExistsException] { catalog.createTempView("tbl1", tempTable1, overrideIfExists = false) } // Temporary table already exists but we override it catalog.createTempView("tbl1", tempTable2, overrideIfExists = true) - assert(catalog.getTempTable("tbl1") == Option(tempTable2)) + assert(catalog.getTempView("tbl1") == Option(tempTable2)) } test("drop table") { @@ -246,11 +246,11 @@ class SessionCatalogSuite extends SparkFunSuite { val tempTable = Range(1, 10, 2, 10) sessionCatalog.createTempView("tbl1", tempTable, overrideIfExists = false) sessionCatalog.setCurrentDatabase("db2") - assert(sessionCatalog.getTempTable("tbl1") == Some(tempTable)) + assert(sessionCatalog.getTempView("tbl1") == Some(tempTable)) assert(externalCatalog.listTables("db2").toSet == Set("tbl1", "tbl2")) // If database is not specified, temp table should be dropped first sessionCatalog.dropTable(TableIdentifier("tbl1"), ignoreIfNotExists = false) - assert(sessionCatalog.getTempTable("tbl1") == None) + assert(sessionCatalog.getTempView("tbl1") == None) assert(externalCatalog.listTables("db2").toSet == Set("tbl1", "tbl2")) // If temp table does not exist, the table in the current database should be dropped sessionCatalog.dropTable(TableIdentifier("tbl1"), ignoreIfNotExists = false) @@ -259,7 +259,7 @@ class SessionCatalogSuite extends SparkFunSuite { sessionCatalog.createTempView("tbl1", tempTable, overrideIfExists = false) sessionCatalog.createTable(newTable("tbl1", "db2"), ignoreIfExists = false) sessionCatalog.dropTable(TableIdentifier("tbl1", Some("db2")), ignoreIfNotExists = false) - assert(sessionCatalog.getTempTable("tbl1") == Some(tempTable)) + assert(sessionCatalog.getTempView("tbl1") == Some(tempTable)) assert(externalCatalog.listTables("db2").toSet == Set("tbl2")) } @@ -307,18 +307,18 @@ class SessionCatalogSuite extends SparkFunSuite { val tempTable = Range(1, 10, 2, 10) sessionCatalog.createTempView("tbl1", tempTable, overrideIfExists = false) sessionCatalog.setCurrentDatabase("db2") - assert(sessionCatalog.getTempTable("tbl1") == Option(tempTable)) + assert(sessionCatalog.getTempView("tbl1") == Option(tempTable)) assert(externalCatalog.listTables("db2").toSet == Set("tbl1", "tbl2")) // If database is not specified, temp table should be renamed first sessionCatalog.renameTable(TableIdentifier("tbl1"), TableIdentifier("tbl3")) - assert(sessionCatalog.getTempTable("tbl1").isEmpty) - assert(sessionCatalog.getTempTable("tbl3") == Option(tempTable)) + assert(sessionCatalog.getTempView("tbl1").isEmpty) + assert(sessionCatalog.getTempView("tbl3") == Option(tempTable)) assert(externalCatalog.listTables("db2").toSet == Set("tbl1", "tbl2")) // If database is specified, temp tables are never renamed sessionCatalog.renameTable( TableIdentifier("tbl2", Some("db2")), TableIdentifier("tbl4", Some("db2"))) - assert(sessionCatalog.getTempTable("tbl3") == Option(tempTable)) - assert(sessionCatalog.getTempTable("tbl4").isEmpty) + assert(sessionCatalog.getTempView("tbl3") == Option(tempTable)) + assert(sessionCatalog.getTempView("tbl4").isEmpty) assert(externalCatalog.listTables("db2").toSet == Set("tbl1", "tbl4")) } @@ -423,46 +423,37 @@ class SessionCatalogSuite extends SparkFunSuite { assert(!catalog.tableExists(TableIdentifier("tbl2", Some("db1")))) // If database is explicitly specified, do not check temporary tables val tempTable = Range(1, 10, 1, 10) - catalog.createTempView("tbl3", tempTable, overrideIfExists = false) assert(!catalog.tableExists(TableIdentifier("tbl3", Some("db2")))) // If database is not explicitly specified, check the current database catalog.setCurrentDatabase("db2") assert(catalog.tableExists(TableIdentifier("tbl1"))) assert(catalog.tableExists(TableIdentifier("tbl2"))) - assert(catalog.tableExists(TableIdentifier("tbl3"))) - } - test("tableExists on temporary views") { - val catalog = new SessionCatalog(newBasicCatalog()) - val tempTable = Range(1, 10, 2, 10) - assert(!catalog.tableExists(TableIdentifier("view1"))) - assert(!catalog.tableExists(TableIdentifier("view1", Some("default")))) - catalog.createTempView("view1", tempTable, overrideIfExists = false) - assert(catalog.tableExists(TableIdentifier("view1"))) - assert(!catalog.tableExists(TableIdentifier("view1", Some("default")))) + catalog.createTempView("tbl3", tempTable, overrideIfExists = false) + // tableExists should not check temp view. + assert(!catalog.tableExists(TableIdentifier("tbl3"))) } - test("getTableMetadata on temporary views") { + test("getTempViewOrPermanentTableMetadata on temporary views") { val catalog = new SessionCatalog(newBasicCatalog()) val tempTable = Range(1, 10, 2, 10) - val m = intercept[AnalysisException] { - catalog.getTableMetadata(TableIdentifier("view1")) + intercept[NoSuchTableException] { + catalog.getTempViewOrPermanentTableMetadata(TableIdentifier("view1")) }.getMessage - assert(m.contains("Table or view 'view1' not found in database 'default'")) - val m2 = intercept[AnalysisException] { - catalog.getTableMetadata(TableIdentifier("view1", Some("default"))) + intercept[NoSuchTableException] { + catalog.getTempViewOrPermanentTableMetadata(TableIdentifier("view1", Some("default"))) }.getMessage - assert(m2.contains("Table or view 'view1' not found in database 'default'")) catalog.createTempView("view1", tempTable, overrideIfExists = false) - assert(catalog.getTableMetadata(TableIdentifier("view1")).identifier.table == "view1") - assert(catalog.getTableMetadata(TableIdentifier("view1")).schema(0).name == "id") + assert(catalog.getTempViewOrPermanentTableMetadata( + TableIdentifier("view1")).identifier.table == "view1") + assert(catalog.getTempViewOrPermanentTableMetadata( + TableIdentifier("view1")).schema(0).name == "id") - val m3 = intercept[AnalysisException] { - catalog.getTableMetadata(TableIdentifier("view1", Some("default"))) + intercept[NoSuchTableException] { + catalog.getTempViewOrPermanentTableMetadata(TableIdentifier("view1", Some("default"))) }.getMessage - assert(m3.contains("Table or view 'view1' not found in database 'default'")) } test("list tables without pattern") { diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ArithmeticExpressionSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ArithmeticExpressionSuite.scala index 2e37887fbc82..069c3b37ffa9 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ArithmeticExpressionSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ArithmeticExpressionSuite.scala @@ -173,6 +173,17 @@ class ArithmeticExpressionSuite extends SparkFunSuite with ExpressionEvalHelper // } } + test("SPARK-17617: % (Remainder) double % double on super big double") { + val leftDouble = Literal(-5083676433652386516D) + val rightDouble = Literal(10D) + checkEvaluation(Remainder(leftDouble, rightDouble), -6.0D) + + // Float has smaller precision + val leftFloat = Literal(-5083676433652386516F) + val rightFloat = Literal(10F) + checkEvaluation(Remainder(leftFloat, rightFloat), -2.0F) + } + test("Abs") { testNumericDataTypes { convert => val input = Literal(convert(1)) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CastSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CastSuite.scala index 5ae0527a9c7a..5c35baacef2f 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CastSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CastSuite.scala @@ -727,6 +727,16 @@ class CastSuite extends SparkFunSuite with ExpressionEvalHelper { } } + test("cast struct with a timestamp field") { + val originalSchema = new StructType().add("tsField", TimestampType, nullable = false) + // nine out of ten times I'm casting a struct, it's to normalize its fields nullability + val targetSchema = new StructType().add("tsField", TimestampType, nullable = true) + + val inp = Literal.create(InternalRow(0L), originalSchema) + val expected = InternalRow(0L) + checkEvaluation(cast(inp, targetSchema), expected) + } + test("complex casting") { val complex = Literal.create( Row( diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CodeGenerationSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CodeGenerationSuite.scala index 8ea8f6115084..45dcfcaf2313 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CodeGenerationSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CodeGenerationSuite.scala @@ -23,7 +23,7 @@ import org.apache.spark.sql.Row import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.dsl.expressions._ import org.apache.spark.sql.catalyst.expressions.codegen._ -import org.apache.spark.sql.catalyst.expressions.objects.CreateExternalRow +import org.apache.spark.sql.catalyst.expressions.objects.{CreateExternalRow, GetExternalRowField, ValidateExternalType} import org.apache.spark.sql.catalyst.util.{ArrayBasedMapData, GenericArrayData} import org.apache.spark.sql.types._ import org.apache.spark.unsafe.types.UTF8String @@ -58,8 +58,8 @@ class CodeGenerationSuite extends SparkFunSuite with ExpressionEvalHelper { GenerateOrdering.generate(Add(Literal(123), Literal(1)).asc :: Nil) assert(CodegenMetrics.METRIC_COMPILATION_TIME.getCount() == startCount1 + 1) assert(CodegenMetrics.METRIC_SOURCE_CODE_SIZE.getCount() == startCount2 + 1) - assert(CodegenMetrics.METRIC_GENERATED_CLASS_BYTECODE_SIZE.getCount() > startCount1) - assert(CodegenMetrics.METRIC_GENERATED_METHOD_BYTECODE_SIZE.getCount() > startCount1) + assert(CodegenMetrics.METRIC_GENERATED_CLASS_BYTECODE_SIZE.getCount() > startCount3) + assert(CodegenMetrics.METRIC_GENERATED_METHOD_BYTECODE_SIZE.getCount() > startCount4) } test("SPARK-8443: split wide projections into blocks due to JVM code size limit") { @@ -265,4 +265,15 @@ class CodeGenerationSuite extends SparkFunSuite with ExpressionEvalHelper { Literal.create("\\\\u001/Compilation error occurs", StringType) :: Nil) } + + test("SPARK-17160: field names are properly escaped by GetExternalRowField") { + val inputObject = BoundReference(0, ObjectType(classOf[Row]), nullable = true) + GenerateUnsafeProjection.generate( + ValidateExternalType( + GetExternalRowField(inputObject, index = 0, fieldName = "\"quote"), IntegerType) :: Nil) + } + + test("SPARK-17160: field names are properly escaped by AssertTrue") { + GenerateUnsafeProjection.generate(AssertTrue(Cast(Literal("\""), BooleanType)) :: Nil) + } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ConditionalExpressionSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ConditionalExpressionSuite.scala index 3c581ecdaf06..36185b8c637a 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ConditionalExpressionSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ConditionalExpressionSuite.scala @@ -21,6 +21,7 @@ import java.sql.{Date, Timestamp} import org.apache.spark.SparkFunSuite import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.analysis.TypeCheckResult.TypeCheckFailure import org.apache.spark.sql.catalyst.dsl.expressions._ import org.apache.spark.sql.types._ @@ -181,6 +182,12 @@ class ConditionalExpressionSuite extends SparkFunSuite with ExpressionEvalHelper Literal(Timestamp.valueOf("2015-07-01 10:00:00")))), Timestamp.valueOf("2015-07-01 08:00:00"), InternalRow.empty) + // Type checking error + assert( + Least(Seq(Literal(1), Literal("1"))).checkInputDataTypes() == + TypeCheckFailure("The expressions should all have the same type, " + + "got LEAST(int, string).")) + DataTypeTestUtils.ordered.foreach { dt => checkConsistencyBetweenInterpretedAndCodegen(Least, dt, 2) } @@ -227,6 +234,12 @@ class ConditionalExpressionSuite extends SparkFunSuite with ExpressionEvalHelper Literal(Timestamp.valueOf("2015-07-01 10:00:00")))), Timestamp.valueOf("2015-07-01 10:00:00"), InternalRow.empty) + // Type checking error + assert( + Greatest(Seq(Literal(1), Literal("1"))).checkInputDataTypes() == + TypeCheckFailure("The expressions should all have the same type, " + + "got GREATEST(int, string).")) + DataTypeTestUtils.ordered.foreach { dt => checkConsistencyBetweenInterpretedAndCodegen(Greatest, dt, 2) } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvalHelper.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvalHelper.scala index d6a9672d1f18..668543a28bd3 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvalHelper.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvalHelper.scala @@ -136,7 +136,7 @@ trait ExpressionEvalHelper extends GeneratorDrivenPropertyChecks { // some expression is reusing variable names across different instances. // This behavior is tested in ExpressionEvalHelperSuite. val plan = generateProject( - GenerateUnsafeProjection.generate( + UnsafeProjection.create( Alias(expression, s"Optimized($expression)1")() :: Alias(expression, s"Optimized($expression)2")() :: Nil), expression) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ObjectExpressionSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ObjectExpressionSuite.scala new file mode 100644 index 000000000000..b6263e77c141 --- /dev/null +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ObjectExpressionSuite.scala @@ -0,0 +1,56 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.sql.catalyst.expressions + +import org.apache.spark.SparkFunSuite +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder +import org.apache.spark.sql.catalyst.util.{ArrayBasedMapData, GenericArrayData} + +class ObjectExpressionSuite extends SparkFunSuite with ExpressionEvalHelper { + test("MapObjects should make copies of unsafe-backed data") { + // test UnsafeRow-backed data + val structEncoder = ExpressionEncoder[Array[(java.lang.Integer, java.lang.Integer)]]() + val structInputRow = InternalRow.fromSeq(Seq(Array((1, 2), (3, 4)))) + val structExpected = new GenericArrayData( + Array(InternalRow.fromSeq(Seq(1, 2)), InternalRow.fromSeq(Seq(3, 4)))) + checkEvalutionWithUnsafeProjection( + structEncoder.serializer.head, structExpected, structInputRow) + + // test UnsafeArray-backed data + val arrayEncoder = ExpressionEncoder[Array[Array[Int]]]() + val arrayInputRow = InternalRow.fromSeq(Seq(Array(Array(1, 2), Array(3, 4)))) + val arrayExpected = new GenericArrayData( + Array(new GenericArrayData(Array(1, 2)), new GenericArrayData(Array(3, 4)))) + checkEvalutionWithUnsafeProjection( + arrayEncoder.serializer.head, arrayExpected, arrayInputRow) + + // test UnsafeMap-backed data + val mapEncoder = ExpressionEncoder[Array[Map[Int, Int]]]() + val mapInputRow = InternalRow.fromSeq(Seq(Array( + Map(1 -> 100, 2 -> 200), Map(3 -> 300, 4 -> 400)))) + val mapExpected = new GenericArrayData(Seq( + new ArrayBasedMapData( + new GenericArrayData(Array(1, 2)), + new GenericArrayData(Array(100, 200))), + new ArrayBasedMapData( + new GenericArrayData(Array(3, 4)), + new GenericArrayData(Array(300, 400))))) + checkEvalutionWithUnsafeProjection( + mapEncoder.serializer.head, mapExpected, mapInputRow) + } +} diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ScalaUDFSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ScalaUDFSuite.scala new file mode 100644 index 000000000000..7e45028653e3 --- /dev/null +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ScalaUDFSuite.scala @@ -0,0 +1,48 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.expressions + +import org.apache.spark.{SparkException, SparkFunSuite} +import org.apache.spark.sql.types.{IntegerType, StringType} + +class ScalaUDFSuite extends SparkFunSuite with ExpressionEvalHelper { + + test("basic") { + val intUdf = ScalaUDF((i: Int) => i + 1, IntegerType, Literal(1) :: Nil) + checkEvaluation(intUdf, 2) + + val stringUdf = ScalaUDF((s: String) => s + "x", StringType, Literal("a") :: Nil) + checkEvaluation(stringUdf, "ax") + } + + test("better error message for NPE") { + val udf = ScalaUDF( + (s: String) => s.toLowerCase, + StringType, + Literal.create(null, StringType) :: Nil) + + val e1 = intercept[SparkException](udf.eval()) + assert(e1.getMessage.contains("Failed to execute user defined function")) + + val e2 = intercept[SparkException] { + checkEvalutionWithUnsafeProjection(udf, null) + } + assert(e2.getMessage.contains("Failed to execute user defined function")) + } + +} diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/TimeWindowSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/TimeWindowSuite.scala index b82cf8d1693e..d6c8fcf29184 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/TimeWindowSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/TimeWindowSuite.scala @@ -108,4 +108,16 @@ class TimeWindowSuite extends SparkFunSuite with ExpressionEvalHelper with Priva TimeWindow.invokePrivate(parseExpression(Rand(123))) } } + + test("SPARK-16837: TimeWindow.apply equivalent to TimeWindow constructor") { + val slideLength = "1 second" + for (windowLength <- Seq("10 second", "1 minute", "2 hours")) { + val applyValue = TimeWindow(Literal(10L), windowLength, slideLength, "0 seconds") + val constructed = new TimeWindow(Literal(10L), + Literal(windowLength), + Literal(slideLength), + Literal("0 seconds")) + assert(applyValue == constructed) + } + } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/AggregateOptimizeSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/AggregateOptimizeSuite.scala index 4c26c184b7b5..aecf59aee6a9 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/AggregateOptimizeSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/AggregateOptimizeSuite.scala @@ -28,7 +28,7 @@ import org.apache.spark.sql.catalyst.plans.logical.{LocalRelation, LogicalPlan} import org.apache.spark.sql.catalyst.rules.RuleExecutor class AggregateOptimizeSuite extends PlanTest { - val conf = new SimpleCatalystConf(caseSensitiveAnalysis = false) + val conf = SimpleCatalystConf(caseSensitiveAnalysis = false, groupByOrdinal = false) val catalog = new SessionCatalog(new InMemoryCatalog, EmptyFunctionRegistry, conf) val analyzer = new Analyzer(catalog, conf) @@ -49,6 +49,14 @@ class AggregateOptimizeSuite extends PlanTest { comparePlans(optimized, correctAnswer) } + test("do not remove all grouping expressions if they are all literals") { + val query = testRelation.groupBy(Literal("1"), Literal(1) + Literal(2))(sum('b)) + val optimized = Optimize.execute(analyzer.execute(query)) + val correctAnswer = analyzer.execute(testRelation.groupBy(Literal(0))(sum('b))) + + comparePlans(optimized, correctAnswer) + } + test("Remove aliased literals") { val query = testRelation.select('a, Literal(1).as('y)).groupBy('a, 'y)(sum('b)) val optimized = Optimize.execute(analyzer.execute(query)) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ColumnPruningSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ColumnPruningSuite.scala index b5664a5e699e..589607e3ad5c 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ColumnPruningSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ColumnPruningSuite.scala @@ -346,5 +346,20 @@ class ColumnPruningSuite extends PlanTest { comparePlans(Optimize.execute(plan1.analyze), correctAnswer1) } + test("push project down into sample") { + val testRelation = LocalRelation('a.int, 'b.int, 'c.int) + val x = testRelation.subquery('x) + + val query1 = Sample(0.0, 0.6, false, 11L, x)().select('a) + val optimized1 = Optimize.execute(query1.analyze) + val expected1 = Sample(0.0, 0.6, false, 11L, x.select('a))() + comparePlans(optimized1, expected1.analyze) + + val query2 = Sample(0.0, 0.6, false, 11L, x)().select('a as 'aa) + val optimized2 = Optimize.execute(query2.analyze) + val expected2 = Sample(0.0, 0.6, false, 11L, x.select('a))().select('a as 'aa) + comparePlans(optimized2, expected2.analyze) + } + // todo: add more tests for column pruning } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/FilterPushdownSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/FilterPushdownSuite.scala index 9cb49e74ad34..ea868d1a73a9 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/FilterPushdownSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/FilterPushdownSuite.scala @@ -34,7 +34,6 @@ class FilterPushdownSuite extends PlanTest { Batch("Subqueries", Once, EliminateSubqueryAliases) :: Batch("Filter Pushdown", FixedPoint(10), - PushProjectThroughSample, CombineFilters, PushDownPredicate, BooleanSimplification, @@ -112,6 +111,12 @@ class FilterPushdownSuite extends PlanTest { assert(optimized == correctAnswer) } + test("SPARK-16994: filter should not be pushed through limit") { + val originalQuery = testRelation.limit(10).where('a === 1).analyze + val optimized = Optimize.execute(originalQuery) + comparePlans(optimized, originalQuery) + } + test("can't push without rewrite") { val originalQuery = testRelation @@ -585,22 +590,6 @@ class FilterPushdownSuite extends PlanTest { comparePlans(optimized, originalQuery) } - test("push project and filter down into sample") { - val x = testRelation.subquery('x) - val originalQuery = - Sample(0.0, 0.6, false, 11L, x)().select('a) - - val originalQueryAnalyzed = - EliminateSubqueryAliases(analysis.SimpleAnalyzer.execute(originalQuery)) - - val optimized = Optimize.execute(originalQueryAnalyzed) - - val correctAnswer = - Sample(0.0, 0.6, false, 11L, x.select('a))() - - comparePlans(optimized, correctAnswer.analyze) - } - test("aggregate: push down filter when filter on group by expression") { val originalQuery = testRelation .groupBy('a)('a, count('b) as 'c) @@ -998,4 +987,18 @@ class FilterPushdownSuite extends PlanTest { comparePlans(Optimize.execute(originalQuery.analyze), correctAnswer) } + + test("join condition pushdown: deterministic and non-deterministic") { + val x = testRelation.subquery('x) + val y = testRelation.subquery('y) + + // Verify that all conditions preceding the first non-deterministic condition are pushed down + // by the optimizer and others are not. + val originalQuery = x.join(y, condition = Some("x.a".attr === 5 && "y.a".attr === 5 && + "x.a".attr === Rand(10) && "y.b".attr === 5)) + val correctAnswer = x.where("x.a".attr === 5).join(y.where("y.a".attr === 5), + condition = Some("x.a".attr === Rand(10) && "y.b".attr === 5)) + + comparePlans(Optimize.execute(originalQuery.analyze), correctAnswer.analyze) + } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/OuterJoinEliminationSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/OuterJoinEliminationSuite.scala index 41754adef421..c168a55e40c5 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/OuterJoinEliminationSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/OuterJoinEliminationSuite.scala @@ -20,6 +20,7 @@ package org.apache.spark.sql.catalyst.optimizer import org.apache.spark.sql.catalyst.analysis.EliminateSubqueryAliases import org.apache.spark.sql.catalyst.dsl.expressions._ import org.apache.spark.sql.catalyst.dsl.plans._ +import org.apache.spark.sql.catalyst.expressions.{Coalesce, IsNotNull} import org.apache.spark.sql.catalyst.plans._ import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.catalyst.rules._ @@ -192,4 +193,42 @@ class OuterJoinEliminationSuite extends PlanTest { comparePlans(optimized, correctAnswer) } + + test("joins: no outer join elimination if the filter is not NULL eliminated") { + val x = testRelation.subquery('x) + val y = testRelation1.subquery('y) + + val originalQuery = + x.join(y, FullOuter, Option("x.a".attr === "y.d".attr)) + .where(Coalesce("y.e".attr :: "x.a".attr :: Nil)) + + val optimized = Optimize.execute(originalQuery.analyze) + + val left = testRelation + val right = testRelation1 + val correctAnswer = + left.join(right, FullOuter, Option("a".attr === "d".attr)) + .where(Coalesce("e".attr :: "a".attr :: Nil)).analyze + + comparePlans(optimized, correctAnswer) + } + + test("joins: no outer join elimination if the filter's constraints are not NULL eliminated") { + val x = testRelation.subquery('x) + val y = testRelation1.subquery('y) + + val originalQuery = + x.join(y, FullOuter, Option("x.a".attr === "y.d".attr)) + .where(IsNotNull(Coalesce("y.e".attr :: "x.a".attr :: Nil))) + + val optimized = Optimize.execute(originalQuery.analyze) + + val left = testRelation + val right = testRelation1 + val correctAnswer = + left.join(right, FullOuter, Option("a".attr === "d".attr)) + .where(IsNotNull(Coalesce("e".attr :: "a".attr :: Nil))).analyze + + comparePlans(optimized, correctAnswer) + } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/RewriteDistinctAggregatesSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/RewriteDistinctAggregatesSuite.scala new file mode 100644 index 000000000000..0b973c3b659c --- /dev/null +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/RewriteDistinctAggregatesSuite.scala @@ -0,0 +1,94 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.sql.catalyst.optimizer + +import org.apache.spark.sql.catalyst.SimpleCatalystConf +import org.apache.spark.sql.catalyst.analysis.{Analyzer, EmptyFunctionRegistry} +import org.apache.spark.sql.catalyst.catalog.{InMemoryCatalog, SessionCatalog} +import org.apache.spark.sql.catalyst.dsl.expressions._ +import org.apache.spark.sql.catalyst.dsl.plans._ +import org.apache.spark.sql.catalyst.expressions.{If, Literal} +import org.apache.spark.sql.catalyst.expressions.aggregate.{CollectSet, Count} +import org.apache.spark.sql.catalyst.plans.PlanTest +import org.apache.spark.sql.catalyst.plans.logical.{Aggregate, Expand, LocalRelation, LogicalPlan} +import org.apache.spark.sql.types.{IntegerType, StringType} + +class RewriteDistinctAggregatesSuite extends PlanTest { + val conf = SimpleCatalystConf(caseSensitiveAnalysis = false, groupByOrdinal = false) + val catalog = new SessionCatalog(new InMemoryCatalog, EmptyFunctionRegistry, conf) + val analyzer = new Analyzer(catalog, conf) + + val nullInt = Literal(null, IntegerType) + val nullString = Literal(null, StringType) + val testRelation = LocalRelation('a.string, 'b.string, 'c.string, 'd.string, 'e.int) + + private def checkRewrite(rewrite: LogicalPlan): Unit = rewrite match { + case Aggregate(_, _, Aggregate(_, _, _: Expand)) => + case _ => fail(s"Plan is not rewritten:\n$rewrite") + } + + test("single distinct group") { + val input = testRelation + .groupBy('a)(countDistinct('e)) + .analyze + val rewrite = RewriteDistinctAggregates(input) + comparePlans(input, rewrite) + } + + test("single distinct group with partial aggregates") { + val input = testRelation + .groupBy('a, 'd)( + countDistinct('e, 'c).as('agg1), + max('b).as('agg2)) + .analyze + val rewrite = RewriteDistinctAggregates(input) + comparePlans(input, rewrite) + } + + test("single distinct group with non-partial aggregates") { + val input = testRelation + .groupBy('a, 'd)( + countDistinct('e, 'c).as('agg1), + CollectSet('b).toAggregateExpression().as('agg2)) + .analyze + checkRewrite(RewriteDistinctAggregates(input)) + } + + test("multiple distinct groups") { + val input = testRelation + .groupBy('a)(countDistinct('b, 'c), countDistinct('d)) + .analyze + checkRewrite(RewriteDistinctAggregates(input)) + } + + test("multiple distinct groups with partial aggregates") { + val input = testRelation + .groupBy('a)(countDistinct('b, 'c), countDistinct('d), sum('e)) + .analyze + checkRewrite(RewriteDistinctAggregates(input)) + } + + test("multiple distinct groups with non-partial aggregates") { + val input = testRelation + .groupBy('a)( + countDistinct('b, 'c), + countDistinct('d), + CollectSet('b).toAggregateExpression()) + .analyze + checkRewrite(RewriteDistinctAggregates(input)) + } +} diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/ExpressionParserSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/ExpressionParserSuite.scala index e73592c7afa2..4aaae72fe91e 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/ExpressionParserSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/ExpressionParserSuite.scala @@ -18,7 +18,7 @@ package org.apache.spark.sql.catalyst.parser import java.sql.{Date, Timestamp} -import org.apache.spark.sql.catalyst.FunctionIdentifier +import org.apache.spark.sql.catalyst.{FunctionIdentifier, TableIdentifier} import org.apache.spark.sql.catalyst.analysis.{UnresolvedAttribute, _} import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.plans.PlanTest @@ -375,20 +375,30 @@ class ExpressionParserSuite extends PlanTest { // Tiny Int Literal assertEqual("10Y", Literal(10.toByte)) - intercept("-1000Y") + intercept("-1000Y", s"does not fit in range [${Byte.MinValue}, ${Byte.MaxValue}]") // Small Int Literal assertEqual("10S", Literal(10.toShort)) - intercept("40000S") + intercept("40000S", s"does not fit in range [${Short.MinValue}, ${Short.MaxValue}]") // Long Int Literal assertEqual("10L", Literal(10L)) - intercept("78732472347982492793712334L") + intercept("78732472347982492793712334L", + s"does not fit in range [${Long.MinValue}, ${Long.MaxValue}]") // Double Literal assertEqual("10.0D", Literal(10.0D)) + intercept("-1.8E308D", s"does not fit in range") + intercept("1.8E308D", s"does not fit in range") // TODO we need to figure out if we should throw an exception here! assertEqual("1E309", Literal(Double.PositiveInfinity)) + + // BigDecimal Literal + assertEqual("90912830918230182310293801923652346786BD", + Literal(BigDecimal("90912830918230182310293801923652346786").underlying())) + assertEqual("123.0E-28BD", Literal(BigDecimal("123.0E-28").underlying())) + assertEqual("123.08BD", Literal(BigDecimal("123.08").underlying())) + intercept("1.20E-38BD", "DecimalType can only support precision up to 38") } test("strings") { @@ -502,4 +512,22 @@ class ExpressionParserSuite extends PlanTest { assertEqual("1 - f('o', o(bar))", Literal(1) - 'f.function("o", 'o.function('bar))) intercept("1 - f('o', o(bar)) hello * world", "mismatched input '*'") } + + test("current date/timestamp braceless expressions") { + assertEqual("current_date", CurrentDate()) + assertEqual("current_timestamp", CurrentTimestamp()) + } + + test("SPARK-17364, fully qualified column name which starts with number") { + assertEqual("123_", UnresolvedAttribute("123_")) + assertEqual("1a.123_", UnresolvedAttribute("1a.123_")) + // ".123" should not be treated as token of type DECIMAL_VALUE + assertEqual("a.123A", UnresolvedAttribute("a.123A")) + // ".123E3" should not be treated as token of type SCIENTIFIC_DECIMAL_VALUE + assertEqual("a.123E3_column", UnresolvedAttribute("a.123E3_column")) + // ".123D" should not be treated as token of type DOUBLE_LITERAL + assertEqual("a.123D_column", UnresolvedAttribute("a.123D_column")) + // ".123BD" should not be treated as token of type BIGDECIMAL_LITERAL + assertEqual("a.123BD_column", UnresolvedAttribute("a.123BD_column")) + } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/PlanParserSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/PlanParserSuite.scala index 456948d6455c..ac9c494a0b43 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/PlanParserSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/PlanParserSuite.scala @@ -17,9 +17,8 @@ package org.apache.spark.sql.catalyst.parser -import org.apache.spark.sql.Row import org.apache.spark.sql.catalyst.FunctionIdentifier -import org.apache.spark.sql.catalyst.analysis.UnresolvedGenerator +import org.apache.spark.sql.catalyst.analysis.{UnresolvedGenerator, UnresolvedInlineTable, UnresolvedTableValuedFunction} import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.plans._ import org.apache.spark.sql.catalyst.plans.logical._ @@ -358,10 +357,54 @@ class PlanParserSuite extends PlanTest { test("left anti join", LeftAnti, testExistence) test("anti join", LeftAnti, testExistence) + // Test natural cross join + intercept("select * from a natural cross join b") + + // Test natural join with a condition + intercept("select * from a natural join b on a.id = b.id") + // Test multiple consecutive joins assertEqual( "select * from a join b join c right join d", table("a").join(table("b")).join(table("c")).join(table("d"), RightOuter).select(star())) + + // SPARK-17296 + assertEqual( + "select * from t1 cross join t2 join t3 on t3.id = t1.id join t4 on t4.id = t1.id", + table("t1") + .join(table("t2"), Inner) + .join(table("t3"), Inner, Option(Symbol("t3.id") === Symbol("t1.id"))) + .join(table("t4"), Inner, Option(Symbol("t4.id") === Symbol("t1.id"))) + .select(star())) + + // Test multiple on clauses. + intercept("select * from t1 inner join t2 inner join t3 on col3 = col2 on col3 = col1") + + // Parenthesis + assertEqual( + "select * from t1 inner join (t2 inner join t3 on col3 = col2) on col3 = col1", + table("t1") + .join(table("t2") + .join(table("t3"), Inner, Option('col3 === 'col2)), Inner, Option('col3 === 'col1)) + .select(star())) + assertEqual( + "select * from t1 inner join (t2 inner join t3) on col3 = col2", + table("t1") + .join(table("t2").join(table("t3"), Inner, None), Inner, Option('col3 === 'col2)) + .select(star())) + assertEqual( + "select * from t1 inner join (t2 inner join t3 on col3 = col2)", + table("t1") + .join(table("t2").join(table("t3"), Inner, Option('col3 === 'col2)), Inner, None) + .select(star())) + + // Implicit joins. + assertEqual( + "select * from t1, t3 join t2 on t1.col1 = t2.col2", + table("t1") + .join(table("t3")) + .join(table("t2"), Inner, Option(Symbol("t1.col1") === Symbol("t2.col2"))) + .select(star())) } test("sampled relations") { @@ -423,20 +466,21 @@ class PlanParserSuite extends PlanTest { assertEqual("table d.t", table("d", "t")) } + test("table valued function") { + assertEqual( + "select * from range(2)", + UnresolvedTableValuedFunction("range", Literal(2) :: Nil).select(star())) + } + test("inline table") { - assertEqual("values 1, 2, 3, 4", LocalRelation.fromExternalRows( - Seq('col1.int), - Seq(1, 2, 3, 4).map(x => Row(x)))) + assertEqual("values 1, 2, 3, 4", + UnresolvedInlineTable(Seq("col1"), Seq(1, 2, 3, 4).map(x => Seq(Literal(x))))) + assertEqual( - "values (1, 'a'), (2, 'b'), (3, 'c') as tbl(a, b)", - LocalRelation.fromExternalRows( - Seq('a.int, 'b.string), - Seq((1, "a"), (2, "b"), (3, "c")).map(x => Row(x._1, x._2))).as("tbl")) - intercept("values (a, 'a'), (b, 'b')", - "All expressions in an inline table must be constants.") - intercept("values (1, 'a'), (2, 'b') as tbl(a, b, c)", - "Number of aliases must match the number of fields in an inline table.") - intercept[ArrayIndexOutOfBoundsException](parsePlan("values (1, 'a'), (2, 'b', 5Y)")) + "values (1, 'a'), (2, 'b') as tbl(a, b)", + UnresolvedInlineTable( + Seq("a", "b"), + Seq(Literal(1), Literal("a")) :: Seq(Literal(2), Literal("b")) :: Nil).as("tbl")) } test("simple select query with !> and !<") { diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/TableIdentifierParserSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/TableIdentifierParserSuite.scala index 8bbf87e62d41..4d3ad2179139 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/TableIdentifierParserSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/TableIdentifierParserSuite.scala @@ -83,4 +83,17 @@ class TableIdentifierParserSuite extends SparkFunSuite { assert(TableIdentifier(nonReserved) === parseTableIdentifier(nonReserved)) } } + + test("SPARK-17364 table identifier - contains number") { + assert(parseTableIdentifier("123_") == TableIdentifier("123_")) + assert(parseTableIdentifier("1a.123_") == TableIdentifier("123_", Some("1a"))) + // ".123" should not be treated as token of type DECIMAL_VALUE + assert(parseTableIdentifier("a.123A") == TableIdentifier("123A", Some("a"))) + // ".123E3" should not be treated as token of type SCIENTIFIC_DECIMAL_VALUE + assert(parseTableIdentifier("a.123E3_LIST") == TableIdentifier("123E3_LIST", Some("a"))) + // ".123D" should not be treated as token of type DOUBLE_LITERAL + assert(parseTableIdentifier("a.123D_LIST") == TableIdentifier("123D_LIST", Some("a"))) + // ".123BD" should not be treated as token of type BIGDECIMAL_LITERAL + assert(parseTableIdentifier("a.123BD_LIST") == TableIdentifier("123BD_LIST", Some("a"))) + } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/ConstraintPropagationSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/ConstraintPropagationSuite.scala index 5a76969235ac..8d6a49a8a37b 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/ConstraintPropagationSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/ConstraintPropagationSuite.scala @@ -352,4 +352,21 @@ class ConstraintPropagationSuite extends SparkFunSuite { verifyConstraints(tr.analyze.constraints, ExpressionSet(Seq(IsNotNull(resolveColumn(tr, "b")), IsNotNull(resolveColumn(tr, "c"))))) } + + test("not infer non-deterministic constraints") { + val tr = LocalRelation('a.int, 'b.string, 'c.int) + + verifyConstraints(tr + .where('a.attr === Rand(0)) + .analyze.constraints, + ExpressionSet(Seq(IsNotNull(resolveColumn(tr, "a"))))) + + verifyConstraints(tr + .where('a.attr === InputFileName()) + .where('a.attr =!= 'c.attr) + .analyze.constraints, + ExpressionSet(Seq(resolveColumn(tr, "a") =!= resolveColumn(tr, "c"), + IsNotNull(resolveColumn(tr, "a")), + IsNotNull(resolveColumn(tr, "c"))))) + } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/util/DateTimeUtilsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/util/DateTimeUtilsSuite.scala index 059a5b7d07cd..4f516d006458 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/util/DateTimeUtilsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/util/DateTimeUtilsSuite.scala @@ -551,7 +551,8 @@ class DateTimeUtilsSuite extends SparkFunSuite { val skipped = skipped_days.getOrElse(tz.getID, Int.MinValue) (-20000 to 20000).foreach { d => if (d != skipped) { - assert(millisToDays(daysToMillis(d)) === d) + assert(millisToDays(daysToMillis(d)) === d, + s"Round trip of ${d} did not work in tz ${tz}") } } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/types/DataTypeSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/types/DataTypeSuite.scala index 6b85f12521c2..569230accfd7 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/types/DataTypeSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/types/DataTypeSuite.scala @@ -18,6 +18,7 @@ package org.apache.spark.sql.types import org.apache.spark.{SparkException, SparkFunSuite} +import org.apache.spark.sql.catalyst.parser.CatalystSqlParser class DataTypeSuite extends SparkFunSuite { @@ -342,4 +343,33 @@ class DataTypeSuite extends SparkFunSuite { StructField("a", StringType, nullable = false) :: StructField("b", StringType, nullable = false) :: Nil), expected = false) + + def checkCatalogString(dt: DataType): Unit = { + test(s"catalogString: $dt") { + val dt2 = CatalystSqlParser.parseDataType(dt.catalogString) + assert(dt === dt2) + } + } + def createStruct(n: Int): StructType = new StructType(Array.tabulate(n) { + i => StructField(s"col$i", IntegerType, nullable = true) + }) + + checkCatalogString(BooleanType) + checkCatalogString(ByteType) + checkCatalogString(ShortType) + checkCatalogString(IntegerType) + checkCatalogString(LongType) + checkCatalogString(FloatType) + checkCatalogString(DoubleType) + checkCatalogString(DecimalType(10, 5)) + checkCatalogString(BinaryType) + checkCatalogString(StringType) + checkCatalogString(DateType) + checkCatalogString(TimestampType) + checkCatalogString(createStruct(4)) + checkCatalogString(createStruct(40)) + checkCatalogString(ArrayType(IntegerType)) + checkCatalogString(ArrayType(createStruct(40))) + checkCatalogString(MapType(IntegerType, StringType)) + checkCatalogString(MapType(IntegerType, createStruct(40))) } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/types/DecimalSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/types/DecimalSuite.scala index e1675c95907a..4cf329ddee21 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/types/DecimalSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/types/DecimalSuite.scala @@ -22,6 +22,7 @@ import scala.language.postfixOps import org.scalatest.PrivateMethodTester import org.apache.spark.SparkFunSuite +import org.apache.spark.sql.types.Decimal._ class DecimalSuite extends SparkFunSuite with PrivateMethodTester { /** Check that a Decimal has the given string representation, precision and scale */ @@ -193,4 +194,18 @@ class DecimalSuite extends SparkFunSuite with PrivateMethodTester { assert(new Decimal().set(100L, 10, 0).toUnscaledLong === 100L) assert(Decimal(Long.MaxValue, 100, 0).toUnscaledLong === Long.MaxValue) } + + test("changePrecision() on compact decimal should respect rounding mode") { + Seq(ROUND_FLOOR, ROUND_CEILING, ROUND_HALF_UP, ROUND_HALF_EVEN).foreach { mode => + Seq("0.4", "0.5", "0.6", "1.0", "1.1", "1.6", "2.5", "5.5").foreach { n => + Seq("", "-").foreach { sign => + val bd = BigDecimal(sign + n) + val unscaled = (bd * 10).toLongExact + val d = Decimal(unscaled, 8, 1) + assert(d.changePrecision(10, 0, mode)) + assert(d.toString === bd.setScale(0, mode).toString(), s"num: $sign$n, mode: $mode") + } + } + } + } } diff --git a/sql/core/benchmarks/WideSchemaBenchmark-results.txt b/sql/core/benchmarks/WideSchemaBenchmark-results.txt index ea6a6616c23d..0b9f791ac85e 100644 --- a/sql/core/benchmarks/WideSchemaBenchmark-results.txt +++ b/sql/core/benchmarks/WideSchemaBenchmark-results.txt @@ -1,93 +1,117 @@ -OpenJDK 64-Bit Server VM 1.8.0_66-internal-b17 on Linux 4.2.0-36-generic -Intel(R) Xeon(R) CPU E5-1650 v3 @ 3.50GHz +Java HotSpot(TM) 64-Bit Server VM 1.8.0_92-b14 on Mac OS X 10.11.6 +Intel(R) Core(TM) i7-4980HQ CPU @ 2.80GHz + parsing large select: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative ------------------------------------------------------------------------------------------------ -1 select expressions 3 / 5 0.0 2967064.0 1.0X -100 select expressions 11 / 12 0.0 11369518.0 0.3X -2500 select expressions 243 / 250 0.0 242561004.0 0.0X +1 select expressions 2 / 4 0.0 2050147.0 1.0X +100 select expressions 6 / 7 0.0 6123412.0 0.3X +2500 select expressions 135 / 141 0.0 134623148.0 0.0X + +Java HotSpot(TM) 64-Bit Server VM 1.8.0_92-b14 on Mac OS X 10.11.6 +Intel(R) Core(TM) i7-4980HQ CPU @ 2.80GHz -OpenJDK 64-Bit Server VM 1.8.0_66-internal-b17 on Linux 4.2.0-36-generic -Intel(R) Xeon(R) CPU E5-1650 v3 @ 3.50GHz many column field r/w: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative ------------------------------------------------------------------------------------------------ -1 cols x 100000 rows (read in-mem) 28 / 40 3.6 278.8 1.0X -1 cols x 100000 rows (exec in-mem) 28 / 42 3.5 284.0 1.0X -1 cols x 100000 rows (read parquet) 23 / 35 4.4 228.8 1.2X -1 cols x 100000 rows (write parquet) 163 / 182 0.6 1633.0 0.2X -100 cols x 1000 rows (read in-mem) 27 / 39 3.7 266.9 1.0X -100 cols x 1000 rows (exec in-mem) 48 / 79 2.1 481.7 0.6X -100 cols x 1000 rows (read parquet) 25 / 36 3.9 254.3 1.1X -100 cols x 1000 rows (write parquet) 182 / 196 0.5 1819.5 0.2X -2500 cols x 40 rows (read in-mem) 280 / 315 0.4 2797.1 0.1X -2500 cols x 40 rows (exec in-mem) 606 / 638 0.2 6064.3 0.0X -2500 cols x 40 rows (read parquet) 836 / 843 0.1 8356.4 0.0X -2500 cols x 40 rows (write parquet) 490 / 522 0.2 4900.6 0.1X +1 cols x 100000 rows (read in-mem) 16 / 18 6.3 158.6 1.0X +1 cols x 100000 rows (exec in-mem) 17 / 19 6.0 166.7 1.0X +1 cols x 100000 rows (read parquet) 24 / 26 4.3 235.1 0.7X +1 cols x 100000 rows (write parquet) 81 / 85 1.2 811.3 0.2X +100 cols x 1000 rows (read in-mem) 17 / 19 6.0 166.2 1.0X +100 cols x 1000 rows (exec in-mem) 25 / 27 4.0 249.2 0.6X +100 cols x 1000 rows (read parquet) 23 / 25 4.4 226.0 0.7X +100 cols x 1000 rows (write parquet) 83 / 87 1.2 831.0 0.2X +2500 cols x 40 rows (read in-mem) 132 / 137 0.8 1322.9 0.1X +2500 cols x 40 rows (exec in-mem) 326 / 330 0.3 3260.6 0.0X +2500 cols x 40 rows (read parquet) 831 / 839 0.1 8305.8 0.0X +2500 cols x 40 rows (write parquet) 237 / 245 0.4 2372.6 0.1X + +Java HotSpot(TM) 64-Bit Server VM 1.8.0_92-b14 on Mac OS X 10.11.6 +Intel(R) Core(TM) i7-4980HQ CPU @ 2.80GHz -OpenJDK 64-Bit Server VM 1.8.0_66-internal-b17 on Linux 4.2.0-36-generic -Intel(R) Xeon(R) CPU E5-1650 v3 @ 3.50GHz wide shallowly nested struct field r/w: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative ------------------------------------------------------------------------------------------------ -1 wide x 100000 rows (read in-mem) 22 / 35 4.6 216.0 1.0X -1 wide x 100000 rows (exec in-mem) 40 / 63 2.5 400.6 0.5X -1 wide x 100000 rows (read parquet) 93 / 134 1.1 933.9 0.2X -1 wide x 100000 rows (write parquet) 133 / 174 0.7 1334.3 0.2X -100 wide x 1000 rows (read in-mem) 22 / 44 4.5 223.3 1.0X -100 wide x 1000 rows (exec in-mem) 88 / 138 1.1 878.6 0.2X -100 wide x 1000 rows (read parquet) 117 / 186 0.9 1172.0 0.2X -100 wide x 1000 rows (write parquet) 144 / 174 0.7 1441.6 0.1X -2500 wide x 40 rows (read in-mem) 36 / 57 2.8 358.9 0.6X -2500 wide x 40 rows (exec in-mem) 1466 / 1507 0.1 14656.6 0.0X -2500 wide x 40 rows (read parquet) 690 / 802 0.1 6898.2 0.0X -2500 wide x 40 rows (write parquet) 197 / 207 0.5 1970.9 0.1X +1 wide x 100000 rows (read in-mem) 15 / 17 6.6 151.0 1.0X +1 wide x 100000 rows (exec in-mem) 20 / 22 5.1 196.6 0.8X +1 wide x 100000 rows (read parquet) 59 / 63 1.7 592.8 0.3X +1 wide x 100000 rows (write parquet) 81 / 87 1.2 814.6 0.2X +100 wide x 1000 rows (read in-mem) 21 / 25 4.8 208.7 0.7X +100 wide x 1000 rows (exec in-mem) 72 / 81 1.4 718.5 0.2X +100 wide x 1000 rows (read parquet) 75 / 85 1.3 752.6 0.2X +100 wide x 1000 rows (write parquet) 88 / 95 1.1 876.7 0.2X +2500 wide x 40 rows (read in-mem) 28 / 34 3.5 282.2 0.5X +2500 wide x 40 rows (exec in-mem) 1269 / 1284 0.1 12688.1 0.0X +2500 wide x 40 rows (read parquet) 549 / 578 0.2 5493.4 0.0X +2500 wide x 40 rows (write parquet) 96 / 104 1.0 959.1 0.2X + +Java HotSpot(TM) 64-Bit Server VM 1.8.0_92-b14 on Mac OS X 10.11.6 +Intel(R) Core(TM) i7-4980HQ CPU @ 2.80GHz -OpenJDK 64-Bit Server VM 1.8.0_66-internal-b17 on Linux 4.2.0-36-generic -Intel(R) Xeon(R) CPU E5-1650 v3 @ 3.50GHz deeply nested struct field r/w: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative ------------------------------------------------------------------------------------------------ -1 deep x 100000 rows (read in-mem) 22 / 35 4.5 223.9 1.0X -1 deep x 100000 rows (exec in-mem) 28 / 52 3.6 280.6 0.8X -1 deep x 100000 rows (read parquet) 41 / 65 2.4 410.5 0.5X -1 deep x 100000 rows (write parquet) 163 / 173 0.6 1634.5 0.1X -100 deep x 1000 rows (read in-mem) 43 / 63 2.3 425.9 0.5X -100 deep x 1000 rows (exec in-mem) 232 / 280 0.4 2321.7 0.1X -100 deep x 1000 rows (read parquet) 1989 / 2281 0.1 19886.6 0.0X -100 deep x 1000 rows (write parquet) 144 / 184 0.7 1442.6 0.2X -250 deep x 400 rows (read in-mem) 68 / 95 1.5 680.9 0.3X -250 deep x 400 rows (exec in-mem) 1310 / 1403 0.1 13096.4 0.0X -250 deep x 400 rows (read parquet) 41477 / 41847 0.0 414766.8 0.0X -250 deep x 400 rows (write parquet) 243 / 272 0.4 2433.1 0.1X +1 deep x 100000 rows (read in-mem) 14 / 16 7.0 143.8 1.0X +1 deep x 100000 rows (exec in-mem) 17 / 19 5.9 169.7 0.8X +1 deep x 100000 rows (read parquet) 33 / 35 3.1 327.0 0.4X +1 deep x 100000 rows (write parquet) 79 / 84 1.3 786.9 0.2X +100 deep x 1000 rows (read in-mem) 21 / 24 4.7 211.3 0.7X +100 deep x 1000 rows (exec in-mem) 221 / 235 0.5 2214.5 0.1X +100 deep x 1000 rows (read parquet) 1928 / 1952 0.1 19277.1 0.0X +100 deep x 1000 rows (write parquet) 91 / 96 1.1 909.5 0.2X +250 deep x 400 rows (read in-mem) 57 / 61 1.8 567.1 0.3X +250 deep x 400 rows (exec in-mem) 1329 / 1385 0.1 13291.8 0.0X +250 deep x 400 rows (read parquet) 36563 / 36750 0.0 365630.2 0.0X +250 deep x 400 rows (write parquet) 126 / 130 0.8 1262.0 0.1X + +Java HotSpot(TM) 64-Bit Server VM 1.8.0_92-b14 on Mac OS X 10.11.6 +Intel(R) Core(TM) i7-4980HQ CPU @ 2.80GHz -OpenJDK 64-Bit Server VM 1.8.0_66-internal-b17 on Linux 4.2.0-36-generic -Intel(R) Xeon(R) CPU E5-1650 v3 @ 3.50GHz bushy struct field r/w: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative ------------------------------------------------------------------------------------------------ -1 x 1 deep x 100000 rows (read in-mem) 23 / 36 4.4 229.8 1.0X -1 x 1 deep x 100000 rows (exec in-mem) 27 / 48 3.7 269.6 0.9X -1 x 1 deep x 100000 rows (read parquet) 25 / 33 4.0 247.5 0.9X -1 x 1 deep x 100000 rows (write parquet) 82 / 134 1.2 821.1 0.3X -128 x 8 deep x 1000 rows (read in-mem) 19 / 29 5.3 189.5 1.2X -128 x 8 deep x 1000 rows (exec in-mem) 144 / 165 0.7 1440.4 0.2X -128 x 8 deep x 1000 rows (read parquet) 117 / 159 0.9 1174.4 0.2X -128 x 8 deep x 1000 rows (write parquet) 135 / 162 0.7 1349.0 0.2X -1024 x 11 deep x 100 rows (read in-mem) 30 / 49 3.3 304.4 0.8X -1024 x 11 deep x 100 rows (exec in-mem) 1146 / 1183 0.1 11457.6 0.0X -1024 x 11 deep x 100 rows (read parquet) 712 / 758 0.1 7119.5 0.0X -1024 x 11 deep x 100 rows (write parquet) 104 / 143 1.0 1037.3 0.2X +1 x 1 deep x 100000 rows (read in-mem) 13 / 15 7.8 127.7 1.0X +1 x 1 deep x 100000 rows (exec in-mem) 15 / 17 6.6 151.5 0.8X +1 x 1 deep x 100000 rows (read parquet) 20 / 23 5.0 198.3 0.6X +1 x 1 deep x 100000 rows (write parquet) 77 / 82 1.3 770.4 0.2X +128 x 8 deep x 1000 rows (read in-mem) 12 / 14 8.2 122.5 1.0X +128 x 8 deep x 1000 rows (exec in-mem) 124 / 140 0.8 1241.2 0.1X +128 x 8 deep x 1000 rows (read parquet) 69 / 74 1.4 693.9 0.2X +128 x 8 deep x 1000 rows (write parquet) 78 / 83 1.3 777.7 0.2X +1024 x 11 deep x 100 rows (read in-mem) 25 / 29 4.1 246.1 0.5X +1024 x 11 deep x 100 rows (exec in-mem) 1197 / 1223 0.1 11974.6 0.0X +1024 x 11 deep x 100 rows (read parquet) 426 / 433 0.2 4263.7 0.0X +1024 x 11 deep x 100 rows (write parquet) 91 / 98 1.1 913.5 0.1X + +Java HotSpot(TM) 64-Bit Server VM 1.8.0_92-b14 on Mac OS X 10.11.6 +Intel(R) Core(TM) i7-4980HQ CPU @ 2.80GHz -OpenJDK 64-Bit Server VM 1.8.0_66-internal-b17 on Linux 4.2.0-36-generic -Intel(R) Xeon(R) CPU E5-1650 v3 @ 3.50GHz wide array field r/w: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative ------------------------------------------------------------------------------------------------ -1 wide x 100000 rows (read in-mem) 18 / 31 5.6 179.3 1.0X -1 wide x 100000 rows (exec in-mem) 31 / 47 3.2 310.2 0.6X -1 wide x 100000 rows (read parquet) 45 / 73 2.2 445.1 0.4X -1 wide x 100000 rows (write parquet) 109 / 140 0.9 1085.9 0.2X -100 wide x 1000 rows (read in-mem) 17 / 25 5.8 172.7 1.0X -100 wide x 1000 rows (exec in-mem) 18 / 22 5.4 184.6 1.0X -100 wide x 1000 rows (read parquet) 26 / 42 3.8 261.8 0.7X -100 wide x 1000 rows (write parquet) 150 / 164 0.7 1499.4 0.1X -2500 wide x 40 rows (read in-mem) 19 / 31 5.1 194.7 0.9X -2500 wide x 40 rows (exec in-mem) 19 / 24 5.3 188.5 1.0X -2500 wide x 40 rows (read parquet) 33 / 47 3.0 334.4 0.5X -2500 wide x 40 rows (write parquet) 153 / 164 0.7 1528.2 0.1X +1 wide x 100000 rows (read in-mem) 14 / 16 7.0 143.2 1.0X +1 wide x 100000 rows (exec in-mem) 17 / 19 5.9 170.9 0.8X +1 wide x 100000 rows (read parquet) 43 / 46 2.3 434.1 0.3X +1 wide x 100000 rows (write parquet) 78 / 83 1.3 777.6 0.2X +100 wide x 1000 rows (read in-mem) 11 / 13 9.0 111.5 1.3X +100 wide x 1000 rows (exec in-mem) 13 / 15 7.8 128.3 1.1X +100 wide x 1000 rows (read parquet) 24 / 27 4.1 245.0 0.6X +100 wide x 1000 rows (write parquet) 74 / 80 1.4 740.5 0.2X +2500 wide x 40 rows (read in-mem) 11 / 13 9.1 109.5 1.3X +2500 wide x 40 rows (exec in-mem) 13 / 15 7.7 129.4 1.1X +2500 wide x 40 rows (read parquet) 24 / 26 4.1 241.3 0.6X +2500 wide x 40 rows (write parquet) 75 / 81 1.3 751.8 0.2X + +Java HotSpot(TM) 64-Bit Server VM 1.8.0_92-b14 on Mac OS X 10.11.6 +Intel(R) Core(TM) i7-4980HQ CPU @ 2.80GHz + +wide map field r/w: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative +------------------------------------------------------------------------------------------------ +1 wide x 100000 rows (read in-mem) 16 / 18 6.2 162.6 1.0X +1 wide x 100000 rows (exec in-mem) 21 / 23 4.8 208.2 0.8X +1 wide x 100000 rows (read parquet) 54 / 59 1.8 543.6 0.3X +1 wide x 100000 rows (write parquet) 80 / 86 1.2 804.5 0.2X +100 wide x 1000 rows (read in-mem) 11 / 13 8.7 114.5 1.4X +100 wide x 1000 rows (exec in-mem) 14 / 16 7.0 143.5 1.1X +100 wide x 1000 rows (read parquet) 30 / 32 3.3 300.4 0.5X +100 wide x 1000 rows (write parquet) 75 / 80 1.3 749.9 0.2X +2500 wide x 40 rows (read in-mem) 13 / 15 7.8 128.1 1.3X +2500 wide x 40 rows (exec in-mem) 15 / 18 6.5 153.6 1.1X +2500 wide x 40 rows (read parquet) 30 / 33 3.3 304.4 0.5X +2500 wide x 40 rows (write parquet) 77 / 83 1.3 768.5 0.2X diff --git a/sql/core/build.gradle b/sql/core/build.gradle new file mode 100644 index 000000000000..c6b3f5038db8 --- /dev/null +++ b/sql/core/build.gradle @@ -0,0 +1,43 @@ +/* + * Copyright (c) 2016 SnappyData, Inc. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you + * may not use this file except in compliance with the License. You + * may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or + * implied. See the License for the specific language governing + * permissions and limitations under the License. See accompanying + * LICENSE file. + */ + +description = 'Spark Project SQL' + +dependencies { + compile project(subprojectBase + 'snappy-spark-core_' + scalaBinaryVersion) + compile project(subprojectBase + 'snappy-spark-catalyst_' + scalaBinaryVersion) + compile project(subprojectBase + 'snappy-spark-sketch_' + scalaBinaryVersion) + compile project(subprojectBase + 'snappy-spark-tags_' + scalaBinaryVersion) + + compile group: 'com.univocity', name: 'univocity-parsers', version: '2.1.1' + compile group: 'org.apache.parquet', name: 'parquet-column', version: parquetVersion + compile group: 'org.apache.parquet', name: 'parquet-hadoop', version: parquetVersion + compile group: 'org.eclipse.jetty', name: 'jetty-servlet', version: jettyVersion + compile group: 'com.fasterxml.jackson.core', name: 'jackson-databind', version: fasterXmlVersion + + testCompile project(path: subprojectBase + 'snappy-spark-core_' + scalaBinaryVersion, configuration: 'testOutput') + testCompile project(path: subprojectBase + 'snappy-spark-catalyst_' + scalaBinaryVersion, configuration: 'testOutput') + testCompile group: 'com.h2database', name: 'h2', version: '1.4.183' + testCompile group: 'mysql', name: 'mysql-connector-java', version: '5.1.38' + testCompile group: 'org.postgresql', name: 'postgresql', version: '9.4.1207.jre7' + testCompile group: 'org.apache.parquet', name: 'parquet-avro', version: parquetVersion + testCompile group: 'org.apache.xbean', name: 'xbean-asm5-shaded', version: '4.4' +} + +// fix scala+java test ordering +sourceSets.test.scala.srcDirs 'src/test/java', 'src/test/gen-java' +sourceSets.test.java.srcDirs = [] diff --git a/sql/core/pom.xml b/sql/core/pom.xml index 02a18b33b087..347f8a1bc82b 100644 --- a/sql/core/pom.xml +++ b/sql/core/pom.xml @@ -22,7 +22,7 @@ org.apache.spark spark-parent_2.11 - 2.0.1-SNAPSHOT + 2.0.1 ../../pom.xml diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/SpecificParquetRecordReaderBase.java b/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/SpecificParquetRecordReaderBase.java index 0d624d17f4cd..b903aeeb5c9e 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/SpecificParquetRecordReaderBase.java +++ b/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/SpecificParquetRecordReaderBase.java @@ -31,6 +31,8 @@ import java.util.Map; import java.util.Set; +import scala.Option; + import static org.apache.parquet.filter2.compat.RowGroupFilter.filterRowGroups; import static org.apache.parquet.format.converter.ParquetMetadataConverter.NO_FILTER; import static org.apache.parquet.format.converter.ParquetMetadataConverter.range; @@ -59,8 +61,12 @@ import org.apache.parquet.hadoop.util.ConfigurationUtil; import org.apache.parquet.schema.MessageType; import org.apache.parquet.schema.Types; +import org.apache.spark.TaskContext; +import org.apache.spark.TaskContext$; import org.apache.spark.sql.types.StructType; import org.apache.spark.sql.types.StructType$; +import org.apache.spark.util.AccumulatorV2; +import org.apache.spark.util.LongAccumulator; /** * Base class for custom RecordReaders for Parquet that directly materialize to `T`. @@ -144,6 +150,18 @@ public void initialize(InputSplit inputSplit, TaskAttemptContext taskAttemptCont for (BlockMetaData block : blocks) { this.totalRowCount += block.getRowCount(); } + + // For test purpose. + // If the predefined accumulator exists, the row group number to read will be updated + // to the accumulator. So we can check if the row groups are filtered or not in test case. + TaskContext taskContext = TaskContext$.MODULE$.get(); + if (taskContext != null) { + Option> accu = (Option>) taskContext.taskMetrics() + .lookForAccumulatorByName("numRowGroups"); + if (accu.isDefined()) { + ((LongAccumulator)accu.get()).add((long)blocks.size()); + } + } } /** diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/VectorizedColumnReader.java b/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/VectorizedColumnReader.java index 6c47dc09a863..3141edd06879 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/VectorizedColumnReader.java +++ b/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/VectorizedColumnReader.java @@ -221,15 +221,21 @@ private void decodeDictionaryIds(int rowId, int num, ColumnVector column, if (column.dataType() == DataTypes.IntegerType || DecimalType.is32BitDecimalType(column.dataType())) { for (int i = rowId; i < rowId + num; ++i) { - column.putInt(i, dictionary.decodeToInt(dictionaryIds.getInt(i))); + if (!column.isNullAt(i)) { + column.putInt(i, dictionary.decodeToInt(dictionaryIds.getInt(i))); + } } } else if (column.dataType() == DataTypes.ByteType) { for (int i = rowId; i < rowId + num; ++i) { - column.putByte(i, (byte) dictionary.decodeToInt(dictionaryIds.getInt(i))); + if (!column.isNullAt(i)) { + column.putByte(i, (byte) dictionary.decodeToInt(dictionaryIds.getInt(i))); + } } } else if (column.dataType() == DataTypes.ShortType) { for (int i = rowId; i < rowId + num; ++i) { - column.putShort(i, (short) dictionary.decodeToInt(dictionaryIds.getInt(i))); + if (!column.isNullAt(i)) { + column.putShort(i, (short) dictionary.decodeToInt(dictionaryIds.getInt(i))); + } } } else { throw new UnsupportedOperationException("Unimplemented type: " + column.dataType()); @@ -240,7 +246,9 @@ private void decodeDictionaryIds(int rowId, int num, ColumnVector column, if (column.dataType() == DataTypes.LongType || DecimalType.is64BitDecimalType(column.dataType())) { for (int i = rowId; i < rowId + num; ++i) { - column.putLong(i, dictionary.decodeToLong(dictionaryIds.getInt(i))); + if (!column.isNullAt(i)) { + column.putLong(i, dictionary.decodeToLong(dictionaryIds.getInt(i))); + } } } else { throw new UnsupportedOperationException("Unimplemented type: " + column.dataType()); @@ -249,21 +257,27 @@ private void decodeDictionaryIds(int rowId, int num, ColumnVector column, case FLOAT: for (int i = rowId; i < rowId + num; ++i) { - column.putFloat(i, dictionary.decodeToFloat(dictionaryIds.getInt(i))); + if (!column.isNullAt(i)) { + column.putFloat(i, dictionary.decodeToFloat(dictionaryIds.getInt(i))); + } } break; case DOUBLE: for (int i = rowId; i < rowId + num; ++i) { - column.putDouble(i, dictionary.decodeToDouble(dictionaryIds.getInt(i))); + if (!column.isNullAt(i)) { + column.putDouble(i, dictionary.decodeToDouble(dictionaryIds.getInt(i))); + } } break; case INT96: if (column.dataType() == DataTypes.TimestampType) { for (int i = rowId; i < rowId + num; ++i) { // TODO: Convert dictionary of Binaries to dictionary of Longs - Binary v = dictionary.decodeToBinary(dictionaryIds.getInt(i)); - column.putLong(i, ParquetRowConverter.binaryToSQLTimestamp(v)); + if (!column.isNullAt(i)) { + Binary v = dictionary.decodeToBinary(dictionaryIds.getInt(i)); + column.putLong(i, ParquetRowConverter.binaryToSQLTimestamp(v)); + } } } else { throw new UnsupportedOperationException(); @@ -275,26 +289,34 @@ private void decodeDictionaryIds(int rowId, int num, ColumnVector column, // and reuse it across batches. This should mean adding a ByteArray would just update // the length and offset. for (int i = rowId; i < rowId + num; ++i) { - Binary v = dictionary.decodeToBinary(dictionaryIds.getInt(i)); - column.putByteArray(i, v.getBytes()); + if (!column.isNullAt(i)) { + Binary v = dictionary.decodeToBinary(dictionaryIds.getInt(i)); + column.putByteArray(i, v.getBytes()); + } } break; case FIXED_LEN_BYTE_ARRAY: // DecimalType written in the legacy mode if (DecimalType.is32BitDecimalType(column.dataType())) { for (int i = rowId; i < rowId + num; ++i) { - Binary v = dictionary.decodeToBinary(dictionaryIds.getInt(i)); - column.putInt(i, (int) ParquetRowConverter.binaryToUnscaledLong(v)); + if (!column.isNullAt(i)) { + Binary v = dictionary.decodeToBinary(dictionaryIds.getInt(i)); + column.putInt(i, (int) ParquetRowConverter.binaryToUnscaledLong(v)); + } } } else if (DecimalType.is64BitDecimalType(column.dataType())) { for (int i = rowId; i < rowId + num; ++i) { - Binary v = dictionary.decodeToBinary(dictionaryIds.getInt(i)); - column.putLong(i, ParquetRowConverter.binaryToUnscaledLong(v)); + if (!column.isNullAt(i)) { + Binary v = dictionary.decodeToBinary(dictionaryIds.getInt(i)); + column.putLong(i, ParquetRowConverter.binaryToUnscaledLong(v)); + } } } else if (DecimalType.isByteArrayDecimalType(column.dataType())) { for (int i = rowId; i < rowId + num; ++i) { - Binary v = dictionary.decodeToBinary(dictionaryIds.getInt(i)); - column.putByteArray(i, v.getBytes()); + if (!column.isNullAt(i)) { + Binary v = dictionary.decodeToBinary(dictionaryIds.getInt(i)); + column.putByteArray(i, v.getBytes()); + } } } else { throw new UnsupportedOperationException(); diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/ColumnVector.java b/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/ColumnVector.java index bbbb796aca0d..9b0ed802fc9d 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/ColumnVector.java +++ b/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/ColumnVector.java @@ -282,16 +282,30 @@ public void reserve(int requiredCapacity) { if (requiredCapacity > capacity) { int newCapacity = (int) Math.min(MAX_CAPACITY, requiredCapacity * 2L); if (requiredCapacity <= newCapacity) { - reserveInternal(newCapacity); + try { + reserveInternal(newCapacity); + } catch (OutOfMemoryError outOfMemoryError) { + throwUnsupportedException(requiredCapacity, outOfMemoryError); + } } else { - throw new RuntimeException("Cannot reserve more than " + newCapacity + - " bytes in the vectorized reader (requested = " + requiredCapacity + " bytes). As a " + - "workaround, you can disable the vectorized reader by setting " - + SQLConf.PARQUET_VECTORIZED_READER_ENABLED().key() + " to false."); + throwUnsupportedException(requiredCapacity, null); } } } + private void throwUnsupportedException(int requiredCapacity, Throwable cause) { + String message = "Cannot reserve additional contiguous bytes in the vectorized reader " + + "(requested = " + requiredCapacity + " bytes). As a workaround, you can disable the " + + "vectorized reader by setting " + SQLConf.PARQUET_VECTORIZED_READER_ENABLED().key() + + " to false."; + + if (cause != null) { + throw new RuntimeException(message, cause); + } else { + throw new RuntimeException(message); + } + } + /** * Ensures that there is enough storage to store capcity elements. That is, the put() APIs * must work for all rowIds < capcity. diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/ColumnVectorUtils.java b/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/ColumnVectorUtils.java index 2fa476b9cfb7..900d7c431e72 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/ColumnVectorUtils.java +++ b/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/ColumnVectorUtils.java @@ -86,8 +86,9 @@ public static void populate(ColumnVector col, InternalRow row, int fieldIdx) { col.getChildColumn(0).putInts(0, capacity, c.months); col.getChildColumn(1).putLongs(0, capacity, c.microseconds); } else if (t instanceof DateType) { - Date date = (Date)row.get(fieldIdx, t); - col.putInts(0, capacity, DateTimeUtils.fromJavaDate(date)); + col.putInts(0, capacity, row.getInt(fieldIdx)); + } else if (t instanceof TimestampType) { + col.putLongs(0, capacity, row.getLong(fieldIdx)); } } } diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/ColumnarBatch.java b/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/ColumnarBatch.java index f3afa8f938f8..62abc2a821a3 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/ColumnarBatch.java +++ b/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/ColumnarBatch.java @@ -137,6 +137,10 @@ public InternalRow copy() { DataType dt = columns[i].dataType(); if (dt instanceof BooleanType) { row.setBoolean(i, getBoolean(i)); + } else if (dt instanceof ByteType) { + row.setByte(i, getByte(i)); + } else if (dt instanceof ShortType) { + row.setShort(i, getShort(i)); } else if (dt instanceof IntegerType) { row.setInt(i, getInt(i)); } else if (dt instanceof LongType) { @@ -154,6 +158,8 @@ public InternalRow copy() { row.setDecimal(i, getDecimal(i, t.precision(), t.scale()), t.precision()); } else if (dt instanceof DateType) { row.setInt(i, getInt(i)); + } else if (dt instanceof TimestampType) { + row.setLong(i, getLong(i)); } else { throw new RuntimeException("Not implemented. " + dt); } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala index e8c2885d7737..fe3da25a4c4e 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala @@ -269,17 +269,24 @@ class DataFrameReader private[sql](sparkSession: SparkSession) extends Logging { *
  • `allowBackslashEscapingAnyCharacter` (default `false`): allows accepting quoting of all * character using backslash quoting mechanism
  • *
  • `mode` (default `PERMISSIVE`): allows a mode for dealing with corrupt records - * during parsing.
  • - *
      - *
    • - `PERMISSIVE` : sets other fields to `null` when it meets a corrupted record, and puts - * the malformed string into a new field configured by `columnNameOfCorruptRecord`. When - * a schema is set by user, it sets `null` for extra fields.
    • - *
    • - `DROPMALFORMED` : ignores the whole corrupted records.
    • - *
    • - `FAILFAST` : throws an exception when it meets corrupted records.
    • - *
    + * during parsing. + *
      + *
    • `PERMISSIVE` : sets other fields to `null` when it meets a corrupted record, and puts + * the malformed string into a new field configured by `columnNameOfCorruptRecord`. When + * a schema is set by user, it sets `null` for extra fields.
    • + *
    • `DROPMALFORMED` : ignores the whole corrupted records.
    • + *
    • `FAILFAST` : throws an exception when it meets corrupted records.
    • + *
    + * *
  • `columnNameOfCorruptRecord` (default is the value specified in * `spark.sql.columnNameOfCorruptRecord`): allows renaming the new field having malformed string * created by `PERMISSIVE` mode. This overrides `spark.sql.columnNameOfCorruptRecord`.
  • + *
  • `dateFormat` (default `yyyy-MM-dd`): sets the string that indicates a date format. + * Custom date formats follow the formats at `java.text.SimpleDateFormat`. This applies to + * date type.
  • + *
  • `timestampFormat` (default `yyyy-MM-dd'T'HH:mm:ss.SSSZZ`): sets the string that + * indicates a timestamp format. Custom date formats follow the formats at + * `java.text.SimpleDateFormat`. This applies to timestamp type.
  • * * @since 2.0.0 */ @@ -370,16 +377,20 @@ class DataFrameReader private[sql](sparkSession: SparkSession) extends Logging { * from values being read should be skipped. *
  • `ignoreTrailingWhiteSpace` (default `false`): defines whether or not trailing * whitespaces from values being read should be skipped.
  • - *
  • `nullValue` (default empty string): sets the string representation of a null value.
  • + *
  • `nullValue` (default empty string): sets the string representation of a null value. Since + * 2.0.1, this applies to all supported types including the string type.
  • *
  • `nanValue` (default `NaN`): sets the string representation of a non-number" value.
  • *
  • `positiveInf` (default `Inf`): sets the string representation of a positive infinity * value.
  • *
  • `negativeInf` (default `-Inf`): sets the string representation of a negative infinity * value.
  • - *
  • `dateFormat` (default `null`): sets the string that indicates a date format. Custom date - * formats follow the formats at `java.text.SimpleDateFormat`. This applies to both date type - * and timestamp type. By default, it is `null` which means trying to parse times and date by - * `java.sql.Timestamp.valueOf()` and `java.sql.Date.valueOf()`.
  • + *
  • `dateFormat` (default `yyyy-MM-dd`): sets the string that indicates a date format. + * Custom date formats follow the formats at `java.text.SimpleDateFormat`. This applies to + * date type.
  • + *
  • `timestampFormat` (default `yyyy-MM-dd'T'HH:mm:ss.SSSZZ`): sets the string that + * indicates a timestamp format. Custom date formats follow the formats at + * `java.text.SimpleDateFormat`. This applies to timestamp type.
  • + * `java.sql.Timestamp.valueOf()` and `java.sql.Date.valueOf()` or ISO 8601 format. *
  • `maxColumns` (default `20480`): defines a hard limit of how many columns * a record can have.
  • *
  • `maxCharsPerColumn` (default `1000000`): defines the maximum number of characters allowed @@ -387,13 +398,14 @@ class DataFrameReader private[sql](sparkSession: SparkSession) extends Logging { *
  • `maxMalformedLogPerPartition` (default `10`): sets the maximum number of malformed rows * Spark will log for each partition. Malformed records beyond this number will be ignored.
  • *
  • `mode` (default `PERMISSIVE`): allows a mode for dealing with corrupt records - * during parsing.
  • - *
      - *
    • - `PERMISSIVE` : sets other fields to `null` when it meets a corrupted record. When - * a schema is set by user, it sets `null` for extra fields.
    • - *
    • - `DROPMALFORMED` : ignores the whole corrupted records.
    • - *
    • - `FAILFAST` : throws an exception when it meets corrupted records.
    • - *
    + * during parsing. + *
      + *
    • `PERMISSIVE` : sets other fields to `null` when it meets a corrupted record. When + * a schema is set by user, it sets `null` for extra fields.
    • + *
    • `DROPMALFORMED` : ignores the whole corrupted records.
    • + *
    • `FAILFAST` : throws an exception when it meets corrupted records.
    • + *
    + * * * @since 2.0.0 */ diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala index 12b304623d30..a4c4a5defa1b 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala @@ -449,9 +449,17 @@ final class DataFrameWriter[T] private[sql](ds: Dataset[T]) { * }}} * * You can set the following JSON-specific option(s) for writing JSON files: + *
      *
    • `compression` (default `null`): compression codec to use when saving to file. This can be * one of the known case-insensitive shorten names (`none`, `bzip2`, `gzip`, `lz4`, * `snappy` and `deflate`).
    • + *
    • `dateFormat` (default `yyyy-MM-dd`): sets the string that indicates a date format. + * Custom date formats follow the formats at `java.text.SimpleDateFormat`. This applies to + * date type.
    • + *
    • `timestampFormat` (default `yyyy-MM-dd'T'HH:mm:ss.SSSZZ`): sets the string that + * indicates a timestamp format. Custom date formats follow the formats at + * `java.text.SimpleDateFormat`. This applies to timestamp type.
    • + *
    * * @since 1.4.0 */ @@ -467,10 +475,12 @@ final class DataFrameWriter[T] private[sql](ds: Dataset[T]) { * }}} * * You can set the following Parquet-specific option(s) for writing Parquet files: + *
      *
    • `compression` (default is the value specified in `spark.sql.parquet.compression.codec`): * compression codec to use when saving to file. This can be one of the known case-insensitive * shorten names(none, `snappy`, `gzip`, and `lzo`). This will override * `spark.sql.parquet.compression.codec`.
    • + *
    * * @since 1.4.0 */ @@ -486,9 +496,11 @@ final class DataFrameWriter[T] private[sql](ds: Dataset[T]) { * }}} * * You can set the following ORC-specific option(s) for writing ORC files: + *
      *
    • `compression` (default `snappy`): compression codec to use when saving to file. This can be * one of the known case-insensitive shorten names(`none`, `snappy`, `zlib`, and `lzo`). * This will override `orc.compress`.
    • + *
    * * @since 1.5.0 * @note Currently, this method can only be used after enabling Hive support @@ -510,9 +522,11 @@ final class DataFrameWriter[T] private[sql](ds: Dataset[T]) { * }}} * * You can set the following option(s) for writing text files: + *
      *
    • `compression` (default `null`): compression codec to use when saving to file. This can be * one of the known case-insensitive shorten names (`none`, `bzip2`, `gzip`, `lz4`, * `snappy` and `deflate`).
    • + *
    * * @since 1.6.0 */ @@ -528,6 +542,7 @@ final class DataFrameWriter[T] private[sql](ds: Dataset[T]) { * }}} * * You can set the following CSV-specific option(s) for writing CSV files: + *
      *
    • `sep` (default `,`): sets the single character as a separator for each * field and value.
    • *
    • `quote` (default `"`): sets the single character used for escaping quoted values where @@ -544,6 +559,13 @@ final class DataFrameWriter[T] private[sql](ds: Dataset[T]) { *
    • `compression` (default `null`): compression codec to use when saving to file. This can be * one of the known case-insensitive shorten names (`none`, `bzip2`, `gzip`, `lz4`, * `snappy` and `deflate`).
    • + *
    • `dateFormat` (default `yyyy-MM-dd`): sets the string that indicates a date format. + * Custom date formats follow the formats at `java.text.SimpleDateFormat`. This applies to + * date type.
    • + *
    • `timestampFormat` (default `yyyy-MM-dd'T'HH:mm:ss.SSSZZ`): sets the string that + * indicates a timestamp format. Custom date formats follow the formats at + * `java.text.SimpleDateFormat`. This applies to timestamp type.
    • + *
    * * @since 2.0.0 */ diff --git a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala index 067cbec4bf61..0b236a0c7466 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala @@ -30,7 +30,7 @@ import org.apache.commons.lang3.StringUtils import org.apache.spark.annotation.{DeveloperApi, Experimental} import org.apache.spark.api.java.JavaRDD import org.apache.spark.api.java.function._ -import org.apache.spark.api.python.PythonRDD +import org.apache.spark.api.python.{PythonRDD, SerDeUtil} import org.apache.spark.broadcast.Broadcast import org.apache.spark.rdd.RDD import org.apache.spark.sql.catalyst._ @@ -1500,8 +1500,13 @@ class Dataset[T] private[sql]( * @group typedrel * @since 1.6.0 */ - def sample(withReplacement: Boolean, fraction: Double, seed: Long): Dataset[T] = withTypedPlan { - Sample(0.0, fraction, withReplacement, seed, logicalPlan)() + def sample(withReplacement: Boolean, fraction: Double, seed: Long): Dataset[T] = { + require(fraction >= 0, + s"Fraction must be nonnegative, but got ${fraction}") + + withTypedPlan { + Sample(0.0, fraction, withReplacement, seed, logicalPlan)() + } } /** @@ -1529,6 +1534,11 @@ class Dataset[T] private[sql]( * @since 2.0.0 */ def randomSplit(weights: Array[Double], seed: Long): Array[Dataset[T]] = { + require(weights.forall(_ >= 0), + s"Weights must be nonnegative, but got ${weights.mkString("[", ",", "]")}") + require(weights.sum > 0, + s"Sum of weights must be positive, but got ${weights.mkString("[", ",", "]")}") + // It is possible that the underlying dataframe doesn't guarantee the ordering of rows in its // constituent partitions each time a split is materialized which could result in // overlapping splits. To prevent this, we explicitly sort each input partition to make the @@ -2509,8 +2519,12 @@ class Dataset[T] private[sql]( } private[sql] def collectToPython(): Int = { + EvaluatePython.registerPicklers() withNewExecutionId { - PythonRDD.collectAndServe(javaToPython.rdd) + val toJava: (Any) => Any = EvaluatePython.toJava(_, schema) + val iter = new SerDeUtil.AutoBatchedPickler( + queryExecution.executedPlan.executeCollect().iterator.map(toJava)) + PythonRDD.serveIterator(iter, "serve-DataFrame") } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/KeyValueGroupedDataset.scala b/sql/core/src/main/scala/org/apache/spark/sql/KeyValueGroupedDataset.scala index a6867a67eead..8eec42aab4fa 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/KeyValueGroupedDataset.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/KeyValueGroupedDataset.scala @@ -21,10 +21,11 @@ import scala.collection.JavaConverters._ import org.apache.spark.annotation.Experimental import org.apache.spark.api.java.function._ -import org.apache.spark.sql.catalyst.encoders.{encoderFor, ExpressionEncoder, OuterScopes} +import org.apache.spark.sql.catalyst.encoders.{encoderFor, ExpressionEncoder} import org.apache.spark.sql.catalyst.expressions.{Alias, Attribute, CreateStruct} import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.execution.QueryExecution +import org.apache.spark.sql.expressions.ReduceAggregator /** * :: Experimental :: @@ -177,10 +178,9 @@ class KeyValueGroupedDataset[K, V] private[sql]( * @since 1.6.0 */ def reduceGroups(f: (V, V) => V): Dataset[(K, V)] = { - val func = (key: K, it: Iterator[V]) => Iterator((key, it.reduce(f))) - - implicit val resultEncoder = ExpressionEncoder.tuple(kExprEnc, vExprEnc) - flatMapGroups(func) + val vEncoder = encoderFor[V] + val aggregator: TypedColumn[V, V] = new ReduceAggregator[V](f)(vEncoder).toColumn + agg(aggregator) } /** diff --git a/sql/core/src/main/scala/org/apache/spark/sql/RelationalGroupedDataset.scala b/sql/core/src/main/scala/org/apache/spark/sql/RelationalGroupedDataset.scala index 1aa5767038d5..6148ddfe05ac 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/RelationalGroupedDataset.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/RelationalGroupedDataset.scala @@ -128,7 +128,7 @@ class RelationalGroupedDataset protected[sql]( } /** - * (Scala-specific) Compute aggregates by specifying a map from column name to + * (Scala-specific) Compute aggregates by specifying the column names and * aggregate methods. The resulting [[DataFrame]] will also contain the grouping columns. * * The available aggregate methods are `avg`, `max`, `min`, `sum`, `count`. @@ -143,7 +143,9 @@ class RelationalGroupedDataset protected[sql]( * @since 1.3.0 */ def agg(aggExpr: (String, String), aggExprs: (String, String)*): DataFrame = { - agg((aggExpr +: aggExprs).toMap) + toDF((aggExpr +: aggExprs).map { case (colName, expr) => + strToExpr(expr)(df(colName).expr) + }) } /** diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SparkSession.scala b/sql/core/src/main/scala/org/apache/spark/sql/SparkSession.scala index 946d8cbc6bf4..c88206c81a04 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/SparkSession.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/SparkSession.scala @@ -822,16 +822,19 @@ object SparkSession { // No active nor global default session. Create a new one. val sparkContext = userSuppliedContext.getOrElse { // set app name if not given - if (!options.contains("spark.app.name")) { - options += "spark.app.name" -> java.util.UUID.randomUUID().toString - } - + val randomAppName = java.util.UUID.randomUUID().toString val sparkConf = new SparkConf() options.foreach { case (k, v) => sparkConf.set(k, v) } + if (!sparkConf.contains("spark.app.name")) { + sparkConf.setAppName(randomAppName) + } val sc = SparkContext.getOrCreate(sparkConf) // maybe this is an existing SparkContext, update its SparkConf which maybe used // by SparkSession options.foreach { case (k, v) => sc.conf.set(k, v) } + if (!sc.conf.contains("spark.app.name")) { + sc.conf.setAppName(randomAppName) + } sc } session = new SparkSession(sparkContext) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/catalyst/SQLBuilder.scala b/sql/core/src/main/scala/org/apache/spark/sql/catalyst/SQLBuilder.scala index 5d93419f357e..5e0263ec5b4c 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/catalyst/SQLBuilder.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/catalyst/SQLBuilder.scala @@ -205,6 +205,12 @@ class SQLBuilder private ( case p: ScriptTransformation => scriptTransformationToSQL(p) + case p: LocalRelation => + p.toSQL(newSubqueryName()) + + case p: Range => + p.toSQL() + case OneRowRelation => "" diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/CacheManager.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/CacheManager.scala index de2503a87ab7..83b7c779ab81 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/CacheManager.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/CacheManager.scala @@ -31,7 +31,7 @@ import org.apache.spark.storage.StorageLevel import org.apache.spark.storage.StorageLevel.MEMORY_AND_DISK /** Holds a cached logical plan and its data */ -private[sql] case class CachedData(plan: LogicalPlan, cachedRepresentation: InMemoryRelation) +case class CachedData(plan: LogicalPlan, cachedRepresentation: InMemoryRelation) /** * Provides support in a SQLContext for caching query results and automatically using these cached @@ -41,7 +41,7 @@ private[sql] case class CachedData(plan: LogicalPlan, cachedRepresentation: InMe * * Internal to Spark SQL. */ -private[sql] class CacheManager extends Logging { +class CacheManager extends Logging { @transient private val cachedData = new scala.collection.mutable.ArrayBuffer[CachedData] @@ -68,13 +68,13 @@ private[sql] class CacheManager extends Logging { } /** Clears all cached tables. */ - private[sql] def clearCache(): Unit = writeLock { + def clearCache(): Unit = writeLock { cachedData.foreach(_.cachedRepresentation.cachedColumnBuffers.unpersist()) cachedData.clear() } /** Checks if the cache is empty. */ - private[sql] def isEmpty: Boolean = readLock { + def isEmpty: Boolean = readLock { cachedData.isEmpty } @@ -83,7 +83,7 @@ private[sql] class CacheManager extends Logging { * Unlike `RDD.cache()`, the default storage level is set to be `MEMORY_AND_DISK` because * recomputing the in-memory columnar representation of the underlying table is expensive. */ - private[sql] def cacheQuery( + def cacheQuery( query: Dataset[_], tableName: Option[String] = None, storageLevel: StorageLevel = MEMORY_AND_DISK): Unit = writeLock { @@ -108,7 +108,7 @@ private[sql] class CacheManager extends Logging { * Tries to remove the data for the given [[Dataset]] from the cache. * No operation, if it's already uncached. */ - private[sql] def uncacheQuery(query: Dataset[_], blocking: Boolean = true): Boolean = writeLock { + def uncacheQuery(query: Dataset[_], blocking: Boolean = true): Boolean = writeLock { val planToCache = query.queryExecution.analyzed val dataIndex = cachedData.indexWhere(cd => planToCache.sameResult(cd.plan)) val found = dataIndex >= 0 @@ -120,17 +120,17 @@ private[sql] class CacheManager extends Logging { } /** Optionally returns cached data for the given [[Dataset]] */ - private[sql] def lookupCachedData(query: Dataset[_]): Option[CachedData] = readLock { + def lookupCachedData(query: Dataset[_]): Option[CachedData] = readLock { lookupCachedData(query.queryExecution.analyzed) } /** Optionally returns cached data for the given [[LogicalPlan]]. */ - private[sql] def lookupCachedData(plan: LogicalPlan): Option[CachedData] = readLock { + def lookupCachedData(plan: LogicalPlan): Option[CachedData] = readLock { cachedData.find(cd => plan.sameResult(cd.plan)) } /** Replaces segments of the given logical plan with cached versions where possible. */ - private[sql] def useCachedData(plan: LogicalPlan): LogicalPlan = { + def useCachedData(plan: LogicalPlan): LogicalPlan = { plan transformDown { case currentFragment => lookupCachedData(currentFragment) @@ -143,7 +143,7 @@ private[sql] class CacheManager extends Logging { * Invalidates the cache of any data that contains `plan`. Note that it is possible that this * function will over invalidate. */ - private[sql] def invalidateCache(plan: LogicalPlan): Unit = writeLock { + def invalidateCache(plan: LogicalPlan): Unit = writeLock { cachedData.foreach { case data if data.plan.collect { case p if p.sameResult(plan) => p }.nonEmpty => data.cachedRepresentation.recache() @@ -155,7 +155,7 @@ private[sql] class CacheManager extends Logging { * Invalidates the cache of any data that contains `resourcePath` in one or more * `HadoopFsRelation` node(s) as part of its logical plan. */ - private[sql] def invalidateCachedPath( + def invalidateCachedPath( sparkSession: SparkSession, resourcePath: String): Unit = writeLock { val (fs, qualifiedPath) = { val path = new Path(resourcePath) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/ExistingRDD.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/ExistingRDD.scala index 09203e69983d..ba30bed0b450 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/ExistingRDD.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/ExistingRDD.scala @@ -75,7 +75,7 @@ object RDDConversions { } /** Logical plan node for scanning data from an RDD. */ -private[sql] case class LogicalRDD( +case class LogicalRDD( output: Seq[Attribute], rdd: RDD[InternalRow])(session: SparkSession) extends LogicalPlan with MultiInstanceRelation { @@ -106,12 +106,12 @@ private[sql] case class LogicalRDD( } /** Physical plan node for scanning data from an RDD. */ -private[sql] case class RDDScanExec( +case class RDDScanExec( output: Seq[Attribute], rdd: RDD[InternalRow], override val nodeName: String) extends LeafExecNode { - private[sql] override lazy val metrics = Map( + override lazy val metrics = Map( "numOutputRows" -> SQLMetrics.createMetric(sparkContext, "number of output rows")) protected override def doExecute(): RDD[InternalRow] = { @@ -130,7 +130,7 @@ private[sql] case class RDDScanExec( } } -private[sql] trait DataSourceScanExec extends LeafExecNode { +trait DataSourceScanExec extends LeafExecNode { val rdd: RDD[InternalRow] val relation: BaseRelation val metastoreTableIdentifier: Option[TableIdentifier] @@ -147,7 +147,7 @@ private[sql] trait DataSourceScanExec extends LeafExecNode { } /** Physical plan node for scanning data from a relation. */ -private[sql] case class RowDataSourceScanExec( +case class RowDataSourceScanExec( output: Seq[Attribute], rdd: RDD[InternalRow], @transient relation: BaseRelation, @@ -156,7 +156,7 @@ private[sql] case class RowDataSourceScanExec( override val metastoreTableIdentifier: Option[TableIdentifier]) extends DataSourceScanExec with CodegenSupport { - private[sql] override lazy val metrics = + override lazy val metrics = Map("numOutputRows" -> SQLMetrics.createMetric(sparkContext, "number of output rows")) val outputUnsafeRows = relation match { @@ -222,7 +222,7 @@ private[sql] case class RowDataSourceScanExec( } /** Physical plan node for scanning data from a batched relation. */ -private[sql] case class BatchedDataSourceScanExec( +case class BatchedDataSourceScanExec( output: Seq[Attribute], rdd: RDD[InternalRow], @transient relation: BaseRelation, @@ -231,7 +231,7 @@ private[sql] case class BatchedDataSourceScanExec( override val metastoreTableIdentifier: Option[TableIdentifier]) extends DataSourceScanExec with CodegenSupport { - private[sql] override lazy val metrics = + override lazy val metrics = Map("numOutputRows" -> SQLMetrics.createMetric(sparkContext, "number of output rows"), "scanTime" -> SQLMetrics.createTimingMetric(sparkContext, "scan time")) @@ -337,7 +337,7 @@ private[sql] case class BatchedDataSourceScanExec( } } -private[sql] object DataSourceScanExec { +object DataSourceScanExec { // Metadata keys val INPUT_PATHS = "InputPaths" val PUSHED_FILTERS = "PushedFilters" diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/ExpandExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/ExpandExec.scala index 4c046f7bdca4..d5603b3b0091 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/ExpandExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/ExpandExec.scala @@ -39,7 +39,7 @@ case class ExpandExec( child: SparkPlan) extends UnaryExecNode with CodegenSupport { - private[sql] override lazy val metrics = Map( + override lazy val metrics = Map( "numOutputRows" -> SQLMetrics.createMetric(sparkContext, "number of output rows")) // The GroupExpressions can output data with arbitrary partitioning, so set it diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/FileRelation.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/FileRelation.scala index 7a2a9eed5807..a299fed7fd14 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/FileRelation.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/FileRelation.scala @@ -22,7 +22,7 @@ package org.apache.spark.sql.execution * the list of paths that it returns will be returned to a user who calls `inputPaths` on any * DataFrame that queries this relation. */ -private[sql] trait FileRelation { +trait FileRelation { /** Returns the list of files that will be read when scanning this relation. */ def inputFiles: Array[String] } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/GenerateExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/GenerateExec.scala index 8b62c5507c0c..39189a2b0c72 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/GenerateExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/GenerateExec.scala @@ -55,7 +55,7 @@ case class GenerateExec( child: SparkPlan) extends UnaryExecNode { - private[sql] override lazy val metrics = Map( + override lazy val metrics = Map( "numOutputRows" -> SQLMetrics.createMetric(sparkContext, "number of output rows")) override def producedAttributes: AttributeSet = AttributeSet(output) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/LocalTableScanExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/LocalTableScanExec.scala index df2f238d8c2e..9f53a99346ca 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/LocalTableScanExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/LocalTableScanExec.scala @@ -26,11 +26,11 @@ import org.apache.spark.sql.execution.metric.SQLMetrics /** * Physical plan node for scanning data from a local collection. */ -private[sql] case class LocalTableScanExec( +case class LocalTableScanExec( output: Seq[Attribute], rows: Seq[InternalRow]) extends LeafExecNode { - private[sql] override lazy val metrics = Map( + override lazy val metrics = Map( "numOutputRows" -> SQLMetrics.createMetric(sparkContext, "number of output rows")) private val unsafeRows: Array[InternalRow] = { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/RowIterator.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/RowIterator.scala index 7462dbc4eba3..717ff93eab5d 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/RowIterator.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/RowIterator.scala @@ -30,7 +30,7 @@ import org.apache.spark.sql.catalyst.InternalRow * iterator to consume the next row, whereas RowIterator combines these calls into a single * [[advanceNext()]] method. */ -private[sql] abstract class RowIterator { +abstract class RowIterator { /** * Advance this iterator by a single row. Returns `false` if this iterator has no more rows * and `true` otherwise. If this returns `true`, then the new row can be retrieved by calling diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SQLExecution.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SQLExecution.scala index 6cb1a44a2044..ec07aab359ac 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SQLExecution.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SQLExecution.scala @@ -24,7 +24,7 @@ import org.apache.spark.sql.SparkSession import org.apache.spark.sql.execution.ui.{SparkListenerSQLExecutionEnd, SparkListenerSQLExecutionStart} -private[sql] object SQLExecution { +object SQLExecution { val EXECUTION_ID_KEY = "spark.sql.execution.id" diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SortExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SortExec.scala index 66a16ac576b3..cde3ed48ffea 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SortExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SortExec.scala @@ -22,11 +22,9 @@ import org.apache.spark.executor.TaskMetrics import org.apache.spark.rdd.RDD import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions._ -import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, ExprCode, GenerateUnsafeProjection} +import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, ExprCode} import org.apache.spark.sql.catalyst.plans.physical.{Distribution, OrderedDistribution, UnspecifiedDistribution} import org.apache.spark.sql.execution.metric.SQLMetrics -import org.apache.spark.sql.types._ -import org.apache.spark.util.collection.unsafe.sort.RadixSort; /** * Performs (external) sorting. @@ -52,7 +50,7 @@ case class SortExec( private val enableRadixSort = sqlContext.conf.enableRadixSort - override private[sql] lazy val metrics = Map( + override lazy val metrics = Map( "sortTime" -> SQLMetrics.createTimingMetric(sparkContext, "sort time"), "peakMemory" -> SQLMetrics.createSizeMetric(sparkContext, "peak memory"), "spillSize" -> SQLMetrics.createSizeMetric(sparkContext, "spill size")) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala index 045ccc7bd6ea..79cb40948b98 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala @@ -72,24 +72,24 @@ abstract class SparkPlan extends QueryPlan[SparkPlan] with Logging with Serializ /** * Return all metadata that describes more details of this SparkPlan. */ - private[sql] def metadata: Map[String, String] = Map.empty + def metadata: Map[String, String] = Map.empty /** * Return all metrics containing metrics of this SparkPlan. */ - private[sql] def metrics: Map[String, SQLMetric] = Map.empty + def metrics: Map[String, SQLMetric] = Map.empty /** * Reset all the metrics. */ - private[sql] def resetMetrics(): Unit = { + def resetMetrics(): Unit = { metrics.valuesIterator.foreach(_.reset()) } /** * Return a LongSQLMetric according to the name. */ - private[sql] def longMetric(name: String): SQLMetric = metrics(name) + def longMetric(name: String): SQLMetric = metrics(name) // TODO: Move to `DistributedPlan` /** Specifies how data is partitioned across different nodes in the cluster. */ @@ -395,7 +395,7 @@ object SparkPlan { ThreadUtils.newDaemonCachedThreadPool("subquery", 16)) } -private[sql] trait LeafExecNode extends SparkPlan { +trait LeafExecNode extends SparkPlan { override def children: Seq[SparkPlan] = Nil override def producedAttributes: AttributeSet = outputSet } @@ -407,7 +407,7 @@ object UnaryExecNode { } } -private[sql] trait UnaryExecNode extends SparkPlan { +trait UnaryExecNode extends SparkPlan { def child: SparkPlan override def children: Seq[SparkPlan] = child :: Nil @@ -415,7 +415,7 @@ private[sql] trait UnaryExecNode extends SparkPlan { override def outputPartitioning: Partitioning = child.outputPartitioning } -private[sql] trait BinaryExecNode extends SparkPlan { +trait BinaryExecNode extends SparkPlan { def left: SparkPlan def right: SparkPlan diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlanInfo.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlanInfo.scala index f84070a0c4bc..7aa93126fdab 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlanInfo.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlanInfo.scala @@ -47,7 +47,7 @@ class SparkPlanInfo( } } -private[sql] object SparkPlanInfo { +private[execution] object SparkPlanInfo { def fromSparkPlan(plan: SparkPlan): SparkPlanInfo = { val children = plan match { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkSqlParser.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkSqlParser.scala index 3573a86d4e83..3072a6d79eac 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkSqlParser.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkSqlParser.scala @@ -18,7 +18,6 @@ package org.apache.spark.sql.execution import scala.collection.JavaConverters._ -import scala.util.Try import org.antlr.v4.runtime.{ParserRuleContext, Token} import org.antlr.v4.runtime.tree.TerminalNode @@ -405,6 +404,20 @@ class SparkSqlAstBuilder(conf: SQLConf) extends AstBuilder { Option(ctx.partitionSpec).map(visitNonOptionalPartitionSpec)) } + /** + * Create a [[AlterTableRecoverPartitionsCommand]] command. + * + * For example: + * {{{ + * MSCK REPAIR TABLE tablename + * }}} + */ + override def visitRepairTable(ctx: RepairTableContext): LogicalPlan = withOrigin(ctx) { + AlterTableRecoverPartitionsCommand( + visitTableIdentifier(ctx.tableIdentifier), + "MSCK REPAIR TABLE") + } + /** * Convert a table property list into a key-value map. * This should be called through [[visitPropertyKeyValues]] or [[visitPropertyKeys]]. @@ -763,6 +776,19 @@ class SparkSqlAstBuilder(conf: SQLConf) extends AstBuilder { ctx.EXISTS != null) } + /** + * Create an [[AlterTableRecoverPartitionsCommand]] command + * + * For example: + * {{{ + * ALTER TABLE table RECOVER PARTITIONS; + * }}} + */ + override def visitRecoverPartitions( + ctx: RecoverPartitionsContext): LogicalPlan = withOrigin(ctx) { + AlterTableRecoverPartitionsCommand(visitTableIdentifier(ctx.tableIdentifier)) + } + /** * Create an [[AlterTableSetLocationCommand]] command * @@ -992,12 +1018,6 @@ class SparkSqlAstBuilder(conf: SQLConf) extends AstBuilder { selectQuery match { case Some(q) => - // Just use whatever is projected in the select statement as our schema - if (schema.nonEmpty) { - operationNotAllowed( - "Schema may not be specified in a Create Table As Select (CTAS) statement", - ctx) - } // Hive does not allow to use a CTAS statement to create a partitioned table. if (tableDesc.partitionColumnNames.nonEmpty) { val errorMessage = "A Create Table As Select (CTAS) statement is not allowed to " + @@ -1007,6 +1027,12 @@ class SparkSqlAstBuilder(conf: SQLConf) extends AstBuilder { "CTAS statement." operationNotAllowed(errorMessage, ctx) } + // Just use whatever is projected in the select statement as our schema + if (schema.nonEmpty) { + operationNotAllowed( + "Schema may not be specified in a Create Table As Select (CTAS) statement", + ctx) + } val hasStorageProperties = (ctx.createFileFormat != null) || (ctx.rowFormat != null) if (conf.convertCTAS && !hasStorageProperties) { @@ -1152,7 +1178,7 @@ class SparkSqlAstBuilder(conf: SQLConf) extends AstBuilder { entry("mapkey.delim", ctx.keysTerminatedBy) ++ Option(ctx.linesSeparatedBy).toSeq.map { token => val value = string(token) - assert( + validate( value == "\n", s"LINES TERMINATED BY only supports newline '\\n' right now: $value", ctx) @@ -1209,7 +1235,7 @@ class SparkSqlAstBuilder(conf: SQLConf) extends AstBuilder { * * For example: * {{{ - * CREATE [TEMPORARY] VIEW [IF NOT EXISTS] [db_name.]view_name + * CREATE [OR REPLACE] [TEMPORARY] VIEW [IF NOT EXISTS] [db_name.]view_name * [(column_name [COMMENT column_comment], ...) ] * [COMMENT view_comment] * [TBLPROPERTIES (property_name = property_value, ...)] @@ -1224,60 +1250,44 @@ class SparkSqlAstBuilder(conf: SQLConf) extends AstBuilder { val schema = identifiers.map { ic => CatalogColumn(ic.identifier.getText, null, nullable = true, Option(ic.STRING).map(string)) } - createView( - ctx, - ctx.tableIdentifier, - comment = Option(ctx.STRING).map(string), - schema, - ctx.query, - Option(ctx.tablePropertyList).map(visitPropertyKeyValues).getOrElse(Map.empty), - ctx.EXISTS != null, - ctx.REPLACE != null, - ctx.TEMPORARY != null - ) + + val sql = Option(source(ctx.query)) + val tableDesc = CatalogTable( + identifier = visitTableIdentifier(ctx.tableIdentifier), + tableType = CatalogTableType.VIEW, + schema = schema, + storage = CatalogStorageFormat.empty, + properties = Option(ctx.tablePropertyList).map(visitPropertyKeyValues).getOrElse(Map.empty), + viewOriginalText = sql, + viewText = sql, + comment = Option(ctx.STRING).map(string)) + + CreateViewCommand( + tableDesc, + plan(ctx.query), + allowExisting = ctx.EXISTS != null, + replace = ctx.REPLACE != null, + isTemporary = ctx.TEMPORARY != null) } } /** - * Alter the query of a view. This creates a [[CreateViewCommand]] command. + * Alter the query of a view. This creates a [[AlterViewAsCommand]] command. + * + * For example: + * {{{ + * ALTER VIEW [db_name.]view_name AS SELECT ...; + * }}} */ override def visitAlterViewQuery(ctx: AlterViewQueryContext): LogicalPlan = withOrigin(ctx) { - createView( - ctx, - ctx.tableIdentifier, - comment = None, - Seq.empty, - ctx.query, - Map.empty, - allowExist = false, - replace = true, - isTemporary = false) - } - - /** - * Create a [[CreateViewCommand]] command. - */ - private def createView( - ctx: ParserRuleContext, - name: TableIdentifierContext, - comment: Option[String], - schema: Seq[CatalogColumn], - query: QueryContext, - properties: Map[String, String], - allowExist: Boolean, - replace: Boolean, - isTemporary: Boolean): LogicalPlan = { - val sql = Option(source(query)) val tableDesc = CatalogTable( - identifier = visitTableIdentifier(name), + identifier = visitTableIdentifier(ctx.tableIdentifier), tableType = CatalogTableType.VIEW, - schema = schema, storage = CatalogStorageFormat.empty, - properties = properties, - viewOriginalText = sql, - viewText = sql, - comment = comment) - CreateViewCommand(tableDesc, plan(query), allowExist, replace, isTemporary) + schema = Nil, + viewOriginalText = Option(source(ctx.query))) + + AlterViewAsCommand(tableDesc, plan(ctx.query)) } /** diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala index b619d4edc30d..e7faab549542 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala @@ -17,7 +17,6 @@ package org.apache.spark.sql.execution -import org.apache.spark.annotation.DeveloperApi import org.apache.spark.rdd.RDD import org.apache.spark.sql.{AnalysisException, Strategy} import org.apache.spark.sql.catalyst.InternalRow @@ -43,13 +42,12 @@ import org.apache.spark.sql.streaming.StreamingQuery * writing libraries should instead consider using the stable APIs provided in * [[org.apache.spark.sql.sources]] */ -@DeveloperApi abstract class SparkStrategy extends GenericStrategy[SparkPlan] { override protected def planLater(plan: LogicalPlan): SparkPlan = PlanLater(plan) } -private[sql] case class PlanLater(plan: LogicalPlan) extends LeafExecNode { +case class PlanLater(plan: LogicalPlan) extends LeafExecNode { override def output: Seq[Attribute] = plan.output @@ -58,7 +56,7 @@ private[sql] case class PlanLater(plan: LogicalPlan) extends LeafExecNode { } } -private[sql] abstract class SparkStrategies extends QueryPlanner[SparkPlan] { +abstract class SparkStrategies extends QueryPlanner[SparkPlan] { self: SparkPlanner => /** @@ -68,22 +66,22 @@ private[sql] abstract class SparkStrategies extends QueryPlanner[SparkPlan] { override def apply(plan: LogicalPlan): Seq[SparkPlan] = plan match { case logical.ReturnAnswer(rootPlan) => rootPlan match { case logical.Limit(IntegerLiteral(limit), logical.Sort(order, true, child)) => - execution.TakeOrderedAndProjectExec(limit, order, None, planLater(child)) :: Nil + execution.TakeOrderedAndProjectExec(limit, order, child.output, planLater(child)) :: Nil case logical.Limit( IntegerLiteral(limit), logical.Project(projectList, logical.Sort(order, true, child))) => execution.TakeOrderedAndProjectExec( - limit, order, Some(projectList), planLater(child)) :: Nil + limit, order, projectList, planLater(child)) :: Nil case logical.Limit(IntegerLiteral(limit), child) => execution.CollectLimitExec(limit, planLater(child)) :: Nil case other => planLater(other) :: Nil } case logical.Limit(IntegerLiteral(limit), logical.Sort(order, true, child)) => - execution.TakeOrderedAndProjectExec(limit, order, None, planLater(child)) :: Nil + execution.TakeOrderedAndProjectExec(limit, order, child.output, planLater(child)) :: Nil case logical.Limit( IntegerLiteral(limit), logical.Project(projectList, logical.Sort(order, true, child))) => execution.TakeOrderedAndProjectExec( - limit, order, Some(projectList), planLater(child)) :: Nil + limit, order, projectList, planLater(child)) :: Nil case _ => Nil } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/UnsafeRowSerializer.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/UnsafeRowSerializer.scala index 484923428f4a..8ab553369de6 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/UnsafeRowSerializer.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/UnsafeRowSerializer.scala @@ -40,12 +40,12 @@ import org.apache.spark.unsafe.Platform * * @param numFields the number of fields in the row being serialized. */ -private[sql] class UnsafeRowSerializer( +class UnsafeRowSerializer( numFields: Int, dataSize: SQLMetric = null) extends Serializer with Serializable { override def newInstance(): SerializerInstance = new UnsafeRowSerializerInstance(numFields, dataSize) - override private[spark] def supportsRelocationOfSerializedObjects: Boolean = true + override def supportsRelocationOfSerializedObjects: Boolean = true } private class UnsafeRowSerializerInstance( diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegenExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegenExec.scala index ac4c3aae5f8e..fb57ed7692de 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegenExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegenExec.scala @@ -295,7 +295,7 @@ case class WholeStageCodegenExec(child: SparkPlan) extends UnaryExecNode with Co override def outputPartitioning: Partitioning = child.outputPartitioning override def outputOrdering: Seq[SortOrder] = child.outputOrdering - override private[sql] lazy val metrics = Map( + override lazy val metrics = Map( "pipelineTime" -> SQLMetrics.createTimingMetric(sparkContext, WholeStageCodegenExec.PIPELINE_DURATION_METRIC)) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/AggUtils.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/AggUtils.scala index 4fbb9d554c9b..d004830d7da5 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/AggUtils.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/AggUtils.scala @@ -14,6 +14,24 @@ * See the License for the specific language governing permissions and * limitations under the License. */ +/* + * Changes for SnappyData data platform. + * + * Portions Copyright (c) 2016 SnappyData, Inc. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you + * may not use this file except in compliance with the License. You + * may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or + * implied. See the License for the specific language governing + * permissions and limitations under the License. See accompanying + * LICENSE file. + */ package org.apache.spark.sql.execution.aggregate @@ -41,7 +59,7 @@ object AggUtils { aggregateExpressions = completeAggregateExpressions, aggregateAttributes = completeAggregateAttributes, initialInputBufferOffset = 0, - resultExpressions = resultExpressions, + __resultExpressions = resultExpressions, child = child ) :: Nil } @@ -63,7 +81,7 @@ object AggUtils { aggregateExpressions = aggregateExpressions, aggregateAttributes = aggregateAttributes, initialInputBufferOffset = initialInputBufferOffset, - resultExpressions = resultExpressions, + __resultExpressions = resultExpressions, child = child) } else { SortAggregateExec( @@ -72,7 +90,7 @@ object AggUtils { aggregateExpressions = aggregateExpressions, aggregateAttributes = aggregateAttributes, initialInputBufferOffset = initialInputBufferOffset, - resultExpressions = resultExpressions, + __resultExpressions = resultExpressions, child = child) } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/AggregationIterator.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/AggregationIterator.scala index 34de76dd4ab4..6ca36e4acb44 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/AggregationIterator.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/AggregationIterator.scala @@ -73,9 +73,10 @@ abstract class AggregationIterator( startingInputBufferOffset: Int): Array[AggregateFunction] = { var mutableBufferOffset = 0 var inputBufferOffset: Int = startingInputBufferOffset - val functions = new Array[AggregateFunction](expressions.length) + val expressionsLength = expressions.length + val functions = new Array[AggregateFunction](expressionsLength) var i = 0 - while (i < expressions.length) { + while (i < expressionsLength) { val func = expressions(i).aggregateFunction val funcWithBoundReferences: AggregateFunction = expressions(i).mode match { case Partial | Complete if func.isInstanceOf[ImperativeAggregate] => @@ -171,7 +172,7 @@ abstract class AggregationIterator( case PartialMerge | Final => (buffer: MutableRow, row: InternalRow) => ae.merge(buffer, row) } - } + }.toArray // This projection is used to merge buffer values for all expression-based aggregates. val aggregationBufferSchema = functions.flatMap(_.aggBufferAttributes) val updateProjection = diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashAggregateExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashAggregateExec.scala index 54d7340d8acd..6d90a37983b4 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashAggregateExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashAggregateExec.scala @@ -14,6 +14,24 @@ * See the License for the specific language governing permissions and * limitations under the License. */ +/* + * Changes for SnappyData data platform. + * + * Portions Copyright (c) 2016 SnappyData, Inc. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you + * may not use this file except in compliance with the License. You + * may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or + * implied. See the License for the specific language governing + * permissions and limitations under the License. See accompanying + * LICENSE file. + */ package org.apache.spark.sql.execution.aggregate @@ -40,21 +58,27 @@ case class HashAggregateExec( aggregateExpressions: Seq[AggregateExpression], aggregateAttributes: Seq[Attribute], initialInputBufferOffset: Int, - resultExpressions: Seq[NamedExpression], + __resultExpressions: Seq[NamedExpression], child: SparkPlan) extends UnaryExecNode with CodegenSupport { - private[this] val aggregateBufferAttributes = { + @transient lazy val resultExpressions = __resultExpressions + + @transient lazy private[this] val aggregateBufferAttributes = { aggregateExpressions.flatMap(_.aggregateFunction.aggBufferAttributes) } + @transient lazy private[this] val aggregateBufferAttributesForGroup = { + aggregateExpressions.flatMap(_.aggregateFunction.aggBufferAttributesForGroup) + } + require(HashAggregateExec.supportsAggregate(aggregateBufferAttributes)) override lazy val allAttributes: AttributeSeq = child.output ++ aggregateBufferAttributes ++ aggregateAttributes ++ aggregateExpressions.flatMap(_.aggregateFunction.inputAggBufferAttributes) - override private[sql] lazy val metrics = Map( + override lazy val metrics = Map( "numOutputRows" -> SQLMetrics.createMetric(sparkContext, "number of output rows"), "peakMemory" -> SQLMetrics.createSizeMetric(sparkContext, "peak memory"), "spillSize" -> SQLMetrics.createSizeMetric(sparkContext, "spill size"), @@ -277,7 +301,7 @@ case class HashAggregateExec( private val declFunctions = aggregateExpressions.map(_.aggregateFunction) .filter(_.isInstanceOf[DeclarativeAggregate]) .map(_.asInstanceOf[DeclarativeAggregate]) - private val bufferSchema = StructType.fromAttributes(aggregateBufferAttributes) + private val bufferSchema = StructType.fromAttributes(aggregateBufferAttributesForGroup) // The name for Vectorized HashMap private var vectorizedHashMapTerm: String = _ @@ -292,7 +316,7 @@ case class HashAggregateExec( */ def createHashMap(): UnsafeFixedWidthAggregationMap = { // create initialized aggregate buffer - val initExpr = declFunctions.flatMap(f => f.initialValues) + val initExpr = declFunctions.flatMap(_.initialValuesForGroup) val initialBuffer = UnsafeProjection.create(initExpr)(EmptyRow) // create hashMap @@ -348,7 +372,7 @@ case class HashAggregateExec( val mergeExpr = declFunctions.flatMap(_.mergeExpressions) val mergeProjection = newMutableProjection( mergeExpr, - aggregateBufferAttributes ++ declFunctions.flatMap(_.inputAggBufferAttributes), + aggregateBufferAttributesForGroup ++ declFunctions.flatMap(_.inputAggBufferAttributes), subexpressionEliminationEnabled) val joinedRow = new JoinedRow() @@ -413,14 +437,14 @@ case class HashAggregateExec( } val evaluateKeyVars = evaluateVariables(keyVars) ctx.INPUT_ROW = bufferTerm - val bufferVars = aggregateBufferAttributes.zipWithIndex.map { case (e, i) => + val bufferVars = aggregateBufferAttributesForGroup.zipWithIndex.map { case (e, i) => BoundReference(i, e.dataType, e.nullable).genCode(ctx) } val evaluateBufferVars = evaluateVariables(bufferVars) // evaluate the aggregation result ctx.currentVars = bufferVars val aggResults = declFunctions.map(_.evaluateExpression).map { e => - BindReferences.bindReference(e, aggregateBufferAttributes).genCode(ctx) + BindReferences.bindReference(e, aggregateBufferAttributesForGroup).genCode(ctx) } val evaluateAggResults = evaluateVariables(aggResults) // generate the final result @@ -603,8 +627,6 @@ case class HashAggregateExec( // create grouping key ctx.currentVars = input - // make sure that the generated code will not be splitted as multiple functions - ctx.INPUT_ROW = null val unsafeRowKeyCode = GenerateUnsafeProjection.createCode( ctx, groupingExpressions.map(e => BindReferences.bindReference[Expression](e, child.output))) val vectorizedRowKeys = ctx.generateExpressions( @@ -628,8 +650,8 @@ case class HashAggregateExec( ctx.currentVars = input val hashEval = BindReferences.bindReference(hashExpr, child.output).genCode(ctx) - val inputAttr = aggregateBufferAttributes ++ child.output - ctx.currentVars = new Array[ExprCode](aggregateBufferAttributes.length) ++ input + val inputAttr = aggregateBufferAttributesForGroup ++ child.output + ctx.currentVars = new Array[ExprCode](aggregateBufferAttributesForGroup.length) ++ input val (checkFallbackForGeneratedHashMap, checkFallbackForBytesToBytesMap, resetCounter, incCounter) = if (testFallbackStartsAt.isDefined) { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/SortAggregateExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/SortAggregateExec.scala index 05dbacf07a17..cc9d902b52f8 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/SortAggregateExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/SortAggregateExec.scala @@ -14,6 +14,24 @@ * See the License for the specific language governing permissions and * limitations under the License. */ +/* + * Changes for SnappyData data platform. + * + * Portions Copyright (c) 2016 SnappyData, Inc. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you + * may not use this file except in compliance with the License. You + * may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or + * implied. See the License for the specific language governing + * permissions and limitations under the License. See accompanying + * LICENSE file. + */ package org.apache.spark.sql.execution.aggregate @@ -36,11 +54,13 @@ case class SortAggregateExec( aggregateExpressions: Seq[AggregateExpression], aggregateAttributes: Seq[Attribute], initialInputBufferOffset: Int, - resultExpressions: Seq[NamedExpression], + __resultExpressions: Seq[NamedExpression], child: SparkPlan) extends UnaryExecNode { - private[this] val aggregateBufferAttributes = { + @transient lazy val resultExpressions = __resultExpressions + + @transient lazy private[this] val aggregateBufferAttributes = { aggregateExpressions.flatMap(_.aggregateFunction.aggBufferAttributes) } @@ -49,7 +69,7 @@ case class SortAggregateExec( AttributeSet(resultExpressions.diff(groupingExpressions).map(_.toAttribute)) ++ AttributeSet(aggregateBufferAttributes) - override private[sql] lazy val metrics = Map( + override lazy val metrics = Map( "numOutputRows" -> SQLMetrics.createMetric(sparkContext, "number of output rows")) override def output: Seq[Attribute] = resultExpressions.map(_.toAttribute) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/udaf.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/udaf.scala index b047bc0641dd..586e1456ac69 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/udaf.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/udaf.scala @@ -204,7 +204,7 @@ sealed trait BufferSetterGetterUtils { /** * A Mutable [[Row]] representing a mutable aggregation buffer. */ -private[sql] class MutableAggregationBufferImpl ( +private[aggregate] class MutableAggregationBufferImpl( schema: StructType, toCatalystConverters: Array[Any => Any], toScalaConverters: Array[Any => Any], @@ -266,7 +266,7 @@ private[sql] class MutableAggregationBufferImpl ( /** * A [[Row]] representing an immutable aggregation buffer. */ -private[sql] class InputAggregationBuffer private[sql] ( +private[aggregate] class InputAggregationBuffer( schema: StructType, toCatalystConverters: Array[Any => Any], toScalaConverters: Array[Any => Any], @@ -319,7 +319,7 @@ private[sql] class InputAggregationBuffer private[sql] ( * The internal wrapper used to hook a [[UserDefinedAggregateFunction]] `udaf` in the * internal aggregation code path. */ -private[sql] case class ScalaUDAF( +case class ScalaUDAF( children: Seq[Expression], udaf: UserDefinedAggregateFunction, mutableAggBufferOffset: Int = 0, diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/basicPhysicalOperators.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/basicPhysicalOperators.scala index 185c79f899e6..a544371ffee7 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/basicPhysicalOperators.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/basicPhysicalOperators.scala @@ -102,7 +102,7 @@ case class FilterExec(condition: Expression, child: SparkPlan) } } - private[sql] override lazy val metrics = Map( + override lazy val metrics = Map( "numOutputRows" -> SQLMetrics.createMetric(sparkContext, "number of output rows")) override def inputRDDs(): Seq[RDD[InternalRow]] = { @@ -228,7 +228,7 @@ case class SampleExec( child: SparkPlan) extends UnaryExecNode with CodegenSupport { override def output: Seq[Attribute] = child.output - private[sql] override lazy val metrics = Map( + override lazy val metrics = Map( "numOutputRows" -> SQLMetrics.createMetric(sparkContext, "number of output rows")) protected override def doExecute(): RDD[InternalRow] = { @@ -260,6 +260,7 @@ case class SampleExec( if (withReplacement) { val samplerClass = classOf[PoissonSampler[UnsafeRow]].getName val initSampler = ctx.freshName("initSampler") + ctx.copyResult = true ctx.addMutableState(s"$samplerClass", sampler, s"$initSampler();") @@ -312,12 +313,12 @@ case class RangeExec(range: org.apache.spark.sql.catalyst.plans.logical.Range) def start: Long = range.start def step: Long = range.step - def numSlices: Int = range.numSlices + def numSlices: Int = range.numSlices.getOrElse(sparkContext.defaultParallelism) def numElements: BigInt = range.numElements override val output: Seq[Attribute] = range.output - private[sql] override lazy val metrics = Map( + override lazy val metrics = Map( "numOutputRows" -> SQLMetrics.createMetric(sparkContext, "number of output rows")) // output attributes should not affect the results diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/GenerateColumnAccessor.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/GenerateColumnAccessor.scala index 7a14879b8b9d..96bd338f092e 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/GenerateColumnAccessor.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/GenerateColumnAccessor.scala @@ -127,7 +127,7 @@ object GenerateColumnAccessor extends CodeGenerator[Seq[DataType], ColumnarItera val groupedAccessorsItr = initializeAccessors.grouped(numberOfStatementsThreshold) val groupedExtractorsItr = extractors.grouped(numberOfStatementsThreshold) var groupedAccessorsLength = 0 - groupedAccessorsItr.zipWithIndex.map { case (body, i) => + groupedAccessorsItr.zipWithIndex.foreach { case (body, i) => groupedAccessorsLength += 1 val funcName = s"accessors$i" val funcCode = s""" @@ -137,7 +137,7 @@ object GenerateColumnAccessor extends CodeGenerator[Seq[DataType], ColumnarItera """.stripMargin ctx.addNewFunction(funcName, funcCode) } - groupedExtractorsItr.zipWithIndex.map { case (body, i) => + groupedExtractorsItr.zipWithIndex.foreach { case (body, i) => val funcName = s"extractors$i" val funcCode = s""" |private void $funcName() { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/InMemoryRelation.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/InMemoryRelation.scala index 079e122a5a85..479934a7afc7 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/InMemoryRelation.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/InMemoryRelation.scala @@ -34,7 +34,7 @@ import org.apache.spark.storage.StorageLevel import org.apache.spark.util.CollectionAccumulator -private[sql] object InMemoryRelation { +object InMemoryRelation { def apply( useCompression: Boolean, batchSize: Int, @@ -55,15 +55,15 @@ private[sql] object InMemoryRelation { private[columnar] case class CachedBatch(numRows: Int, buffers: Array[Array[Byte]], stats: InternalRow) -private[sql] case class InMemoryRelation( +case class InMemoryRelation( output: Seq[Attribute], useCompression: Boolean, batchSize: Int, storageLevel: StorageLevel, @transient child: SparkPlan, tableName: Option[String])( - @transient private[sql] var _cachedColumnBuffers: RDD[CachedBatch] = null, - private[sql] val batchStats: CollectionAccumulator[InternalRow] = + @transient var _cachedColumnBuffers: RDD[CachedBatch] = null, + val batchStats: CollectionAccumulator[InternalRow] = child.sqlContext.sparkContext.collectionAccumulator[InternalRow]) extends logical.LeafNode with MultiInstanceRelation { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/InMemoryTableScanExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/InMemoryTableScanExec.scala index 183e4947b6d7..e63b313cb1d5 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/InMemoryTableScanExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/InMemoryTableScanExec.scala @@ -28,7 +28,7 @@ import org.apache.spark.sql.execution.metric.SQLMetrics import org.apache.spark.sql.types.UserDefinedType -private[sql] case class InMemoryTableScanExec( +case class InMemoryTableScanExec( attributes: Seq[Attribute], predicates: Seq[Expression], @transient relation: InMemoryRelation) @@ -36,7 +36,7 @@ private[sql] case class InMemoryTableScanExec( override protected def innerChildren: Seq[QueryPlan[_]] = Seq(relation) ++ super.innerChildren - private[sql] override lazy val metrics = Map( + override lazy val metrics = Map( "numOutputRows" -> SQLMetrics.createMetric(sparkContext, "number of output rows")) override def output: Seq[Attribute] = attributes diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/compression/CompressibleColumnBuilder.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/compression/CompressibleColumnBuilder.scala index 63eae1b8685a..0f4680e50278 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/compression/CompressibleColumnBuilder.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/compression/CompressibleColumnBuilder.scala @@ -66,11 +66,7 @@ private[columnar] trait CompressibleColumnBuilder[T <: AtomicType] } private def gatherCompressibilityStats(row: InternalRow, ordinal: Int): Unit = { - var i = 0 - while (i < compressionEncoders.length) { - compressionEncoders(i).gatherCompressibilityStats(row, ordinal) - i += 1 - } + compressionEncoders.foreach(_.gatherCompressibilityStats(row, ordinal)) } abstract override def appendFrom(row: InternalRow, ordinal: Int): Unit = { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/compression/compressionSchemes.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/compression/compressionSchemes.scala index 941f03b745a0..089f3944e5c1 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/compression/compressionSchemes.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/compression/compressionSchemes.scala @@ -25,6 +25,7 @@ import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions.{MutableRow, SpecificMutableRow} import org.apache.spark.sql.execution.columnar._ import org.apache.spark.sql.types._ +import org.apache.spark.util.collection.OpenHashMap private[columnar] case object PassThrough extends CompressionScheme { @@ -208,7 +209,7 @@ private[columnar] case object DictionaryEncoding extends CompressionScheme { private var values = new mutable.ArrayBuffer[T#InternalType](1024) // The dictionary that maps a value to the encoded short integer. - private val dictionary = mutable.HashMap.empty[Any, Short] + private val dictionary = new OpenHashMap[Any, Short] // Size of the serialized dictionary in bytes. Initialized to 4 since we need at least an `Int` // to store dictionary element count. diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/AnalyzeTableCommand.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/AnalyzeTableCommand.scala index a469d4da8613..07127533b034 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/AnalyzeTableCommand.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/AnalyzeTableCommand.scala @@ -22,8 +22,9 @@ import scala.util.control.NonFatal import org.apache.hadoop.fs.{FileSystem, Path} import org.apache.spark.sql.{AnalysisException, Row, SparkSession} +import org.apache.spark.sql.catalyst.TableIdentifier import org.apache.spark.sql.catalyst.analysis.EliminateSubqueryAliases -import org.apache.spark.sql.catalyst.catalog.{CatalogRelation, CatalogTable} +import org.apache.spark.sql.catalyst.catalog.{CatalogRelation, CatalogTable, SimpleCatalogRelation} /** @@ -38,10 +39,12 @@ case class AnalyzeTableCommand(tableName: String) extends RunnableCommand { override def run(sparkSession: SparkSession): Seq[Row] = { val sessionState = sparkSession.sessionState val tableIdent = sessionState.sqlParser.parseTableIdentifier(tableName) - val relation = EliminateSubqueryAliases(sessionState.catalog.lookupRelation(tableIdent)) + val db = tableIdent.database.getOrElse(sessionState.catalog.getCurrentDatabase) + val tableIdentwithDB = TableIdentifier(tableIdent.table, Some(db)) + val relation = EliminateSubqueryAliases(sessionState.catalog.lookupRelation(tableIdentwithDB)) relation match { - case relation: CatalogRelation => + case relation: CatalogRelation if !relation.isInstanceOf[SimpleCatalogRelation] => val catalogTable: CatalogTable = relation.catalogTable // This method is mainly based on // org.apache.hadoop.hive.ql.stats.StatsUtils.getFileSizeForTable(HiveConf, Table) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/commands.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/commands.scala index 7eaad81a8161..424a962b5eb1 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/commands.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/commands.scala @@ -35,9 +35,9 @@ import org.apache.spark.sql.types._ * A logical command that is executed for its side-effects. `RunnableCommand`s are * wrapped in `ExecutedCommand` during execution. */ -private[sql] trait RunnableCommand extends LogicalPlan with logical.Command { +trait RunnableCommand extends LogicalPlan with logical.Command { override def output: Seq[Attribute] = Seq.empty - override def children: Seq[LogicalPlan] = Seq.empty + final override def children: Seq[LogicalPlan] = Seq.empty def run(sparkSession: SparkSession): Seq[Row] } @@ -45,7 +45,7 @@ private[sql] trait RunnableCommand extends LogicalPlan with logical.Command { * A physical operator that executes the run method of a `RunnableCommand` and * saves the result to prevent multiple executions. */ -private[sql] case class ExecutedCommandExec(cmd: RunnableCommand) extends SparkPlan { +case class ExecutedCommandExec(cmd: RunnableCommand) extends SparkPlan { /** * A concrete command should override this lazy field to wrap up any side effects caused by the * command or any other computation that should be evaluated exactly once. The value of this field diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/createDataSourceTables.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/createDataSourceTables.scala index c38eca5156e5..de7d1fa0afe6 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/createDataSourceTables.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/createDataSourceTables.scala @@ -27,7 +27,6 @@ import org.apache.spark.sql._ import org.apache.spark.sql.catalyst.TableIdentifier import org.apache.spark.sql.catalyst.analysis.EliminateSubqueryAliases import org.apache.spark.sql.catalyst.catalog._ -import org.apache.spark.sql.catalyst.plans.QueryPlan import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan import org.apache.spark.sql.execution.datasources._ import org.apache.spark.sql.internal.HiveSerDe @@ -74,14 +73,12 @@ case class CreateDataSourceTableCommand( s"characters, numbers and _.") } - val tableName = tableIdent.unquotedString val sessionState = sparkSession.sessionState - if (sessionState.catalog.tableExists(tableIdent)) { if (ignoreIfExists) { return Seq.empty[Row] } else { - throw new AnalysisException(s"Table $tableName already exists.") + throw new AnalysisException(s"Table ${tableIdent.unquotedString} already exists.") } } @@ -139,7 +136,7 @@ case class CreateDataSourceTableAsSelectCommand( query: LogicalPlan) extends RunnableCommand { - override protected def innerChildren: Seq[QueryPlan[_]] = Seq(query) + override protected def innerChildren: Seq[LogicalPlan] = Seq(query) override def run(sparkSession: SparkSession): Seq[Row] = { // Since we are saving metadata to metastore, we need to check if metastore supports @@ -157,8 +154,11 @@ case class CreateDataSourceTableAsSelectCommand( s"characters, numbers and _.") } - val tableName = tableIdent.unquotedString val sessionState = sparkSession.sessionState + val db = tableIdent.database.getOrElse(sessionState.catalog.getCurrentDatabase) + val tableIdentWithDB = tableIdent.copy(database = Some(db)) + val tableName = tableIdentWithDB.unquotedString + var createMetastoreTable = false var isExternal = true val optionsWithPath = @@ -170,7 +170,9 @@ case class CreateDataSourceTableAsSelectCommand( } var existingSchema = Option.empty[StructType] - if (sparkSession.sessionState.catalog.tableExists(tableIdent)) { + // Pass a table identifier with database part, so that `tableExists` won't check temp views + // unexpectedly. + if (sparkSession.sessionState.catalog.tableExists(tableIdentWithDB)) { // Check if we need to throw an exception or just return. mode match { case SaveMode.ErrorIfExists => @@ -194,14 +196,15 @@ case class CreateDataSourceTableAsSelectCommand( // TODO: Check that options from the resolved relation match the relation that we are // inserting into (i.e. using the same compression). - EliminateSubqueryAliases( - sessionState.catalog.lookupRelation(tableIdent)) match { + // Pass a table identifier with database part, so that `tableExists` won't check temp + // views unexpectedly. + EliminateSubqueryAliases(sessionState.catalog.lookupRelation(tableIdentWithDB)) match { case l @ LogicalRelation(_: InsertableRelation | _: HadoopFsRelation, _, _) => // check if the file formats match l.relation match { case r: HadoopFsRelation if r.fileFormat.getClass != dataSource.providingClass => throw new AnalysisException( - s"The file format of the existing table $tableIdent is " + + s"The file format of the existing table $tableName is " + s"`${r.fileFormat.getClass.getName}`. It doesn't match the specified " + s"format `$provider`") case _ => @@ -218,7 +221,7 @@ case class CreateDataSourceTableAsSelectCommand( throw new AnalysisException(s"Saving data in ${o.toString} is not supported.") } case SaveMode.Overwrite => - sparkSession.sql(s"DROP TABLE IF EXISTS $tableName") + sessionState.catalog.dropTable(tableIdentWithDB, ignoreIfNotExists = true) // Need to create the table again. createMetastoreTable = true } @@ -246,7 +249,7 @@ case class CreateDataSourceTableAsSelectCommand( dataSource.write(mode, df) } catch { case ex: AnalysisException => - logError(s"Failed to write to table ${tableIdent.identifier} in $mode mode", ex) + logError(s"Failed to write to table $tableName in $mode mode", ex) throw ex } if (createMetastoreTable) { @@ -265,7 +268,7 @@ case class CreateDataSourceTableAsSelectCommand( } // Refresh the cache of the table in the catalog. - sessionState.catalog.refreshTable(tableIdent) + sessionState.catalog.refreshTable(tableIdentWithDB) Seq.empty[Row] } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/ddl.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/ddl.scala index 226f61ef404a..16deee359f53 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/ddl.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/ddl.scala @@ -17,18 +17,25 @@ package org.apache.spark.sql.execution.command +import scala.collection.{GenMap, GenSeq} +import scala.collection.parallel.ForkJoinTaskSupport +import scala.concurrent.forkjoin.ForkJoinPool import scala.util.control.NonFatal +import org.apache.hadoop.conf.Configuration +import org.apache.hadoop.fs._ +import org.apache.hadoop.mapred.{FileInputFormat, JobConf} + import org.apache.spark.sql.{AnalysisException, Row, SparkSession} import org.apache.spark.sql.catalyst.TableIdentifier -import org.apache.spark.sql.catalyst.catalog.{CatalogDatabase, CatalogTable} -import org.apache.spark.sql.catalyst.catalog.{CatalogTablePartition, CatalogTableType, SessionCatalog} +import org.apache.spark.sql.catalyst.catalog.{CatalogDatabase, CatalogTable, CatalogTablePartition, CatalogTableType, SessionCatalog} import org.apache.spark.sql.catalyst.catalog.CatalogTypes.TablePartitionSpec import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeReference} import org.apache.spark.sql.execution.command.CreateDataSourceTableUtils._ import org.apache.spark.sql.execution.datasources.BucketSpec +import org.apache.spark.sql.execution.datasources.PartitioningUtils import org.apache.spark.sql.types._ - +import org.apache.spark.util.SerializableConfiguration // Note: The definition of these commands are based on the ones described in // https://cwiki.apache.org/confluence/display/Hive/LanguageManual+DDL @@ -183,32 +190,25 @@ case class DropTableCommand( override def run(sparkSession: SparkSession): Seq[Row] = { val catalog = sparkSession.sessionState.catalog - if (!catalog.tableExists(tableName)) { - if (!ifExists) { - val objectName = if (isView) "View" else "Table" - throw new AnalysisException(s"$objectName to drop '$tableName' does not exist") - } - } else { - // If the command DROP VIEW is to drop a table or DROP TABLE is to drop a view - // issue an exception. - catalog.getTableMetadataOption(tableName).map(_.tableType match { - case CatalogTableType.VIEW if !isView => - throw new AnalysisException( - "Cannot drop a view with DROP TABLE. Please use DROP VIEW instead") - case o if o != CatalogTableType.VIEW && isView => - throw new AnalysisException( - s"Cannot drop a table with DROP VIEW. Please use DROP TABLE instead") - case _ => - }) - try { - sparkSession.sharedState.cacheManager.uncacheQuery( - sparkSession.table(tableName.quotedString)) - } catch { - case NonFatal(e) => log.warn(e.toString, e) - } - catalog.refreshTable(tableName) - catalog.dropTable(tableName, ifExists) + // If the command DROP VIEW is to drop a table or DROP TABLE is to drop a view + // issue an exception. + catalog.getTableMetadataOption(tableName).map(_.tableType match { + case CatalogTableType.VIEW if !isView => + throw new AnalysisException( + "Cannot drop a view with DROP TABLE. Please use DROP VIEW instead") + case o if o != CatalogTableType.VIEW && isView => + throw new AnalysisException( + s"Cannot drop a table with DROP VIEW. Please use DROP TABLE instead") + case _ => + }) + try { + sparkSession.sharedState.cacheManager.uncacheQuery( + sparkSession.table(tableName.quotedString)) + } catch { + case NonFatal(e) => log.warn(e.toString, e) } + catalog.refreshTable(tableName) + catalog.dropTable(tableName, ifExists) Seq.empty[Row] } } @@ -268,7 +268,7 @@ case class AlterTableUnsetPropertiesCommand( propKeys.foreach { k => if (!table.properties.contains(k)) { throw new AnalysisException( - s"Attempted to unset non-existent property '$k' in table '$tableName'") + s"Attempted to unset non-existent property '$k' in table '${table.identifier}'") } } } @@ -323,11 +323,11 @@ case class AlterTableSerDePropertiesCommand( catalog.alterTable(newTable) } else { val spec = partSpec.get - val part = catalog.getPartition(tableName, spec) + val part = catalog.getPartition(table.identifier, spec) val newPart = part.copy(storage = part.storage.copy( serde = serdeClassName.orElse(part.storage.serde), serdeProperties = part.storage.serdeProperties ++ serdeProperties.getOrElse(Map()))) - catalog.alterPartitions(tableName, Seq(newPart)) + catalog.alterPartitions(table.identifier, Seq(newPart)) } Seq.empty[Row] } @@ -363,7 +363,7 @@ case class AlterTableAddPartitionCommand( // inherit table storage format (possibly except for location) CatalogTablePartition(spec, table.storage.copy(locationUri = location)) } - catalog.createPartitions(tableName, parts, ignoreIfExists = ifNotExists) + catalog.createPartitions(table.identifier, parts, ignoreIfExists = ifNotExists) Seq.empty[Row] } @@ -418,13 +418,212 @@ case class AlterTableDropPartitionCommand( throw new AnalysisException( "ALTER TABLE DROP PARTITIONS is not allowed for tables defined using the datasource API") } - catalog.dropPartitions(tableName, specs, ignoreIfNotExists = ifExists) + catalog.dropPartitions(table.identifier, specs, ignoreIfNotExists = ifExists) Seq.empty[Row] } } +case class PartitionStatistics(numFiles: Int, totalSize: Long) + +/** + * Recover Partitions in ALTER TABLE: recover all the partition in the directory of a table and + * update the catalog. + * + * The syntax of this command is: + * {{{ + * ALTER TABLE table RECOVER PARTITIONS; + * MSCK REPAIR TABLE table; + * }}} + */ +case class AlterTableRecoverPartitionsCommand( + tableName: TableIdentifier, + cmd: String = "ALTER TABLE RECOVER PARTITIONS") extends RunnableCommand { + + // These are list of statistics that can be collected quickly without requiring a scan of the data + // see https://github.com/apache/hive/blob/master/ + // common/src/java/org/apache/hadoop/hive/common/StatsSetupConst.java + val NUM_FILES = "numFiles" + val TOTAL_SIZE = "totalSize" + val DDL_TIME = "transient_lastDdlTime" + + private def getPathFilter(hadoopConf: Configuration): PathFilter = { + // Dummy jobconf to get to the pathFilter defined in configuration + // It's very expensive to create a JobConf(ClassUtil.findContainingJar() is slow) + val jobConf = new JobConf(hadoopConf, this.getClass) + val pathFilter = FileInputFormat.getInputPathFilter(jobConf) + new PathFilter { + override def accept(path: Path): Boolean = { + val name = path.getName + if (name != "_SUCCESS" && name != "_temporary" && !name.startsWith(".")) { + pathFilter == null || pathFilter.accept(path) + } else { + false + } + } + } + } + + override def run(spark: SparkSession): Seq[Row] = { + val catalog = spark.sessionState.catalog + val table = catalog.getTableMetadata(tableName) + val tableIdentWithDB = table.identifier.quotedString + if (DDLUtils.isDatasourceTable(table)) { + throw new AnalysisException( + s"Operation not allowed: $cmd on datasource tables: $tableIdentWithDB") + } + if (!DDLUtils.isTablePartitioned(table)) { + throw new AnalysisException( + s"Operation not allowed: $cmd only works on partitioned tables: $tableIdentWithDB") + } + if (table.storage.locationUri.isEmpty) { + throw new AnalysisException(s"Operation not allowed: $cmd only works on table with " + + s"location provided: $tableIdentWithDB") + } + + val root = new Path(table.storage.locationUri.get) + logInfo(s"Recover all the partitions in $root") + val fs = root.getFileSystem(spark.sparkContext.hadoopConfiguration) + + val threshold = spark.conf.get("spark.rdd.parallelListingThreshold", "10").toInt + val hadoopConf = spark.sparkContext.hadoopConfiguration + val pathFilter = getPathFilter(hadoopConf) + val partitionSpecsAndLocs = scanPartitions( + spark, fs, pathFilter, root, Map(), table.partitionColumnNames.map(_.toLowerCase), threshold) + val total = partitionSpecsAndLocs.length + logInfo(s"Found $total partitions in $root") + + val partitionStats = if (spark.sqlContext.conf.gatherFastStats) { + gatherPartitionStats(spark, partitionSpecsAndLocs, fs, pathFilter, threshold) + } else { + GenMap.empty[String, PartitionStatistics] + } + logInfo(s"Finished to gather the fast stats for all $total partitions.") + + addPartitions(spark, table, partitionSpecsAndLocs, partitionStats) + logInfo(s"Recovered all partitions ($total).") + Seq.empty[Row] + } + + @transient private lazy val evalTaskSupport = new ForkJoinTaskSupport(new ForkJoinPool(8)) + + private def scanPartitions( + spark: SparkSession, + fs: FileSystem, + filter: PathFilter, + path: Path, + spec: TablePartitionSpec, + partitionNames: Seq[String], + threshold: Int): GenSeq[(TablePartitionSpec, Path)] = { + if (partitionNames.isEmpty) { + return Seq(spec -> path) + } + + val statuses = fs.listStatus(path, filter) + val statusPar: GenSeq[FileStatus] = + if (partitionNames.length > 1 && statuses.length > threshold || partitionNames.length > 2) { + // parallelize the list of partitions here, then we can have better parallelism later. + val parArray = statuses.par + parArray.tasksupport = evalTaskSupport + parArray + } else { + statuses + } + statusPar.flatMap { st => + val name = st.getPath.getName + if (st.isDirectory && name.contains("=")) { + val ps = name.split("=", 2) + val columnName = PartitioningUtils.unescapePathName(ps(0)).toLowerCase + // TODO: Validate the value + val value = PartitioningUtils.unescapePathName(ps(1)) + // comparing with case-insensitive, but preserve the case + if (columnName == partitionNames.head) { + scanPartitions(spark, fs, filter, st.getPath, spec ++ Map(columnName -> value), + partitionNames.drop(1), threshold) + } else { + logWarning(s"expect partition column ${partitionNames.head}, but got ${ps(0)}, ignore it") + Seq() + } + } else { + logWarning(s"ignore ${new Path(path, name)}") + Seq() + } + } + } + + private def gatherPartitionStats( + spark: SparkSession, + partitionSpecsAndLocs: GenSeq[(TablePartitionSpec, Path)], + fs: FileSystem, + pathFilter: PathFilter, + threshold: Int): GenMap[String, PartitionStatistics] = { + if (partitionSpecsAndLocs.length > threshold) { + val hadoopConf = spark.sparkContext.hadoopConfiguration + val serializableConfiguration = new SerializableConfiguration(hadoopConf) + val serializedPaths = partitionSpecsAndLocs.map(_._2.toString).toArray + + // Set the number of parallelism to prevent following file listing from generating many tasks + // in case of large #defaultParallelism. + val numParallelism = Math.min(serializedPaths.length, + Math.min(spark.sparkContext.defaultParallelism, 10000)) + // gather the fast stats for all the partitions otherwise Hive metastore will list all the + // files for all the new partitions in sequential way, which is super slow. + logInfo(s"Gather the fast stats in parallel using $numParallelism tasks.") + spark.sparkContext.parallelize(serializedPaths, numParallelism) + .mapPartitions { paths => + val pathFilter = getPathFilter(serializableConfiguration.value) + paths.map(new Path(_)).map{ path => + val fs = path.getFileSystem(serializableConfiguration.value) + val statuses = fs.listStatus(path, pathFilter) + (path.toString, PartitionStatistics(statuses.length, statuses.map(_.getLen).sum)) + } + }.collectAsMap() + } else { + partitionSpecsAndLocs.map { case (_, location) => + val statuses = fs.listStatus(location, pathFilter) + (location.toString, PartitionStatistics(statuses.length, statuses.map(_.getLen).sum)) + }.toMap + } + } + + private def addPartitions( + spark: SparkSession, + table: CatalogTable, + partitionSpecsAndLocs: GenSeq[(TablePartitionSpec, Path)], + partitionStats: GenMap[String, PartitionStatistics]): Unit = { + val total = partitionSpecsAndLocs.length + var done = 0L + // Hive metastore may not have enough memory to handle millions of partitions in single RPC, + // we should split them into smaller batches. Since Hive client is not thread safe, we cannot + // do this in parallel. + val batchSize = 100 + partitionSpecsAndLocs.toIterator.grouped(batchSize).foreach { batch => + val now = System.currentTimeMillis() / 1000 + val parts = batch.map { case (spec, location) => + val params = partitionStats.get(location.toString).map { + case PartitionStatistics(numFiles, totalSize) => + // This two fast stat could prevent Hive metastore to list the files again. + Map(NUM_FILES -> numFiles.toString, + TOTAL_SIZE -> totalSize.toString, + // Workaround a bug in HiveMetastore that try to mutate a read-only parameters. + // see metastore/src/java/org/apache/hadoop/hive/metastore/HiveMetaStore.java + DDL_TIME -> now.toString) + }.getOrElse(Map.empty) + // inherit table storage format (possibly except for location) + CatalogTablePartition( + spec, + table.storage.copy(locationUri = Some(location.toUri.toString)), + params) + } + spark.sessionState.catalog.createPartitions(tableName, parts, ignoreIfExists = true) + done += parts.length + logDebug(s"Recovered ${parts.length} partitions ($done/$total so far)") + } + } +} + + /** * A command that sets the location of a table or a partition. * @@ -448,7 +647,7 @@ case class AlterTableSetLocationCommand( partitionSpec match { case Some(spec) => // Partition spec is specified, so we set the location only for this partition - val part = catalog.getPartition(tableName, spec) + val part = catalog.getPartition(table.identifier, spec) val newPart = if (DDLUtils.isDatasourceTable(table)) { throw new AnalysisException( @@ -457,7 +656,7 @@ case class AlterTableSetLocationCommand( } else { part.copy(storage = part.storage.copy(locationUri = Some(location))) } - catalog.alterPartitions(tableName, Seq(newPart)) + catalog.alterPartitions(table.identifier, Seq(newPart)) case None => // No partition spec is specified, so we set the location for the table itself val newTable = diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/tables.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/tables.scala index b2300b416d34..995feb3b670e 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/tables.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/tables.scala @@ -33,8 +33,10 @@ import org.apache.spark.sql.catalyst.catalog.{CatalogColumn, CatalogTable, Catal import org.apache.spark.sql.catalyst.catalog.CatalogTableType._ import org.apache.spark.sql.catalyst.catalog.CatalogTypes.TablePartitionSpec import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeReference} +import org.apache.spark.sql.catalyst.parser.CatalystSqlParser import org.apache.spark.sql.catalyst.plans.logical.{Command, LogicalPlan, UnaryNode} import org.apache.spark.sql.catalyst.util.quoteIdentifier +import org.apache.spark.sql.execution.command.CreateDataSourceTableUtils._ import org.apache.spark.sql.execution.datasources.PartitioningUtils import org.apache.spark.sql.types._ import org.apache.spark.util.Utils @@ -56,7 +58,12 @@ case class CreateHiveTableAsSelectLogicalPlan( } /** - * A command to create a table with the same definition of the given existing table. + * A command to create a MANAGED table with the same definition of the given existing table. + * In the target table definition, the table comment is always empty but the column comments + * are identical to the ones defined in the source table. + * + * The CatalogTable attributes copied from the source table are storage(inputFormat, outputFormat, + * serde, compressed, properties), schema, provider, partitionColumnNames, bucketSpec. * * The syntax of using this command in SQL is: * {{{ @@ -71,22 +78,54 @@ case class CreateTableLikeCommand( override def run(sparkSession: SparkSession): Seq[Row] = { val catalog = sparkSession.sessionState.catalog - if (!catalog.tableExists(sourceTable)) { - throw new AnalysisException( - s"Source table in CREATE TABLE LIKE does not exist: '$sourceTable'") - } - if (catalog.isTemporaryTable(sourceTable)) { - throw new AnalysisException( - s"Source table in CREATE TABLE LIKE cannot be temporary: '$sourceTable'") + val sourceTableDesc = catalog.getTempViewOrPermanentTableMetadata(sourceTable) + + if (DDLUtils.isDatasourceTable(sourceTableDesc) || + sourceTableDesc.tableType == CatalogTableType.VIEW) { + val outputSchema = + StructType(sourceTableDesc.schema.map { c => + val builder = new MetadataBuilder + c.comment.map(comment => builder.putString("comment", comment)) + StructField( + c.name, + CatalystSqlParser.parseDataType(c.dataType), + c.nullable, + metadata = builder.build()) + }) + val (schema, provider) = if (DDLUtils.isDatasourceTable(sourceTableDesc)) { + (DDLUtils.getSchemaFromTableProperties(sourceTableDesc).getOrElse(outputSchema), + sourceTableDesc.properties(CreateDataSourceTableUtils.DATASOURCE_PROVIDER)) + } else { // VIEW + (outputSchema, sparkSession.sessionState.conf.defaultDataSourceName) + } + createDataSourceTable( + sparkSession = sparkSession, + tableIdent = targetTable, + userSpecifiedSchema = Some(schema), + partitionColumns = Array.empty[String], + bucketSpec = None, + provider = provider, + options = Map("path" -> catalog.defaultTablePath(targetTable)), + isExternal = false) + } else { + val newStorage = + sourceTableDesc.storage.copy( + locationUri = None, + serdeProperties = sourceTableDesc.storage.serdeProperties) + val newTableDesc = + CatalogTable( + identifier = targetTable, + tableType = CatalogTableType.MANAGED, + storage = newStorage, + schema = sourceTableDesc.schema, + partitionColumnNames = sourceTableDesc.partitionColumnNames, + sortColumnNames = sourceTableDesc.sortColumnNames, + bucketColumnNames = sourceTableDesc.bucketColumnNames, + numBuckets = sourceTableDesc.numBuckets) + + catalog.createTable(newTableDesc, ifNotExists) } - val tableToCreate = catalog.getTableMetadata(sourceTable).copy( - identifier = targetTable, - tableType = CatalogTableType.MANAGED, - createTime = System.currentTimeMillis, - lastAccessTime = -1).withNewStorage(locationUri = None) - - catalog.createTable(tableToCreate, ifNotExists) Seq.empty[Row] } } @@ -145,13 +184,13 @@ case class AlterTableRenameCommand( override def run(sparkSession: SparkSession): Seq[Row] = { val catalog = sparkSession.sessionState.catalog - DDLUtils.verifyAlterTableType(catalog, oldName, isView) // If this is a temp view, just rename the view. // Otherwise, if this is a real table, we also need to uncache and invalidate the table. - val isTemporary = catalog.isTemporaryTable(oldName) - if (isTemporary) { + if (catalog.isTemporaryTable(oldName)) { catalog.renameTable(oldName, newName) } else { + val table = catalog.getTableMetadata(oldName) + DDLUtils.verifyAlterTableType(catalog, table.identifier, isView) // If an exception is thrown here we can just assume the table is uncached; // this can happen with Hive tables when the underlying catalog is in-memory. val wasCached = Try(sparkSession.catalog.isCached(oldName.unquotedString)).getOrElse(false) @@ -163,7 +202,6 @@ case class AlterTableRenameCommand( } } // For datasource tables, we also need to update the "path" serde property - val table = catalog.getTableMetadata(oldName) if (DDLUtils.isDatasourceTable(table) && table.tableType == CatalogTableType.MANAGED) { val newPath = catalog.defaultTablePath(newName) val newTable = table.withNewStorage( @@ -201,37 +239,34 @@ case class LoadDataCommand( override def run(sparkSession: SparkSession): Seq[Row] = { val catalog = sparkSession.sessionState.catalog - if (!catalog.tableExists(table)) { - throw new AnalysisException(s"Target table in LOAD DATA does not exist: '$table'") - } - val targetTable = catalog.getTableMetadataOption(table).getOrElse { - throw new AnalysisException(s"Target table in LOAD DATA cannot be temporary: '$table'") - } + val targetTable = catalog.getTableMetadata(table) + val tableIdentwithDB = targetTable.identifier.quotedString if (DDLUtils.isDatasourceTable(targetTable)) { - throw new AnalysisException(s"LOAD DATA is not supported for datasource tables: '$table'") + throw new AnalysisException( + s"LOAD DATA is not supported for datasource tables: '$tableIdentwithDB'") } if (targetTable.partitionColumnNames.nonEmpty) { if (partition.isEmpty) { - throw new AnalysisException(s"LOAD DATA target table '$table' is partitioned, " + + throw new AnalysisException(s"LOAD DATA target table '$tableIdentwithDB' is partitioned, " + s"but no partition spec is provided") } if (targetTable.partitionColumnNames.size != partition.get.size) { - throw new AnalysisException(s"LOAD DATA target table '$table' is partitioned, " + + throw new AnalysisException(s"LOAD DATA target table '$tableIdentwithDB' is partitioned, " + s"but number of columns in provided partition spec (${partition.get.size}) " + s"do not match number of partitioned columns in table " + s"(s${targetTable.partitionColumnNames.size})") } partition.get.keys.foreach { colName => if (!targetTable.partitionColumnNames.contains(colName)) { - throw new AnalysisException(s"LOAD DATA target table '$table' is partitioned, " + - s"but the specified partition spec refers to a column that is not partitioned: " + - s"'$colName'") + throw new AnalysisException(s"LOAD DATA target table '$tableIdentwithDB' is " + + s"partitioned, but the specified partition spec refers to a column that is " + + s"not partitioned: '$colName'") } } } else { if (partition.nonEmpty) { - throw new AnalysisException(s"LOAD DATA target table '$table' is not partitioned, " + - s"but a partition spec was provided.") + throw new AnalysisException(s"LOAD DATA target table '$tableIdentwithDB' is not " + + s"partitioned, but a partition spec was provided.") } } @@ -320,32 +355,26 @@ case class TruncateTableCommand( override def run(spark: SparkSession): Seq[Row] = { val catalog = spark.sessionState.catalog - if (!catalog.tableExists(tableName)) { - throw new AnalysisException(s"Table '$tableName' in TRUNCATE TABLE does not exist.") - } - if (catalog.isTemporaryTable(tableName)) { - throw new AnalysisException( - s"Operation not allowed: TRUNCATE TABLE on temporary tables: '$tableName'") - } val table = catalog.getTableMetadata(tableName) + val tableIdentwithDB = table.identifier.quotedString if (table.tableType == CatalogTableType.EXTERNAL) { throw new AnalysisException( - s"Operation not allowed: TRUNCATE TABLE on external tables: '$tableName'") + s"Operation not allowed: TRUNCATE TABLE on external tables: '$tableIdentwithDB'") } if (table.tableType == CatalogTableType.VIEW) { throw new AnalysisException( - s"Operation not allowed: TRUNCATE TABLE on views: '$tableName'") + s"Operation not allowed: TRUNCATE TABLE on views: '$tableIdentwithDB'") } val isDatasourceTable = DDLUtils.isDatasourceTable(table) if (isDatasourceTable && partitionSpec.isDefined) { throw new AnalysisException( s"Operation not allowed: TRUNCATE TABLE ... PARTITION is not supported " + - s"for tables created using the data sources API: '$tableName'") + s"for tables created using the data sources API: '$tableIdentwithDB'") } if (table.partitionColumnNames.isEmpty && partitionSpec.isDefined) { throw new AnalysisException( s"Operation not allowed: TRUNCATE TABLE ... PARTITION is not supported " + - s"for tables that are not partitioned: '$tableName'") + s"for tables that are not partitioned: '$tableIdentwithDB'") } val locations = if (isDatasourceTable) { @@ -353,7 +382,7 @@ case class TruncateTableCommand( } else if (table.partitionColumnNames.isEmpty) { Seq(table.storage.locationUri) } else { - catalog.listPartitions(tableName, partitionSpec).map(_.storage.locationUri) + catalog.listPartitions(table.identifier, partitionSpec).map(_.storage.locationUri) } val hadoopConf = spark.sessionState.newHadoopConf() locations.foreach { location => @@ -366,8 +395,8 @@ case class TruncateTableCommand( } catch { case NonFatal(e) => throw new AnalysisException( - s"Failed to truncate table '$tableName' when removing data of the path: $path " + - s"because of ${e.toString}") + s"Failed to truncate table '$tableIdentwithDB' when removing data of the path: " + + s"$path because of ${e.toString}") } } } @@ -376,10 +405,10 @@ case class TruncateTableCommand( spark.sessionState.refreshTable(tableName.unquotedString) // Also try to drop the contents of the table from the columnar cache try { - spark.sharedState.cacheManager.uncacheQuery(spark.table(tableName.quotedString)) + spark.sharedState.cacheManager.uncacheQuery(spark.table(table.identifier)) } catch { case NonFatal(e) => - log.warn(s"Exception when attempting to uncache table '$tableName'", e) + log.warn(s"Exception when attempting to uncache table '$tableIdentwithDB'", e) } Seq.empty[Row] } @@ -436,11 +465,12 @@ case class DescribeTableCommand(table: TableIdentifier, isExtended: Boolean, isF private def describePartitionInfo(table: CatalogTable, buffer: ArrayBuffer[Row]): Unit = { if (DDLUtils.isDatasourceTable(table)) { - val partCols = DDLUtils.getPartitionColumnsFromTableProperties(table) - if (partCols.nonEmpty) { + val userSpecifiedSchema = DDLUtils.getSchemaFromTableProperties(table) + val partColNames = DDLUtils.getPartitionColumnsFromTableProperties(table) + for (schema <- userSpecifiedSchema if partColNames.nonEmpty) { append(buffer, "# Partition Information", "", "") - append(buffer, s"# ${output.head.name}", "", "") - partCols.foreach(col => append(buffer, col, "", "")) + append(buffer, s"# ${output.head.name}", output(1).name, output(2).name) + describeSchema(StructType(partColNames.map(schema(_))), buffer) } } else { if (table.partitionColumns.nonEmpty) { @@ -527,7 +557,7 @@ case class DescribeTableCommand(table: TableIdentifier, isExtended: Boolean, isF private def describeSchema(schema: StructType, buffer: ArrayBuffer[Row]): Unit = { schema.foreach { column => val comment = - if (column.metadata.contains("comment")) column.metadata.getString("comment") else "" + if (column.metadata.contains("comment")) column.metadata.getString("comment") else null append(buffer, column.name, column.dataType.simpleString, comment) } } @@ -622,14 +652,16 @@ case class ShowTablePropertiesCommand(table: TableIdentifier, propertyKey: Optio * SHOW COLUMNS (FROM | IN) table_identifier [(FROM | IN) database]; * }}} */ -case class ShowColumnsCommand(table: TableIdentifier) extends RunnableCommand { +case class ShowColumnsCommand(tableName: TableIdentifier) extends RunnableCommand { // The result of SHOW COLUMNS has one column called 'result' override val output: Seq[Attribute] = { AttributeReference("result", StringType, nullable = false)() :: Nil } override def run(sparkSession: SparkSession): Seq[Row] = { - sparkSession.sessionState.catalog.getTableMetadata(table).schema.map { c => + val catalog = sparkSession.sessionState.catalog + val table = catalog.getTempViewOrPermanentTableMetadata(tableName) + table.schema.map { c => Row(c.name) } } @@ -651,7 +683,7 @@ case class ShowColumnsCommand(table: TableIdentifier) extends RunnableCommand { * }}} */ case class ShowPartitionsCommand( - table: TableIdentifier, + tableName: TableIdentifier, spec: Option[TablePartitionSpec]) extends RunnableCommand { // The result of SHOW PARTITIONS has one column called 'result' override val output: Seq[Attribute] = { @@ -666,34 +698,27 @@ case class ShowPartitionsCommand( override def run(sparkSession: SparkSession): Seq[Row] = { val catalog = sparkSession.sessionState.catalog - - if (catalog.isTemporaryTable(table)) { - throw new AnalysisException( - s"SHOW PARTITIONS is not allowed on a temporary table: ${table.unquotedString}") - } - - val tab = catalog.getTableMetadata(table) + val table = catalog.getTableMetadata(tableName) + val tableIdentWithDB = table.identifier.quotedString /** * Validate and throws an [[AnalysisException]] exception under the following conditions: * 1. If the table is not partitioned. * 2. If it is a datasource table. - * 3. If it is a view or index table. + * 3. If it is a view. */ - if (tab.tableType == VIEW || - tab.tableType == INDEX) { - throw new AnalysisException( - s"SHOW PARTITIONS is not allowed on a view or index table: ${tab.qualifiedName}") + if (table.tableType == VIEW) { + throw new AnalysisException(s"SHOW PARTITIONS is not allowed on a view: $tableIdentWithDB") } - if (!DDLUtils.isTablePartitioned(tab)) { + if (!DDLUtils.isTablePartitioned(table)) { throw new AnalysisException( - s"SHOW PARTITIONS is not allowed on a table that is not partitioned: ${tab.qualifiedName}") + s"SHOW PARTITIONS is not allowed on a table that is not partitioned: $tableIdentWithDB") } - if (DDLUtils.isDatasourceTable(tab)) { + if (DDLUtils.isDatasourceTable(table)) { throw new AnalysisException( - s"SHOW PARTITIONS is not allowed on a datasource table: ${tab.qualifiedName}") + s"SHOW PARTITIONS is not allowed on a datasource table: $tableIdentWithDB") } /** @@ -702,7 +727,7 @@ case class ShowPartitionsCommand( * thrown if the partitioning spec is invalid. */ if (spec.isDefined) { - val badColumns = spec.get.keySet.filterNot(tab.partitionColumns.map(_.name).contains) + val badColumns = spec.get.keySet.filterNot(table.partitionColumns.map(_.name).contains) if (badColumns.nonEmpty) { val badCols = badColumns.mkString("[", ", ", "]") throw new AnalysisException( @@ -710,8 +735,8 @@ case class ShowPartitionsCommand( } } - val partNames = catalog.listPartitions(table, spec).map { p => - getPartName(p.spec, tab.partitionColumnNames) + val partNames = catalog.listPartitions(tableName, spec).map { p => + getPartName(p.spec, table.partitionColumnNames) } partNames.map(Row(_)) @@ -725,16 +750,6 @@ case class ShowCreateTableCommand(table: TableIdentifier) extends RunnableComman override def run(sparkSession: SparkSession): Seq[Row] = { val catalog = sparkSession.sessionState.catalog - - if (catalog.isTemporaryTable(table)) { - throw new AnalysisException( - s"SHOW CREATE TABLE cannot be applied to temporary table") - } - - if (!catalog.tableExists(table)) { - throw new AnalysisException(s"Table $table doesn't exist") - } - val tableMetadata = catalog.getTableMetadata(table) val stmt = if (DDLUtils.isDatasourceTable(tableMetadata)) { @@ -765,7 +780,6 @@ case class ShowCreateTableCommand(table: TableIdentifier) extends RunnableComman case EXTERNAL => " EXTERNAL TABLE" case VIEW => " VIEW" case MANAGED => " TABLE" - case INDEX => reportUnsupportedError(Seq("index table")) } builder ++= s"CREATE$tableTypeString ${table.quotedString}" diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/views.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/views.scala index 6533d796e806..125b3d1b0587 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/views.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/views.scala @@ -22,7 +22,7 @@ import scala.util.control.NonFatal import org.apache.spark.sql.{AnalysisException, Row, SparkSession} import org.apache.spark.sql.catalyst.{SQLBuilder, TableIdentifier} import org.apache.spark.sql.catalyst.catalog.{CatalogColumn, CatalogTable, CatalogTableType} -import org.apache.spark.sql.catalyst.expressions.{Alias, Attribute} +import org.apache.spark.sql.catalyst.expressions.Alias import org.apache.spark.sql.catalyst.plans.QueryPlan import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, Project} @@ -56,8 +56,6 @@ case class CreateViewCommand( // TODO: Note that this class can NOT canonicalize the view SQL string entirely, which is // different from Hive and may not work for some cases like create view on self join. - override def output: Seq[Attribute] = Seq.empty[Attribute] - require(tableDesc.tableType == CatalogTableType.VIEW, "The type of the table to created with CREATE VIEW must be 'CatalogTableType.VIEW'.") if (!isTemporary) { @@ -191,8 +189,7 @@ case class CreateViewCommand( sparkSession.sql(viewSQL).queryExecution.assertAnalyzed() } catch { case NonFatal(e) => - throw new RuntimeException( - "Failed to analyze the canonicalized SQL. It is possible there is a bug in Spark.", e) + throw new RuntimeException(s"Failed to analyze the canonicalized SQL: $viewSQL", e) } val viewSchema: Seq[CatalogColumn] = { @@ -213,3 +210,68 @@ case class CreateViewCommand( /** Escape backtick with double-backtick in column name and wrap it with backtick. */ private def quote(name: String) = s"`${name.replaceAll("`", "``")}`" } + +/** + * Alter a view with given query plan. If the view name contains database prefix, this command will + * alter a permanent view matching the given name, or throw an exception if view not exist. Else, + * this command will try to alter a temporary view first, if view not exist, try permanent view + * next, if still not exist, throw an exception. + * + * @param tableDesc the catalog table + * @param query the logical plan that represents the view; this is used to generate a canonicalized + * version of the SQL that can be saved in the catalog. + */ +case class AlterViewAsCommand( + tableDesc: CatalogTable, + query: LogicalPlan) extends RunnableCommand { + + override protected def innerChildren: Seq[QueryPlan[_]] = Seq(query) + + override def run(session: SparkSession): Seq[Row] = { + // If the plan cannot be analyzed, throw an exception and don't proceed. + val qe = session.sessionState.executePlan(query) + qe.assertAnalyzed() + val analyzedPlan = qe.analyzed + + if (session.sessionState.catalog.isTemporaryTable(tableDesc.identifier)) { + session.sessionState.catalog.createTempView( + tableDesc.identifier.table, + analyzedPlan, + overrideIfExists = true) + } else { + alterPermanentView(session, analyzedPlan) + } + + Seq.empty[Row] + } + + private def alterPermanentView(session: SparkSession, analyzedPlan: LogicalPlan): Unit = { + val viewMeta = session.sessionState.catalog.getTableMetadata(tableDesc.identifier) + if (viewMeta.tableType != CatalogTableType.VIEW) { + throw new AnalysisException(s"${viewMeta.identifier} is not a view.") + } + + val viewSQL: String = new SQLBuilder(analyzedPlan).toSQL + // Validate the view SQL - make sure we can parse it and analyze it. + // If we cannot analyze the generated query, there is probably a bug in SQL generation. + try { + session.sql(viewSQL).queryExecution.assertAnalyzed() + } catch { + case NonFatal(e) => + throw new RuntimeException(s"Failed to analyze the canonicalized SQL: $viewSQL", e) + } + + val viewSchema: Seq[CatalogColumn] = { + analyzedPlan.output.map { a => + CatalogColumn(a.name, a.dataType.catalogString) + } + } + + val updatedViewMeta = viewMeta.copy( + schema = viewSchema, + viewOriginalText = tableDesc.viewOriginalText, + viewText = Some(viewSQL)) + + session.sessionState.catalog.alterTable(updatedViewMeta) + } +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSource.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSource.scala index f572b93991e0..ee37390c91dd 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSource.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSource.scala @@ -17,7 +17,7 @@ package org.apache.spark.sql.execution.datasources -import java.util.ServiceLoader +import java.util.{ServiceConfigurationError, ServiceLoader} import scala.collection.JavaConverters._ import scala.language.{existentials, implicitConversions} @@ -30,6 +30,7 @@ import org.apache.spark.deploy.SparkHadoopUtil import org.apache.spark.internal.Logging import org.apache.spark.sql._ import org.apache.spark.sql.catalyst.analysis.UnresolvedAttribute +import org.apache.spark.sql.catalyst.expressions.Attribute import org.apache.spark.sql.execution.datasources.csv.CSVFileFormat import org.apache.spark.sql.execution.datasources.jdbc.JdbcRelationProvider import org.apache.spark.sql.execution.datasources.json.JsonFileFormat @@ -123,50 +124,64 @@ case class DataSource( val loader = Utils.getContextOrSparkClassLoader val serviceLoader = ServiceLoader.load(classOf[DataSourceRegister], loader) - serviceLoader.asScala.filter(_.shortName().equalsIgnoreCase(provider)).toList match { - // the provider format did not match any given registered aliases - case Nil => - try { - Try(loader.loadClass(provider)).orElse(Try(loader.loadClass(provider2))) match { - case Success(dataSource) => - // Found the data source using fully qualified path - dataSource - case Failure(error) => - if (provider.toLowerCase == "orc" || + try { + serviceLoader.asScala.filter(_.shortName().equalsIgnoreCase(provider)).toList match { + // the provider format did not match any given registered aliases + case Nil => + try { + Try(loader.loadClass(provider)).orElse(Try(loader.loadClass(provider2))) match { + case Success(dataSource) => + // Found the data source using fully qualified path + dataSource + case Failure(error) => + if (provider.toLowerCase == "orc" || provider.startsWith("org.apache.spark.sql.hive.orc")) { - throw new AnalysisException( - "The ORC data source must be used with Hive support enabled") - } else if (provider.toLowerCase == "avro" || + throw new AnalysisException( + "The ORC data source must be used with Hive support enabled") + } else if (provider.toLowerCase == "avro" || provider == "com.databricks.spark.avro") { - throw new AnalysisException( - s"Failed to find data source: ${provider.toLowerCase}. Please use Spark " + - "package http://spark-packages.org/package/databricks/spark-avro") + throw new AnalysisException( + s"Failed to find data source: ${provider.toLowerCase}. Please find an Avro " + + "package at " + + "https://cwiki.apache.org/confluence/display/SPARK/Third+Party+Projects") + } else { + throw new ClassNotFoundException( + s"Failed to find data source: $provider. Please find packages at " + + "https://cwiki.apache.org/confluence/display/SPARK/Third+Party+Projects", + error) + } + } + } catch { + case e: NoClassDefFoundError => // This one won't be caught by Scala NonFatal + // NoClassDefFoundError's class name uses "/" rather than "." for packages + val className = e.getMessage.replaceAll("/", ".") + if (spark2RemovedClasses.contains(className)) { + throw new ClassNotFoundException(s"$className was removed in Spark 2.0. " + + "Please check if your library is compatible with Spark 2.0", e) } else { - throw new ClassNotFoundException( - s"Failed to find data source: $provider. Please find packages at " + - "http://spark-packages.org", - error) + throw e } } - } catch { - case e: NoClassDefFoundError => // This one won't be caught by Scala NonFatal - // NoClassDefFoundError's class name uses "/" rather than "." for packages - val className = e.getMessage.replaceAll("/", ".") - if (spark2RemovedClasses.contains(className)) { - throw new ClassNotFoundException(s"$className was removed in Spark 2.0. " + - "Please check if your library is compatible with Spark 2.0", e) - } else { - throw e - } + case head :: Nil => + // there is exactly one registered alias + head.getClass + case sources => + // There are multiple registered aliases for the input + sys.error(s"Multiple sources found for $provider " + + s"(${sources.map(_.getClass.getName).mkString(", ")}), " + + "please specify the fully qualified class name.") + } + } catch { + case e: ServiceConfigurationError if e.getCause.isInstanceOf[NoClassDefFoundError] => + // NoClassDefFoundError's class name uses "/" rather than "." for packages + val className = e.getCause.getMessage.replaceAll("/", ".") + if (spark2RemovedClasses.contains(className)) { + throw new ClassNotFoundException(s"Detected an incompatible DataSourceRegister. " + + "Please remove the incompatible library from classpath or upgrade it. " + + s"Error: ${e.getMessage}", e) + } else { + throw e } - case head :: Nil => - // there is exactly one registered alias - head.getClass - case sources => - // There are multiple registered aliases for the input - sys.error(s"Multiple sources found for $provider " + - s"(${sources.map(_.getClass.getName).mkString(", ")}), " + - "please specify the fully qualified class name.") } } @@ -336,13 +351,12 @@ case class DataSource( } HadoopFsRelation( - sparkSession, fileCatalog, partitionSchema = fileCatalog.partitionSpec().partitionColumns, dataSchema = dataSchema, bucketSpec = None, format, - options) + options)(sparkSession) // This is a non-streaming file based datasource. case (format: FileFormat, _) => @@ -400,13 +414,12 @@ case class DataSource( } HadoopFsRelation( - sparkSession, fileCatalog, partitionSchema = fileCatalog.partitionSpec().partitionColumns, dataSchema = dataSchema.asNullable, bucketSpec = bucketSpec, format, - caseInsensitiveOptions) + caseInsensitiveOptions)(sparkSession) case _ => throw new AnalysisException( @@ -471,13 +484,23 @@ case class DataSource( } } + // SPARK-17230: Resolve the partition columns so InsertIntoHadoopFsRelationCommand does + // not need to have the query as child, to avoid to analyze an optimized query, + // because InsertIntoHadoopFsRelationCommand will be optimized first. + val columns = partitionColumns.map { name => + val plan = data.logicalPlan + plan.resolve(name :: Nil, data.sparkSession.sessionState.analyzer.resolver).getOrElse { + throw new AnalysisException( + s"Unable to resolve ${name} given [${plan.output.map(_.name).mkString(", ")}]") + }.asInstanceOf[Attribute] + } // For partitioned relation r, r.schema's column ordering can be different from the column // ordering of data.logicalPlan (partition columns are all moved after data column). This // will be adjusted within InsertIntoHadoopFsRelation. val plan = InsertIntoHadoopFsRelationCommand( outputPath, - partitionColumns.map(UnresolvedAttribute.quoted), + columns, bucketSpec, format, () => Unit, // No existing table needs to be refreshed. diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategy.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategy.scala index 27133f0a43f2..277969465e42 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategy.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategy.scala @@ -43,7 +43,7 @@ import org.apache.spark.unsafe.types.UTF8String * Replaces generic operations with specific variants that are designed to work with Spark * SQL Data Sources. */ -private[sql] case class DataSourceAnalysis(conf: CatalystConf) extends Rule[LogicalPlan] { +case class DataSourceAnalysis(conf: CatalystConf) extends Rule[LogicalPlan] { def resolver: Resolver = { if (conf.caseSensitiveAnalysis) { @@ -54,7 +54,7 @@ private[sql] case class DataSourceAnalysis(conf: CatalystConf) extends Rule[Logi } // The access modifier is used to expose this method to tests. - private[sql] def convertStaticPartitions( + def convertStaticPartitions( sourceAttributes: Seq[Attribute], providedPartitions: Map[String, Option[String]], targetAttributes: Seq[Attribute], @@ -187,7 +187,7 @@ private[sql] case class DataSourceAnalysis(conf: CatalystConf) extends Rule[Logi InsertIntoHadoopFsRelationCommand( outputPath, - t.partitionSchema.fields.map(_.name).map(UnresolvedAttribute(_)), + query.resolve(t.partitionSchema, t.sparkSession.sessionState.analyzer.resolver), t.bucketSpec, t.fileFormat, () => t.refresh(), @@ -202,7 +202,7 @@ private[sql] case class DataSourceAnalysis(conf: CatalystConf) extends Rule[Logi * Replaces [[SimpleCatalogRelation]] with data source table if its table property contains data * source information. */ -private[sql] class FindDataSourceTable(sparkSession: SparkSession) extends Rule[LogicalPlan] { +class FindDataSourceTable(sparkSession: SparkSession) extends Rule[LogicalPlan] { private def readDataSourceTable(sparkSession: SparkSession, table: CatalogTable): LogicalPlan = { val userSpecifiedSchema = DDLUtils.getSchemaFromTableProperties(table) @@ -242,7 +242,7 @@ private[sql] class FindDataSourceTable(sparkSession: SparkSession) extends Rule[ /** * A Strategy for planning scans over data sources defined using the sources API. */ -private[sql] object DataSourceStrategy extends Strategy with Logging { +object DataSourceStrategy extends Strategy with Logging { def apply(plan: LogicalPlan): Seq[execution.SparkPlan] = plan match { case PhysicalOperation(projects, filters, l @ LogicalRelation(t: CatalystScan, _, _)) => pruneFilterProjectRaw( @@ -347,13 +347,16 @@ private[sql] object DataSourceStrategy extends Strategy with Logging { // `Filter`s or cannot be handled by `relation`. val filterCondition = unhandledPredicates.reduceLeftOption(expressions.And) + // These metadata values make scan plans uniquely identifiable for equality checking. + // TODO(SPARK-17701) using strings for equality checking is brittle val metadata: Map[String, String] = { val pairs = ArrayBuffer.empty[(String, String)] if (pushedFilters.nonEmpty) { pairs += (PUSHED_FILTERS -> pushedFilters.mkString("[", ", ", "]")) } - + pairs += ("ReadSchema" -> + StructType.fromAttributes(projects.map(_.toAttribute)).catalogString) pairs.toMap } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileSourceStrategy.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileSourceStrategy.scala index 13a86bfb3896..74510f9c08b6 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileSourceStrategy.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileSourceStrategy.scala @@ -54,7 +54,7 @@ import org.apache.spark.sql.execution.SparkPlan * is under the threshold with the addition of the next file, add it. If not, open a new bucket * and add it. Proceed to the next file. */ -private[sql] object FileSourceStrategy extends Strategy with Logging { +object FileSourceStrategy extends Strategy with Logging { def apply(plan: LogicalPlan): Seq[SparkPlan] = plan match { case PhysicalOperation(projects, filters, l @ LogicalRelation(files: HadoopFsRelation, _, table)) => @@ -202,7 +202,9 @@ private[sql] object FileSourceStrategy extends Strategy with Logging { partitions } + // These metadata values make scan plans uniquely identifiable for equality checking. val meta = Map( + "PartitionFilters" -> partitionKeyFilters.mkString("[", ", ", "]"), "Format" -> files.fileFormat.toString, "ReadSchema" -> prunedDataSchema.simpleString, PUSHED_FILTERS -> pushedDownFilters.mkString("[", ", ", "]"), diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/HadoopFileLinesReader.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/HadoopFileLinesReader.scala index 18f9b55895a6..83cf26c63a17 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/HadoopFileLinesReader.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/HadoopFileLinesReader.scala @@ -17,6 +17,7 @@ package org.apache.spark.sql.execution.datasources +import java.io.Closeable import java.net.URI import org.apache.hadoop.conf.Configuration @@ -30,7 +31,8 @@ import org.apache.hadoop.mapreduce.task.TaskAttemptContextImpl * An adaptor from a [[PartitionedFile]] to an [[Iterator]] of [[Text]], which are all of the lines * in that file. */ -class HadoopFileLinesReader(file: PartitionedFile, conf: Configuration) extends Iterator[Text] { +class HadoopFileLinesReader( + file: PartitionedFile, conf: Configuration) extends Iterator[Text] with Closeable { private val iterator = { val fileSplit = new FileSplit( new Path(new URI(file.filePath)), @@ -48,4 +50,6 @@ class HadoopFileLinesReader(file: PartitionedFile, conf: Configuration) extends override def hasNext: Boolean = iterator.hasNext override def next(): Text = iterator.next() + + override def close(): Unit = iterator.close() } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/InsertIntoDataSourceCommand.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/InsertIntoDataSourceCommand.scala index 8549ae96e2f3..b2ff68a833fe 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/InsertIntoDataSourceCommand.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/InsertIntoDataSourceCommand.scala @@ -27,7 +27,7 @@ import org.apache.spark.sql.sources.InsertableRelation /** * Inserts the results of `query` in to a relation that extends [[InsertableRelation]]. */ -private[sql] case class InsertIntoDataSourceCommand( +case class InsertIntoDataSourceCommand( logicalRelation: LogicalRelation, query: LogicalPlan, overwrite: Boolean) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/InsertIntoHadoopFsRelationCommand.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/InsertIntoHadoopFsRelationCommand.scala index 1426dcf4697f..518b02b71875 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/InsertIntoHadoopFsRelationCommand.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/InsertIntoHadoopFsRelationCommand.scala @@ -55,7 +55,7 @@ import org.apache.spark.sql.internal.SQLConf * 4. If all tasks are committed, commit the job, otherwise aborts the job; If any exception is * thrown during job commitment, also aborts the job. */ -private[sql] case class InsertIntoHadoopFsRelationCommand( +case class InsertIntoHadoopFsRelationCommand( outputPath: Path, partitionColumns: Seq[Attribute], bucketSpec: Option[BucketSpec], @@ -66,7 +66,7 @@ private[sql] case class InsertIntoHadoopFsRelationCommand( mode: SaveMode) extends RunnableCommand { - override def children: Seq[LogicalPlan] = query :: Nil + override protected def innerChildren: Seq[LogicalPlan] = query :: Nil override def run(sparkSession: SparkSession): Seq[Row] = { // Most formats don't do well with duplicate columns, so lets not allow that diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/LogicalRelation.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/LogicalRelation.scala index 90711f2b1dde..2a8e147011f5 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/LogicalRelation.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/LogicalRelation.scala @@ -79,11 +79,18 @@ case class LogicalRelation( /** Used to lookup original attribute capitalization */ val attributeMap: AttributeMap[AttributeReference] = AttributeMap(output.map(o => (o, o))) - def newInstance(): this.type = + /** + * Returns a new instance of this LogicalRelation. According to the semantics of + * MultiInstanceRelation, this method returns a copy of this object with + * unique expression ids. We respect the `expectedOutputAttributes` and create + * new instances of attributes in it. + */ + override def newInstance(): this.type = { LogicalRelation( relation, - expectedOutputAttributes, + expectedOutputAttributes.map(_.map(_.newInstance())), metastoreTableIdentifier).asInstanceOf[this.type] + } override def refresh(): Unit = relation match { case fs: HadoopFsRelation => fs.refresh() diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/PartitioningAwareFileCatalog.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/PartitioningAwareFileCatalog.scala index 811e96c99a96..2130c27ebd8d 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/PartitioningAwareFileCatalog.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/PartitioningAwareFileCatalog.scala @@ -76,7 +76,15 @@ abstract class PartitioningAwareFileCatalog( paths.flatMap { path => // Make the path qualified (consistent with listLeafFiles and listLeafFilesInParallel). val fs = path.getFileSystem(hadoopConf) - val qualifiedPath = fs.makeQualified(path) + val qualifiedPathPre = fs.makeQualified(path) + val qualifiedPath: Path = if (qualifiedPathPre.isRoot && !qualifiedPathPre.isAbsolute) { + // SPARK-17613: Always append `Path.SEPARATOR` to the end of parent directories, + // because the `leafFile.getParent` would have returned an absolute path with the + // separator at the end. + new Path(qualifiedPathPre, Path.SEPARATOR) + } else { + qualifiedPathPre + } // There are three cases possible with each path // 1. The path is a directory and has children files in it. Then it must be present in @@ -204,6 +212,6 @@ abstract class PartitioningAwareFileCatalog( private def isDataPath(path: Path): Boolean = { val name = path.getName - !(name.startsWith("_") || name.startsWith(".")) + !((name.startsWith("_") && !name.contains("=")) || name.startsWith(".")) } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/PartitioningUtils.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/PartitioningUtils.scala index c3561099d684..504464216e5a 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/PartitioningUtils.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/PartitioningUtils.scala @@ -31,6 +31,7 @@ import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions.{Cast, Literal} import org.apache.spark.sql.types._ +// TODO: We should tighten up visibility of the classes here once we clean up Hive coupling. object PartitionDirectory { def apply(values: InternalRow, path: String): PartitionDirectory = @@ -41,22 +42,23 @@ object PartitionDirectory { * Holds a directory in a partitioned collection of files as well as as the partition values * in the form of a Row. Before scanning, the files at `path` need to be enumerated. */ -private[sql] case class PartitionDirectory(values: InternalRow, path: Path) +case class PartitionDirectory(values: InternalRow, path: Path) -private[sql] case class PartitionSpec( +case class PartitionSpec( partitionColumns: StructType, partitions: Seq[PartitionDirectory]) -private[sql] object PartitionSpec { +object PartitionSpec { val emptySpec = PartitionSpec(StructType(Seq.empty[StructField]), Seq.empty[PartitionDirectory]) } -private[sql] object PartitioningUtils { +object PartitioningUtils { // This duplicates default value of Hive `ConfVars.DEFAULTPARTITIONNAME`, since sql/core doesn't // depend on Hive. - private[sql] val DEFAULT_PARTITION_NAME = "__HIVE_DEFAULT_PARTITION__" + val DEFAULT_PARTITION_NAME = "__HIVE_DEFAULT_PARTITION__" - private[sql] case class PartitionValues(columnNames: Seq[String], literals: Seq[Literal]) { + private[datasources] case class PartitionValues(columnNames: Seq[String], literals: Seq[Literal]) + { require(columnNames.size == literals.size) } @@ -83,7 +85,7 @@ private[sql] object PartitioningUtils { * path = "hdfs://:/path/to/partition/a=2/b=world/c=6.28"))) * }}} */ - private[sql] def parsePartitions( + private[datasources] def parsePartitions( paths: Seq[Path], defaultPartitionName: String, typeInference: Boolean, @@ -166,7 +168,7 @@ private[sql] object PartitioningUtils { * hdfs://:/path/to/partition * }}} */ - private[sql] def parsePartition( + private[datasources] def parsePartition( path: Path, defaultPartitionName: String, typeInference: Boolean, @@ -249,7 +251,7 @@ private[sql] object PartitioningUtils { * DoubleType -> StringType * }}} */ - private[sql] def resolvePartitions( + def resolvePartitions( pathsWithPartitionValues: Seq[(Path, PartitionValues)]): Seq[PartitionValues] = { if (pathsWithPartitionValues.isEmpty) { Seq.empty @@ -275,7 +277,7 @@ private[sql] object PartitioningUtils { } } - private[sql] def listConflictingPartitionColumns( + private[datasources] def listConflictingPartitionColumns( pathWithPartitionValues: Seq[(Path, PartitionValues)]): String = { val distinctPartColNames = pathWithPartitionValues.map(_._2.columnNames).distinct @@ -308,7 +310,7 @@ private[sql] object PartitioningUtils { * [[IntegerType]], [[LongType]], [[DoubleType]], [[DecimalType.SYSTEM_DEFAULT]], and * [[StringType]]. */ - private[sql] def inferPartitionColumnValue( + private[datasources] def inferPartitionColumnValue( raw: String, defaultPartitionName: String, typeInference: Boolean): Literal = { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/RecordReaderIterator.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/RecordReaderIterator.scala index f03ae94d5583..938af25a9684 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/RecordReaderIterator.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/RecordReaderIterator.scala @@ -17,6 +17,8 @@ package org.apache.spark.sql.execution.datasources +import java.io.Closeable + import org.apache.hadoop.mapreduce.RecordReader import org.apache.spark.sql.catalyst.InternalRow @@ -27,7 +29,8 @@ import org.apache.spark.sql.catalyst.InternalRow * Note that this returns [[Object]]s instead of [[InternalRow]] because we rely on erasure to pass * column batches by pretending they are rows. */ -class RecordReaderIterator[T](rowReader: RecordReader[_, T]) extends Iterator[T] { +class RecordReaderIterator[T]( + private[this] var rowReader: RecordReader[_, T]) extends Iterator[T] with Closeable { private[this] var havePair = false private[this] var finished = false @@ -38,7 +41,7 @@ class RecordReaderIterator[T](rowReader: RecordReader[_, T]) extends Iterator[T] // Close and release the reader here; close() will also be called when the task // completes, but for tasks that read from many files, it helps to release the // resources early. - rowReader.close() + close() } havePair = !finished } @@ -52,4 +55,18 @@ class RecordReaderIterator[T](rowReader: RecordReader[_, T]) extends Iterator[T] havePair = false rowReader.getCurrentValue } + + override def close(): Unit = { + if (rowReader != null) { + try { + // Close the reader and release it. Note: it's very important that we don't close the + // reader more than once, since that exposes us to MAPREDUCE-5918 when running against + // older Hadoop 2.x releases. That bug can lead to non-deterministic corruption issues + // when reading compressed input. + rowReader.close() + } finally { + rowReader = null + } + } + } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/WriterContainer.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/WriterContainer.scala index 9a0b46c1a4a5..e25924b1ba1e 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/WriterContainer.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/WriterContainer.scala @@ -40,14 +40,14 @@ import org.apache.spark.util.collection.unsafe.sort.UnsafeExternalSorter /** A container for all the details required when writing to a table. */ -case class WriteRelation( +private[datasources] case class WriteRelation( sparkSession: SparkSession, dataSchema: StructType, path: String, prepareJobForWrite: Job => OutputWriterFactory, bucketSpec: Option[BucketSpec]) -private[sql] abstract class BaseWriterContainer( +private[datasources] abstract class BaseWriterContainer( @transient val relation: WriteRelation, @transient private val job: Job, isAppend: Boolean) @@ -234,7 +234,7 @@ private[sql] abstract class BaseWriterContainer( /** * A writer that writes all of the rows in a partition to a single file. */ -private[sql] class DefaultWriterContainer( +private[datasources] class DefaultWriterContainer( relation: WriteRelation, job: Job, isAppend: Boolean) @@ -293,7 +293,7 @@ private[sql] class DefaultWriterContainer( * done by maintaining a HashMap of open files until `maxFiles` is reached. If this occurs, the * writer externally sorts the remaining rows and then writes out them out one file at a time. */ -private[sql] class DynamicPartitionWriterContainer( +private[datasources] class DynamicPartitionWriterContainer( relation: WriteRelation, job: Job, partitionColumns: Seq[Attribute], diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/bucket.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/bucket.scala index 6008d73717f7..2bafe967993b 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/bucket.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/bucket.scala @@ -31,7 +31,7 @@ private[sql] case class BucketSpec( bucketColumnNames: Seq[String], sortColumnNames: Seq[String]) -private[sql] object BucketingUtils { +object BucketingUtils { // The file name of bucketed data should have 3 parts: // 1. some other information in the head of file name // 2. bucket id part, some numbers, starts with "_" diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVFileFormat.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVFileFormat.scala index 12e19f955caa..107b6007ce46 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVFileFormat.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVFileFormat.scala @@ -25,6 +25,7 @@ import org.apache.hadoop.io.{LongWritable, Text} import org.apache.hadoop.mapred.TextInputFormat import org.apache.hadoop.mapreduce._ +import org.apache.spark.TaskContext import org.apache.spark.rdd.RDD import org.apache.spark.sql.SparkSession import org.apache.spark.sql.catalyst.InternalRow @@ -111,7 +112,9 @@ class CSVFileFormat extends TextBasedFileFormat with DataSourceRegister { (file: PartitionedFile) => { val lineIterator = { val conf = broadcastedHadoopConf.value.value - new HadoopFileLinesReader(file, conf).map { line => + val linesReader = new HadoopFileLinesReader(file, conf) + Option(TaskContext.get()).foreach(_.addTaskCompletionListener(_ => linesReader.close())) + linesReader.map { line => new String(line.getBytes, 0, line.getLength, csvOptions.charset) } } @@ -180,13 +183,18 @@ class CSVFileFormat extends TextBasedFileFormat with DataSourceRegister { } private def verifySchema(schema: StructType): Unit = { - schema.foreach { field => - field.dataType match { - case _: ArrayType | _: MapType | _: StructType => - throw new UnsupportedOperationException( - s"CSV data source does not support ${field.dataType.simpleString} data type.") - case _ => - } + def verifyType(dataType: DataType): Unit = dataType match { + case ByteType | ShortType | IntegerType | LongType | FloatType | + DoubleType | BooleanType | _: DecimalType | TimestampType | + DateType | StringType => + + case udt: UserDefinedType[_] => verifyType(udt.sqlType) + + case _ => + throw new UnsupportedOperationException( + s"CSV data source does not support ${dataType.simpleString} data type.") } + + schema.foreach(field => verifyType(field.dataType)) } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVInferSchema.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVInferSchema.scala index de3d889621b7..3ab775c90923 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVInferSchema.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVInferSchema.scala @@ -139,20 +139,14 @@ private[csv] object CSVInferSchema { } private def tryParseTimestamp(field: String, options: CSVOptions): DataType = { - if (options.dateFormat != null) { - // This case infers a custom `dataFormat` is set. - if ((allCatch opt options.dateFormat.parse(field)).isDefined) { - TimestampType - } else { - tryParseBoolean(field, options) - } - } else { + // This case infers a custom `dataFormat` is set. + if ((allCatch opt options.timestampFormat.parse(field)).isDefined) { + TimestampType + } else if ((allCatch opt DateTimeUtils.stringToTime(field)).isDefined) { // We keep this for backwords competibility. - if ((allCatch opt DateTimeUtils.stringToTime(field)).isDefined) { - TimestampType - } else { - tryParseBoolean(field, options) - } + TimestampType + } else { + tryParseBoolean(field, options) } } @@ -238,59 +232,58 @@ private[csv] object CSVTypeCast { nullable: Boolean = true, options: CSVOptions = CSVOptions()): Any = { - castType match { - case _: ByteType => if (datum == options.nullValue && nullable) null else datum.toByte - case _: ShortType => if (datum == options.nullValue && nullable) null else datum.toShort - case _: IntegerType => if (datum == options.nullValue && nullable) null else datum.toInt - case _: LongType => if (datum == options.nullValue && nullable) null else datum.toLong - case _: FloatType => - if (datum == options.nullValue && nullable) { - null - } else if (datum == options.nanValue) { - Float.NaN - } else if (datum == options.negativeInf) { - Float.NegativeInfinity - } else if (datum == options.positiveInf) { - Float.PositiveInfinity - } else { - Try(datum.toFloat) - .getOrElse(NumberFormat.getInstance(Locale.getDefault).parse(datum).floatValue()) - } - case _: DoubleType => - if (datum == options.nullValue && nullable) { - null - } else if (datum == options.nanValue) { - Double.NaN - } else if (datum == options.negativeInf) { - Double.NegativeInfinity - } else if (datum == options.positiveInf) { - Double.PositiveInfinity - } else { - Try(datum.toDouble) - .getOrElse(NumberFormat.getInstance(Locale.getDefault).parse(datum).doubleValue()) - } - case _: BooleanType => datum.toBoolean - case dt: DecimalType => - if (datum == options.nullValue && nullable) { - null - } else { + if (nullable && datum == options.nullValue) { + null + } else { + castType match { + case _: ByteType => datum.toByte + case _: ShortType => datum.toShort + case _: IntegerType => datum.toInt + case _: LongType => datum.toLong + case _: FloatType => + datum match { + case options.nanValue => Float.NaN + case options.negativeInf => Float.NegativeInfinity + case options.positiveInf => Float.PositiveInfinity + case _ => + Try(datum.toFloat) + .getOrElse(NumberFormat.getInstance(Locale.getDefault).parse(datum).floatValue()) + } + case _: DoubleType => + datum match { + case options.nanValue => Double.NaN + case options.negativeInf => Double.NegativeInfinity + case options.positiveInf => Double.PositiveInfinity + case _ => + Try(datum.toDouble) + .getOrElse(NumberFormat.getInstance(Locale.getDefault).parse(datum).doubleValue()) + } + case _: BooleanType => datum.toBoolean + case dt: DecimalType => val value = new BigDecimal(datum.replaceAll(",", "")) Decimal(value, dt.precision, dt.scale) - } - case _: TimestampType if options.dateFormat != null => - // This one will lose microseconds parts. - // See https://issues.apache.org/jira/browse/SPARK-10681. - options.dateFormat.parse(datum).getTime * 1000L - case _: TimestampType => - // This one will lose microseconds parts. - // See https://issues.apache.org/jira/browse/SPARK-10681. - DateTimeUtils.stringToTime(datum).getTime * 1000L - case _: DateType if options.dateFormat != null => - DateTimeUtils.millisToDays(options.dateFormat.parse(datum).getTime) - case _: DateType => - DateTimeUtils.millisToDays(DateTimeUtils.stringToTime(datum).getTime) - case _: StringType => UTF8String.fromString(datum) - case _ => throw new RuntimeException(s"Unsupported type: ${castType.typeName}") + case _: TimestampType => + // This one will lose microseconds parts. + // See https://issues.apache.org/jira/browse/SPARK-10681. + Try(options.timestampFormat.parse(datum).getTime * 1000L) + .getOrElse { + // If it fails to parse, then tries the way used in 2.0 and 1.x for backwards + // compatibility. + DateTimeUtils.stringToTime(datum).getTime * 1000L + } + case _: DateType => + // This one will lose microseconds parts. + // See https://issues.apache.org/jira/browse/SPARK-10681.x + Try(DateTimeUtils.millisToDays(options.dateFormat.parse(datum).getTime)) + .getOrElse { + // If it fails to parse, then tries the way used in 2.0 and 1.x for backwards + // compatibility. + DateTimeUtils.millisToDays(DateTimeUtils.stringToTime(datum).getTime) + } + case _: StringType => UTF8String.fromString(datum) + case udt: UserDefinedType[_] => castTo(datum, udt.sqlType, nullable, options) + case _ => throw new RuntimeException(s"Unsupported type: ${castType.typeName}") + } } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVOptions.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVOptions.scala index 22fb8163b1c0..364d7c831eb4 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVOptions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVOptions.scala @@ -18,12 +18,13 @@ package org.apache.spark.sql.execution.datasources.csv import java.nio.charset.StandardCharsets -import java.text.SimpleDateFormat + +import org.apache.commons.lang3.time.FastDateFormat import org.apache.spark.internal.Logging import org.apache.spark.sql.execution.datasources.{CompressionCodecs, ParseModes} -private[sql] class CSVOptions(@transient private val parameters: Map[String, String]) +private[csv] class CSVOptions(@transient private val parameters: Map[String, String]) extends Logging with Serializable { private def getChar(paramName: String, default: Char): Char = { @@ -101,11 +102,13 @@ private[sql] class CSVOptions(@transient private val parameters: Map[String, Str name.map(CompressionCodecs.getCodecClassName) } - // Share date format object as it is expensive to parse date pattern. - val dateFormat: SimpleDateFormat = { - val dateFormat = parameters.get("dateFormat") - dateFormat.map(new SimpleDateFormat(_)).orNull - } + // Uses `FastDateFormat` which can be direct replacement for `SimpleDateFormat` and thread-safe. + val dateFormat: FastDateFormat = + FastDateFormat.getInstance(parameters.getOrElse("dateFormat", "yyyy-MM-dd")) + + val timestampFormat: FastDateFormat = + FastDateFormat.getInstance( + parameters.getOrElse("timestampFormat", "yyyy-MM-dd'T'HH:mm:ss.SSSZZ")) val maxColumns = getInt("maxColumns", 20480) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVParser.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVParser.scala index 7929ebbd90f7..0a996547d253 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVParser.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVParser.scala @@ -30,7 +30,7 @@ import org.apache.spark.internal.Logging * @param params Parameters object * @param headers headers for the columns */ -private[sql] abstract class CsvReader(params: CSVOptions, headers: Seq[String]) { +private[csv] abstract class CsvReader(params: CSVOptions, headers: Seq[String]) { protected lazy val parser: CsvParser = { val settings = new CsvParserSettings() @@ -60,7 +60,7 @@ private[sql] abstract class CsvReader(params: CSVOptions, headers: Seq[String]) * @param params Parameters object for configuration * @param headers headers for columns */ -private[sql] class LineCsvWriter(params: CSVOptions, headers: Seq[String]) extends Logging { +private[csv] class LineCsvWriter(params: CSVOptions, headers: Seq[String]) extends Logging { private val writerSettings = new CsvWriterSettings private val format = writerSettings.getFormat diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVRelation.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVRelation.scala index 083ac3350ef0..d0d5ce06cf8b 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVRelation.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVRelation.scala @@ -30,6 +30,7 @@ import org.apache.spark.rdd.RDD import org.apache.spark.sql._ import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions.GenericMutableRow +import org.apache.spark.sql.catalyst.util.DateTimeUtils import org.apache.spark.sql.execution.command.CreateDataSourceTableUtils import org.apache.spark.sql.execution.datasources.{OutputWriter, OutputWriterFactory, PartitionedFile} import org.apache.spark.sql.types._ @@ -159,7 +160,7 @@ object CSVRelation extends Logging { } } -private[sql] class CSVOutputWriterFactory(params: CSVOptions) extends OutputWriterFactory { +private[csv] class CSVOutputWriterFactory(params: CSVOptions) extends OutputWriterFactory { override def newInstance( path: String, bucketId: Option[Int], @@ -170,7 +171,7 @@ private[sql] class CSVOutputWriterFactory(params: CSVOptions) extends OutputWrit } } -private[sql] class CsvOutputWriter( +private[csv] class CsvOutputWriter( path: String, dataSchema: StructType, context: TaskAttemptContext, @@ -179,6 +180,14 @@ private[sql] class CsvOutputWriter( // create the Generator without separator inserted between 2 records private[this] val text = new Text() + // A `ValueConverter` is responsible for converting a value of an `InternalRow` to `String`. + // When the value is null, this converter should not be called. + private type ValueConverter = (InternalRow, Int) => String + + // `ValueConverter`s for all values in the fields of the schema + private val valueConverters: Array[ValueConverter] = + dataSchema.map(_.dataType).map(makeConverter).toArray + private val recordWriter: RecordWriter[NullWritable, Text] = { new TextOutputFormat[NullWritable, Text]() { override def getDefaultWorkFile(context: TaskAttemptContext, extension: String): Path = { @@ -195,18 +204,40 @@ private[sql] class CsvOutputWriter( private var records: Long = 0L private val csvWriter = new LineCsvWriter(params, dataSchema.fieldNames.toSeq) - private def rowToString(row: Seq[Any]): Seq[String] = row.map { field => - if (field != null) { - field.toString - } else { - params.nullValue + private def rowToString(row: InternalRow): Seq[String] = { + var i = 0 + val values = new Array[String](row.numFields) + while (i < row.numFields) { + if (!row.isNullAt(i)) { + values(i) = valueConverters(i).apply(row, i) + } else { + values(i) = params.nullValue + } + i += 1 } + values + } + + private def makeConverter(dataType: DataType): ValueConverter = dataType match { + case DateType => + (row: InternalRow, ordinal: Int) => + params.dateFormat.format(DateTimeUtils.toJavaDate(row.getInt(ordinal))) + + case TimestampType => + (row: InternalRow, ordinal: Int) => + params.timestampFormat.format(DateTimeUtils.toJavaTimestamp(row.getLong(ordinal))) + + case udt: UserDefinedType[_] => makeConverter(udt.sqlType) + + case dt: DataType => + (row: InternalRow, ordinal: Int) => + row.get(ordinal, dt).toString } override def write(row: Row): Unit = throw new UnsupportedOperationException("call writeInternal") override protected[sql] def writeInternal(row: InternalRow): Unit = { - csvWriter.writeRow(rowToString(row.toSeq(dataSchema)), records == 0L && params.headerFlag) + csvWriter.writeRow(rowToString(row), records == 0L && params.headerFlag) records += 1 if (records % FLUSH_BATCH_SIZE == 0) { flush() diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/fileSourceInterfaces.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/fileSourceInterfaces.scala index 0b5a19fe9384..ea614e55b540 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/fileSourceInterfaces.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/fileSourceInterfaces.scala @@ -76,7 +76,7 @@ abstract class OutputWriterFactory extends Serializable { * through the [[OutputWriterFactory]] implementation. * @since 2.0.0 */ - private[sql] def newWriter(path: String): OutputWriter = { + def newWriter(path: String): OutputWriter = { throw new UnsupportedOperationException("newInstance with just path not supported") } } @@ -134,13 +134,13 @@ abstract class OutputWriter { * @param options Configuration used when reading / writing data. */ case class HadoopFsRelation( - sparkSession: SparkSession, location: FileCatalog, partitionSchema: StructType, dataSchema: StructType, bucketSpec: Option[BucketSpec], fileFormat: FileFormat, - options: Map[String, String]) extends BaseRelation with FileRelation { + options: Map[String, String])(val sparkSession: SparkSession) + extends BaseRelation with FileRelation { override def sqlContext: SQLContext = sparkSession.sqlContext @@ -249,7 +249,7 @@ trait FileFormat { * appends partition values to [[InternalRow]]s produced by the reader function [[buildReader]] * returns. */ - private[sql] def buildReaderWithPartitionValues( + def buildReaderWithPartitionValues( sparkSession: SparkSession, dataSchema: StructType, partitionSchema: StructType, @@ -356,14 +356,14 @@ trait FileCatalog { /** * Helper methods for gathering metadata from HDFS. */ -private[sql] object HadoopFsRelation extends Logging { +object HadoopFsRelation extends Logging { /** Checks if we should filter out this path name. */ def shouldFilterOut(pathName: String): Boolean = { // We filter everything that starts with _ and ., except _common_metadata and _metadata // because Parquet needs to find those metadata files from leaf files returned by this method. // We should refactor this logic to not mix metadata files with data files. - (pathName.startsWith("_") || pathName.startsWith(".")) && + ((pathName.startsWith("_") && !pathName.contains("=")) || pathName.startsWith(".")) && !pathName.startsWith("_common_metadata") && !pathName.startsWith("_metadata") } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCRDD.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCRDD.scala index 24e2c1a5fd2f..82bc9d75beff 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCRDD.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCRDD.scala @@ -38,11 +38,11 @@ import org.apache.spark.unsafe.types.UTF8String /** * Data corresponding to one partition of a JDBCRDD. */ -private[sql] case class JDBCPartition(whereClause: String, idx: Int) extends Partition { +case class JDBCPartition(whereClause: String, idx: Int) extends Partition { override def index: Int = idx } -private[sql] object JDBCRDD extends Logging { +object JDBCRDD extends Logging { /** * Maps a JDBC type to a Catalyst type. This function is called only when @@ -136,7 +136,16 @@ private[sql] object JDBCRDD extends Logging { val typeName = rsmd.getColumnTypeName(i + 1) val fieldSize = rsmd.getPrecision(i + 1) val fieldScale = rsmd.getScale(i + 1) - val isSigned = rsmd.isSigned(i + 1) + val isSigned = { + try { + rsmd.isSigned(i + 1) + } catch { + // Workaround for HIVE-14684: + case e: SQLException if + e.getMessage == "Method not supported" && + rsmd.getClass.getName == "org.apache.hive.jdbc.HiveResultSetMetaData" => true + } + } val nullable = rsmd.isNullable(i + 1) != ResultSetMetaData.columnNoNulls val metadata = new MetadataBuilder() .putString("name", columnName) @@ -192,7 +201,7 @@ private[sql] object JDBCRDD extends Logging { * Turns a single Filter into a String representing a SQL expression. * Returns None for an unhandled filter. */ - private[jdbc] def compileFilter(f: Filter): Option[String] = { + def compileFilter(f: Filter): Option[String] = { Option(f match { case EqualTo(attr, value) => s"$attr = ${compileValue(value)}" case EqualNullSafe(attr, value) => @@ -275,7 +284,7 @@ private[sql] object JDBCRDD extends Logging { * driver code and the workers must be able to access the database; the driver * needs to fetch the schema while the workers need to fetch the data. */ -private[sql] class JDBCRDD( +private[jdbc] class JDBCRDD( sc: SparkContext, getConnection: () => Connection, schema: StructType, diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JdbcUtils.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JdbcUtils.scala index d3e1efc56277..7a8b82509383 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JdbcUtils.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JdbcUtils.scala @@ -17,7 +17,7 @@ package org.apache.spark.sql.execution.datasources.jdbc -import java.sql.{Connection, Driver, DriverManager, PreparedStatement} +import java.sql.{Connection, Driver, DriverManager, PreparedStatement, SQLException} import java.util.Properties import scala.collection.JavaConverters._ @@ -54,7 +54,7 @@ object JdbcUtils extends Logging { DriverManager.getDriver(url).getClass.getCanonicalName } () => { - userSpecifiedDriverClass.foreach(DriverRegistry.register) + DriverRegistry.register(driverClass) val driver: Driver = DriverManager.getDrivers.asScala.collectFirst { case d: DriverWrapper if d.wrapped.getClass.getCanonicalName == driverClass => d case d if d.getClass.getCanonicalName == driverClass => d @@ -233,6 +233,17 @@ object JdbcUtils extends Logging { conn.commit() } committed = true + } catch { + case e: SQLException => + val cause = e.getNextException + if (e.getCause != cause) { + if (e.getCause == null) { + e.initCause(cause) + } else { + e.addSuppressed(cause) + } + } + throw e } finally { if (!committed) { // The stage must fail. We got here through an exception path, so diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/InferSchema.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/InferSchema.scala index 579b036417d2..f260304999a5 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/InferSchema.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/InferSchema.scala @@ -14,6 +14,24 @@ * See the License for the specific language governing permissions and * limitations under the License. */ +/* + * Changes for SnappyData data platform. + * + * Portions Copyright (c) 2016 SnappyData, Inc. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you + * may not use this file except in compliance with the License. You + * may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or + * implied. See the License for the specific language governing + * permissions and limitations under the License. See accompanying + * LICENSE file. + */ package org.apache.spark.sql.execution.datasources.json @@ -263,8 +281,8 @@ private[sql] object InferSchema { case (t1: DecimalType, t2: DecimalType) => val scale = math.max(t1.scale, t2.scale) val range = math.max(t1.precision - t1.scale, t2.precision - t2.scale) - if (range + scale > 38) { - // DecimalType can't support precision > 38 + if (range + scale > DecimalType.MAX_PRECISION) { + // DecimalType can't support precision > DecimalType.MAX_PRECISION DoubleType } else { DecimalType(range + scale, scale) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JSONOptions.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JSONOptions.scala index 66f1126fb9ae..02d211d04265 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JSONOptions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JSONOptions.scala @@ -18,6 +18,7 @@ package org.apache.spark.sql.execution.datasources.json import com.fasterxml.jackson.core.{JsonFactory, JsonParser} +import org.apache.commons.lang3.time.FastDateFormat import org.apache.spark.internal.Logging import org.apache.spark.sql.execution.datasources.{CompressionCodecs, ParseModes} @@ -53,6 +54,14 @@ private[sql] class JSONOptions( private val parseMode = parameters.getOrElse("mode", "PERMISSIVE") val columnNameOfCorruptRecord = parameters.get("columnNameOfCorruptRecord") + // Uses `FastDateFormat` which can be direct replacement for `SimpleDateFormat` and thread-safe. + val dateFormat: FastDateFormat = + FastDateFormat.getInstance(parameters.getOrElse("dateFormat", "yyyy-MM-dd")) + + val timestampFormat: FastDateFormat = + FastDateFormat.getInstance( + parameters.getOrElse("timestampFormat", "yyyy-MM-dd'T'HH:mm:ss.SSSZZ")) + // Parse mode flags if (!ParseModes.isValidMode(parseMode)) { logWarning(s"$parseMode is not a valid parse mode. Using ${ParseModes.DEFAULT}.") diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JacksonGenerator.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JacksonGenerator.scala index 8b920ecafaee..800d43f3039c 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JacksonGenerator.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JacksonGenerator.scala @@ -32,11 +32,17 @@ private[sql] object JacksonGenerator { * @param gen a JsonGenerator object * @param row The row to convert */ - def apply(rowSchema: StructType, gen: JsonGenerator)(row: InternalRow): Unit = { + def apply( + rowSchema: StructType, + gen: JsonGenerator, + options: JSONOptions = new JSONOptions(Map.empty[String, String])) + (row: InternalRow): Unit = { def valWriter: (DataType, Any) => Unit = { case (_, null) | (NullType, _) => gen.writeNull() case (StringType, v) => gen.writeString(v.toString) - case (TimestampType, v: Long) => gen.writeString(DateTimeUtils.toJavaTimestamp(v).toString) + case (TimestampType, v: Long) => + val timestampString = options.timestampFormat.format(DateTimeUtils.toJavaTimestamp(v)) + gen.writeString(timestampString) case (IntegerType, v: Int) => gen.writeNumber(v) case (ShortType, v: Short) => gen.writeNumber(v) case (FloatType, v: Float) => gen.writeNumber(v) @@ -46,7 +52,9 @@ private[sql] object JacksonGenerator { case (ByteType, v: Byte) => gen.writeNumber(v.toInt) case (BinaryType, v: Array[Byte]) => gen.writeBinary(v) case (BooleanType, v: Boolean) => gen.writeBoolean(v) - case (DateType, v: Int) => gen.writeString(DateTimeUtils.toJavaDate(v).toString) + case (DateType, v: Int) => + val dateString = options.dateFormat.format(DateTimeUtils.toJavaDate(v)) + gen.writeString(dateString) // For UDT values, they should be in the SQL type's corresponding value type. // We should not see values in the user-defined class at here. // For example, VectorUDT's SQL type is an array of double. So, we should expect that v is diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JacksonParser.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JacksonParser.scala index 733fcbfea101..a5417dc4a0e1 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JacksonParser.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JacksonParser.scala @@ -20,6 +20,7 @@ package org.apache.spark.sql.execution.datasources.json import java.io.ByteArrayOutputStream import scala.collection.mutable.ArrayBuffer +import scala.util.Try import com.fasterxml.jackson.core._ @@ -56,28 +57,30 @@ object JacksonParser extends Logging { def convertRootField( factory: JsonFactory, parser: JsonParser, - schema: DataType): Any = { + schema: DataType, + configOptions: JSONOptions): Any = { import com.fasterxml.jackson.core.JsonToken._ (parser.getCurrentToken, schema) match { case (START_ARRAY, st: StructType) => // SPARK-3308: support reading top level JSON arrays and take every element // in such an array as a row - convertArray(factory, parser, st) + convertArray(factory, parser, st, configOptions) case (START_OBJECT, ArrayType(st, _)) => // the business end of SPARK-3308: // when an object is found but an array is requested just wrap it in a list - convertField(factory, parser, st) :: Nil + convertField(factory, parser, st, configOptions) :: Nil case _ => - convertField(factory, parser, schema) + convertField(factory, parser, schema, configOptions) } } private def convertField( factory: JsonFactory, parser: JsonParser, - schema: DataType): Any = { + schema: DataType, + configOptions: JSONOptions): Any = { import com.fasterxml.jackson.core.JsonToken._ (parser.getCurrentToken, schema) match { case (null | VALUE_NULL, _) => @@ -85,7 +88,7 @@ object JacksonParser extends Logging { case (FIELD_NAME, _) => parser.nextToken() - convertField(factory, parser, schema) + convertField(factory, parser, schema, configOptions) case (VALUE_STRING, StringType) => UTF8String.fromString(parser.getText) @@ -99,19 +102,29 @@ object JacksonParser extends Logging { case (VALUE_STRING, DateType) => val stringValue = parser.getText - if (stringValue.contains("-")) { - // The format of this string will probably be "yyyy-mm-dd". - DateTimeUtils.millisToDays(DateTimeUtils.stringToTime(parser.getText).getTime) - } else { - // In Spark 1.5.0, we store the data as number of days since epoch in string. - // So, we just convert it to Int. - stringValue.toInt - } + // This one will lose microseconds parts. + // See https://issues.apache.org/jira/browse/SPARK-10681.x + Try(DateTimeUtils.millisToDays(configOptions.dateFormat.parse(parser.getText).getTime)) + .getOrElse { + // If it fails to parse, then tries the way used in 2.0 and 1.x for backwards + // compatibility. + Try(DateTimeUtils.millisToDays(DateTimeUtils.stringToTime(parser.getText).getTime)) + .getOrElse { + // In Spark 1.5.0, we store the data as number of days since epoch in string. + // So, we just convert it to Int. + stringValue.toInt + } + } case (VALUE_STRING, TimestampType) => // This one will lose microseconds parts. // See https://issues.apache.org/jira/browse/SPARK-10681. - DateTimeUtils.stringToTime(parser.getText).getTime * 1000L + Try(configOptions.timestampFormat.parse(parser.getText).getTime * 1000L) + .getOrElse { + // If it fails to parse, then tries the way used in 2.0 and 1.x for backwards + // compatibility. + DateTimeUtils.stringToTime(parser.getText).getTime * 1000L + } case (VALUE_NUMBER_INT, TimestampType) => parser.getLongValue * 1000000L @@ -179,16 +192,16 @@ object JacksonParser extends Logging { false case (START_OBJECT, st: StructType) => - convertObject(factory, parser, st) + convertObject(factory, parser, st, configOptions) case (START_ARRAY, ArrayType(st, _)) => - convertArray(factory, parser, st) + convertArray(factory, parser, st, configOptions) case (START_OBJECT, MapType(StringType, kt, _)) => - convertMap(factory, parser, kt) + convertMap(factory, parser, kt, configOptions) case (_, udt: UserDefinedType[_]) => - convertField(factory, parser, udt.sqlType) + convertField(factory, parser, udt.sqlType, configOptions) case (token, dataType) => // We cannot parse this token based on the given data type. So, we throw a @@ -207,12 +220,13 @@ object JacksonParser extends Logging { private def convertObject( factory: JsonFactory, parser: JsonParser, - schema: StructType): InternalRow = { + schema: StructType, + configOptions: JSONOptions): InternalRow = { val row = new GenericMutableRow(schema.length) while (nextUntil(parser, JsonToken.END_OBJECT)) { schema.getFieldIndex(parser.getCurrentName) match { case Some(index) => - row.update(index, convertField(factory, parser, schema(index).dataType)) + row.update(index, convertField(factory, parser, schema(index).dataType, configOptions)) case None => parser.skipChildren() @@ -228,12 +242,13 @@ object JacksonParser extends Logging { private def convertMap( factory: JsonFactory, parser: JsonParser, - valueType: DataType): MapData = { + valueType: DataType, + configOptions: JSONOptions): MapData = { val keys = ArrayBuffer.empty[UTF8String] val values = ArrayBuffer.empty[Any] while (nextUntil(parser, JsonToken.END_OBJECT)) { keys += UTF8String.fromString(parser.getCurrentName) - values += convertField(factory, parser, valueType) + values += convertField(factory, parser, valueType, configOptions) } ArrayBasedMapData(keys.toArray, values.toArray) } @@ -241,10 +256,11 @@ object JacksonParser extends Logging { private def convertArray( factory: JsonFactory, parser: JsonParser, - elementType: DataType): ArrayData = { + elementType: DataType, + configOptions: JSONOptions): ArrayData = { val values = ArrayBuffer.empty[Any] while (nextUntil(parser, JsonToken.END_ARRAY)) { - values += convertField(factory, parser, elementType) + values += convertField(factory, parser, elementType, configOptions) } new GenericArrayData(values.toArray) @@ -285,7 +301,7 @@ object JacksonParser extends Logging { Utils.tryWithResource(factory.createParser(record)) { parser => parser.nextToken() - convertRootField(factory, parser, schema) match { + convertRootField(factory, parser, schema, configOptions) match { case null => failedRecord(record) case row: InternalRow => row :: Nil case array: ArrayData => diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JsonFileFormat.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JsonFileFormat.scala index 86aef1f7d441..cba3255d86f4 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JsonFileFormat.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JsonFileFormat.scala @@ -28,6 +28,7 @@ import org.apache.hadoop.mapreduce.{Job, RecordWriter, TaskAttemptContext} import org.apache.hadoop.mapreduce.lib.input.FileInputFormat import org.apache.hadoop.mapreduce.lib.output.TextOutputFormat +import org.apache.spark.TaskContext import org.apache.spark.internal.Logging import org.apache.spark.rdd.RDD import org.apache.spark.sql.{AnalysisException, Row, SparkSession} @@ -55,7 +56,7 @@ class JsonFileFormat extends TextBasedFileFormat with DataSourceRegister { .getOrElse(sparkSession.sessionState.conf.columnNameOfCorruptRecord) val jsonFiles = files.filterNot { status => val name = status.getPath.getName - name.startsWith("_") || name.startsWith(".") + (name.startsWith("_") && !name.contains("=")) || name.startsWith(".") }.toArray val jsonSchema = InferSchema.infer( @@ -85,7 +86,7 @@ class JsonFileFormat extends TextBasedFileFormat with DataSourceRegister { bucketId: Option[Int], dataSchema: StructType, context: TaskAttemptContext): OutputWriter = { - new JsonOutputWriter(path, bucketId, dataSchema, context) + new JsonOutputWriter(path, parsedOptions, bucketId, dataSchema, context) } } } @@ -106,7 +107,9 @@ class JsonFileFormat extends TextBasedFileFormat with DataSourceRegister { .getOrElse(sparkSession.sessionState.conf.columnNameOfCorruptRecord) (file: PartitionedFile) => { - val lines = new HadoopFileLinesReader(file, broadcastedHadoopConf.value.value).map(_.toString) + val linesReader = new HadoopFileLinesReader(file, broadcastedHadoopConf.value.value) + Option(TaskContext.get()).foreach(_.addTaskCompletionListener(_ => linesReader.close())) + val lines = linesReader.map(_.toString) JacksonParser.parseJson( lines, @@ -155,6 +158,7 @@ class JsonFileFormat extends TextBasedFileFormat with DataSourceRegister { private[json] class JsonOutputWriter( path: String, + options: JSONOptions, bucketId: Option[Int], dataSchema: StructType, context: TaskAttemptContext) @@ -181,7 +185,7 @@ private[json] class JsonOutputWriter( override def write(row: Row): Unit = throw new UnsupportedOperationException("call writeInternal") override protected[sql] def writeInternal(row: InternalRow): Unit = { - JacksonGenerator(dataSchema, gen)(row) + JacksonGenerator(dataSchema, gen, options)(row) gen.flush() result.set(writer.toString) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFileFormat.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFileFormat.scala index f1c78bb60c4f..aef0f1b4157f 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFileFormat.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFileFormat.scala @@ -37,7 +37,7 @@ import org.apache.parquet.hadoop.util.ContextUtil import org.apache.parquet.schema.MessageType import org.slf4j.bridge.SLF4JBridgeHandler -import org.apache.spark.SparkException +import org.apache.spark.{SparkException, TaskContext} import org.apache.spark.internal.Logging import org.apache.spark.sql._ import org.apache.spark.sql.catalyst.InternalRow @@ -46,12 +46,13 @@ import org.apache.spark.sql.catalyst.expressions.codegen.GenerateUnsafeProjectio import org.apache.spark.sql.catalyst.parser.LegacyTypeStringParser import org.apache.spark.sql.execution.command.CreateDataSourceTableUtils import org.apache.spark.sql.execution.datasources._ +import org.apache.spark.sql.execution.metric.SQLMetric import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.sources._ import org.apache.spark.sql.types._ import org.apache.spark.util.SerializableConfiguration -private[sql] class ParquetFileFormat +class ParquetFileFormat extends FileFormat with DataSourceRegister with Logging @@ -233,7 +234,8 @@ private[sql] class ParquetFileFormat // Lists `FileStatus`es of all leaf nodes (files) under all base directories. val leaves = allFiles.filter { f => isSummaryFile(f.getPath) || - !(f.getPath.getName.startsWith("_") || f.getPath.getName.startsWith(".")) + !((f.getPath.getName.startsWith("_") && !f.getPath.getName.contains("=")) || + f.getPath.getName.startsWith(".")) }.toArray.sortBy(_.getPath.toString) FileTypes( @@ -266,7 +268,7 @@ private[sql] class ParquetFileFormat true } - override private[sql] def buildReaderWithPartitionValues( + override def buildReaderWithPartitionValues( sparkSession: SparkSession, dataSchema: StructType, partitionSchema: StructType, @@ -355,6 +357,11 @@ private[sql] class ParquetFileFormat val hadoopAttemptContext = new TaskAttemptContextImpl(broadcastedHadoopConf.value.value, attemptId) + // Try to push down filters when filter push-down is enabled. + // Notice: This push-down is RowGroups level, not individual records. + if (pushed.isDefined) { + ParquetInputFormat.setFilterPredicate(hadoopAttemptContext.getConfiguration, pushed.get) + } val parquetReader = if (enableVectorizedReader) { val vectorizedReader = new VectorizedParquetRecordReader() vectorizedReader.initialize(split, hadoopAttemptContext) @@ -366,6 +373,7 @@ private[sql] class ParquetFileFormat vectorizedReader } else { logDebug(s"Falling back to parquet-mr") + // ParquetRecordReader returns UnsafeRow val reader = pushed match { case Some(filter) => new ParquetRecordReader[InternalRow]( @@ -379,6 +387,7 @@ private[sql] class ParquetFileFormat } val iter = new RecordReaderIterator(parquetReader) + Option(TaskContext.get()).foreach(_.addTaskCompletionListener(_ => iter.close())) // UnsafeRowParquetRecordReader appends the columns internally to avoid another copy. if (parquetReader.isInstanceOf[VectorizedParquetRecordReader] && @@ -392,8 +401,13 @@ private[sql] class ParquetFileFormat // This is a horrible erasure hack... if we type the iterator above, then it actually check // the type in next() and we get a class cast exception. If we make that function return // Object, then we can defer the cast until later! - iter.asInstanceOf[Iterator[InternalRow]] + if (partitionSchema.length == 0) { + // There is no partition columns + iter.asInstanceOf[Iterator[InternalRow]] + } else { + iter.asInstanceOf[Iterator[InternalRow]] .map(d => appendPartitionColumns(joinedRow(d, file.partitionValues))) + } } } } @@ -416,7 +430,7 @@ private[sql] class ParquetFileFormat * writes the data to the path used to generate the output writer. Callers of this factory * has to ensure which files are to be considered as committed. */ -private[sql] class ParquetOutputWriterFactory( +private[parquet] class ParquetOutputWriterFactory( sqlConf: SQLConf, dataSchema: StructType, hadoopConf: Configuration, @@ -465,7 +479,7 @@ private[sql] class ParquetOutputWriterFactory( * Returns a [[OutputWriter]] that writes data to the give path without using * [[OutputCommitter]]. */ - override private[sql] def newWriter(path: String): OutputWriter = new OutputWriter { + override def newWriter(path: String): OutputWriter = new OutputWriter { // Create TaskAttemptContext that is used to pass on Configuration to the ParquetRecordWriter private val hadoopTaskAttempId = new TaskAttemptID(new TaskID(new JobID, TaskType.MAP, 0), 0) @@ -512,7 +526,7 @@ private[sql] class ParquetOutputWriterFactory( // NOTE: This class is instantiated and used on executor side only, no need to be serializable. -private[sql] class ParquetOutputWriter( +private[parquet] class ParquetOutputWriter( path: String, bucketId: Option[Int], context: TaskAttemptContext) @@ -550,12 +564,13 @@ private[sql] class ParquetOutputWriter( override def write(row: Row): Unit = throw new UnsupportedOperationException("call writeInternal") - override protected[sql] def writeInternal(row: InternalRow): Unit = recordWriter.write(null, row) + override def writeInternal(row: InternalRow): Unit = recordWriter.write(null, row) override def close(): Unit = recordWriter.close(context) } -private[sql] object ParquetFileFormat extends Logging { + +object ParquetFileFormat extends Logging { /** * If parquet's block size (row group size) setting is larger than the min split size, * we use parquet's block size setting as the min split size. Otherwise, we will create @@ -702,7 +717,7 @@ private[sql] object ParquetFileFormat extends Logging { * distinguish binary and string). This method generates a correct schema by merging Metastore * schema data types and Parquet schema field names. */ - private[sql] def mergeMetastoreParquetSchema( + def mergeMetastoreParquetSchema( metastoreSchema: StructType, parquetSchema: StructType): StructType = { def schemaConflictMessage: String = diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFilters.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFilters.scala index 70ae829219d5..2edd2757428a 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFilters.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFilters.scala @@ -28,7 +28,7 @@ import org.apache.parquet.schema.PrimitiveType.PrimitiveTypeName import org.apache.spark.sql.sources import org.apache.spark.sql.types._ -private[sql] object ParquetFilters { +object ParquetFilters { case class SetInFilter[T <: Comparable[T]]( valueSet: Set[T]) extends UserDefinedPredicate[T] with Serializable { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetOptions.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetOptions.scala index dd2e915e7b7f..3eec582714e1 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetOptions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetOptions.scala @@ -24,7 +24,7 @@ import org.apache.spark.sql.internal.SQLConf /** * Options for the Parquet data source. */ -private[sql] class ParquetOptions( +private[parquet] class ParquetOptions( @transient private val parameters: Map[String, String], @transient private val sqlConf: SQLConf) extends Serializable { @@ -56,8 +56,8 @@ private[sql] class ParquetOptions( } -private[sql] object ParquetOptions { - private[sql] val MERGE_SCHEMA = "mergeSchema" +object ParquetOptions { + val MERGE_SCHEMA = "mergeSchema" // The parquet compression short names private val shortParquetCompressionCodecNames = Map( diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetSchemaConverter.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetSchemaConverter.scala index 1ac083f48a8c..8894a7f8dc9b 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetSchemaConverter.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetSchemaConverter.scala @@ -14,6 +14,24 @@ * See the License for the specific language governing permissions and * limitations under the License. */ +/* + * Changes for SnappyData data platform. + * + * Portions Copyright (c) 2016 SnappyData, Inc. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you + * may not use this file except in compliance with the License. You + * may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or + * implied. See the License for the specific language governing + * permissions and limitations under the License. See accompanying + * LICENSE file. + */ package org.apache.spark.sql.execution.datasources.parquet @@ -567,7 +585,8 @@ private[parquet] object ParquetSchemaConverter { } // Returns the minimum number of bytes needed to store a decimal with a given `precision`. - val minBytesForPrecision = Array.tabulate[Int](39)(computeMinBytesForPrecision) + val minBytesForPrecision = Array.tabulate[Int](DecimalType.MAX_PRECISION + 1)( + computeMinBytesForPrecision) // Max precision of a decimal value stored in `numBytes` bytes def maxPrecisionForBytes(numBytes: Int): Int = { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/rules.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/rules.scala index 15b9d14bd73f..05908d908fd2 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/rules.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/rules.scala @@ -32,7 +32,7 @@ import org.apache.spark.sql.sources.{BaseRelation, InsertableRelation} /** * Try to replaces [[UnresolvedRelation]]s with [[ResolveDataSource]]. */ -private[sql] class ResolveDataSource(sparkSession: SparkSession) extends Rule[LogicalPlan] { +class ResolveDataSource(sparkSession: SparkSession) extends Rule[LogicalPlan] { def apply(plan: LogicalPlan): LogicalPlan = plan resolveOperators { case u: UnresolvedRelation if u.tableIdentifier.database.isDefined => try { @@ -67,7 +67,7 @@ private[sql] class ResolveDataSource(sparkSession: SparkSession) extends Rule[Lo * table. It also does data type casting and field renaming, to make sure that the columns to be * inserted have the correct data type and fields have the correct names. */ -private[sql] case class PreprocessTableInsertion(conf: SQLConf) extends Rule[LogicalPlan] { +case class PreprocessTableInsertion(conf: SQLConf) extends Rule[LogicalPlan] { private def preprocess( insert: InsertIntoTable, tblName: String, @@ -147,7 +147,7 @@ private[sql] case class PreprocessTableInsertion(conf: SQLConf) extends Rule[Log /** * A rule to do various checks before inserting into or writing to a data source table. */ -private[sql] case class PreWriteCheck(conf: SQLConf, catalog: SessionCatalog) +case class PreWriteCheck(conf: SQLConf, catalog: SessionCatalog) extends (LogicalPlan => Unit) { def failAnalysis(msg: String): Unit = { throw new AnalysisException(msg) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/text/TextFileFormat.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/text/TextFileFormat.scala index abb6059f75ba..6aa078af3ea1 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/text/TextFileFormat.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/text/TextFileFormat.scala @@ -23,6 +23,7 @@ import org.apache.hadoop.io.{NullWritable, Text} import org.apache.hadoop.mapreduce.{Job, RecordWriter, TaskAttemptContext} import org.apache.hadoop.mapreduce.lib.output.TextOutputFormat +import org.apache.spark.TaskContext import org.apache.spark.sql.{AnalysisException, Row, SparkSession} import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions.UnsafeRow @@ -101,6 +102,7 @@ class TextFileFormat extends TextBasedFileFormat with DataSourceRegister { (file: PartitionedFile) => { val reader = new HadoopFileLinesReader(file, broadcastedHadoopConf.value.value) + Option(TaskContext.get()).foreach(_.addTaskCompletionListener(_ => reader.close())) if (requiredSchema.isEmpty) { val emptyUnsafeRow = new UnsafeRow(0) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/debug/package.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/debug/package.scala index e89f792496d6..d321f4cd7687 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/debug/package.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/debug/package.scala @@ -17,7 +17,9 @@ package org.apache.spark.sql.execution -import scala.collection.mutable.HashSet +import java.util.Collections + +import scala.collection.JavaConverters._ import org.apache.spark.internal.Logging import org.apache.spark.rdd.RDD @@ -104,21 +106,23 @@ package object debug { } } - private[sql] case class DebugExec(child: SparkPlan) extends UnaryExecNode with CodegenSupport { + case class DebugExec(child: SparkPlan) extends UnaryExecNode with CodegenSupport { def output: Seq[Attribute] = child.output - class SetAccumulator[T] extends AccumulatorV2[T, HashSet[T]] { - private val _set = new HashSet[T]() + class SetAccumulator[T] extends AccumulatorV2[T, java.util.Set[T]] { + private val _set = Collections.synchronizedSet(new java.util.HashSet[T]()) override def isZero: Boolean = _set.isEmpty - override def copy(): AccumulatorV2[T, HashSet[T]] = { + override def copy(): AccumulatorV2[T, java.util.Set[T]] = { val newAcc = new SetAccumulator[T]() - newAcc._set ++= _set + newAcc._set.addAll(_set) newAcc } override def reset(): Unit = _set.clear() - override def add(v: T): Unit = _set += v - override def merge(other: AccumulatorV2[T, HashSet[T]]): Unit = _set ++= other.value - override def value: HashSet[T] = _set + override def add(v: T): Unit = _set.add(v) + override def merge(other: AccumulatorV2[T, java.util.Set[T]]): Unit = { + _set.addAll(other.value) + } + override def value: java.util.Set[T] = _set } /** @@ -138,7 +142,9 @@ package object debug { debugPrint(s"== ${child.simpleString} ==") debugPrint(s"Tuples output: ${tupleCount.value}") child.output.zip(columnStats).foreach { case (attr, metric) => - val actualDataTypes = metric.elementTypes.value.mkString("{", ",", "}") + // This is called on driver. All accumulator updates have a fixed value. So it's safe to use + // `asScala` which accesses the internal values using `java.util.Iterator`. + val actualDataTypes = metric.elementTypes.value.asScala.mkString("{", ",", "}") debugPrint(s" ${attr.name} ${attr.dataType}: $actualDataTypes") } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/BroadcastExchangeExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/BroadcastExchangeExec.scala index bd0841db7e8a..a809076de541 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/BroadcastExchangeExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/BroadcastExchangeExec.scala @@ -38,7 +38,7 @@ case class BroadcastExchangeExec( mode: BroadcastMode, child: SparkPlan) extends Exchange { - override private[sql] lazy val metrics = Map( + override lazy val metrics = Map( "dataSize" -> SQLMetrics.createMetric(sparkContext, "data size (bytes)"), "collectTime" -> SQLMetrics.createMetric(sparkContext, "time to collect (ms)"), "buildTime" -> SQLMetrics.createMetric(sparkContext, "time to build (ms)"), diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/EnsureRequirements.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/EnsureRequirements.scala index 446571aa8409..5aabb08efc9f 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/EnsureRequirements.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/EnsureRequirements.scala @@ -47,10 +47,11 @@ case class EnsureRequirements(conf: SQLConf) extends Rule[SparkPlan] { */ private def createPartitioning( requiredDistribution: Distribution, - numPartitions: Int): Partitioning = { + numPartitions: Int, numBuckets: Int = 0): Partitioning = { requiredDistribution match { case AllTuples => SinglePartition - case ClusteredDistribution(clustering) => HashPartitioning(clustering, numPartitions) + case ClusteredDistribution(clustering) => + HashPartitioning(clustering, numPartitions, numBuckets) case OrderedDistribution(ordering) => RangePartitioning(ordering, numPartitions) case dist => sys.error(s"Do not know how to satisfy distribution $dist") } @@ -180,10 +181,20 @@ case class EnsureRequirements(conf: SQLConf) extends Rule[SparkPlan] { // partitioned by the same partitioning into the same number of partitions. In that case, // don't try to make them match `defaultPartitions`, just use the existing partitioning. val maxChildrenNumPartitions = children.map(_.outputPartitioning.numPartitions).max + val numBuckets = { + children.map(child => { + if (child.outputPartitioning.isInstanceOf[OrderlessHashPartitioning]) { + child.outputPartitioning.asInstanceOf[OrderlessHashPartitioning].numBuckets + } + else { + 0 + } + }).reduceLeft(_ max _) + } val useExistingPartitioning = children.zip(requiredChildDistributions).forall { case (child, distribution) => child.outputPartitioning.guarantees( - createPartitioning(distribution, maxChildrenNumPartitions)) + createPartitioning(distribution, maxChildrenNumPartitions, numBuckets)) } children = if (useExistingPartitioning) { @@ -205,10 +216,10 @@ case class EnsureRequirements(conf: SQLConf) extends Rule[SparkPlan] { // number of partitions. Otherwise, we use maxChildrenNumPartitions. if (shufflesAllChildren) defaultNumPreShufflePartitions else maxChildrenNumPartitions } - children.zip(requiredChildDistributions).map { case (child, distribution) => - val targetPartitioning = createPartitioning(distribution, numPartitions) + val targetPartitioning = createPartitioning(distribution, + numPartitions) if (child.outputPartitioning.guarantees(targetPartitioning)) { child } else { @@ -236,7 +247,16 @@ case class EnsureRequirements(conf: SQLConf) extends Rule[SparkPlan] { children = children.zip(requiredChildOrderings).map { case (child, requiredOrdering) => if (requiredOrdering.nonEmpty) { // If child.outputOrdering is [a, b] and requiredOrdering is [a], we do not need to sort. - if (requiredOrdering != child.outputOrdering.take(requiredOrdering.length)) { + val orderingMatched = if (requiredOrdering.length > child.outputOrdering.length) { + false + } else { + requiredOrdering.zip(child.outputOrdering).forall { + case (requiredOrder, childOutputOrder) => + requiredOrder.semanticEquals(childOutputOrder) + } + } + + if (!orderingMatched) { SortExec(requiredOrdering, global = false, child = child) } else { child diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/ExchangeCoordinator.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/ExchangeCoordinator.scala index 2ea6ee38a932..57da85fa84f9 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/ExchangeCoordinator.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/ExchangeCoordinator.scala @@ -79,7 +79,7 @@ import org.apache.spark.sql.execution.{ShuffledRowRDD, SparkPlan} * - post-shuffle partition 1: pre-shuffle partition 2 * - post-shuffle partition 2: pre-shuffle partition 3 and 4 */ -private[sql] class ExchangeCoordinator( +class ExchangeCoordinator( numExchanges: Int, advisoryTargetPostShuffleInputSize: Long, minNumPostShufflePartitions: Option[Int] = None) @@ -112,7 +112,7 @@ private[sql] class ExchangeCoordinator( * Estimates partition start indices for post-shuffle partitions based on * mapOutputStatistics provided by all pre-shuffle stages. */ - private[sql] def estimatePartitionStartIndices( + def estimatePartitionStartIndices( mapOutputStatistics: Array[MapOutputStatistics]): Array[Int] = { // If we have mapOutputStatistics.length < numExchange, it is because we do not submit // a stage when the number of partitions of this dependency is 0. diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/ShuffleExchange.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/ShuffleExchange.scala index afe0fbea73bd..fd86cdfcac66 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/ShuffleExchange.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/ShuffleExchange.scala @@ -25,7 +25,7 @@ import org.apache.spark.serializer.Serializer import org.apache.spark.shuffle.sort.SortShuffleManager import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.errors._ -import org.apache.spark.sql.catalyst.expressions.{Attribute, UnsafeProjection, UnsafeRow} +import org.apache.spark.sql.catalyst.expressions.{Attribute, UnsafeProjection} import org.apache.spark.sql.catalyst.expressions.codegen.LazilyGeneratedOrdering import org.apache.spark.sql.catalyst.plans.physical._ import org.apache.spark.sql.execution._ @@ -40,7 +40,7 @@ case class ShuffleExchange( child: SparkPlan, @transient coordinator: Option[ExchangeCoordinator]) extends Exchange { - override private[sql] lazy val metrics = Map( + override lazy val metrics = Map( "dataSize" -> SQLMetrics.createSizeMetric(sparkContext, "data size")) override def nodeName: String = { @@ -81,7 +81,8 @@ case class ShuffleExchange( * the partitioning scheme defined in `newPartitioning`. Those partitions of * the returned ShuffleDependency will be the input of shuffle. */ - private[sql] def prepareShuffleDependency(): ShuffleDependency[Int, InternalRow, InternalRow] = { + private[exchange] def prepareShuffleDependency() + : ShuffleDependency[Int, InternalRow, InternalRow] = { ShuffleExchange.prepareShuffleDependency( child.execute(), child.output, newPartitioning, serializer) } @@ -92,7 +93,7 @@ case class ShuffleExchange( * partition start indices array. If this optional array is defined, the returned * [[ShuffledRowRDD]] will fetch pre-shuffle partitions based on indices of this array. */ - private[sql] def preparePostShuffleRDD( + private[exchange] def preparePostShuffleRDD( shuffleDependency: ShuffleDependency[Int, InternalRow, InternalRow], specifiedPartitionStartIndices: Option[Array[Int]] = None): ShuffledRowRDD = { // If an array of partition start indices is provided, we need to use this array @@ -194,20 +195,14 @@ object ShuffleExchange { * the partitioning scheme defined in `newPartitioning`. Those partitions of * the returned ShuffleDependency will be the input of shuffle. */ - private[sql] def prepareShuffleDependency( + def prepareShuffleDependency( rdd: RDD[InternalRow], outputAttributes: Seq[Attribute], newPartitioning: Partitioning, serializer: Serializer): ShuffleDependency[Int, InternalRow, InternalRow] = { val part: Partitioner = newPartitioning match { case RoundRobinPartitioning(numPartitions) => new HashPartitioner(numPartitions) - case HashPartitioning(_, n) => - new Partitioner { - override def numPartitions: Int = n - // For HashPartitioning, the partitioning key is already a valid partition ID, as we use - // `HashPartitioning.partitionIdExpression` to produce partitioning key. - override def getPartition(key: Any): Int = key.asInstanceOf[Int] - } + case HashPartitioning(_, n, b) => new HashPartitioner(n, b) case RangePartitioning(sortingExpressions, numPartitions) => // Internally, RangePartitioner runs a job on the RDD that samples keys to compute // partition bounds. To get accurate samples, we need to copy the mutable keys. diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashJoinExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashJoinExec.scala index 7c194ab72643..0f24baacd18d 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashJoinExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashJoinExec.scala @@ -45,7 +45,7 @@ case class BroadcastHashJoinExec( right: SparkPlan) extends BinaryExecNode with HashJoin with CodegenSupport { - override private[sql] lazy val metrics = Map( + override lazy val metrics = Map( "numOutputRows" -> SQLMetrics.createMetric(sparkContext, "number of output rows")) override def requiredChildDistribution: Seq[Distribution] = { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastNestedLoopJoinExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastNestedLoopJoinExec.scala index 4d43765f8fcd..6a9965f1a24c 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastNestedLoopJoinExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastNestedLoopJoinExec.scala @@ -37,7 +37,7 @@ case class BroadcastNestedLoopJoinExec( condition: Option[Expression], withinBroadcastThreshold: Boolean = true) extends BinaryExecNode { - override private[sql] lazy val metrics = Map( + override lazy val metrics = Map( "numOutputRows" -> SQLMetrics.createMetric(sparkContext, "number of output rows")) /** BuildRight means the right relation <=> the broadcast relation. */ diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/CartesianProductExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/CartesianProductExec.scala index 3a0b6efdfc91..c97fffe88b71 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/CartesianProductExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/CartesianProductExec.scala @@ -34,7 +34,6 @@ import org.apache.spark.util.collection.unsafe.sort.UnsafeExternalSorter * will be much faster than building the right partition for every row in left RDD, it also * materialize the right RDD (in case of the right RDD is nondeterministic). */ -private[spark] class UnsafeCartesianRDD(left : RDD[UnsafeRow], right : RDD[UnsafeRow], numFieldsOfRight: Int) extends CartesianRDD[UnsafeRow, UnsafeRow](left.sparkContext, left, right) { @@ -78,7 +77,7 @@ class UnsafeCartesianRDD(left : RDD[UnsafeRow], right : RDD[UnsafeRow], numField for (x <- rdd1.iterator(partition.s1, context); y <- createIter()) yield (x, y) CompletionIterator[(UnsafeRow, UnsafeRow), Iterator[(UnsafeRow, UnsafeRow)]]( - resultIter, sorter.cleanupResources) + resultIter, sorter.cleanupResources()) } } @@ -89,7 +88,7 @@ case class CartesianProductExec( condition: Option[Expression]) extends BinaryExecNode { override def output: Seq[Attribute] = left.output ++ right.output - override private[sql] lazy val metrics = Map( + override lazy val metrics = Map( "numOutputRows" -> SQLMetrics.createMetric(sparkContext, "number of output rows")) protected override def doPrepare(): Unit = { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashJoin.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashJoin.scala index d46a80423fa3..19cb6a394c2c 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashJoin.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashJoin.scala @@ -134,9 +134,9 @@ trait HashJoin { joinRow.withLeft(srow) val matches = hashedRelation.get(joinKeys(srow)) if (matches != null) { - matches.map(joinRow.withRight(_)).filter(boundCondition) + matches.map(joinRow.withRight).filter(boundCondition) } else { - Seq.empty + Iterator.empty } } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashedRelation.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashedRelation.scala index 412e8c54ca30..c88d983bb7fb 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashedRelation.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashedRelation.scala @@ -134,10 +134,17 @@ private[joins] class UnsafeHashedRelation( // re-used in get()/getValue() var resultRow = new UnsafeRow(numFields) + private var mapLoc = initMapLoc() + + private def initMapLoc(): BytesToBytesMap#Location = { + val map = binaryMap + new map.Location + } + override def get(key: InternalRow): Iterator[InternalRow] = { val unsafeKey = key.asInstanceOf[UnsafeRow] val map = binaryMap // avoid the compiler error - val loc = new map.Location // this could be allocated in stack + val loc = mapLoc binaryMap.safeLookup(unsafeKey.getBaseObject, unsafeKey.getBaseOffset, unsafeKey.getSizeInBytes, loc, unsafeKey.hashCode()) if (loc.isDefined) { @@ -243,6 +250,7 @@ private[joins] class UnsafeHashedRelation( taskMemoryManager, (nKeys * 1.5 + 1).toInt, // reduce hash collision pageSizeBytes) + mapLoc = initMapLoc() var i = 0 var keyBuffer = new Array[Byte](1024) @@ -447,10 +455,20 @@ private[execution] final class LongToUnsafeRowMap(val mm: TaskMemoryManager, cap */ private def nextSlot(pos: Int): Int = (pos + 2) & mask + private[this] def toAddress(offset: Long, size: Int): Long = { + ((offset - Platform.LONG_ARRAY_OFFSET) << SIZE_BITS) | size + } + + private[this] def toOffset(address: Long): Long = { + (address >>> SIZE_BITS) + Platform.LONG_ARRAY_OFFSET + } + + private[this] def toSize(address: Long): Int = { + (address & SIZE_MASK).toInt + } + private def getRow(address: Long, resultRow: UnsafeRow): UnsafeRow = { - val offset = address >>> SIZE_BITS - val size = address & SIZE_MASK - resultRow.pointTo(page, offset, size.toInt) + resultRow.pointTo(page, toOffset(address), toSize(address)) resultRow } @@ -459,9 +477,11 @@ private[execution] final class LongToUnsafeRowMap(val mm: TaskMemoryManager, cap */ def getValue(key: Long, resultRow: UnsafeRow): UnsafeRow = { if (isDense) { - val idx = (key - minKey).toInt - if (idx >= 0 && key <= maxKey && array(idx) > 0) { - return getRow(array(idx), resultRow) + if (key >= minKey && key <= maxKey) { + val value = array((key - minKey).toInt) + if (value > 0) { + return getRow(value, resultRow) + } } } else { var pos = firstSlot(key) @@ -483,9 +503,9 @@ private[execution] final class LongToUnsafeRowMap(val mm: TaskMemoryManager, cap var addr = address override def hasNext: Boolean = addr != 0 override def next(): UnsafeRow = { - val offset = addr >>> SIZE_BITS - val size = addr & SIZE_MASK - resultRow.pointTo(page, offset, size.toInt) + val offset = toOffset(addr) + val size = toSize(addr) + resultRow.pointTo(page, offset, size) addr = Platform.getLong(page, offset + size) resultRow } @@ -497,9 +517,11 @@ private[execution] final class LongToUnsafeRowMap(val mm: TaskMemoryManager, cap */ def get(key: Long, resultRow: UnsafeRow): Iterator[UnsafeRow] = { if (isDense) { - val idx = (key - minKey).toInt - if (idx >=0 && key <= maxKey && array(idx) > 0) { - return valueIter(array(idx), resultRow) + if (key >= minKey && key <= maxKey) { + val value = array((key - minKey).toInt) + if (value > 0) { + return valueIter(value, resultRow) + } } } else { var pos = firstSlot(key) @@ -550,7 +572,7 @@ private[execution] final class LongToUnsafeRowMap(val mm: TaskMemoryManager, cap Platform.putLong(page, cursor, 0) cursor += 8 numValues += 1 - updateIndex(key, (offset.toLong << SIZE_BITS) | row.getSizeInBytes) + updateIndex(key, toAddress(offset, row.getSizeInBytes)) } /** @@ -558,6 +580,7 @@ private[execution] final class LongToUnsafeRowMap(val mm: TaskMemoryManager, cap */ private def updateIndex(key: Long, address: Long): Unit = { var pos = firstSlot(key) + assert(numKeys < array.length / 2) while (array(pos) != key && array(pos + 1) != 0) { pos = nextSlot(pos) } @@ -578,7 +601,7 @@ private[execution] final class LongToUnsafeRowMap(val mm: TaskMemoryManager, cap } } else { // there are some values for this key, put the address in the front of them. - val pointer = (address >>> SIZE_BITS) + (address & SIZE_MASK) + val pointer = toOffset(address) + toSize(address) Platform.putLong(page, pointer, array(pos + 1)) array(pos + 1) = address } @@ -608,7 +631,8 @@ private[execution] final class LongToUnsafeRowMap(val mm: TaskMemoryManager, cap def optimize(): Unit = { val range = maxKey - minKey // Convert to dense mode if it does not require more memory or could fit within L1 cache - if (range < array.length || range < 1024) { + // SPARK-16740: Make sure range doesn't overflow if minKey has a large negative value + if (range >= 0 && (range < array.length || range < 1024)) { try { ensureAcquireMemory((range + 1) * 8L) } catch { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/ShuffledHashJoinExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/ShuffledHashJoinExec.scala index 0036f9aadc5d..afb6e5e3dd23 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/ShuffledHashJoinExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/ShuffledHashJoinExec.scala @@ -39,7 +39,7 @@ case class ShuffledHashJoinExec( right: SparkPlan) extends BinaryExecNode with HashJoin { - override private[sql] lazy val metrics = Map( + override lazy val metrics = Map( "numOutputRows" -> SQLMetrics.createMetric(sparkContext, "number of output rows"), "buildDataSize" -> SQLMetrics.createSizeMetric(sparkContext, "data size of build side"), "buildTime" -> SQLMetrics.createTimingMetric(sparkContext, "time to build hash map")) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoinExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoinExec.scala index fac6b8de8ed5..5c9c1e6062f0 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoinExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoinExec.scala @@ -40,7 +40,7 @@ case class SortMergeJoinExec( left: SparkPlan, right: SparkPlan) extends BinaryExecNode with CodegenSupport { - override private[sql] lazy val metrics = Map( + override lazy val metrics = Map( "numOutputRows" -> SQLMetrics.createMetric(sparkContext, "number of output rows")) override def output: Seq[Attribute] = { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/limit.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/limit.scala index 781c01609542..86a877071560 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/limit.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/limit.scala @@ -39,9 +39,10 @@ case class CollectLimitExec(limit: Int, child: SparkPlan) extends UnaryExecNode override def executeCollect(): Array[InternalRow] = child.executeTake(limit) private val serializer: Serializer = new UnsafeRowSerializer(child.output.size) protected override def doExecute(): RDD[InternalRow] = { + val locallyLimited = child.execute().mapPartitionsInternal(_.take(limit)) val shuffled = new ShuffledRowRDD( ShuffleExchange.prepareShuffleDependency( - child.execute(), child.output, SinglePartition, serializer)) + locallyLimited, child.output, SinglePartition, serializer)) shuffled.mapPartitionsInternal(_.take(limit)) } } @@ -114,11 +115,11 @@ case class GlobalLimitExec(limit: Int, child: SparkPlan) extends BaseLimitExec { case class TakeOrderedAndProjectExec( limit: Int, sortOrder: Seq[SortOrder], - projectList: Option[Seq[NamedExpression]], + projectList: Seq[NamedExpression], child: SparkPlan) extends UnaryExecNode { override def output: Seq[Attribute] = { - projectList.map(_.map(_.toAttribute)).getOrElse(child.output) + projectList.map(_.toAttribute) } override def outputPartitioning: Partitioning = SinglePartition @@ -126,8 +127,8 @@ case class TakeOrderedAndProjectExec( override def executeCollect(): Array[InternalRow] = { val ord = new LazilyGeneratedOrdering(sortOrder, child.output) val data = child.execute().map(_.copy()).takeOrdered(limit)(ord) - if (projectList.isDefined) { - val proj = UnsafeProjection.create(projectList.get, child.output) + if (projectList != child.output) { + val proj = UnsafeProjection.create(projectList, child.output) data.map(r => proj(r).copy()) } else { data @@ -148,8 +149,8 @@ case class TakeOrderedAndProjectExec( localTopK, child.output, SinglePartition, serializer)) shuffled.mapPartitions { iter => val topK = org.apache.spark.util.collection.Utils.takeOrdered(iter.map(_.copy()), limit)(ord) - if (projectList.isDefined) { - val proj = UnsafeProjection.create(projectList.get, child.output) + if (projectList != child.output) { + val proj = UnsafeProjection.create(projectList, child.output) topK.map(r => proj(r)) } else { topK diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/metric/SQLMetrics.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/metric/SQLMetrics.scala index 9817a56f499a..e6e01a4a7479 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/metric/SQLMetrics.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/metric/SQLMetrics.scala @@ -18,13 +18,15 @@ package org.apache.spark.sql.execution.metric import java.text.NumberFormat +import java.util.Locale import org.apache.spark.SparkContext import org.apache.spark.scheduler.AccumulableInfo import org.apache.spark.util.{AccumulatorContext, AccumulatorV2, Utils} -class SQLMetric(val metricType: String, initValue: Long = 0L) extends AccumulatorV2[Long, Long] { +final class SQLMetric(val metricType: String, initValue: Long = 0L) + extends AccumulatorV2[Long, Long] { // This is a workaround for SPARK-11013. // We may use -1 as initial value of the accumulator, if the accumulator is valid, we will // update it at the end of task and the value will be at least 0. Then we can filter out the -1 @@ -55,17 +57,17 @@ class SQLMetric(val metricType: String, initValue: Long = 0L) extends Accumulato override def value: Long = _value // Provide special identifier as metadata so we can tell that this is a `SQLMetric` later - private[spark] override def toInfo(update: Option[Any], value: Option[Any]): AccumulableInfo = { + override def toInfo(update: Option[Any], value: Option[Any]): AccumulableInfo = { new AccumulableInfo( id, name, update, value, true, true, Some(AccumulatorContext.SQL_ACCUM_IDENTIFIER)) } } -private[sql] object SQLMetrics { - private[sql] val SUM_METRIC = "sum" - private[sql] val SIZE_METRIC = "size" - private[sql] val TIMING_METRIC = "timing" +object SQLMetrics { + private val SUM_METRIC = "sum" + private val SIZE_METRIC = "size" + private val TIMING_METRIC = "timing" def createMetric(sc: SparkContext, name: String): SQLMetric = { val acc = new SQLMetric(SUM_METRIC) @@ -101,8 +103,7 @@ private[sql] object SQLMetrics { */ def stringValue(metricsType: String, values: Seq[Long]): String = { if (metricsType == SUM_METRIC) { - val numberFormat = NumberFormat.getInstance() - numberFormat.setGroupingUsed(false) + val numberFormat = NumberFormat.getIntegerInstance(Locale.ENGLISH) numberFormat.format(values.sum) } else { val strFormat: Long => String = if (metricsType == SIZE_METRIC) { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/EvaluatePython.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/EvaluatePython.scala index cf68ed4ec36a..724025b4647f 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/EvaluatePython.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/EvaluatePython.scala @@ -24,9 +24,8 @@ import scala.collection.JavaConverters._ import net.razorvine.pickle.{IObjectPickler, Opcodes, Pickler} -import org.apache.spark.api.python.{PythonRDD, SerDeUtil} +import org.apache.spark.api.python.SerDeUtil import org.apache.spark.rdd.RDD -import org.apache.spark.sql.DataFrame import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.util.{ArrayBasedMapData, ArrayData, GenericArrayData, MapData} @@ -34,16 +33,6 @@ import org.apache.spark.sql.types._ import org.apache.spark.unsafe.types.UTF8String object EvaluatePython { - def takeAndServe(df: DataFrame, n: Int): Int = { - registerPicklers() - df.withNewExecutionId { - val iter = new SerDeUtil.AutoBatchedPickler( - df.queryExecution.executedPlan.executeTake(n).iterator.map { row => - EvaluatePython.toJava(row, df.schema) - }) - PythonRDD.serveIterator(iter, s"serve-DataFrame") - } - } def needConversionInPython(dt: DataType): Boolean = dt match { case DateType | TimestampType => true diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ExtractPythonUDFs.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ExtractPythonUDFs.scala index 829bcae6f95d..16e44845d528 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ExtractPythonUDFs.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ExtractPythonUDFs.scala @@ -32,7 +32,7 @@ import org.apache.spark.sql.execution.SparkPlan * Extracts all the Python UDFs in logical aggregate, which depends on aggregate expression or * grouping key, evaluate them after aggregate. */ -private[spark] object ExtractPythonUDFFromAggregate extends Rule[LogicalPlan] { +object ExtractPythonUDFFromAggregate extends Rule[LogicalPlan] { /** * Returns whether the expression could only be evaluated within aggregate. @@ -90,7 +90,7 @@ private[spark] object ExtractPythonUDFFromAggregate extends Rule[LogicalPlan] { * This has the limitation that the input to the Python UDF is not allowed include attributes from * multiple child operators. */ -private[spark] object ExtractPythonUDFs extends Rule[SparkPlan] { +object ExtractPythonUDFs extends Rule[SparkPlan] { private def hasPythonUDF(e: Expression): Boolean = { e.find(_.isInstanceOf[PythonUDF]).isDefined diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/r/MapPartitionsRWrapper.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/r/MapPartitionsRWrapper.scala index 70539da348b0..d2178e971ec2 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/r/MapPartitionsRWrapper.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/r/MapPartitionsRWrapper.scala @@ -21,12 +21,12 @@ import org.apache.spark.api.r._ import org.apache.spark.broadcast.Broadcast import org.apache.spark.sql.api.r.SQLUtils._ import org.apache.spark.sql.Row -import org.apache.spark.sql.types.{BinaryType, StructField, StructType} +import org.apache.spark.sql.types.StructType /** * A function wrapper that applies the given R function to each partition. */ -private[sql] case class MapPartitionsRWrapper( +case class MapPartitionsRWrapper( func: Array[Byte], packageNames: Array[Byte], broadcastVars: Array[Broadcast[Object]], diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/stat/FrequentItems.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/stat/FrequentItems.scala index b19344f04383..b9dbfcf7734c 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/stat/FrequentItems.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/stat/FrequentItems.scala @@ -24,7 +24,7 @@ import org.apache.spark.sql.{Column, DataFrame, Dataset, Row} import org.apache.spark.sql.catalyst.plans.logical.LocalRelation import org.apache.spark.sql.types._ -private[sql] object FrequentItems extends Logging { +object FrequentItems extends Logging { /** A helper class wrapping `MutableMap[Any, Long]` for simplicity. */ private class FreqItemCounter(size: Int) extends Serializable { @@ -79,7 +79,7 @@ private[sql] object FrequentItems extends Logging { * than 1e-4. * @return A Local DataFrame with the Array of frequent items for each column. */ - private[sql] def singlePassFreqItems( + def singlePassFreqItems( df: DataFrame, cols: Seq[String], support: Double): DataFrame = { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/stat/StatFunctions.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/stat/StatFunctions.scala index ea58df70b325..7e2ebe856bc8 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/stat/StatFunctions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/stat/StatFunctions.scala @@ -17,7 +17,7 @@ package org.apache.spark.sql.execution.stat -import scala.collection.mutable.ArrayBuffer +import scala.collection.mutable.{ArrayBuffer, ListBuffer} import org.apache.spark.internal.Logging import org.apache.spark.sql.{Column, DataFrame, Dataset, Row} @@ -27,7 +27,7 @@ import org.apache.spark.sql.functions._ import org.apache.spark.sql.types._ import org.apache.spark.unsafe.types.UTF8String -private[sql] object StatFunctions extends Logging { +object StatFunctions extends Logging { import QuantileSummaries.Stats @@ -119,7 +119,7 @@ private[sql] object StatFunctions extends Logging { class QuantileSummaries( val compressThreshold: Int, val relativeError: Double, - val sampled: ArrayBuffer[Stats] = ArrayBuffer.empty, + val sampled: Array[Stats] = Array.empty, private[stat] var count: Long = 0L, val headSampled: ArrayBuffer[Double] = ArrayBuffer.empty) extends Serializable { @@ -134,7 +134,12 @@ private[sql] object StatFunctions extends Logging { def insert(x: Double): QuantileSummaries = { headSampled.append(x) if (headSampled.size >= defaultHeadSize) { - this.withHeadBufferInserted + val result = this.withHeadBufferInserted + if (result.sampled.length >= compressThreshold) { + result.compress() + } else { + result + } } else { this } @@ -186,7 +191,7 @@ private[sql] object StatFunctions extends Logging { newSamples.append(sampled(sampleIdx)) sampleIdx += 1 } - new QuantileSummaries(compressThreshold, relativeError, newSamples, currentCount) + new QuantileSummaries(compressThreshold, relativeError, newSamples.toArray, currentCount) } /** @@ -305,10 +310,10 @@ private[sql] object StatFunctions extends Logging { private def compressImmut( currentSamples: IndexedSeq[Stats], - mergeThreshold: Double): ArrayBuffer[Stats] = { - val res: ArrayBuffer[Stats] = ArrayBuffer.empty + mergeThreshold: Double): Array[Stats] = { + val res = ListBuffer.empty[Stats] if (currentSamples.isEmpty) { - return res + return res.toArray } // Start for the last element, which is always part of the set. // The head contains the current new head, that may be merged with the current element. @@ -331,13 +336,16 @@ private[sql] object StatFunctions extends Logging { } res.prepend(head) // If necessary, add the minimum element: - res.prepend(currentSamples.head) - res + val currHead = currentSamples.head + if (currHead.value < head.value) { + res.prepend(currentSamples.head) + } + res.toArray } } /** Calculate the Pearson Correlation Coefficient for the given columns */ - private[sql] def pearsonCorrelation(df: DataFrame, cols: Seq[String]): Double = { + def pearsonCorrelation(df: DataFrame, cols: Seq[String]): Double = { val counts = collectStatisticalData(df, cols, "correlation") counts.Ck / math.sqrt(counts.MkX * counts.MkY) } @@ -407,13 +415,13 @@ private[sql] object StatFunctions extends Logging { * @param cols the column names * @return the covariance of the two columns. */ - private[sql] def calculateCov(df: DataFrame, cols: Seq[String]): Double = { + def calculateCov(df: DataFrame, cols: Seq[String]): Double = { val counts = collectStatisticalData(df, cols, "covariance") counts.cov } /** Generate a table of frequencies for the elements of two columns. */ - private[sql] def crossTabulate(df: DataFrame, col1: String, col2: String): DataFrame = { + def crossTabulate(df: DataFrame, col1: String, col2: String): DataFrame = { val tableName = s"${col1}_$col2" val counts = df.groupBy(col1, col2).agg(count("*")).take(1e6.toInt) if (counts.length == 1e6.toInt) { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/CompactibleFileStreamLog.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/CompactibleFileStreamLog.scala new file mode 100644 index 000000000000..027b5bbfab8d --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/CompactibleFileStreamLog.scala @@ -0,0 +1,245 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.streaming + +import java.io.IOException +import java.nio.charset.StandardCharsets.UTF_8 + +import scala.reflect.ClassTag + +import org.apache.hadoop.fs.{Path, PathFilter} + +import org.apache.spark.sql.SparkSession + +/** + * An abstract class for compactible metadata logs. It will write one log file for each batch. + * The first line of the log file is the version number, and there are multiple serialized + * metadata lines following. + * + * As reading from many small files is usually pretty slow, also too many + * small files in one folder will mess the FS, [[CompactibleFileStreamLog]] will + * compact log files every 10 batches by default into a big file. When + * doing a compaction, it will read all old log files and merge them with the new batch. + */ +abstract class CompactibleFileStreamLog[T: ClassTag]( + metadataLogVersion: String, + sparkSession: SparkSession, + path: String) + extends HDFSMetadataLog[Array[T]](sparkSession, path) { + + import CompactibleFileStreamLog._ + + /** + * If we delete the old files after compaction at once, there is a race condition in S3: other + * processes may see the old files are deleted but still cannot see the compaction file using + * "list". The `allFiles` handles this by looking for the next compaction file directly, however, + * a live lock may happen if the compaction happens too frequently: one processing keeps deleting + * old files while another one keeps retrying. Setting a reasonable cleanup delay could avoid it. + */ + protected def fileCleanupDelayMs: Long + + protected def isDeletingExpiredLog: Boolean + + protected def compactInterval: Int + + /** + * Serialize the data into encoded string. + */ + protected def serializeData(t: T): String + + /** + * Deserialize the string into data object. + */ + protected def deserializeData(encodedString: String): T + + /** + * Filter out the obsolete logs. + */ + def compactLogs(logs: Seq[T]): Seq[T] + + override def batchIdToPath(batchId: Long): Path = { + if (isCompactionBatch(batchId, compactInterval)) { + new Path(metadataPath, s"$batchId$COMPACT_FILE_SUFFIX") + } else { + new Path(metadataPath, batchId.toString) + } + } + + override def pathToBatchId(path: Path): Long = { + getBatchIdFromFileName(path.getName) + } + + override def isBatchFile(path: Path): Boolean = { + try { + getBatchIdFromFileName(path.getName) + true + } catch { + case _: NumberFormatException => false + } + } + + override def serialize(logData: Array[T]): Array[Byte] = { + (metadataLogVersion +: logData.map(serializeData)).mkString("\n").getBytes(UTF_8) + } + + override def deserialize(bytes: Array[Byte]): Array[T] = { + val lines = new String(bytes, UTF_8).split("\n") + if (lines.length == 0) { + throw new IllegalStateException("Incomplete log file") + } + val version = lines(0) + if (version != metadataLogVersion) { + throw new IllegalStateException(s"Unknown log version: ${version}") + } + lines.slice(1, lines.length).map(deserializeData) + } + + override def add(batchId: Long, logs: Array[T]): Boolean = { + if (isCompactionBatch(batchId, compactInterval)) { + compact(batchId, logs) + } else { + super.add(batchId, logs) + } + } + + /** + * Compacts all logs before `batchId` plus the provided `logs`, and writes them into the + * corresponding `batchId` file. It will delete expired files as well if enabled. + */ + private def compact(batchId: Long, logs: Array[T]): Boolean = { + val validBatches = getValidBatchesBeforeCompactionBatch(batchId, compactInterval) + val allLogs = validBatches.flatMap(batchId => super.get(batchId)).flatten ++ logs + if (super.add(batchId, compactLogs(allLogs).toArray)) { + if (isDeletingExpiredLog) { + deleteExpiredLog(batchId) + } + true + } else { + // Return false as there is another writer. + false + } + } + + /** + * Returns all files except the deleted ones. + */ + def allFiles(): Array[T] = { + var latestId = getLatest().map(_._1).getOrElse(-1L) + // There is a race condition when `FileStreamSink` is deleting old files and `StreamFileCatalog` + // is calling this method. This loop will retry the reading to deal with the + // race condition. + while (true) { + if (latestId >= 0) { + try { + val logs = + getAllValidBatches(latestId, compactInterval).flatMap(id => super.get(id)).flatten + return compactLogs(logs).toArray + } catch { + case e: IOException => + // Another process using `CompactibleFileStreamLog` may delete the batch files when + // `StreamFileCatalog` are reading. However, it only happens when a compaction is + // deleting old files. If so, let's try the next compaction batch and we should find it. + // Otherwise, this is a real IO issue and we should throw it. + latestId = nextCompactionBatchId(latestId, compactInterval) + super.get(latestId).getOrElse { + throw e + } + } + } else { + return Array.empty + } + } + Array.empty + } + + /** + * Since all logs before `compactionBatchId` are compacted and written into the + * `compactionBatchId` log file, they can be removed. However, due to the eventual consistency of + * S3, the compaction file may not be seen by other processes at once. So we only delete files + * created `fileCleanupDelayMs` milliseconds ago. + */ + private def deleteExpiredLog(compactionBatchId: Long): Unit = { + val expiredTime = System.currentTimeMillis() - fileCleanupDelayMs + fileManager.list(metadataPath, new PathFilter { + override def accept(path: Path): Boolean = { + try { + val batchId = getBatchIdFromFileName(path.getName) + batchId < compactionBatchId + } catch { + case _: NumberFormatException => + false + } + } + }).foreach { f => + if (f.getModificationTime <= expiredTime) { + fileManager.delete(f.getPath) + } + } + } +} + +object CompactibleFileStreamLog { + val COMPACT_FILE_SUFFIX = ".compact" + + def getBatchIdFromFileName(fileName: String): Long = { + fileName.stripSuffix(COMPACT_FILE_SUFFIX).toLong + } + + /** + * Returns if this is a compaction batch. FileStreamSinkLog will compact old logs every + * `compactInterval` commits. + * + * E.g., if `compactInterval` is 3, then 2, 5, 8, ... are all compaction batches. + */ + def isCompactionBatch(batchId: Long, compactInterval: Int): Boolean = { + (batchId + 1) % compactInterval == 0 + } + + /** + * Returns all valid batches before the specified `compactionBatchId`. They contain all logs we + * need to do a new compaction. + * + * E.g., if `compactInterval` is 3 and `compactionBatchId` is 5, this method should returns + * `Seq(2, 3, 4)` (Note: it includes the previous compaction batch 2). + */ + def getValidBatchesBeforeCompactionBatch( + compactionBatchId: Long, + compactInterval: Int): Seq[Long] = { + assert(isCompactionBatch(compactionBatchId, compactInterval), + s"$compactionBatchId is not a compaction batch") + (math.max(0, compactionBatchId - compactInterval)) until compactionBatchId + } + + /** + * Returns all necessary logs before `batchId` (inclusive). If `batchId` is a compaction, just + * return itself. Otherwise, it will find the previous compaction batch and return all batches + * between it and `batchId`. + */ + def getAllValidBatches(batchId: Long, compactInterval: Long): Seq[Long] = { + assert(batchId >= 0) + val start = math.max(0, (batchId + 1) / compactInterval * compactInterval - 1) + start to batchId + } + + /** + * Returns the next compaction batch id after `batchId`. + */ + def nextCompactionBatchId(batchId: Long, compactInterval: Long): Long = { + (batchId + compactInterval + 1) / compactInterval * compactInterval - 1 + } +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/FileStreamOptions.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/FileStreamOptions.scala new file mode 100644 index 000000000000..3efc20c1d662 --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/FileStreamOptions.scala @@ -0,0 +1,54 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.streaming + +import scala.util.Try + +import org.apache.spark.internal.Logging +import org.apache.spark.sql.execution.datasources.CaseInsensitiveMap +import org.apache.spark.util.Utils + +/** + * User specified options for file streams. + */ +class FileStreamOptions(parameters: Map[String, String]) extends Logging { + + val maxFilesPerTrigger: Option[Int] = parameters.get("maxFilesPerTrigger").map { str => + Try(str.toInt).toOption.filter(_ > 0).getOrElse { + throw new IllegalArgumentException( + s"Invalid value '$str' for option 'maxFilesPerTrigger', must be a positive integer") + } + } + + /** + * Maximum age of a file that can be found in this directory, before it is deleted. + * + * The max age is specified with respect to the timestamp of the latest file, and not the + * timestamp of the current system. That this means if the last file has timestamp 1000, and the + * current system time is 2000, and max age is 200, the system will purge files older than + * 800 (rather than 1800) from the internal state. + * + * Default to a week. + */ + val maxFileAgeMs: Long = + Utils.timeStringAsMs(parameters.getOrElse("maxFileAge", "7d")) + + /** Options as specified by the user, in a case-insensitive map, without "path" set. */ + val optionMapWithoutPath: Map[String, String] = + new CaseInsensitiveMap(parameters).filterKeys(_ != "path") +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/FileStreamSink.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/FileStreamSink.scala index 117d6672ee2f..b5f73fb2591d 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/FileStreamSink.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/FileStreamSink.scala @@ -56,7 +56,8 @@ class FileStreamSink( private val basePath = new Path(path) private val logPath = new Path(basePath, FileStreamSink.metadataDir) - private val fileLog = new FileStreamSinkLog(sparkSession, logPath.toUri.toString) + private val fileLog = + new FileStreamSinkLog(FileStreamSinkLog.VERSION, sparkSession, logPath.toUri.toString) private val hadoopConf = sparkSession.sessionState.newHadoopConf() private val fs = basePath.getFileSystem(hadoopConf) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/FileStreamSinkLog.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/FileStreamSinkLog.scala index 4254df44c97a..1ed5486f382a 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/FileStreamSinkLog.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/FileStreamSinkLog.scala @@ -17,10 +17,7 @@ package org.apache.spark.sql.execution.streaming -import java.io.IOException -import java.nio.charset.StandardCharsets.UTF_8 - -import org.apache.hadoop.fs.{FileStatus, Path, PathFilter} +import org.apache.hadoop.fs.{FileStatus, Path} import org.json4s.NoTypeHints import org.json4s.jackson.Serialization import org.json4s.jackson.Serialization.{read, write} @@ -79,213 +76,47 @@ object SinkFileStatus { * When the reader uses `allFiles` to list all files, this method only returns the visible files * (drops the deleted files). */ -class FileStreamSinkLog(sparkSession: SparkSession, path: String) - extends HDFSMetadataLog[Seq[SinkFileStatus]](sparkSession, path) { - - import FileStreamSinkLog._ +class FileStreamSinkLog( + metadataLogVersion: String, + sparkSession: SparkSession, + path: String) + extends CompactibleFileStreamLog[SinkFileStatus](metadataLogVersion, sparkSession, path) { private implicit val formats = Serialization.formats(NoTypeHints) - /** - * If we delete the old files after compaction at once, there is a race condition in S3: other - * processes may see the old files are deleted but still cannot see the compaction file using - * "list". The `allFiles` handles this by looking for the next compaction file directly, however, - * a live lock may happen if the compaction happens too frequently: one processing keeps deleting - * old files while another one keeps retrying. Setting a reasonable cleanup delay could avoid it. - */ - private val fileCleanupDelayMs = sparkSession.conf.get(SQLConf.FILE_SINK_LOG_CLEANUP_DELAY) + protected override val fileCleanupDelayMs = + sparkSession.conf.get(SQLConf.FILE_SINK_LOG_CLEANUP_DELAY) + + protected override val isDeletingExpiredLog = + sparkSession.conf.get(SQLConf.FILE_SINK_LOG_DELETION) - private val isDeletingExpiredLog = sparkSession.conf.get(SQLConf.FILE_SINK_LOG_DELETION) + protected override val compactInterval = + sparkSession.conf.get(SQLConf.FILE_SINK_LOG_COMPACT_INTERVAL) - private val compactInterval = sparkSession.conf.get(SQLConf.FILE_SINK_LOG_COMPACT_INTERVAL) require(compactInterval > 0, s"Please set ${SQLConf.FILE_SINK_LOG_COMPACT_INTERVAL.key} (was $compactInterval) " + "to a positive value.") - override def batchIdToPath(batchId: Long): Path = { - if (isCompactionBatch(batchId, compactInterval)) { - new Path(metadataPath, s"$batchId$COMPACT_FILE_SUFFIX") - } else { - new Path(metadataPath, batchId.toString) - } - } - - override def pathToBatchId(path: Path): Long = { - getBatchIdFromFileName(path.getName) - } - - override def isBatchFile(path: Path): Boolean = { - try { - getBatchIdFromFileName(path.getName) - true - } catch { - case _: NumberFormatException => false - } - } - - override def serialize(logData: Seq[SinkFileStatus]): Array[Byte] = { - (VERSION +: logData.map(write(_))).mkString("\n").getBytes(UTF_8) + protected override def serializeData(data: SinkFileStatus): String = { + write(data) } - override def deserialize(bytes: Array[Byte]): Seq[SinkFileStatus] = { - val lines = new String(bytes, UTF_8).split("\n") - if (lines.length == 0) { - throw new IllegalStateException("Incomplete log file") - } - val version = lines(0) - if (version != VERSION) { - throw new IllegalStateException(s"Unknown log version: ${version}") - } - lines.toSeq.slice(1, lines.length).map(read[SinkFileStatus](_)) + protected override def deserializeData(encodedString: String): SinkFileStatus = { + read[SinkFileStatus](encodedString) } - override def add(batchId: Long, logs: Seq[SinkFileStatus]): Boolean = { - if (isCompactionBatch(batchId, compactInterval)) { - compact(batchId, logs) - } else { - super.add(batchId, logs) - } - } - - /** - * Returns all files except the deleted ones. - */ - def allFiles(): Array[SinkFileStatus] = { - var latestId = getLatest().map(_._1).getOrElse(-1L) - // There is a race condition when `FileStreamSink` is deleting old files and `StreamFileCatalog` - // is calling this method. This loop will retry the reading to deal with the - // race condition. - while (true) { - if (latestId >= 0) { - val startId = getAllValidBatches(latestId, compactInterval)(0) - try { - val logs = get(Some(startId), Some(latestId)).flatMap(_._2) - return compactLogs(logs).toArray - } catch { - case e: IOException => - // Another process using `FileStreamSink` may delete the batch files when - // `StreamFileCatalog` are reading. However, it only happens when a compaction is - // deleting old files. If so, let's try the next compaction batch and we should find it. - // Otherwise, this is a real IO issue and we should throw it. - latestId = nextCompactionBatchId(latestId, compactInterval) - get(latestId).getOrElse { - throw e - } - } - } else { - return Array.empty - } - } - Array.empty - } - - /** - * Compacts all logs before `batchId` plus the provided `logs`, and writes them into the - * corresponding `batchId` file. It will delete expired files as well if enabled. - */ - private def compact(batchId: Long, logs: Seq[SinkFileStatus]): Boolean = { - val validBatches = getValidBatchesBeforeCompactionBatch(batchId, compactInterval) - val allLogs = validBatches.flatMap(batchId => get(batchId)).flatten ++ logs - if (super.add(batchId, compactLogs(allLogs))) { - if (isDeletingExpiredLog) { - deleteExpiredLog(batchId) - } - true + override def compactLogs(logs: Seq[SinkFileStatus]): Seq[SinkFileStatus] = { + val deletedFiles = logs.filter(_.action == FileStreamSinkLog.DELETE_ACTION).map(_.path).toSet + if (deletedFiles.isEmpty) { + logs } else { - // Return false as there is another writer. - false - } - } - - /** - * Since all logs before `compactionBatchId` are compacted and written into the - * `compactionBatchId` log file, they can be removed. However, due to the eventual consistency of - * S3, the compaction file may not be seen by other processes at once. So we only delete files - * created `fileCleanupDelayMs` milliseconds ago. - */ - private def deleteExpiredLog(compactionBatchId: Long): Unit = { - val expiredTime = System.currentTimeMillis() - fileCleanupDelayMs - fileManager.list(metadataPath, new PathFilter { - override def accept(path: Path): Boolean = { - try { - val batchId = getBatchIdFromFileName(path.getName) - batchId < compactionBatchId - } catch { - case _: NumberFormatException => - false - } - } - }).foreach { f => - if (f.getModificationTime <= expiredTime) { - fileManager.delete(f.getPath) - } + logs.filter(f => !deletedFiles.contains(f.path)) } } } object FileStreamSinkLog { val VERSION = "v1" - val COMPACT_FILE_SUFFIX = ".compact" val DELETE_ACTION = "delete" val ADD_ACTION = "add" - - def getBatchIdFromFileName(fileName: String): Long = { - fileName.stripSuffix(COMPACT_FILE_SUFFIX).toLong - } - - /** - * Returns if this is a compaction batch. FileStreamSinkLog will compact old logs every - * `compactInterval` commits. - * - * E.g., if `compactInterval` is 3, then 2, 5, 8, ... are all compaction batches. - */ - def isCompactionBatch(batchId: Long, compactInterval: Int): Boolean = { - (batchId + 1) % compactInterval == 0 - } - - /** - * Returns all valid batches before the specified `compactionBatchId`. They contain all logs we - * need to do a new compaction. - * - * E.g., if `compactInterval` is 3 and `compactionBatchId` is 5, this method should returns - * `Seq(2, 3, 4)` (Note: it includes the previous compaction batch 2). - */ - def getValidBatchesBeforeCompactionBatch( - compactionBatchId: Long, - compactInterval: Int): Seq[Long] = { - assert(isCompactionBatch(compactionBatchId, compactInterval), - s"$compactionBatchId is not a compaction batch") - (math.max(0, compactionBatchId - compactInterval)) until compactionBatchId - } - - /** - * Returns all necessary logs before `batchId` (inclusive). If `batchId` is a compaction, just - * return itself. Otherwise, it will find the previous compaction batch and return all batches - * between it and `batchId`. - */ - def getAllValidBatches(batchId: Long, compactInterval: Long): Seq[Long] = { - assert(batchId >= 0) - val start = math.max(0, (batchId + 1) / compactInterval * compactInterval - 1) - start to batchId - } - - /** - * Removes all deleted files from logs. It assumes once one file is deleted, it won't be added to - * the log in future. - */ - def compactLogs(logs: Seq[SinkFileStatus]): Seq[SinkFileStatus] = { - val deletedFiles = logs.filter(_.action == DELETE_ACTION).map(_.path).toSet - if (deletedFiles.isEmpty) { - logs - } else { - logs.filter(f => !deletedFiles.contains(f.path)) - } - } - - /** - * Returns the next compaction batch id after `batchId`. - */ - def nextCompactionBatchId(batchId: Long, compactInterval: Long): Long = { - (batchId + compactInterval + 1) / compactInterval * compactInterval - 1 - } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/FileStreamSource.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/FileStreamSource.scala index 0cfad659dc92..8c3e7184a65b 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/FileStreamSource.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/FileStreamSource.scala @@ -17,21 +17,18 @@ package org.apache.spark.sql.execution.streaming -import scala.util.Try +import scala.collection.JavaConverters._ import org.apache.hadoop.fs.Path import org.apache.spark.deploy.SparkHadoopUtil import org.apache.spark.internal.Logging import org.apache.spark.sql.{DataFrame, Dataset, SparkSession} -import org.apache.spark.sql.execution.datasources.{CaseInsensitiveMap, DataSource, ListingFileCatalog, LogicalRelation} +import org.apache.spark.sql.execution.datasources.{DataSource, ListingFileCatalog, LogicalRelation} import org.apache.spark.sql.types.StructType -import org.apache.spark.util.collection.OpenHashSet /** - * A very simple source that reads text files from the given directory as they appear. - * - * TODO Clean up the metadata files periodically + * A very simple source that reads files from the given directory as they appear. */ class FileStreamSource( sparkSession: SparkSession, @@ -41,18 +38,32 @@ class FileStreamSource( metadataPath: String, options: Map[String, String]) extends Source with Logging { - private val fs = new Path(path).getFileSystem(sparkSession.sessionState.newHadoopConf()) - private val qualifiedBasePath = fs.makeQualified(new Path(path)) // can contains glob patterns - private val metadataLog = new HDFSMetadataLog[Seq[String]](sparkSession, metadataPath) + import FileStreamSource._ + + private val sourceOptions = new FileStreamOptions(options) + + private val qualifiedBasePath: Path = { + val fs = new Path(path).getFileSystem(sparkSession.sessionState.newHadoopConf()) + fs.makeQualified(new Path(path)) // can contains glob patterns + } + + private val metadataLog = + new FileStreamSourceLog(FileStreamSourceLog.VERSION, sparkSession, metadataPath) private var maxBatchId = metadataLog.getLatest().map(_._1).getOrElse(-1L) /** Maximum number of new files to be considered in each batch */ - private val maxFilesPerBatch = getMaxFilesPerBatch() + private val maxFilesPerBatch = sourceOptions.maxFilesPerTrigger + + /** A mapping from a file that we have processed to some timestamp it was last modified. */ + // Visible for testing and debugging in production. + val seenFiles = new SeenFilesMap(sourceOptions.maxFileAgeMs) - private val seenFiles = new OpenHashSet[String] - metadataLog.get(None, Some(maxBatchId)).foreach { case (batchId, files) => - files.foreach(seenFiles.add) + metadataLog.allFiles().foreach { entry => + seenFiles.add(entry.path, entry.timestamp) } + seenFiles.purge() + + logInfo(s"maxFilesPerBatch = $maxFilesPerBatch, maxFileAge = ${sourceOptions.maxFileAgeMs}") /** * Returns the maximum offset that can be retrieved from the source. @@ -61,19 +72,34 @@ class FileStreamSource( * there is no race here, so the cost of `synchronized` should be rare. */ private def fetchMaxOffset(): LongOffset = synchronized { - val newFiles = fetchAllFiles().filter(!seenFiles.contains(_)) + // All the new files found - ignore aged files and files that we have seen. + val newFiles = fetchAllFiles().filter { + case (path, timestamp) => seenFiles.isNewFile(path, timestamp) + } + + // Obey user's setting to limit the number of files in this batch trigger. val batchFiles = if (maxFilesPerBatch.nonEmpty) newFiles.take(maxFilesPerBatch.get) else newFiles + batchFiles.foreach { file => - seenFiles.add(file) + seenFiles.add(file._1, file._2) logDebug(s"New file: $file") } - logTrace(s"Number of new files = ${newFiles.size})") - logTrace(s"Number of files selected for batch = ${batchFiles.size}") - logTrace(s"Number of seen files = ${seenFiles.size}") + val numPurged = seenFiles.purge() + + logTrace( + s""" + |Number of new files = ${newFiles.size} + |Number of files selected for batch = ${batchFiles.size} + |Number of seen files = ${seenFiles.size} + |Number of files purged from tracking map = $numPurged + """.stripMargin) + if (batchFiles.nonEmpty) { maxBatchId += 1 - metadataLog.add(maxBatchId, batchFiles) + metadataLog.add(maxBatchId, batchFiles.map { case (path, timestamp) => + FileEntry(path = path, timestamp = timestamp, batchId = maxBatchId) + }.toArray) logInfo(s"Max batch id increased to $maxBatchId with ${batchFiles.size} new files") } @@ -104,22 +130,27 @@ class FileStreamSource( val files = metadataLog.get(Some(startId + 1), Some(endId)).flatMap(_._2) logInfo(s"Processing ${files.length} files from ${startId + 1}:$endId") logTrace(s"Files are:\n\t" + files.mkString("\n\t")) - val newOptions = new CaseInsensitiveMap(options).filterKeys(_ != "path") val newDataSource = DataSource( sparkSession, - paths = files, + paths = files.map(_.path), userSpecifiedSchema = Some(schema), className = fileFormatClassName, - options = newOptions) - Dataset.ofRows(sparkSession, LogicalRelation(newDataSource.resolveRelation())) + options = sourceOptions.optionMapWithoutPath) + Dataset.ofRows(sparkSession, LogicalRelation(newDataSource.resolveRelation( + checkPathExist = false))) } - private def fetchAllFiles(): Seq[String] = { + /** + * Returns a list of files found, sorted by their timestamp. + */ + private def fetchAllFiles(): Seq[(String, Long)] = { val startTime = System.nanoTime val globbedPaths = SparkHadoopUtil.get.globPathIfNecessary(qualifiedBasePath) val catalog = new ListingFileCatalog(sparkSession, globbedPaths, options, Some(new StructType)) - val files = catalog.allFiles().sortBy(_.getModificationTime).map(_.getPath.toUri.toString) + val files = catalog.allFiles().sortBy(_.getModificationTime).map { status => + (status.getPath.toUri.toString, status.getModificationTime) + } val endTime = System.nanoTime val listingTimeMs = (endTime.toDouble - startTime) / 1000000 if (listingTimeMs > 2000) { @@ -132,20 +163,76 @@ class FileStreamSource( files } - private def getMaxFilesPerBatch(): Option[Int] = { - new CaseInsensitiveMap(options) - .get("maxFilesPerTrigger") - .map { str => - Try(str.toInt).toOption.filter(_ > 0).getOrElse { - throw new IllegalArgumentException( - s"Invalid value '$str' for option 'maxFilesPerTrigger', must be a positive integer") - } - } - } - override def getOffset: Option[Offset] = Some(fetchMaxOffset()).filterNot(_.offset == -1) override def toString: String = s"FileStreamSource[$qualifiedBasePath]" override def stop() {} } + + +object FileStreamSource { + + /** Timestamp for file modification time, in ms since January 1, 1970 UTC. */ + type Timestamp = Long + + case class FileEntry(path: String, timestamp: Timestamp, batchId: Long) extends Serializable + + /** + * A custom hash map used to track the list of files seen. This map is not thread-safe. + * + * To prevent the hash map from growing indefinitely, a purge function is available to + * remove files "maxAgeMs" older than the latest file. + */ + class SeenFilesMap(maxAgeMs: Long) { + require(maxAgeMs >= 0) + + /** Mapping from file to its timestamp. */ + private val map = new java.util.HashMap[String, Timestamp] + + /** Timestamp of the latest file. */ + private var latestTimestamp: Timestamp = 0L + + /** Timestamp for the last purge operation. */ + private var lastPurgeTimestamp: Timestamp = 0L + + /** Add a new file to the map. */ + def add(path: String, timestamp: Timestamp): Unit = { + map.put(path, timestamp) + if (timestamp > latestTimestamp) { + latestTimestamp = timestamp + } + } + + /** + * Returns true if we should consider this file a new file. The file is only considered "new" + * if it is new enough that we are still tracking, and we have not seen it before. + */ + def isNewFile(path: String, timestamp: Timestamp): Boolean = { + // Note that we are testing against lastPurgeTimestamp here so we'd never miss a file that + // is older than (latestTimestamp - maxAgeMs) but has not been purged yet. + timestamp >= lastPurgeTimestamp && !map.containsKey(path) + } + + /** Removes aged entries and returns the number of files removed. */ + def purge(): Int = { + lastPurgeTimestamp = latestTimestamp - maxAgeMs + val iter = map.entrySet().iterator() + var count = 0 + while (iter.hasNext) { + val entry = iter.next() + if (entry.getValue < lastPurgeTimestamp) { + count += 1 + iter.remove() + } + } + count + } + + def size: Int = map.size() + + def allEntries: Seq[(String, Timestamp)] = { + map.asScala.toSeq + } + } +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/FileStreamSourceLog.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/FileStreamSourceLog.scala new file mode 100644 index 000000000000..8103309aff2a --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/FileStreamSourceLog.scala @@ -0,0 +1,132 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.streaming + +import java.util.{LinkedHashMap => JLinkedHashMap} +import java.util.Map.Entry + +import scala.collection.mutable + +import org.json4s.NoTypeHints +import org.json4s.jackson.Serialization + +import org.apache.spark.sql.SparkSession +import org.apache.spark.sql.execution.streaming.FileStreamSource.FileEntry +import org.apache.spark.sql.internal.SQLConf + +class FileStreamSourceLog( + metadataLogVersion: String, + sparkSession: SparkSession, + path: String) + extends CompactibleFileStreamLog[FileEntry](metadataLogVersion, sparkSession, path) { + + import CompactibleFileStreamLog._ + + // Configurations about metadata compaction + protected override val compactInterval = + sparkSession.conf.get(SQLConf.FILE_SOURCE_LOG_COMPACT_INTERVAL) + require(compactInterval > 0, + s"Please set ${SQLConf.FILE_SOURCE_LOG_COMPACT_INTERVAL.key} (was $compactInterval) to a " + + s"positive value.") + + protected override val fileCleanupDelayMs = + sparkSession.conf.get(SQLConf.FILE_SOURCE_LOG_CLEANUP_DELAY) + + protected override val isDeletingExpiredLog = + sparkSession.conf.get(SQLConf.FILE_SOURCE_LOG_DELETION) + + private implicit val formats = Serialization.formats(NoTypeHints) + + // A fixed size log entry cache to cache the file entries belong to the compaction batch. It is + // used to avoid scanning the compacted log file to retrieve it's own batch data. + private val cacheSize = compactInterval + private val fileEntryCache = new JLinkedHashMap[Long, Array[FileEntry]] { + override def removeEldestEntry(eldest: Entry[Long, Array[FileEntry]]): Boolean = { + size() > cacheSize + } + } + + protected override def serializeData(data: FileEntry): String = { + Serialization.write(data) + } + + protected override def deserializeData(encodedString: String): FileEntry = { + Serialization.read[FileEntry](encodedString) + } + + def compactLogs(logs: Seq[FileEntry]): Seq[FileEntry] = { + logs + } + + override def add(batchId: Long, logs: Array[FileEntry]): Boolean = { + if (super.add(batchId, logs)) { + if (isCompactionBatch(batchId, compactInterval)) { + fileEntryCache.put(batchId, logs) + } + true + } else { + false + } + } + + override def get(startId: Option[Long], endId: Option[Long]): Array[(Long, Array[FileEntry])] = { + val startBatchId = startId.getOrElse(0L) + val endBatchId = getLatest().map(_._1).getOrElse(0L) + + val (existedBatches, removedBatches) = (startBatchId to endBatchId).map { id => + if (isCompactionBatch(id, compactInterval) && fileEntryCache.containsKey(id)) { + (id, Some(fileEntryCache.get(id))) + } else { + val logs = super.get(id).map(_.filter(_.batchId == id)) + (id, logs) + } + }.partition(_._2.isDefined) + + // The below code may only be happened when original metadata log file has been removed, so we + // have to get the batch from latest compacted log file. This is quite time-consuming and may + // not be happened in the current FileStreamSource code path, since we only fetch the + // latest metadata log file. + val searchKeys = removedBatches.map(_._1) + val retrievedBatches = if (searchKeys.nonEmpty) { + logWarning(s"Get batches from removed files, this is unexpected in the current code path!!!") + val latestBatchId = getLatest().map(_._1).getOrElse(-1L) + if (latestBatchId < 0) { + Map.empty[Long, Option[Array[FileEntry]]] + } else { + val latestCompactedBatchId = getAllValidBatches(latestBatchId, compactInterval)(0) + val allLogs = new mutable.HashMap[Long, mutable.ArrayBuffer[FileEntry]] + + super.get(latestCompactedBatchId).foreach { entries => + entries.foreach { e => + allLogs.put(e.batchId, allLogs.getOrElse(e.batchId, mutable.ArrayBuffer()) += e) + } + } + + searchKeys.map(id => id -> allLogs.get(id).map(_.toArray)).filter(_._2.isDefined).toMap + } + } else { + Map.empty[Long, Option[Array[FileEntry]]] + } + + (existedBatches ++ retrievedBatches).map(i => i._1 -> i._2.get).toArray.sortBy(_._1) + } +} + +object FileStreamSourceLog { + val VERSION = "v1" +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/HDFSMetadataLog.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/HDFSMetadataLog.scala index 698f07b0a187..39a0f3341389 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/HDFSMetadataLog.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/HDFSMetadataLog.scala @@ -49,6 +49,10 @@ import org.apache.spark.util.UninterruptibleThread class HDFSMetadataLog[T: ClassTag](sparkSession: SparkSession, path: String) extends MetadataLog[T] with Logging { + // Avoid serializing generic sequences, see SPARK-17372 + require(implicitly[ClassTag[T]].runtimeClass != classOf[Seq[_]], + "Should not create a log with type Seq, use Arrays instead - see SPARK-17372") + import HDFSMetadataLog._ val metadataPath = new Path(path) @@ -180,7 +184,7 @@ class HDFSMetadataLog[T: ClassTag](sparkSession: SparkSession, path: String) private def isFileAlreadyExistsException(e: IOException): Boolean = { e.isInstanceOf[FileAlreadyExistsException] || // Old Hadoop versions don't throw FileAlreadyExistsException. Although it's fixed in - // HADOOP-9361, we still need to support old Hadoop versions. + // HADOOP-9361 in Hadoop 2.5, we still need to support old Hadoop versions. (e.getMessage != null && e.getMessage.startsWith("File already exists: ")) } @@ -227,6 +231,20 @@ class HDFSMetadataLog[T: ClassTag](sparkSession: SparkSession, path: String) None } + /** + * Removes all the log entry earlier than thresholdBatchId (exclusive). + */ + override def purge(thresholdBatchId: Long): Unit = { + val batchIds = fileManager.list(metadataPath, batchFilesFilter) + .map(f => pathToBatchId(f.getPath)) + + for (batchId <- batchIds if batchId < thresholdBatchId) { + val path = batchIdToPath(batchId) + fileManager.delete(path) + logTrace(s"Removed metadata log file: $path") + } + } + private def createFileManager(): FileManager = { val hadoopConf = sparkSession.sessionState.newHadoopConf() try { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/IncrementalExecution.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/IncrementalExecution.scala index 7367c68d0a0e..05294df2673d 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/IncrementalExecution.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/IncrementalExecution.scala @@ -27,7 +27,7 @@ import org.apache.spark.sql.streaming.OutputMode * A variant of [[QueryExecution]] that allows the execution of the given [[LogicalPlan]] * plan incrementally. Possibly preserving state in between each execution. */ -class IncrementalExecution private[sql]( +class IncrementalExecution( sparkSession: SparkSession, logicalPlan: LogicalPlan, val outputMode: OutputMode, diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/MetadataLog.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/MetadataLog.scala index cc70e1d314d1..9e2604c9c069 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/MetadataLog.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/MetadataLog.scala @@ -24,6 +24,7 @@ package org.apache.spark.sql.execution.streaming * - Allow the user to query the latest batch id. * - Allow the user to query the metadata object of a specified batch id. * - Allow the user to query metadata objects in a range of batch ids. + * - Allow the user to remove obsolete metadata */ trait MetadataLog[T] { @@ -48,4 +49,10 @@ trait MetadataLog[T] { * Return the latest batch Id and its metadata if exist. */ def getLatest(): Option[(Long, T)] + + /** + * Removes all the log entry earlier than thresholdBatchId (exclusive). + * This operation should be idempotent. + */ + def purge(thresholdBatchId: Long): Unit } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/MetadataLogFileCatalog.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/MetadataLogFileCatalog.scala index 20ade12e3796..a32c4671e347 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/MetadataLogFileCatalog.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/MetadataLogFileCatalog.scala @@ -34,7 +34,8 @@ class MetadataLogFileCatalog(sparkSession: SparkSession, path: Path) private val metadataDirectory = new Path(path, FileStreamSink.metadataDir) logInfo(s"Reading streaming file log from $metadataDirectory") - private val metadataLog = new FileStreamSinkLog(sparkSession, metadataDirectory.toUri.toString) + private val metadataLog = + new FileStreamSinkLog(FileStreamSinkLog.VERSION, sparkSession, metadataDirectory.toUri.toString) private val allFilesFromLog = metadataLog.allFiles().map(_.toFileStatus).filterNot(_.isDirectory) private var cachedPartitionSpec: PartitionSpec = _ diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamExecution.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamExecution.scala index af2229a46beb..b7587f26af9e 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamExecution.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamExecution.scala @@ -49,10 +49,10 @@ class StreamExecution( override val id: Long, override val name: String, checkpointRoot: String, - private[sql] val logicalPlan: LogicalPlan, + val logicalPlan: LogicalPlan, val sink: Sink, val trigger: Trigger, - private[sql] val triggerClock: Clock, + val triggerClock: Clock, val outputMode: OutputMode) extends StreamingQuery with Logging { @@ -74,7 +74,7 @@ class StreamExecution( * input source. */ @volatile - private[sql] var committedOffsets = new StreamProgress + var committedOffsets = new StreamProgress /** * Tracks the offsets that are available to be processed, but have not yet be committed to the @@ -102,10 +102,10 @@ class StreamExecution( private var state: State = INITIALIZED @volatile - private[sql] var lastExecution: QueryExecution = null + var lastExecution: QueryExecution = null @volatile - private[sql] var streamDeathCause: StreamingQueryException = null + var streamDeathCause: StreamingQueryException = null /* Get the call site in the caller thread; will pass this into the micro batch thread */ private val callSite = Utils.getCallSite() @@ -115,7 +115,7 @@ class StreamExecution( * [[org.apache.spark.util.UninterruptibleThread]] to avoid potential deadlocks in using * [[HDFSMetadataLog]]. See SPARK-14131 for more details. */ - private[sql] val microBatchThread = + val microBatchThread = new UninterruptibleThread(s"stream execution thread for $name") { override def run(): Unit = { // To fix call site like "run at :0", we bridge the call site from the caller @@ -131,8 +131,7 @@ class StreamExecution( * processing is done. Thus, the Nth record in this log indicated data that is currently being * processed and the N-1th entry indicates which offsets have been durably committed to the sink. */ - private[sql] val offsetLog = - new HDFSMetadataLog[CompositeOffset](sparkSession, checkpointFile("offsets")) + val offsetLog = new HDFSMetadataLog[CompositeOffset](sparkSession, checkpointFile("offsets")) /** Whether the query is currently active or not */ override def isActive: Boolean = state == ACTIVE @@ -159,7 +158,7 @@ class StreamExecution( * Starts the execution. This returns only after the thread has started and [[QueryStarted]] event * has been posted to all the listeners. */ - private[sql] def start(): Unit = { + def start(): Unit = { microBatchThread.setDaemon(true) microBatchThread.start() startLatch.await() // Wait until thread started and QueryStart event has been posted @@ -218,10 +217,7 @@ class StreamExecution( } finally { state = TERMINATED sparkSession.streams.notifyQueryTermination(StreamExecution.this) - postEvent(new QueryTerminated( - this.toInfo, - exception.map(_.getMessage), - exception.map(_.getStackTrace.toSeq).getOrElse(Nil))) + postEvent(new QueryTerminated(this.toInfo, exception.map(_.cause).map(Utils.exceptionString))) terminationLatch.countDown() } } @@ -294,6 +290,13 @@ class StreamExecution( assert(offsetLog.add(currentBatchId, availableOffsets.toCompositeOffset(sources)), s"Concurrent update to the log. Multiple streaming jobs detected for $currentBatchId") logInfo(s"Committed offsets for batch $currentBatchId.") + + // Now that we have logged the new batch, no further processing will happen for + // the previous batch, and it is safe to discard the old metadata. + // Note that purge is exclusive, i.e. it purges everything before currentBatchId. + // NOTE: If StreamExecution implements pipeline parallelism (multiple batches in + // flight at the same time), this cleanup logic will need to change. + offsetLog.purge(currentBatchId) } else { awaitBatchLock.lock() try { @@ -411,6 +414,9 @@ class StreamExecution( awaitBatchLock.lock() try { awaitBatchLockCondition.await(100, TimeUnit.MILLISECONDS) + if (streamDeathCause != null) { + throw streamDeathCause + } } finally { awaitBatchLock.unlock() } @@ -518,7 +524,7 @@ class StreamExecution( case object TERMINATED extends State } -private[sql] object StreamExecution { +object StreamExecution { private val _nextId = new AtomicLong(0) def nextId: Long = _nextId.getAndIncrement() diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamProgress.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamProgress.scala index 405a5f0387a7..db0bd9e6bc6f 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamProgress.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamProgress.scala @@ -26,7 +26,7 @@ class StreamProgress( val baseMap: immutable.Map[Source, Offset] = new immutable.HashMap[Source, Offset]) extends scala.collection.immutable.Map[Source, Offset] { - private[sql] def toCompositeOffset(source: Seq[Source]): CompositeOffset = { + def toCompositeOffset(source: Seq[Source]): CompositeOffset = { CompositeOffset(source.map(get)) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStore.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStore.scala index 066765324ac9..a67fdceb3cee 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStore.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStore.scala @@ -113,7 +113,7 @@ case class KeyRemoved(key: UnsafeRow) extends StoreUpdate * the store is the active instance. Accordingly, it either keeps it loaded and performs * maintenance, or unloads the store. */ -private[sql] object StateStore extends Logging { +object StateStore extends Logging { val MAINTENANCE_INTERVAL_CONFIG = "spark.sql.streaming.stateStore.maintenanceInterval" val MAINTENANCE_INTERVAL_DEFAULT_SECS = 60 diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStoreCoordinator.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStoreCoordinator.scala index e418217238cc..d945d7aff2da 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStoreCoordinator.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStoreCoordinator.scala @@ -45,7 +45,7 @@ private object StopCoordinator extends StateStoreCoordinatorMessage /** Helper object used to create reference to [[StateStoreCoordinator]]. */ -private[sql] object StateStoreCoordinatorRef extends Logging { +object StateStoreCoordinatorRef extends Logging { private val endpointName = "StateStoreCoordinator" @@ -77,7 +77,7 @@ private[sql] object StateStoreCoordinatorRef extends Logging { * Reference to a [[StateStoreCoordinator]] that can be used to coordinate instances of * [[StateStore]]s across all the executors, and get their locations for job scheduling. */ -private[sql] class StateStoreCoordinatorRef private(rpcEndpointRef: RpcEndpointRef) { +class StateStoreCoordinatorRef private(rpcEndpointRef: RpcEndpointRef) { private[state] def reportActiveInstance( storeId: StateStoreId, diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/ui/ExecutionPage.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/ui/ExecutionPage.scala index 4b4fa126b85f..23fc0bd0bce1 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/ui/ExecutionPage.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/ui/ExecutionPage.scala @@ -24,7 +24,7 @@ import scala.xml.Node import org.apache.spark.internal.Logging import org.apache.spark.ui.{UIUtils, WebUIPage} -private[sql] class ExecutionPage(parent: SQLTab) extends WebUIPage("execution") with Logging { +class ExecutionPage(parent: SQLTab) extends WebUIPage("execution") with Logging { private val listener = parent.listener diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/ui/SQLListener.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/ui/SQLListener.scala index 6e9479190176..60f13432d78d 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/ui/SQLListener.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/ui/SQLListener.scala @@ -46,14 +46,14 @@ case class SparkListenerSQLExecutionEnd(executionId: Long, time: Long) case class SparkListenerDriverAccumUpdates(executionId: Long, accumUpdates: Seq[(Long, Long)]) extends SparkListenerEvent -private[sql] class SQLHistoryListenerFactory extends SparkHistoryListenerFactory { +class SQLHistoryListenerFactory extends SparkHistoryListenerFactory { override def createListeners(conf: SparkConf, sparkUI: SparkUI): Seq[SparkListener] = { List(new SQLHistoryListener(conf, sparkUI)) } } -private[sql] class SQLListener(conf: SparkConf) extends SparkListener with Logging { +class SQLListener(conf: SparkConf) extends SparkListener with Logging { private val retainedExecutions = conf.getInt("spark.sql.ui.retainedExecutions", 1000) @@ -333,7 +333,7 @@ private[sql] class SQLListener(conf: SparkConf) extends SparkListener with Loggi /** * A [[SQLListener]] for rendering the SQL UI in the history server. */ -private[spark] class SQLHistoryListener(conf: SparkConf, sparkUI: SparkUI) +class SQLHistoryListener(conf: SparkConf, sparkUI: SparkUI) extends SQLListener(conf) { private var sqlTabAttached = false diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/ui/SQLTab.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/ui/SQLTab.scala index e8675ce749a2..d0376af3e31c 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/ui/SQLTab.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/ui/SQLTab.scala @@ -20,7 +20,7 @@ package org.apache.spark.sql.execution.ui import org.apache.spark.internal.Logging import org.apache.spark.ui.{SparkUI, SparkUITab} -private[sql] class SQLTab(val listener: SQLListener, sparkUI: SparkUI) +class SQLTab(val listener: SQLListener, sparkUI: SparkUI) extends SparkUITab(sparkUI, "SQL") with Logging { val parent = sparkUI @@ -32,6 +32,6 @@ private[sql] class SQLTab(val listener: SQLListener, sparkUI: SparkUI) parent.addStaticHandler(SQLTab.STATIC_RESOURCE_DIR, "/static/sql") } -private[sql] object SQLTab { +object SQLTab { private val STATIC_RESOURCE_DIR = "org/apache/spark/sql/execution/ui/static" } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/ui/SparkPlanGraph.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/ui/SparkPlanGraph.scala index 8f5681bfc7cc..4bb9d6fef4c1 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/ui/SparkPlanGraph.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/ui/SparkPlanGraph.scala @@ -24,7 +24,7 @@ import scala.collection.mutable import org.apache.commons.lang3.StringEscapeUtils import org.apache.spark.sql.execution.{SparkPlanInfo, WholeStageCodegenExec} -import org.apache.spark.sql.execution.metric.SQLMetrics + /** * A graph used for storing information of an executionPlan of DataFrame. @@ -32,7 +32,7 @@ import org.apache.spark.sql.execution.metric.SQLMetrics * Each graph is defined with a set of nodes and a set of edges. Each node represents a node in the * SparkPlan tree, and each edge represents a parent-child relationship between two nodes. */ -private[ui] case class SparkPlanGraph( +case class SparkPlanGraph( nodes: Seq[SparkPlanGraphNode], edges: Seq[SparkPlanGraphEdge]) { def makeDotFile(metrics: Map[Long, String]): String = { @@ -55,7 +55,7 @@ private[ui] case class SparkPlanGraph( } } -private[sql] object SparkPlanGraph { +object SparkPlanGraph { /** * Build a SparkPlanGraph from the root of a SparkPlan tree. diff --git a/sql/core/src/main/scala/org/apache/spark/sql/expressions/ReduceAggregator.scala b/sql/core/src/main/scala/org/apache/spark/sql/expressions/ReduceAggregator.scala new file mode 100644 index 000000000000..174378304d4a --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/expressions/ReduceAggregator.scala @@ -0,0 +1,68 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.expressions + +import org.apache.spark.sql.Encoder +import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder + +/** + * An aggregator that uses a single associative and commutative reduce function. This reduce + * function can be used to go through all input values and reduces them to a single value. + * If there is no input, a null value is returned. + * + * This class currently assumes there is at least one input row. + */ +private[sql] class ReduceAggregator[T: Encoder](func: (T, T) => T) + extends Aggregator[T, (Boolean, T), T] { + + private val encoder = implicitly[Encoder[T]] + + override def zero: (Boolean, T) = (false, null.asInstanceOf[T]) + + override def bufferEncoder: Encoder[(Boolean, T)] = + ExpressionEncoder.tuple( + ExpressionEncoder[Boolean](), + encoder.asInstanceOf[ExpressionEncoder[T]]) + + override def outputEncoder: Encoder[T] = encoder + + override def reduce(b: (Boolean, T), a: T): (Boolean, T) = { + if (b._1) { + (true, func(b._2, a)) + } else { + (true, a) + } + } + + override def merge(b1: (Boolean, T), b2: (Boolean, T)): (Boolean, T) = { + if (!b1._1) { + b2 + } else if (!b2._1) { + b1 + } else { + (true, func(b1._2, b2._2)) + } + } + + override def finish(reduction: (Boolean, T)): T = { + if (!reduction._1) { + throw new IllegalStateException("ReduceAggregator requires at least one input row") + } + reduction._2 + } +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala index ab09ef7450b0..eb504c81bd80 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala @@ -2175,7 +2175,8 @@ object functions { def ltrim(e: Column): Column = withExpr {StringTrimLeft(e.expr) } /** - * Extract a specific(idx) group identified by a java regex, from the specified string column. + * Extract a specific group matched by a Java regex, from the specified string column. + * If the regex did not match, or the specified group did not match, an empty string is returned. * * @group string_funcs * @since 1.5.0 @@ -2595,12 +2596,15 @@ object functions { * The time column must be of TimestampType. * @param windowDuration A string specifying the width of the window, e.g. `10 minutes`, * `1 second`. Check [[org.apache.spark.unsafe.types.CalendarInterval]] for - * valid duration identifiers. + * valid duration identifiers. Note that the duration is a fixed length of + * time, and does not vary over time according to a calendar. For example, + * `1 day` always means 86,400,000 milliseconds, not a calendar day. * @param slideDuration A string specifying the sliding interval of the window, e.g. `1 minute`. * A new window will be generated every `slideDuration`. Must be less than * or equal to the `windowDuration`. Check * [[org.apache.spark.unsafe.types.CalendarInterval]] for valid duration - * identifiers. + * identifiers. This duration is likewise absolute, and does not vary + * according to a calendar. * @param startTime The offset with respect to 1970-01-01 00:00:00 UTC with which to start * window intervals. For example, in order to have hourly tumbling windows that * start 15 minutes past the hour, e.g. 12:15-13:15, 13:15-14:15... provide @@ -2649,11 +2653,15 @@ object functions { * The time column must be of TimestampType. * @param windowDuration A string specifying the width of the window, e.g. `10 minutes`, * `1 second`. Check [[org.apache.spark.unsafe.types.CalendarInterval]] for - * valid duration identifiers. + * valid duration identifiers. Note that the duration is a fixed length of + * time, and does not vary over time according to a calendar. For example, + * `1 day` always means 86,400,000 milliseconds, not a calendar day. * @param slideDuration A string specifying the sliding interval of the window, e.g. `1 minute`. * A new window will be generated every `slideDuration`. Must be less than * or equal to the `windowDuration`. Check - * [[org.apache.spark.unsafe.types.CalendarInterval]] for valid duration. + * [[org.apache.spark.unsafe.types.CalendarInterval]] for valid duration + * identifiers. This duration is likewise absolute, and does not vary + * according to a calendar. * * @group datetime_funcs * @since 2.0.0 diff --git a/sql/core/src/main/scala/org/apache/spark/sql/internal/CatalogImpl.scala b/sql/core/src/main/scala/org/apache/spark/sql/internal/CatalogImpl.scala index a6ae6fe2aad2..414a4a5ed910 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/internal/CatalogImpl.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/internal/CatalogImpl.scala @@ -151,7 +151,7 @@ class CatalogImpl(sparkSession: SparkSession) extends Catalog { } private def listColumns(tableIdentifier: TableIdentifier): Dataset[Column] = { - val tableMetadata = sessionCatalog.getTableMetadata(tableIdentifier) + val tableMetadata = sessionCatalog.getTempViewOrPermanentTableMetadata(tableIdentifier) val partitionColumnNames = tableMetadata.partitionColumnNames.toSet val bucketColumnNames = tableMetadata.bucketColumnNames.toSet val columns = tableMetadata.schema.map { c => @@ -296,8 +296,10 @@ class CatalogImpl(sparkSession: SparkSession) extends Catalog { * @since 2.0.0 */ override def dropTempView(viewName: String): Unit = { - sparkSession.sharedState.cacheManager.uncacheQuery(sparkSession.table(viewName)) - sessionCatalog.dropTable(TableIdentifier(viewName), ignoreIfNotExists = true) + sparkSession.sessionState.catalog.getTempView(viewName).foreach { tempView => + sparkSession.sharedState.cacheManager.uncacheQuery(Dataset.ofRows(sparkSession, tempView)) + sessionCatalog.dropTempView(viewName) + } } /** diff --git a/sql/core/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala b/sql/core/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala index 1a9bb6a0b54e..2614032d04a9 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala @@ -23,6 +23,7 @@ import java.util.concurrent.TimeUnit import scala.collection.JavaConverters._ import scala.collection.immutable +import org.apache.hadoop.fs.Path import org.apache.parquet.hadoop.ParquetOutputCommitter import org.apache.spark.internal.Logging @@ -318,6 +319,14 @@ object SQLConf { .booleanConf .createWithDefault(false) + val GATHER_FASTSTAT = SQLConfigBuilder("spark.sql.hive.gatherFastStats") + .internal() + .doc("When true, fast stats (number of files and total size of all files) will be gathered" + + " in parallel while repairing table partitions to avoid the sequential listing in Hive" + + " metastore.") + .booleanConf + .createWithDefault(true) + // This is used to control the when we will split a schema's JSON string to multiple pieces // in order to fit the JSON string in metastore's table property (by default, the value has // a length restriction of 4000 characters). We will split the JSON string of a schema @@ -536,7 +545,28 @@ object SQLConf { .internal() .doc("How long that a file is guaranteed to be visible for all readers.") .timeConf(TimeUnit.MILLISECONDS) - .createWithDefault(60 * 1000L) // 10 minutes + .createWithDefault(TimeUnit.MINUTES.toMillis(10)) // 10 minutes + + val FILE_SOURCE_LOG_DELETION = SQLConfigBuilder("spark.sql.streaming.fileSource.log.deletion") + .internal() + .doc("Whether to delete the expired log files in file stream source.") + .booleanConf + .createWithDefault(true) + + val FILE_SOURCE_LOG_COMPACT_INTERVAL = + SQLConfigBuilder("spark.sql.streaming.fileSource.log.compactInterval") + .internal() + .doc("Number of log files after which all the previous files " + + "are compacted into the next log file.") + .intConf + .createWithDefault(10) + + val FILE_SOURCE_LOG_CLEANUP_DELAY = + SQLConfigBuilder("spark.sql.streaming.fileSource.log.cleanupDelay") + .internal() + .doc("How long in milliseconds a file is guaranteed to be visible for all readers.") + .timeConf(TimeUnit.MILLISECONDS) + .createWithDefault(TimeUnit.MINUTES.toMillis(10)) // 10 minutes val STREAMING_SCHEMA_INFERENCE = SQLConfigBuilder("spark.sql.streaming.schemaInference") @@ -615,6 +645,8 @@ private[sql] class SQLConf extends Serializable with CatalystConf with Logging { def nativeView: Boolean = getConf(NATIVE_VIEW) + def gatherFastStats: Boolean = getConf(GATHER_FASTSTAT) + def wholeStageEnabled: Boolean = getConf(WHOLESTAGE_CODEGEN_ENABLED) def wholeStageMaxNumFields: Int = getConf(WHOLESTAGE_MAX_NUM_FIELDS) @@ -691,9 +723,9 @@ private[sql] class SQLConf extends Serializable with CatalystConf with Logging { def variableSubstituteDepth: Int = getConf(VARIABLE_SUBSTITUTE_DEPTH) def warehousePath: String = { - getConf(WAREHOUSE_PATH).replace("${system:user.dir}", System.getProperty("user.dir")) + new Path(getConf(WAREHOUSE_PATH).replace("${system:user.dir}", + System.getProperty("user.dir"))).toString } - override def orderByOrdinal: Boolean = getConf(ORDER_BY_ORDINAL) override def groupByOrdinal: Boolean = getConf(GROUP_BY_ORDINAL) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/internal/SharedState.scala b/sql/core/src/main/scala/org/apache/spark/sql/internal/SharedState.scala index 6c43fe3177d6..54aee5e02bb9 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/internal/SharedState.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/internal/SharedState.scala @@ -17,8 +17,6 @@ package org.apache.spark.sql.internal -import org.apache.hadoop.conf.Configuration - import org.apache.spark.SparkContext import org.apache.spark.internal.Logging import org.apache.spark.sql.{SparkSession, SQLContext} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/sources/interfaces.scala b/sql/core/src/main/scala/org/apache/spark/sql/sources/interfaces.scala index d2077a07f440..b84953deac9e 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/sources/interfaces.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/sources/interfaces.scala @@ -112,8 +112,10 @@ trait SchemaRelationProvider { } /** + * ::Experimental:: * Implemented by objects that can produce a streaming [[Source]] for a specific format or system. */ +@Experimental trait StreamSourceProvider { /** Returns the name and schema of the source that can be used to continually read data. */ @@ -132,8 +134,10 @@ trait StreamSourceProvider { } /** + * ::Experimental:: * Implemented by objects that can produce a streaming [[Sink]] for a specific format or system. */ +@Experimental trait StreamSinkProvider { def createSink( sqlContext: SQLContext, diff --git a/sql/core/src/main/scala/org/apache/spark/sql/streaming/DataStreamReader.scala b/sql/core/src/main/scala/org/apache/spark/sql/streaming/DataStreamReader.scala index 2e606b21bdf3..36c80ad8362e 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/streaming/DataStreamReader.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/streaming/DataStreamReader.scala @@ -161,6 +161,7 @@ final class DataStreamReader private[sql](sparkSession: SparkSession) extends Lo * schema in advance, use the version that specifies the schema to avoid the extra scan. * * You can set the following JSON-specific options to deal with non-standard JSON files: + *
      *
    • `maxFilesPerTrigger` (default: no max limit): sets the maximum number of new files to be * considered in every trigger.
    • *
    • `primitivesAsString` (default `false`): infers all primitive values as a string type
    • @@ -175,17 +176,25 @@ final class DataStreamReader private[sql](sparkSession: SparkSession) extends Lo *
    • `allowBackslashEscapingAnyCharacter` (default `false`): allows accepting quoting of all * character using backslash quoting mechanism
    • *
    • `mode` (default `PERMISSIVE`): allows a mode for dealing with corrupt records - * during parsing.
    • - *
        - *
      • `PERMISSIVE` : sets other fields to `null` when it meets a corrupted record, and puts the - * malformed string into a new field configured by `columnNameOfCorruptRecord`. When - * a schema is set by user, it sets `null` for extra fields.
      • - *
      • `DROPMALFORMED` : ignores the whole corrupted records.
      • - *
      • `FAILFAST` : throws an exception when it meets corrupted records.
      • - *
      + * during parsing. + *
        + *
      • `PERMISSIVE` : sets other fields to `null` when it meets a corrupted record, and puts + * the malformed string into a new field configured by `columnNameOfCorruptRecord`. When + * a schema is set by user, it sets `null` for extra fields.
      • + *
      • `DROPMALFORMED` : ignores the whole corrupted records.
      • + *
      • `FAILFAST` : throws an exception when it meets corrupted records.
      • + *
      + * *
    • `columnNameOfCorruptRecord` (default is the value specified in * `spark.sql.columnNameOfCorruptRecord`): allows renaming the new field having malformed string * created by `PERMISSIVE` mode. This overrides `spark.sql.columnNameOfCorruptRecord`.
    • + *
    • `dateFormat` (default `yyyy-MM-dd`): sets the string that indicates a date format. + * Custom date formats follow the formats at `java.text.SimpleDateFormat`. This applies to + * date type.
    • + *
    • `timestampFormat` (default `yyyy-MM-dd'T'HH:mm:ss.SSSZZ`): sets the string that + * indicates a timestamp format. Custom date formats follow the formats at + * `java.text.SimpleDateFormat`. This applies to timestamp type.
    • + *
    * * @since 2.0.0 */ @@ -201,6 +210,7 @@ final class DataStreamReader private[sql](sparkSession: SparkSession) extends Lo * specify the schema explicitly using [[schema]]. * * You can set the following CSV-specific options to deal with CSV files: + *
      *
    • `maxFilesPerTrigger` (default: no max limit): sets the maximum number of new files to be * considered in every trigger.
    • *
    • `sep` (default `,`): sets the single character as a separator for each @@ -222,27 +232,32 @@ final class DataStreamReader private[sql](sparkSession: SparkSession) extends Lo * from values being read should be skipped.
    • *
    • `ignoreTrailingWhiteSpace` (default `false`): defines whether or not trailing * whitespaces from values being read should be skipped.
    • - *
    • `nullValue` (default empty string): sets the string representation of a null value.
    • + *
    • `nullValue` (default empty string): sets the string representation of a null value. Since + * 2.0.1, this applies to all supported types including the string type.
    • *
    • `nanValue` (default `NaN`): sets the string representation of a non-number" value.
    • *
    • `positiveInf` (default `Inf`): sets the string representation of a positive infinity * value.
    • *
    • `negativeInf` (default `-Inf`): sets the string representation of a negative infinity * value.
    • - *
    • `dateFormat` (default `null`): sets the string that indicates a date format. Custom date - * formats follow the formats at `java.text.SimpleDateFormat`. This applies to both date type - * and timestamp type. By default, it is `null` which means trying to parse times and date by - * `java.sql.Timestamp.valueOf()` and `java.sql.Date.valueOf()`.
    • + *
    • `dateFormat` (default `yyyy-MM-dd`): sets the string that indicates a date format. + * Custom date formats follow the formats at `java.text.SimpleDateFormat`. This applies to + * date type.
    • + *
    • `timestampFormat` (default `yyyy-MM-dd'T'HH:mm:ss.SSSZZ`): sets the string that + * indicates a timestamp format. Custom date formats follow the formats at + * `java.text.SimpleDateFormat`. This applies to timestamp type.
    • *
    • `maxColumns` (default `20480`): defines a hard limit of how many columns * a record can have.
    • *
    • `maxCharsPerColumn` (default `1000000`): defines the maximum number of characters allowed * for any given value being read.
    • *
    • `mode` (default `PERMISSIVE`): allows a mode for dealing with corrupt records - * during parsing.
    • - *
        - *
      • `PERMISSIVE` : sets other fields to `null` when it meets a corrupted record. When - * a schema is set by user, it sets `null` for extra fields.
      • - *
      • `DROPMALFORMED` : ignores the whole corrupted records.
      • - *
      • `FAILFAST` : throws an exception when it meets corrupted records.
      • + * during parsing. + *
          + *
        • `PERMISSIVE` : sets other fields to `null` when it meets a corrupted record. When + * a schema is set by user, it sets `null` for extra fields.
        • + *
        • `DROPMALFORMED` : ignores the whole corrupted records.
        • + *
        • `FAILFAST` : throws an exception when it meets corrupted records.
        • + *
        + * *
      * * @since 2.0.0 @@ -255,11 +270,13 @@ final class DataStreamReader private[sql](sparkSession: SparkSession) extends Lo * Loads a Parquet file stream, returning the result as a [[DataFrame]]. * * You can set the following Parquet-specific option(s) for reading Parquet files: + *
        *
      • `maxFilesPerTrigger` (default: no max limit): sets the maximum number of new files to be * considered in every trigger.
      • *
      • `mergeSchema` (default is the value specified in `spark.sql.parquet.mergeSchema`): sets * whether we should merge schemas collected from all Parquet part-files. This will override * `spark.sql.parquet.mergeSchema`.
      • + *
      * * @since 2.0.0 */ @@ -283,8 +300,10 @@ final class DataStreamReader private[sql](sparkSession: SparkSession) extends Lo * }}} * * You can set the following text-specific options to deal with text files: + *
        *
      • `maxFilesPerTrigger` (default: no max limit): sets the maximum number of new files to be * considered in every trigger.
      • + *
      * * @since 2.0.0 */ diff --git a/sql/core/src/main/scala/org/apache/spark/sql/streaming/DataStreamWriter.scala b/sql/core/src/main/scala/org/apache/spark/sql/streaming/DataStreamWriter.scala index d38e3e58125d..f70c7d08a691 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/streaming/DataStreamWriter.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/streaming/DataStreamWriter.scala @@ -122,7 +122,7 @@ final class DataStreamWriter[T] private[sql](ds: Dataset[T]) { /** * :: Experimental :: - * Specifies the underlying output data source. Built-in options include "parquet", "json", etc. + * Specifies the underlying output data source. Built-in options include "parquet" for now. * * @since 2.0.0 */ diff --git a/sql/core/src/main/scala/org/apache/spark/sql/streaming/StreamingQueryException.scala b/sql/core/src/main/scala/org/apache/spark/sql/streaming/StreamingQueryException.scala index 90f95ca9d422..bd3e5a5618ec 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/streaming/StreamingQueryException.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/streaming/StreamingQueryException.scala @@ -22,7 +22,8 @@ import org.apache.spark.sql.execution.streaming.{Offset, StreamExecution} /** * :: Experimental :: - * Exception that stopped a [[StreamingQuery]]. + * Exception that stopped a [[StreamingQuery]]. Use `cause` get the actual exception + * that caused the failure. * @param query Query that caused the exception * @param message Message of this exception * @param cause Internal cause of this exception diff --git a/sql/core/src/main/scala/org/apache/spark/sql/streaming/StreamingQueryListener.scala b/sql/core/src/main/scala/org/apache/spark/sql/streaming/StreamingQueryListener.scala index 3b3cead3a66d..db606abb8ce4 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/streaming/StreamingQueryListener.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/streaming/StreamingQueryListener.scala @@ -108,6 +108,5 @@ object StreamingQueryListener { @Experimental class QueryTerminated private[sql]( val queryInfo: StreamingQueryInfo, - val exception: Option[String], - val stackTrace: Seq[StackTraceElement]) extends Event + val exception: Option[String]) extends Event } diff --git a/sql/core/src/test/java/test/org/apache/spark/sql/JavaDataFrameSuite.java b/sql/core/src/test/java/test/org/apache/spark/sql/JavaDataFrameSuite.java index 318b53cdbbaa..c44fc3d39386 100644 --- a/sql/core/src/test/java/test/org/apache/spark/sql/JavaDataFrameSuite.java +++ b/sql/core/src/test/java/test/org/apache/spark/sql/JavaDataFrameSuite.java @@ -327,23 +327,23 @@ private String getResource(String resource) { @Test public void testGenericLoad() { - Dataset df1 = spark.read().format("text").load(getResource("text-suite.txt")); + Dataset df1 = spark.read().format("text").load(getResource("test-data/text-suite.txt")); Assert.assertEquals(4L, df1.count()); Dataset df2 = spark.read().format("text").load( - getResource("text-suite.txt"), - getResource("text-suite2.txt")); + getResource("test-data/text-suite.txt"), + getResource("test-data/text-suite2.txt")); Assert.assertEquals(5L, df2.count()); } @Test public void testTextLoad() { - Dataset ds1 = spark.read().textFile(getResource("text-suite.txt")); + Dataset ds1 = spark.read().textFile(getResource("test-data/text-suite.txt")); Assert.assertEquals(4L, ds1.count()); Dataset ds2 = spark.read().textFile( - getResource("text-suite.txt"), - getResource("text-suite2.txt")); + getResource("test-data/text-suite.txt"), + getResource("test-data/text-suite2.txt")); Assert.assertEquals(5L, ds2.count()); } diff --git a/sql/core/src/test/resources/old-repeated.parquet b/sql/core/src/test/resources/old-repeated.parquet deleted file mode 100644 index 213f1a90291b..000000000000 Binary files a/sql/core/src/test/resources/old-repeated.parquet and /dev/null differ diff --git a/sql/core/src/test/resources/sql-tests/inputs/arithmetic.sql b/sql/core/src/test/resources/sql-tests/inputs/arithmetic.sql new file mode 100644 index 000000000000..f62b10ca0037 --- /dev/null +++ b/sql/core/src/test/resources/sql-tests/inputs/arithmetic.sql @@ -0,0 +1,34 @@ + +-- unary minus and plus +select -100; +select +230; +select -5.2; +select +6.8e0; +select -key, +key from testdata where key = 2; +select -(key + 1), - key + 1, +(key + 5) from testdata where key = 1; +select -max(key), +max(key) from testdata; +select - (-10); +select + (-key) from testdata where key = 32; +select - (+max(key)) from testdata; +select - - 3; +select - + 20; +select + + 100; +select - - max(key) from testdata; +select + - key from testdata where key = 33; + +-- div +select 5 / 2; +select 5 / 0; +select 5 / null; +select null / 5; +select 5 div 2; +select 5 div 0; +select 5 div null; +select null div 5; + +-- other arithmetics +select 1 + 2; +select 1 - 2; +select 2 * 5; +select 5 % 3; +select pmod(-7, 3); diff --git a/sql/core/src/test/resources/sql-tests/inputs/array.sql b/sql/core/src/test/resources/sql-tests/inputs/array.sql new file mode 100644 index 000000000000..4038a0da41d2 --- /dev/null +++ b/sql/core/src/test/resources/sql-tests/inputs/array.sql @@ -0,0 +1,86 @@ +-- test cases for array functions + +create temporary view data as select * from values + ("one", array(11, 12, 13), array(array(111, 112, 113), array(121, 122, 123))), + ("two", array(21, 22, 23), array(array(211, 212, 213), array(221, 222, 223))) + as data(a, b, c); + +select * from data; + +-- index into array +select a, b[0], b[0] + b[1] from data; + +-- index into array of arrays +select a, c[0][0] + c[0][0 + 1] from data; + + +create temporary view primitive_arrays as select * from values ( + array(true), + array(2Y, 1Y), + array(2S, 1S), + array(2, 1), + array(2L, 1L), + array(9223372036854775809, 9223372036854775808), + array(2.0D, 1.0D), + array(float(2.0), float(1.0)), + array(date '2016-03-14', date '2016-03-13'), + array(timestamp '2016-11-15 20:54:00.000', timestamp '2016-11-12 20:54:00.000') +) as primitive_arrays( + boolean_array, + tinyint_array, + smallint_array, + int_array, + bigint_array, + decimal_array, + double_array, + float_array, + date_array, + timestamp_array +); + +select * from primitive_arrays; + +-- array_contains on all primitive types: result should alternate between true and false +select + array_contains(boolean_array, true), array_contains(boolean_array, false), + array_contains(tinyint_array, 2Y), array_contains(tinyint_array, 0Y), + array_contains(smallint_array, 2S), array_contains(smallint_array, 0S), + array_contains(int_array, 2), array_contains(int_array, 0), + array_contains(bigint_array, 2L), array_contains(bigint_array, 0L), + array_contains(decimal_array, 9223372036854775809), array_contains(decimal_array, 1), + array_contains(double_array, 2.0D), array_contains(double_array, 0.0D), + array_contains(float_array, float(2.0)), array_contains(float_array, float(0.0)), + array_contains(date_array, date '2016-03-14'), array_contains(date_array, date '2016-01-01'), + array_contains(timestamp_array, timestamp '2016-11-15 20:54:00.000'), array_contains(timestamp_array, timestamp '2016-01-01 20:54:00.000') +from primitive_arrays; + +-- array_contains on nested arrays +select array_contains(b, 11), array_contains(c, array(111, 112, 113)) from data; + +-- sort_array +select + sort_array(boolean_array), + sort_array(tinyint_array), + sort_array(smallint_array), + sort_array(int_array), + sort_array(bigint_array), + sort_array(decimal_array), + sort_array(double_array), + sort_array(float_array), + sort_array(date_array), + sort_array(timestamp_array) +from primitive_arrays; + +-- size +select + size(boolean_array), + size(tinyint_array), + size(smallint_array), + size(int_array), + size(bigint_array), + size(decimal_array), + size(double_array), + size(float_array), + size(date_array), + size(timestamp_array) +from primitive_arrays; diff --git a/sql/core/src/test/resources/sql-tests/inputs/blacklist.sql b/sql/core/src/test/resources/sql-tests/inputs/blacklist.sql new file mode 100644 index 000000000000..d69f8147a526 --- /dev/null +++ b/sql/core/src/test/resources/sql-tests/inputs/blacklist.sql @@ -0,0 +1,4 @@ +-- This is a query file that has been blacklisted. +-- It includes a query that should crash Spark. +-- If the test case is run, the whole suite would fail. +some random not working query that should crash Spark. diff --git a/sql/core/src/test/resources/sql-tests/inputs/datetime.sql b/sql/core/src/test/resources/sql-tests/inputs/datetime.sql new file mode 100644 index 000000000000..3fd1c37e7179 --- /dev/null +++ b/sql/core/src/test/resources/sql-tests/inputs/datetime.sql @@ -0,0 +1,4 @@ +-- date time functions + +-- [SPARK-16836] current_date and current_timestamp literals +select current_date = current_date(), current_timestamp = current_timestamp(); diff --git a/sql/core/src/test/resources/sql-tests/inputs/group-by-ordinal.sql b/sql/core/src/test/resources/sql-tests/inputs/group-by-ordinal.sql new file mode 100644 index 000000000000..36b469c61788 --- /dev/null +++ b/sql/core/src/test/resources/sql-tests/inputs/group-by-ordinal.sql @@ -0,0 +1,50 @@ +-- group by ordinal positions + +create temporary view data as select * from values + (1, 1), + (1, 2), + (2, 1), + (2, 2), + (3, 1), + (3, 2) + as data(a, b); + +-- basic case +select a, sum(b) from data group by 1; + +-- constant case +select 1, 2, sum(b) from data group by 1, 2; + +-- duplicate group by column +select a, 1, sum(b) from data group by a, 1; +select a, 1, sum(b) from data group by 1, 2; + +-- group by a non-aggregate expression's ordinal +select a, b + 2, count(2) from data group by a, 2; + +-- with alias +select a as aa, b + 2 as bb, count(2) from data group by 1, 2; + +-- foldable non-literal: this should be the same as no grouping. +select sum(b) from data group by 1 + 0; + +-- negative cases: ordinal out of range +select a, b from data group by -1; +select a, b from data group by 0; +select a, b from data group by 3; + +-- negative case: position is an aggregate expression +select a, b, sum(b) from data group by 3; +select a, b, sum(b) + 2 from data group by 3; + +-- negative case: nondeterministic expression +select a, rand(0), sum(b) from data group by a, 2; + +-- negative case: star +select * from data group by a, b, 1; + +-- turn of group by ordinal +set spark.sql.groupByOrdinal=false; + +-- can now group by negative literal +select sum(b) from data group by -1; diff --git a/sql/core/src/test/resources/sql-tests/inputs/group-by.sql b/sql/core/src/test/resources/sql-tests/inputs/group-by.sql new file mode 100644 index 000000000000..6741703d9d82 --- /dev/null +++ b/sql/core/src/test/resources/sql-tests/inputs/group-by.sql @@ -0,0 +1,17 @@ +-- Temporary data. +create temporary view myview as values 128, 256 as v(int_col); + +-- group by should produce all input rows, +select int_col, count(*) from myview group by int_col; + +-- group by should produce a single row. +select 'foo', count(*) from myview group by 1; + +-- group-by should not produce any rows (whole stage code generation). +select 'foo' from myview where int_col == 0 group by 1; + +-- group-by should not produce any rows (hash aggregate). +select 'foo', approx_count_distinct(int_col) from myview where int_col == 0 group by 1; + +-- group-by should not produce any rows (sort aggregate). +select 'foo', max(struct(int_col)) from myview where int_col == 0 group by 1; diff --git a/sql/core/src/test/resources/sql-tests/inputs/having.sql b/sql/core/src/test/resources/sql-tests/inputs/having.sql new file mode 100644 index 000000000000..364c022d959d --- /dev/null +++ b/sql/core/src/test/resources/sql-tests/inputs/having.sql @@ -0,0 +1,15 @@ +create temporary view hav as select * from values + ("one", 1), + ("two", 2), + ("three", 3), + ("one", 5) + as hav(k, v); + +-- having clause +SELECT k, sum(v) FROM hav GROUP BY k HAVING sum(v) > 2; + +-- having condition contains grouping column +SELECT count(k) FROM hav GROUP BY v + 1 HAVING v + 1 = 2; + +-- SPARK-11032: resolve having correctly +SELECT MIN(t.v) FROM (SELECT * FROM hav WHERE v > 0) t HAVING(COUNT(1) > 0); diff --git a/sql/core/src/test/resources/sql-tests/inputs/inline-table.sql b/sql/core/src/test/resources/sql-tests/inputs/inline-table.sql new file mode 100644 index 000000000000..5107fa4d5553 --- /dev/null +++ b/sql/core/src/test/resources/sql-tests/inputs/inline-table.sql @@ -0,0 +1,48 @@ + +-- single row, without table and column alias +select * from values ("one", 1); + +-- single row, without column alias +select * from values ("one", 1) as data; + +-- single row +select * from values ("one", 1) as data(a, b); + +-- single column multiple rows +select * from values 1, 2, 3 as data(a); + +-- three rows +select * from values ("one", 1), ("two", 2), ("three", null) as data(a, b); + +-- null type +select * from values ("one", null), ("two", null) as data(a, b); + +-- int and long coercion +select * from values ("one", 1), ("two", 2L) as data(a, b); + +-- foldable expressions +select * from values ("one", 1 + 0), ("two", 1 + 3L) as data(a, b); + +-- complex types +select * from values ("one", array(0, 1)), ("two", array(2, 3)) as data(a, b); + +-- decimal and double coercion +select * from values ("one", 2.0), ("two", 3.0D) as data(a, b); + +-- error reporting: nondeterministic function rand +select * from values ("one", rand(5)), ("two", 3.0D) as data(a, b); + +-- error reporting: different number of columns +select * from values ("one", 2.0), ("two") as data(a, b); + +-- error reporting: types that are incompatible +select * from values ("one", array(0, 1)), ("two", struct(1, 2)) as data(a, b); + +-- error reporting: number aliases different from number data values +select * from values ("one"), ("two") as data(a, b); + +-- error reporting: unresolved expression +select * from values ("one", random_not_exist_func(1)), ("two", 2) as data(a, b); + +-- error reporting: aggregate expression +select * from values ("one", count(1)), ("two", 2) as data(a, b); diff --git a/sql/core/src/test/resources/sql-tests/inputs/limit.sql b/sql/core/src/test/resources/sql-tests/inputs/limit.sql new file mode 100644 index 000000000000..2ea35f7f3a5c --- /dev/null +++ b/sql/core/src/test/resources/sql-tests/inputs/limit.sql @@ -0,0 +1,23 @@ + +-- limit on various data types +select * from testdata limit 2; +select * from arraydata limit 2; +select * from mapdata limit 2; + +-- foldable non-literal in limit +select * from testdata limit 2 + 1; + +select * from testdata limit CAST(1 AS int); + +-- limit must be non-negative +select * from testdata limit -1; + +-- limit must be foldable +select * from testdata limit key > 3; + +-- limit must be integer +select * from testdata limit true; +select * from testdata limit 'a'; + +-- limit within a subquery +select * from (select * from range(10) limit 5) where id > 3; diff --git a/sql/core/src/test/resources/sql-tests/inputs/literals.sql b/sql/core/src/test/resources/sql-tests/inputs/literals.sql new file mode 100644 index 000000000000..a532a598c6bf --- /dev/null +++ b/sql/core/src/test/resources/sql-tests/inputs/literals.sql @@ -0,0 +1,98 @@ +-- Literal parsing + +-- null +select null, Null, nUll; + +-- boolean +select true, tRue, false, fALse; + +-- byte (tinyint) +select 1Y; +select 127Y, -128Y; + +-- out of range byte +select 128Y; + +-- short (smallint) +select 1S; +select 32767S, -32768S; + +-- out of range short +select 32768S; + +-- long (bigint) +select 1L, 2147483648L; +select 9223372036854775807L, -9223372036854775808L; + +-- out of range long +select 9223372036854775808L; + +-- integral parsing + +-- parse int +select 1, -1; + +-- parse int max and min value as int +select 2147483647, -2147483648; + +-- parse long max and min value as long +select 9223372036854775807, -9223372036854775808; + +-- parse as decimals (Long.MaxValue + 1, and Long.MinValue - 1) +select 9223372036854775808, -9223372036854775809; + +-- out of range decimal numbers +select 1234567890123456789012345678901234567890; +select 1234567890123456789012345678901234567890.0; + +-- double +select 1D, 1.2D, 1e10, 1.5e5, .10D, 0.10D, .1e5, .9e+2, 0.9e+2, 900e-1, 9.e+1; +select -1D, -1.2D, -1e10, -1.5e5, -.10D, -0.10D, -.1e5; +-- negative double +select .e3; +-- inf and -inf +select 1E309, -1E309; + +-- decimal parsing +select 0.3, -0.8, .5, -.18, 0.1111, .1111; + +-- super large scientific notation numbers should still be valid doubles +select 123456789012345678901234567890123456789e10, 123456789012345678901234567890123456789.1e10; + +-- string +select "Hello Peter!", 'hello lee!'; +-- multi string +select 'hello' 'world', 'hello' " " 'lee'; +-- single quote within double quotes +select "hello 'peter'"; +select 'pattern%', 'no-pattern\%', 'pattern\\%', 'pattern\\\%'; +select '\'', '"', '\n', '\r', '\t', 'Z'; +-- "Hello!" in octals +select '\110\145\154\154\157\041'; +-- "World :)" in unicode +select '\u0057\u006F\u0072\u006C\u0064\u0020\u003A\u0029'; + +-- date +select dAte '2016-03-12'; +-- invalid date +select date 'mar 11 2016'; + +-- timestamp +select tImEstAmp '2016-03-11 20:54:00.000'; +-- invalid timestamp +select timestamp '2016-33-11 20:54:00.000'; + +-- interval +select interval 13.123456789 seconds, interval -13.123456789 second; +select interval 1 year 2 month 3 week 4 day 5 hour 6 minute 7 seconds 8 millisecond, 9 microsecond; +-- ns is not supported +select interval 10 nanoseconds; + +-- unsupported data type +select GEO '(10,-6)'; + +-- big decimal parsing +select 90912830918230182310293801923652346786BD, 123.0E-28BD, 123.08BD; + +-- out of range big decimal +select 1.20E-38BD; diff --git a/sql/core/src/test/resources/sql-tests/inputs/natural-join.sql b/sql/core/src/test/resources/sql-tests/inputs/natural-join.sql new file mode 100644 index 000000000000..71a50157b766 --- /dev/null +++ b/sql/core/src/test/resources/sql-tests/inputs/natural-join.sql @@ -0,0 +1,20 @@ +create temporary view nt1 as select * from values + ("one", 1), + ("two", 2), + ("three", 3) + as nt1(k, v1); + +create temporary view nt2 as select * from values + ("one", 1), + ("two", 22), + ("one", 5) + as nt2(k, v2); + + +SELECT * FROM nt1 natural join nt2 where k = "one"; + +SELECT * FROM nt1 natural left join nt2 order by v1, v2; + +SELECT * FROM nt1 natural right join nt2 order by v1, v2; + +SELECT count(*) FROM nt1 natural full outer join nt2; diff --git a/sql/core/src/test/resources/sql-tests/inputs/null-propagation.sql b/sql/core/src/test/resources/sql-tests/inputs/null-propagation.sql new file mode 100644 index 000000000000..66549da7971d --- /dev/null +++ b/sql/core/src/test/resources/sql-tests/inputs/null-propagation.sql @@ -0,0 +1,9 @@ + +-- count(null) should be 0 +SELECT COUNT(NULL) FROM VALUES 1, 2, 3; +SELECT COUNT(1 + NULL) FROM VALUES 1, 2, 3; + +-- count(null) on window should be 0 +SELECT COUNT(NULL) OVER () FROM VALUES 1, 2, 3; +SELECT COUNT(1 + NULL) OVER () FROM VALUES 1, 2, 3; + diff --git a/sql/core/src/test/resources/sql-tests/inputs/order-by-ordinal.sql b/sql/core/src/test/resources/sql-tests/inputs/order-by-ordinal.sql new file mode 100644 index 000000000000..8d733e77fa8d --- /dev/null +++ b/sql/core/src/test/resources/sql-tests/inputs/order-by-ordinal.sql @@ -0,0 +1,36 @@ +-- order by and sort by ordinal positions + +create temporary view data as select * from values + (1, 1), + (1, 2), + (2, 1), + (2, 2), + (3, 1), + (3, 2) + as data(a, b); + +select * from data order by 1 desc; + +-- mix ordinal and column name +select * from data order by 1 desc, b desc; + +-- order by multiple ordinals +select * from data order by 1 desc, 2 desc; + +-- 1 + 0 is considered a constant (not an ordinal) and thus ignored +select * from data order by 1 + 0 desc, b desc; + +-- negative cases: ordinal position out of range +select * from data order by 0; +select * from data order by -1; +select * from data order by 3; + +-- sort by ordinal +select * from data sort by 1 desc; + +-- turn off order by ordinal +set spark.sql.orderByOrdinal=false; + +-- 0 is now a valid literal +select * from data order by 0; +select * from data sort by 0; diff --git a/sql/core/src/test/resources/sql-tests/inputs/outer-join.sql b/sql/core/src/test/resources/sql-tests/inputs/outer-join.sql new file mode 100644 index 000000000000..f50f1ebad970 --- /dev/null +++ b/sql/core/src/test/resources/sql-tests/inputs/outer-join.sql @@ -0,0 +1,36 @@ +-- SPARK-17099: Incorrect result when HAVING clause is added to group by query +CREATE OR REPLACE TEMPORARY VIEW t1 AS SELECT * FROM VALUES +(-234), (145), (367), (975), (298) +as t1(int_col1); + +CREATE OR REPLACE TEMPORARY VIEW t2 AS SELECT * FROM VALUES +(-769, -244), (-800, -409), (940, 86), (-507, 304), (-367, 158) +as t2(int_col0, int_col1); + +SELECT + (SUM(COALESCE(t1.int_col1, t2.int_col0))), + ((COALESCE(t1.int_col1, t2.int_col0)) * 2) +FROM t1 +RIGHT JOIN t2 + ON (t2.int_col0) = (t1.int_col1) +GROUP BY GREATEST(COALESCE(t2.int_col1, 109), COALESCE(t1.int_col1, -449)), + COALESCE(t1.int_col1, t2.int_col0) +HAVING (SUM(COALESCE(t1.int_col1, t2.int_col0))) + > ((COALESCE(t1.int_col1, t2.int_col0)) * 2); + + +-- SPARK-17120: Analyzer incorrectly optimizes plan to empty LocalRelation +CREATE OR REPLACE TEMPORARY VIEW t1 AS SELECT * FROM VALUES (97) as t1(int_col1); + +CREATE OR REPLACE TEMPORARY VIEW t2 AS SELECT * FROM VALUES (0) as t2(int_col1); + +SELECT * +FROM ( +SELECT + COALESCE(t2.int_col1, t1.int_col1) AS int_col + FROM t1 + LEFT JOIN t2 ON false +) t where (t.int_col) is not null; + + + diff --git a/sql/core/src/test/resources/sql-tests/inputs/table-valued-functions.sql b/sql/core/src/test/resources/sql-tests/inputs/table-valued-functions.sql new file mode 100644 index 000000000000..2e6dcd538b7a --- /dev/null +++ b/sql/core/src/test/resources/sql-tests/inputs/table-valued-functions.sql @@ -0,0 +1,20 @@ +-- unresolved function +select * from dummy(3); + +-- range call with end +select * from range(6 + cos(3)); + +-- range call with start and end +select * from range(5, 10); + +-- range call with step +select * from range(0, 10, 2); + +-- range call with numPartitions +select * from range(0, 10, 1, 200); + +-- range call error +select * from range(1, 1, 1, 1, 1); + +-- range call with null +select * from range(1, null); diff --git a/sql/core/src/test/resources/sql-tests/results/arithmetic.sql.out b/sql/core/src/test/resources/sql-tests/results/arithmetic.sql.out new file mode 100644 index 000000000000..6abe048af477 --- /dev/null +++ b/sql/core/src/test/resources/sql-tests/results/arithmetic.sql.out @@ -0,0 +1,226 @@ +-- Automatically generated by SQLQueryTestSuite +-- Number of queries: 28 + + +-- !query 0 +select -100 +-- !query 0 schema +struct<-100:int> +-- !query 0 output +-100 + + +-- !query 1 +select +230 +-- !query 1 schema +struct<230:int> +-- !query 1 output +230 + + +-- !query 2 +select -5.2 +-- !query 2 schema +struct<-5.2:decimal(2,1)> +-- !query 2 output +-5.2 + + +-- !query 3 +select +6.8e0 +-- !query 3 schema +struct<6.8:double> +-- !query 3 output +6.8 + + +-- !query 4 +select -key, +key from testdata where key = 2 +-- !query 4 schema +struct<(- key):int,key:int> +-- !query 4 output +-2 2 + + +-- !query 5 +select -(key + 1), - key + 1, +(key + 5) from testdata where key = 1 +-- !query 5 schema +struct<(- (key + 1)):int,((- key) + 1):int,(key + 5):int> +-- !query 5 output +-2 0 6 + + +-- !query 6 +select -max(key), +max(key) from testdata +-- !query 6 schema +struct<(- max(key)):int,max(key):int> +-- !query 6 output +-100 100 + + +-- !query 7 +select - (-10) +-- !query 7 schema +struct<(- -10):int> +-- !query 7 output +10 + + +-- !query 8 +select + (-key) from testdata where key = 32 +-- !query 8 schema +struct<(- key):int> +-- !query 8 output +-32 + + +-- !query 9 +select - (+max(key)) from testdata +-- !query 9 schema +struct<(- max(key)):int> +-- !query 9 output +-100 + + +-- !query 10 +select - - 3 +-- !query 10 schema +struct<(- -3):int> +-- !query 10 output +3 + + +-- !query 11 +select - + 20 +-- !query 11 schema +struct<(- 20):int> +-- !query 11 output +-20 + + +-- !query 12 +select + + 100 +-- !query 12 schema +struct<100:int> +-- !query 12 output +100 + + +-- !query 13 +select - - max(key) from testdata +-- !query 13 schema +struct<(- (- max(key))):int> +-- !query 13 output +100 + + +-- !query 14 +select + - key from testdata where key = 33 +-- !query 14 schema +struct<(- key):int> +-- !query 14 output +-33 + + +-- !query 15 +select 5 / 2 +-- !query 15 schema +struct<(CAST(5 AS DOUBLE) / CAST(2 AS DOUBLE)):double> +-- !query 15 output +2.5 + + +-- !query 16 +select 5 / 0 +-- !query 16 schema +struct<(CAST(5 AS DOUBLE) / CAST(0 AS DOUBLE)):double> +-- !query 16 output +NULL + + +-- !query 17 +select 5 / null +-- !query 17 schema +struct<(CAST(5 AS DOUBLE) / CAST(NULL AS DOUBLE)):double> +-- !query 17 output +NULL + + +-- !query 18 +select null / 5 +-- !query 18 schema +struct<(CAST(NULL AS DOUBLE) / CAST(5 AS DOUBLE)):double> +-- !query 18 output +NULL + + +-- !query 19 +select 5 div 2 +-- !query 19 schema +struct +-- !query 19 output +2 + + +-- !query 20 +select 5 div 0 +-- !query 20 schema +struct +-- !query 20 output +NULL + + +-- !query 21 +select 5 div null +-- !query 21 schema +struct +-- !query 21 output +NULL + + +-- !query 22 +select null div 5 +-- !query 22 schema +struct +-- !query 22 output +NULL + + +-- !query 23 +select 1 + 2 +-- !query 23 schema +struct<(1 + 2):int> +-- !query 23 output +3 + + +-- !query 24 +select 1 - 2 +-- !query 24 schema +struct<(1 - 2):int> +-- !query 24 output +-1 + + +-- !query 25 +select 2 * 5 +-- !query 25 schema +struct<(2 * 5):int> +-- !query 25 output +10 + + +-- !query 26 +select 5 % 3 +-- !query 26 schema +struct<(5 % 3):int> +-- !query 26 output +2 + + +-- !query 27 +select pmod(-7, 3) +-- !query 27 schema +struct +-- !query 27 output +2 diff --git a/sql/core/src/test/resources/sql-tests/results/array.sql.out b/sql/core/src/test/resources/sql-tests/results/array.sql.out new file mode 100644 index 000000000000..4a1d149c1f36 --- /dev/null +++ b/sql/core/src/test/resources/sql-tests/results/array.sql.out @@ -0,0 +1,144 @@ +-- Automatically generated by SQLQueryTestSuite +-- Number of queries: 10 + + +-- !query 0 +create temporary view data as select * from values + ("one", array(11, 12, 13), array(array(111, 112, 113), array(121, 122, 123))), + ("two", array(21, 22, 23), array(array(211, 212, 213), array(221, 222, 223))) + as data(a, b, c) +-- !query 0 schema +struct<> +-- !query 0 output + + + +-- !query 1 +select * from data +-- !query 1 schema +struct,c:array>> +-- !query 1 output +one [11,12,13] [[111,112,113],[121,122,123]] +two [21,22,23] [[211,212,213],[221,222,223]] + + +-- !query 2 +select a, b[0], b[0] + b[1] from data +-- !query 2 schema +struct +-- !query 2 output +one 11 23 +two 21 43 + + +-- !query 3 +select a, c[0][0] + c[0][0 + 1] from data +-- !query 3 schema +struct +-- !query 3 output +one 223 +two 423 + + +-- !query 4 +create temporary view primitive_arrays as select * from values ( + array(true), + array(2Y, 1Y), + array(2S, 1S), + array(2, 1), + array(2L, 1L), + array(9223372036854775809, 9223372036854775808), + array(2.0D, 1.0D), + array(float(2.0), float(1.0)), + array(date '2016-03-14', date '2016-03-13'), + array(timestamp '2016-11-15 20:54:00.000', timestamp '2016-11-12 20:54:00.000') +) as primitive_arrays( + boolean_array, + tinyint_array, + smallint_array, + int_array, + bigint_array, + decimal_array, + double_array, + float_array, + date_array, + timestamp_array +) +-- !query 4 schema +struct<> +-- !query 4 output + + + +-- !query 5 +select * from primitive_arrays +-- !query 5 schema +struct,tinyint_array:array,smallint_array:array,int_array:array,bigint_array:array,decimal_array:array,double_array:array,float_array:array,date_array:array,timestamp_array:array> +-- !query 5 output +[true] [2,1] [2,1] [2,1] [2,1] [9223372036854775809,9223372036854775808] [2.0,1.0] [2.0,1.0] [2016-03-14,2016-03-13] [2016-11-15 20:54:00.0,2016-11-12 20:54:00.0] + + +-- !query 6 +select + array_contains(boolean_array, true), array_contains(boolean_array, false), + array_contains(tinyint_array, 2Y), array_contains(tinyint_array, 0Y), + array_contains(smallint_array, 2S), array_contains(smallint_array, 0S), + array_contains(int_array, 2), array_contains(int_array, 0), + array_contains(bigint_array, 2L), array_contains(bigint_array, 0L), + array_contains(decimal_array, 9223372036854775809), array_contains(decimal_array, 1), + array_contains(double_array, 2.0D), array_contains(double_array, 0.0D), + array_contains(float_array, float(2.0)), array_contains(float_array, float(0.0)), + array_contains(date_array, date '2016-03-14'), array_contains(date_array, date '2016-01-01'), + array_contains(timestamp_array, timestamp '2016-11-15 20:54:00.000'), array_contains(timestamp_array, timestamp '2016-01-01 20:54:00.000') +from primitive_arrays +-- !query 6 schema +struct +-- !query 6 output +true false true false true false true false true false true false true false true false true false true false + + +-- !query 7 +select array_contains(b, 11), array_contains(c, array(111, 112, 113)) from data +-- !query 7 schema +struct +-- !query 7 output +false false +true true + + +-- !query 8 +select + sort_array(boolean_array), + sort_array(tinyint_array), + sort_array(smallint_array), + sort_array(int_array), + sort_array(bigint_array), + sort_array(decimal_array), + sort_array(double_array), + sort_array(float_array), + sort_array(date_array), + sort_array(timestamp_array) +from primitive_arrays +-- !query 8 schema +struct,sort_array(tinyint_array, true):array,sort_array(smallint_array, true):array,sort_array(int_array, true):array,sort_array(bigint_array, true):array,sort_array(decimal_array, true):array,sort_array(double_array, true):array,sort_array(float_array, true):array,sort_array(date_array, true):array,sort_array(timestamp_array, true):array> +-- !query 8 output +[true] [1,2] [1,2] [1,2] [1,2] [9223372036854775808,9223372036854775809] [1.0,2.0] [1.0,2.0] [2016-03-13,2016-03-14] [2016-11-12 20:54:00.0,2016-11-15 20:54:00.0] + + +-- !query 9 +select + size(boolean_array), + size(tinyint_array), + size(smallint_array), + size(int_array), + size(bigint_array), + size(decimal_array), + size(double_array), + size(float_array), + size(date_array), + size(timestamp_array) +from primitive_arrays +-- !query 9 schema +struct +-- !query 9 output +1 2 2 2 2 2 2 2 2 2 diff --git a/sql/core/src/test/resources/sql-tests/results/datetime.sql.out b/sql/core/src/test/resources/sql-tests/results/datetime.sql.out new file mode 100644 index 000000000000..032e4258500f --- /dev/null +++ b/sql/core/src/test/resources/sql-tests/results/datetime.sql.out @@ -0,0 +1,10 @@ +-- Automatically generated by SQLQueryTestSuite +-- Number of queries: 1 + + +-- !query 0 +select current_date = current_date(), current_timestamp = current_timestamp() +-- !query 0 schema +struct<(current_date() = current_date()):boolean,(current_timestamp() = current_timestamp()):boolean> +-- !query 0 output +true true diff --git a/sql/core/src/test/resources/sql-tests/results/group-by-ordinal.sql.out b/sql/core/src/test/resources/sql-tests/results/group-by-ordinal.sql.out new file mode 100644 index 000000000000..2f10b7ebc6d3 --- /dev/null +++ b/sql/core/src/test/resources/sql-tests/results/group-by-ordinal.sql.out @@ -0,0 +1,168 @@ +-- Automatically generated by SQLQueryTestSuite +-- Number of queries: 17 + + +-- !query 0 +create temporary view data as select * from values + (1, 1), + (1, 2), + (2, 1), + (2, 2), + (3, 1), + (3, 2) + as data(a, b) +-- !query 0 schema +struct<> +-- !query 0 output + + + +-- !query 1 +select a, sum(b) from data group by 1 +-- !query 1 schema +struct +-- !query 1 output +1 3 +2 3 +3 3 + + +-- !query 2 +select 1, 2, sum(b) from data group by 1, 2 +-- !query 2 schema +struct<1:int,2:int,sum(b):bigint> +-- !query 2 output +1 2 9 + + +-- !query 3 +select a, 1, sum(b) from data group by a, 1 +-- !query 3 schema +struct +-- !query 3 output +1 1 3 +2 1 3 +3 1 3 + + +-- !query 4 +select a, 1, sum(b) from data group by 1, 2 +-- !query 4 schema +struct +-- !query 4 output +1 1 3 +2 1 3 +3 1 3 + + +-- !query 5 +select a, b + 2, count(2) from data group by a, 2 +-- !query 5 schema +struct +-- !query 5 output +1 3 1 +1 4 1 +2 3 1 +2 4 1 +3 3 1 +3 4 1 + + +-- !query 6 +select a as aa, b + 2 as bb, count(2) from data group by 1, 2 +-- !query 6 schema +struct +-- !query 6 output +1 3 1 +1 4 1 +2 3 1 +2 4 1 +3 3 1 +3 4 1 + + +-- !query 7 +select sum(b) from data group by 1 + 0 +-- !query 7 schema +struct +-- !query 7 output +9 + + +-- !query 8 +select a, b from data group by -1 +-- !query 8 schema +struct<> +-- !query 8 output +org.apache.spark.sql.AnalysisException +GROUP BY position -1 is not in select list (valid range is [1, 2]); line 1 pos 31 + + +-- !query 9 +select a, b from data group by 0 +-- !query 9 schema +struct<> +-- !query 9 output +org.apache.spark.sql.AnalysisException +GROUP BY position 0 is not in select list (valid range is [1, 2]); line 1 pos 31 + + +-- !query 10 +select a, b from data group by 3 +-- !query 10 schema +struct<> +-- !query 10 output +org.apache.spark.sql.AnalysisException +GROUP BY position 3 is not in select list (valid range is [1, 2]); line 1 pos 31 + + +-- !query 11 +select a, b, sum(b) from data group by 3 +-- !query 11 schema +struct<> +-- !query 11 output +org.apache.spark.sql.AnalysisException +GROUP BY position 3 is an aggregate function, and aggregate functions are not allowed in GROUP BY; line 1 pos 39 + + +-- !query 12 +select a, b, sum(b) + 2 from data group by 3 +-- !query 12 schema +struct<> +-- !query 12 output +org.apache.spark.sql.AnalysisException +GROUP BY position 3 is an aggregate function, and aggregate functions are not allowed in GROUP BY; line 1 pos 43 + + +-- !query 13 +select a, rand(0), sum(b) from data group by a, 2 +-- !query 13 schema +struct<> +-- !query 13 output +org.apache.spark.sql.AnalysisException +nondeterministic expression rand(0) should not appear in grouping expression.; + + +-- !query 14 +select * from data group by a, b, 1 +-- !query 14 schema +struct<> +-- !query 14 output +org.apache.spark.sql.AnalysisException +Star (*) is not allowed in select list when GROUP BY ordinal position is used; + + +-- !query 15 +set spark.sql.groupByOrdinal=false +-- !query 15 schema +struct +-- !query 15 output +spark.sql.groupByOrdinal + + +-- !query 16 +select sum(b) from data group by -1 +-- !query 16 schema +struct +-- !query 16 output +9 diff --git a/sql/core/src/test/resources/sql-tests/results/group-by.sql.out b/sql/core/src/test/resources/sql-tests/results/group-by.sql.out new file mode 100644 index 000000000000..9127bd4dd4c6 --- /dev/null +++ b/sql/core/src/test/resources/sql-tests/results/group-by.sql.out @@ -0,0 +1,51 @@ +-- Automatically generated by SQLQueryTestSuite +-- Number of queries: 6 + + +-- !query 0 +create temporary view myview as values 128, 256 as v(int_col) +-- !query 0 schema +struct<> +-- !query 0 output + + + +-- !query 1 +select int_col, count(*) from myview group by int_col +-- !query 1 schema +struct +-- !query 1 output +128 1 +256 1 + + +-- !query 2 +select 'foo', count(*) from myview group by 1 +-- !query 2 schema +struct +-- !query 2 output +foo 2 + + +-- !query 3 +select 'foo' from myview where int_col == 0 group by 1 +-- !query 3 schema +struct +-- !query 3 output + + + +-- !query 4 +select 'foo', approx_count_distinct(int_col) from myview where int_col == 0 group by 1 +-- !query 4 schema +struct +-- !query 4 output + + + +-- !query 5 +select 'foo', max(struct(int_col)) from myview where int_col == 0 group by 1 +-- !query 5 schema +struct> +-- !query 5 output + diff --git a/sql/core/src/test/resources/sql-tests/results/having.sql.out b/sql/core/src/test/resources/sql-tests/results/having.sql.out new file mode 100644 index 000000000000..e0923832673c --- /dev/null +++ b/sql/core/src/test/resources/sql-tests/results/having.sql.out @@ -0,0 +1,40 @@ +-- Automatically generated by SQLQueryTestSuite +-- Number of queries: 4 + + +-- !query 0 +create temporary view hav as select * from values + ("one", 1), + ("two", 2), + ("three", 3), + ("one", 5) + as hav(k, v) +-- !query 0 schema +struct<> +-- !query 0 output + + + +-- !query 1 +SELECT k, sum(v) FROM hav GROUP BY k HAVING sum(v) > 2 +-- !query 1 schema +struct +-- !query 1 output +one 6 +three 3 + + +-- !query 2 +SELECT count(k) FROM hav GROUP BY v + 1 HAVING v + 1 = 2 +-- !query 2 schema +struct +-- !query 2 output +1 + + +-- !query 3 +SELECT MIN(t.v) FROM (SELECT * FROM hav WHERE v > 0) t HAVING(COUNT(1) > 0) +-- !query 3 schema +struct +-- !query 3 output +1 diff --git a/sql/core/src/test/resources/sql-tests/results/inline-table.sql.out b/sql/core/src/test/resources/sql-tests/results/inline-table.sql.out new file mode 100644 index 000000000000..de6f01b8de77 --- /dev/null +++ b/sql/core/src/test/resources/sql-tests/results/inline-table.sql.out @@ -0,0 +1,145 @@ +-- Automatically generated by SQLQueryTestSuite +-- Number of queries: 16 + + +-- !query 0 +select * from values ("one", 1) +-- !query 0 schema +struct +-- !query 0 output +one 1 + + +-- !query 1 +select * from values ("one", 1) as data +-- !query 1 schema +struct +-- !query 1 output +one 1 + + +-- !query 2 +select * from values ("one", 1) as data(a, b) +-- !query 2 schema +struct +-- !query 2 output +one 1 + + +-- !query 3 +select * from values 1, 2, 3 as data(a) +-- !query 3 schema +struct +-- !query 3 output +1 +2 +3 + + +-- !query 4 +select * from values ("one", 1), ("two", 2), ("three", null) as data(a, b) +-- !query 4 schema +struct +-- !query 4 output +one 1 +three NULL +two 2 + + +-- !query 5 +select * from values ("one", null), ("two", null) as data(a, b) +-- !query 5 schema +struct +-- !query 5 output +one NULL +two NULL + + +-- !query 6 +select * from values ("one", 1), ("two", 2L) as data(a, b) +-- !query 6 schema +struct +-- !query 6 output +one 1 +two 2 + + +-- !query 7 +select * from values ("one", 1 + 0), ("two", 1 + 3L) as data(a, b) +-- !query 7 schema +struct +-- !query 7 output +one 1 +two 4 + + +-- !query 8 +select * from values ("one", array(0, 1)), ("two", array(2, 3)) as data(a, b) +-- !query 8 schema +struct> +-- !query 8 output +one [0,1] +two [2,3] + + +-- !query 9 +select * from values ("one", 2.0), ("two", 3.0D) as data(a, b) +-- !query 9 schema +struct +-- !query 9 output +one 2.0 +two 3.0 + + +-- !query 10 +select * from values ("one", rand(5)), ("two", 3.0D) as data(a, b) +-- !query 10 schema +struct<> +-- !query 10 output +org.apache.spark.sql.AnalysisException +cannot evaluate expression rand(5) in inline table definition; line 1 pos 29 + + +-- !query 11 +select * from values ("one", 2.0), ("two") as data(a, b) +-- !query 11 schema +struct<> +-- !query 11 output +org.apache.spark.sql.AnalysisException +expected 2 columns but found 1 columns in row 1; line 1 pos 14 + + +-- !query 12 +select * from values ("one", array(0, 1)), ("two", struct(1, 2)) as data(a, b) +-- !query 12 schema +struct<> +-- !query 12 output +org.apache.spark.sql.AnalysisException +incompatible types found in column b for inline table; line 1 pos 14 + + +-- !query 13 +select * from values ("one"), ("two") as data(a, b) +-- !query 13 schema +struct<> +-- !query 13 output +org.apache.spark.sql.AnalysisException +expected 2 columns but found 1 columns in row 0; line 1 pos 14 + + +-- !query 14 +select * from values ("one", random_not_exist_func(1)), ("two", 2) as data(a, b) +-- !query 14 schema +struct<> +-- !query 14 output +org.apache.spark.sql.AnalysisException +Undefined function: 'random_not_exist_func'. This function is neither a registered temporary function nor a permanent function registered in the database 'default'.; line 1 pos 29 + + +-- !query 15 +select * from values ("one", count(1)), ("two", 2) as data(a, b) +-- !query 15 schema +struct<> +-- !query 15 output +org.apache.spark.sql.AnalysisException +cannot evaluate expression count(1) in inline table definition; line 1 pos 29 diff --git a/sql/core/src/test/resources/sql-tests/results/limit.sql.out b/sql/core/src/test/resources/sql-tests/results/limit.sql.out new file mode 100644 index 000000000000..cb4e4d04810d --- /dev/null +++ b/sql/core/src/test/resources/sql-tests/results/limit.sql.out @@ -0,0 +1,91 @@ +-- Automatically generated by SQLQueryTestSuite +-- Number of queries: 10 + + +-- !query 0 +select * from testdata limit 2 +-- !query 0 schema +struct +-- !query 0 output +1 1 +2 2 + + +-- !query 1 +select * from arraydata limit 2 +-- !query 1 schema +struct,nestedarraycol:array>> +-- !query 1 output +[1,2,3] [[1,2,3]] +[2,3,4] [[2,3,4]] + + +-- !query 2 +select * from mapdata limit 2 +-- !query 2 schema +struct> +-- !query 2 output +{1:"a1",2:"b1",3:"c1",4:"d1",5:"e1"} +{1:"a2",2:"b2",3:"c2",4:"d2"} + + +-- !query 3 +select * from testdata limit 2 + 1 +-- !query 3 schema +struct +-- !query 3 output +1 1 +2 2 +3 3 + + +-- !query 4 +select * from testdata limit CAST(1 AS int) +-- !query 4 schema +struct +-- !query 4 output +1 1 + + +-- !query 5 +select * from testdata limit -1 +-- !query 5 schema +struct<> +-- !query 5 output +org.apache.spark.sql.AnalysisException +The limit expression must be equal to or greater than 0, but got -1; + + +-- !query 6 +select * from testdata limit key > 3 +-- !query 6 schema +struct<> +-- !query 6 output +org.apache.spark.sql.AnalysisException +The limit expression must evaluate to a constant value, but got (testdata.`key` > 3); + + +-- !query 7 +select * from testdata limit true +-- !query 7 schema +struct<> +-- !query 7 output +org.apache.spark.sql.AnalysisException +The limit expression must be integer type, but got boolean; + + +-- !query 8 +select * from testdata limit 'a' +-- !query 8 schema +struct<> +-- !query 8 output +org.apache.spark.sql.AnalysisException +The limit expression must be integer type, but got string; + + +-- !query 9 +select * from (select * from range(10) limit 5) where id > 3 +-- !query 9 schema +struct +-- !query 9 output +4 diff --git a/sql/core/src/test/resources/sql-tests/results/literals.sql.out b/sql/core/src/test/resources/sql-tests/results/literals.sql.out new file mode 100644 index 000000000000..85629f7ba813 --- /dev/null +++ b/sql/core/src/test/resources/sql-tests/results/literals.sql.out @@ -0,0 +1,378 @@ +-- Automatically generated by SQLQueryTestSuite +-- Number of queries: 40 + + +-- !query 0 +select null, Null, nUll +-- !query 0 schema +struct +-- !query 0 output +NULL NULL NULL + + +-- !query 1 +select true, tRue, false, fALse +-- !query 1 schema +struct +-- !query 1 output +true true false false + + +-- !query 2 +select 1Y +-- !query 2 schema +struct<1:tinyint> +-- !query 2 output +1 + + +-- !query 3 +select 127Y, -128Y +-- !query 3 schema +struct<127:tinyint,-128:tinyint> +-- !query 3 output +127 -128 + + +-- !query 4 +select 128Y +-- !query 4 schema +struct<> +-- !query 4 output +org.apache.spark.sql.catalyst.parser.ParseException + +Numeric literal 128 does not fit in range [-128, 127] for type tinyint(line 1, pos 7) + +== SQL == +select 128Y +-------^^^ + + +-- !query 5 +select 1S +-- !query 5 schema +struct<1:smallint> +-- !query 5 output +1 + + +-- !query 6 +select 32767S, -32768S +-- !query 6 schema +struct<32767:smallint,-32768:smallint> +-- !query 6 output +32767 -32768 + + +-- !query 7 +select 32768S +-- !query 7 schema +struct<> +-- !query 7 output +org.apache.spark.sql.catalyst.parser.ParseException + +Numeric literal 32768 does not fit in range [-32768, 32767] for type smallint(line 1, pos 7) + +== SQL == +select 32768S +-------^^^ + + +-- !query 8 +select 1L, 2147483648L +-- !query 8 schema +struct<1:bigint,2147483648:bigint> +-- !query 8 output +1 2147483648 + + +-- !query 9 +select 9223372036854775807L, -9223372036854775808L +-- !query 9 schema +struct<9223372036854775807:bigint,-9223372036854775808:bigint> +-- !query 9 output +9223372036854775807 -9223372036854775808 + + +-- !query 10 +select 9223372036854775808L +-- !query 10 schema +struct<> +-- !query 10 output +org.apache.spark.sql.catalyst.parser.ParseException + +Numeric literal 9223372036854775808 does not fit in range [-9223372036854775808, 9223372036854775807] for type bigint(line 1, pos 7) + +== SQL == +select 9223372036854775808L +-------^^^ + + +-- !query 11 +select 1, -1 +-- !query 11 schema +struct<1:int,-1:int> +-- !query 11 output +1 -1 + + +-- !query 12 +select 2147483647, -2147483648 +-- !query 12 schema +struct<2147483647:int,-2147483648:int> +-- !query 12 output +2147483647 -2147483648 + + +-- !query 13 +select 9223372036854775807, -9223372036854775808 +-- !query 13 schema +struct<9223372036854775807:bigint,-9223372036854775808:bigint> +-- !query 13 output +9223372036854775807 -9223372036854775808 + + +-- !query 14 +select 9223372036854775808, -9223372036854775809 +-- !query 14 schema +struct<9223372036854775808:decimal(19,0),-9223372036854775809:decimal(19,0)> +-- !query 14 output +9223372036854775808 -9223372036854775809 + + +-- !query 15 +select 1234567890123456789012345678901234567890 +-- !query 15 schema +struct<> +-- !query 15 output +org.apache.spark.sql.catalyst.parser.ParseException + +DecimalType can only support precision up to 38 +== SQL == +select 1234567890123456789012345678901234567890 + + +-- !query 16 +select 1234567890123456789012345678901234567890.0 +-- !query 16 schema +struct<> +-- !query 16 output +org.apache.spark.sql.catalyst.parser.ParseException + +DecimalType can only support precision up to 38 +== SQL == +select 1234567890123456789012345678901234567890.0 + + +-- !query 17 +select 1D, 1.2D, 1e10, 1.5e5, .10D, 0.10D, .1e5, .9e+2, 0.9e+2, 900e-1, 9.e+1 +-- !query 17 schema +struct<1.0:double,1.2:double,1.0E10:double,150000.0:double,0.1:double,0.1:double,10000.0:double,90.0:double,90.0:double,90.0:double,90.0:double> +-- !query 17 output +1.0 1.2 1.0E10 150000.0 0.1 0.1 10000.0 90.0 90.0 90.0 90.0 + + +-- !query 18 +select -1D, -1.2D, -1e10, -1.5e5, -.10D, -0.10D, -.1e5 +-- !query 18 schema +struct<-1.0:double,-1.2:double,-1.0E10:double,-150000.0:double,-0.1:double,-0.1:double,-10000.0:double> +-- !query 18 output +-1.0 -1.2 -1.0E10 -150000.0 -0.1 -0.1 -10000.0 + + +-- !query 19 +select .e3 +-- !query 19 schema +struct<> +-- !query 19 output +org.apache.spark.sql.catalyst.parser.ParseException + +no viable alternative at input 'select .'(line 1, pos 7) + +== SQL == +select .e3 +-------^^^ + + +-- !query 20 +select 1E309, -1E309 +-- !query 20 schema +struct +-- !query 20 output +Infinity -Infinity + + +-- !query 21 +select 0.3, -0.8, .5, -.18, 0.1111, .1111 +-- !query 21 schema +struct<0.3:decimal(1,1),-0.8:decimal(1,1),0.5:decimal(1,1),-0.18:decimal(2,2),0.1111:decimal(4,4),0.1111:decimal(4,4)> +-- !query 21 output +0.3 -0.8 0.5 -0.18 0.1111 0.1111 + + +-- !query 22 +select 123456789012345678901234567890123456789e10, 123456789012345678901234567890123456789.1e10 +-- !query 22 schema +struct<1.2345678901234568E48:double,1.2345678901234568E48:double> +-- !query 22 output +1.2345678901234568E48 1.2345678901234568E48 + + +-- !query 23 +select "Hello Peter!", 'hello lee!' +-- !query 23 schema +struct +-- !query 23 output +Hello Peter! hello lee! + + +-- !query 24 +select 'hello' 'world', 'hello' " " 'lee' +-- !query 24 schema +struct +-- !query 24 output +helloworld hello lee + + +-- !query 25 +select "hello 'peter'" +-- !query 25 schema +struct +-- !query 25 output +hello 'peter' + + +-- !query 26 +select 'pattern%', 'no-pattern\%', 'pattern\\%', 'pattern\\\%' +-- !query 26 schema +struct +-- !query 26 output +pattern% no-pattern\% pattern\% pattern\\% + + +-- !query 27 +select '\'', '"', '\n', '\r', '\t', 'Z' +-- !query 27 schema +struct<':string,":string, +:string, :string, :string,Z:string> +-- !query 27 output +' " + Z + + +-- !query 28 +select '\110\145\154\154\157\041' +-- !query 28 schema +struct +-- !query 28 output +Hello! + + +-- !query 29 +select '\u0057\u006F\u0072\u006C\u0064\u0020\u003A\u0029' +-- !query 29 schema +struct +-- !query 29 output +World :) + + +-- !query 30 +select dAte '2016-03-12' +-- !query 30 schema +struct +-- !query 30 output +2016-03-12 + + +-- !query 31 +select date 'mar 11 2016' +-- !query 31 schema +struct<> +-- !query 31 output +java.lang.IllegalArgumentException +null + + +-- !query 32 +select tImEstAmp '2016-03-11 20:54:00.000' +-- !query 32 schema +struct +-- !query 32 output +2016-03-11 20:54:00 + + +-- !query 33 +select timestamp '2016-33-11 20:54:00.000' +-- !query 33 schema +struct<> +-- !query 33 output +java.lang.IllegalArgumentException +Timestamp format must be yyyy-mm-dd hh:mm:ss[.fffffffff] + + +-- !query 34 +select interval 13.123456789 seconds, interval -13.123456789 second +-- !query 34 schema +struct<> +-- !query 34 output +scala.MatchError +(interval 13 seconds 123 milliseconds 456 microseconds,CalendarIntervalType) (of class scala.Tuple2) + + +-- !query 35 +select interval 1 year 2 month 3 week 4 day 5 hour 6 minute 7 seconds 8 millisecond, 9 microsecond +-- !query 35 schema +struct<> +-- !query 35 output +scala.MatchError +(interval 1 years 2 months 3 weeks 4 days 5 hours 6 minutes 7 seconds 8 milliseconds,CalendarIntervalType) (of class scala.Tuple2) + + +-- !query 36 +select interval 10 nanoseconds +-- !query 36 schema +struct<> +-- !query 36 output +org.apache.spark.sql.catalyst.parser.ParseException + +No interval can be constructed(line 1, pos 16) + +== SQL == +select interval 10 nanoseconds +----------------^^^ + + +-- !query 37 +select GEO '(10,-6)' +-- !query 37 schema +struct<> +-- !query 37 output +org.apache.spark.sql.catalyst.parser.ParseException + +Literals of type 'GEO' are currently not supported.(line 1, pos 7) + +== SQL == +select GEO '(10,-6)' +-------^^^ + + +-- !query 38 +select 90912830918230182310293801923652346786BD, 123.0E-28BD, 123.08BD +-- !query 38 schema +struct<90912830918230182310293801923652346786:decimal(38,0),1.230E-26:decimal(29,29),123.08:decimal(5,2)> +-- !query 38 output +90912830918230182310293801923652346786 0.0000000000000000000000000123 123.08 + + +-- !query 39 +select 1.20E-38BD +-- !query 39 schema +struct<> +-- !query 39 output +org.apache.spark.sql.catalyst.parser.ParseException + +DecimalType can only support precision up to 38(line 1, pos 7) + +== SQL == +select 1.20E-38BD +-------^^^ diff --git a/sql/core/src/test/resources/sql-tests/results/natural-join.sql.out b/sql/core/src/test/resources/sql-tests/results/natural-join.sql.out new file mode 100644 index 000000000000..43f2f9af61d9 --- /dev/null +++ b/sql/core/src/test/resources/sql-tests/results/natural-join.sql.out @@ -0,0 +1,64 @@ +-- Automatically generated by SQLQueryTestSuite +-- Number of queries: 6 + + +-- !query 0 +create temporary view nt1 as select * from values + ("one", 1), + ("two", 2), + ("three", 3) + as nt1(k, v1) +-- !query 0 schema +struct<> +-- !query 0 output + + + +-- !query 1 +create temporary view nt2 as select * from values + ("one", 1), + ("two", 22), + ("one", 5) + as nt2(k, v2) +-- !query 1 schema +struct<> +-- !query 1 output + + + +-- !query 2 +SELECT * FROM nt1 natural join nt2 where k = "one" +-- !query 2 schema +struct +-- !query 2 output +one 1 1 +one 1 5 + + +-- !query 3 +SELECT * FROM nt1 natural left join nt2 order by v1, v2 +-- !query 3 schema +struct +-- !query 3 output +one 1 1 +one 1 5 +two 2 22 +three 3 NULL + + +-- !query 4 +SELECT * FROM nt1 natural right join nt2 order by v1, v2 +-- !query 4 schema +struct +-- !query 4 output +one 1 1 +one 1 5 +two 2 22 + + +-- !query 5 +SELECT count(*) FROM nt1 natural full outer join nt2 +-- !query 5 schema +struct +-- !query 5 output +4 diff --git a/sql/core/src/test/resources/sql-tests/results/null-propagation.sql.out b/sql/core/src/test/resources/sql-tests/results/null-propagation.sql.out new file mode 100644 index 000000000000..ed3a651aa661 --- /dev/null +++ b/sql/core/src/test/resources/sql-tests/results/null-propagation.sql.out @@ -0,0 +1,38 @@ +-- Automatically generated by SQLQueryTestSuite +-- Number of queries: 4 + + +-- !query 0 +SELECT COUNT(NULL) FROM VALUES 1, 2, 3 +-- !query 0 schema +struct +-- !query 0 output +0 + + +-- !query 1 +SELECT COUNT(1 + NULL) FROM VALUES 1, 2, 3 +-- !query 1 schema +struct +-- !query 1 output +0 + + +-- !query 2 +SELECT COUNT(NULL) OVER () FROM VALUES 1, 2, 3 +-- !query 2 schema +struct +-- !query 2 output +0 +0 +0 + + +-- !query 3 +SELECT COUNT(1 + NULL) OVER () FROM VALUES 1, 2, 3 +-- !query 3 schema +struct +-- !query 3 output +0 +0 +0 diff --git a/sql/core/src/test/resources/sql-tests/results/order-by-ordinal.sql.out b/sql/core/src/test/resources/sql-tests/results/order-by-ordinal.sql.out new file mode 100644 index 000000000000..03a4e72d0fa3 --- /dev/null +++ b/sql/core/src/test/resources/sql-tests/results/order-by-ordinal.sql.out @@ -0,0 +1,143 @@ +-- Automatically generated by SQLQueryTestSuite +-- Number of queries: 12 + + +-- !query 0 +create temporary view data as select * from values + (1, 1), + (1, 2), + (2, 1), + (2, 2), + (3, 1), + (3, 2) + as data(a, b) +-- !query 0 schema +struct<> +-- !query 0 output + + + +-- !query 1 +select * from data order by 1 desc +-- !query 1 schema +struct +-- !query 1 output +3 1 +3 2 +2 1 +2 2 +1 1 +1 2 + + +-- !query 2 +select * from data order by 1 desc, b desc +-- !query 2 schema +struct +-- !query 2 output +3 2 +3 1 +2 2 +2 1 +1 2 +1 1 + + +-- !query 3 +select * from data order by 1 desc, 2 desc +-- !query 3 schema +struct +-- !query 3 output +3 2 +3 1 +2 2 +2 1 +1 2 +1 1 + + +-- !query 4 +select * from data order by 1 + 0 desc, b desc +-- !query 4 schema +struct +-- !query 4 output +1 2 +2 2 +3 2 +1 1 +2 1 +3 1 + + +-- !query 5 +select * from data order by 0 +-- !query 5 schema +struct<> +-- !query 5 output +org.apache.spark.sql.AnalysisException +ORDER BY position 0 is not in select list (valid range is [1, 2]); line 1 pos 28 + + +-- !query 6 +select * from data order by -1 +-- !query 6 schema +struct<> +-- !query 6 output +org.apache.spark.sql.AnalysisException +ORDER BY position -1 is not in select list (valid range is [1, 2]); line 1 pos 28 + + +-- !query 7 +select * from data order by 3 +-- !query 7 schema +struct<> +-- !query 7 output +org.apache.spark.sql.AnalysisException +ORDER BY position 3 is not in select list (valid range is [1, 2]); line 1 pos 28 + + +-- !query 8 +select * from data sort by 1 desc +-- !query 8 schema +struct +-- !query 8 output +1 1 +1 2 +2 1 +2 2 +3 1 +3 2 + + +-- !query 9 +set spark.sql.orderByOrdinal=false +-- !query 9 schema +struct +-- !query 9 output +spark.sql.orderByOrdinal + + +-- !query 10 +select * from data order by 0 +-- !query 10 schema +struct +-- !query 10 output +1 1 +1 2 +2 1 +2 2 +3 1 +3 2 + + +-- !query 11 +select * from data sort by 0 +-- !query 11 schema +struct +-- !query 11 output +1 1 +1 2 +2 1 +2 2 +3 1 +3 2 diff --git a/sql/core/src/test/resources/sql-tests/results/outer-join.sql.out b/sql/core/src/test/resources/sql-tests/results/outer-join.sql.out new file mode 100644 index 000000000000..b39fdb0e5872 --- /dev/null +++ b/sql/core/src/test/resources/sql-tests/results/outer-join.sql.out @@ -0,0 +1,72 @@ +-- Automatically generated by SQLQueryTestSuite +-- Number of queries: 6 + + +-- !query 0 +CREATE OR REPLACE TEMPORARY VIEW t1 AS SELECT * FROM VALUES +(-234), (145), (367), (975), (298) +as t1(int_col1) +-- !query 0 schema +struct<> +-- !query 0 output + + + +-- !query 1 +CREATE OR REPLACE TEMPORARY VIEW t2 AS SELECT * FROM VALUES +(-769, -244), (-800, -409), (940, 86), (-507, 304), (-367, 158) +as t2(int_col0, int_col1) +-- !query 1 schema +struct<> +-- !query 1 output + + + +-- !query 2 +SELECT + (SUM(COALESCE(t1.int_col1, t2.int_col0))), + ((COALESCE(t1.int_col1, t2.int_col0)) * 2) +FROM t1 +RIGHT JOIN t2 + ON (t2.int_col0) = (t1.int_col1) +GROUP BY GREATEST(COALESCE(t2.int_col1, 109), COALESCE(t1.int_col1, -449)), + COALESCE(t1.int_col1, t2.int_col0) +HAVING (SUM(COALESCE(t1.int_col1, t2.int_col0))) + > ((COALESCE(t1.int_col1, t2.int_col0)) * 2) +-- !query 2 schema +struct +-- !query 2 output +-367 -734 +-507 -1014 +-769 -1538 +-800 -1600 + + +-- !query 3 +CREATE OR REPLACE TEMPORARY VIEW t1 AS SELECT * FROM VALUES (97) as t1(int_col1) +-- !query 3 schema +struct<> +-- !query 3 output + + + +-- !query 4 +CREATE OR REPLACE TEMPORARY VIEW t2 AS SELECT * FROM VALUES (0) as t2(int_col1) +-- !query 4 schema +struct<> +-- !query 4 output + + + +-- !query 5 +SELECT * +FROM ( +SELECT + COALESCE(t2.int_col1, t1.int_col1) AS int_col + FROM t1 + LEFT JOIN t2 ON false +) t where (t.int_col) is not null +-- !query 5 schema +struct +-- !query 5 output +97 diff --git a/sql/core/src/test/resources/sql-tests/results/table-valued-functions.sql.out b/sql/core/src/test/resources/sql-tests/results/table-valued-functions.sql.out new file mode 100644 index 000000000000..d769bcef0aca --- /dev/null +++ b/sql/core/src/test/resources/sql-tests/results/table-valued-functions.sql.out @@ -0,0 +1,87 @@ +-- Automatically generated by SQLQueryTestSuite +-- Number of queries: 7 + + +-- !query 0 +select * from dummy(3) +-- !query 0 schema +struct<> +-- !query 0 output +org.apache.spark.sql.AnalysisException +could not resolve `dummy` to a table-valued function; line 1 pos 14 + + +-- !query 1 +select * from range(6 + cos(3)) +-- !query 1 schema +struct +-- !query 1 output +0 +1 +2 +3 +4 + + +-- !query 2 +select * from range(5, 10) +-- !query 2 schema +struct +-- !query 2 output +5 +6 +7 +8 +9 + + +-- !query 3 +select * from range(0, 10, 2) +-- !query 3 schema +struct +-- !query 3 output +0 +2 +4 +6 +8 + + +-- !query 4 +select * from range(0, 10, 1, 200) +-- !query 4 schema +struct +-- !query 4 output +0 +1 +2 +3 +4 +5 +6 +7 +8 +9 + + +-- !query 5 +select * from range(1, 1, 1, 1, 1) +-- !query 5 schema +struct<> +-- !query 5 output +org.apache.spark.sql.AnalysisException +error: table-valued function range with alternatives: + (end: long) + (start: long, end: long) + (start: long, end: long, step: long) + (start: long, end: long, step: long, numPartitions: integer) +cannot be applied to: (integer, integer, integer, integer, integer); line 1 pos 14 + + +-- !query 6 +select * from range(1, null) +-- !query 6 schema +struct<> +-- !query 6 output +java.lang.IllegalArgumentException +Invalid arguments for resolved function: 1, null diff --git a/sql/core/src/test/resources/bool.csv b/sql/core/src/test/resources/test-data/bool.csv similarity index 100% rename from sql/core/src/test/resources/bool.csv rename to sql/core/src/test/resources/test-data/bool.csv diff --git a/sql/core/src/test/resources/cars-alternative.csv b/sql/core/src/test/resources/test-data/cars-alternative.csv similarity index 100% rename from sql/core/src/test/resources/cars-alternative.csv rename to sql/core/src/test/resources/test-data/cars-alternative.csv diff --git a/sql/core/src/test/resources/cars-blank-column-name.csv b/sql/core/src/test/resources/test-data/cars-blank-column-name.csv similarity index 100% rename from sql/core/src/test/resources/cars-blank-column-name.csv rename to sql/core/src/test/resources/test-data/cars-blank-column-name.csv diff --git a/sql/core/src/test/resources/cars-malformed.csv b/sql/core/src/test/resources/test-data/cars-malformed.csv similarity index 100% rename from sql/core/src/test/resources/cars-malformed.csv rename to sql/core/src/test/resources/test-data/cars-malformed.csv diff --git a/sql/core/src/test/resources/cars-null.csv b/sql/core/src/test/resources/test-data/cars-null.csv similarity index 100% rename from sql/core/src/test/resources/cars-null.csv rename to sql/core/src/test/resources/test-data/cars-null.csv diff --git a/sql/core/src/test/resources/cars-unbalanced-quotes.csv b/sql/core/src/test/resources/test-data/cars-unbalanced-quotes.csv similarity index 100% rename from sql/core/src/test/resources/cars-unbalanced-quotes.csv rename to sql/core/src/test/resources/test-data/cars-unbalanced-quotes.csv diff --git a/sql/core/src/test/resources/cars.csv b/sql/core/src/test/resources/test-data/cars.csv similarity index 100% rename from sql/core/src/test/resources/cars.csv rename to sql/core/src/test/resources/test-data/cars.csv diff --git a/sql/core/src/test/resources/cars.tsv b/sql/core/src/test/resources/test-data/cars.tsv similarity index 100% rename from sql/core/src/test/resources/cars.tsv rename to sql/core/src/test/resources/test-data/cars.tsv diff --git a/sql/core/src/test/resources/cars_iso-8859-1.csv b/sql/core/src/test/resources/test-data/cars_iso-8859-1.csv similarity index 100% rename from sql/core/src/test/resources/cars_iso-8859-1.csv rename to sql/core/src/test/resources/test-data/cars_iso-8859-1.csv diff --git a/sql/core/src/test/resources/comments.csv b/sql/core/src/test/resources/test-data/comments.csv similarity index 100% rename from sql/core/src/test/resources/comments.csv rename to sql/core/src/test/resources/test-data/comments.csv diff --git a/sql/core/src/test/resources/dates.csv b/sql/core/src/test/resources/test-data/dates.csv similarity index 100% rename from sql/core/src/test/resources/dates.csv rename to sql/core/src/test/resources/test-data/dates.csv diff --git a/sql/core/src/test/resources/dec-in-fixed-len.parquet b/sql/core/src/test/resources/test-data/dec-in-fixed-len.parquet similarity index 100% rename from sql/core/src/test/resources/dec-in-fixed-len.parquet rename to sql/core/src/test/resources/test-data/dec-in-fixed-len.parquet diff --git a/sql/core/src/test/resources/dec-in-i32.parquet b/sql/core/src/test/resources/test-data/dec-in-i32.parquet similarity index 100% rename from sql/core/src/test/resources/dec-in-i32.parquet rename to sql/core/src/test/resources/test-data/dec-in-i32.parquet diff --git a/sql/core/src/test/resources/dec-in-i64.parquet b/sql/core/src/test/resources/test-data/dec-in-i64.parquet similarity index 100% rename from sql/core/src/test/resources/dec-in-i64.parquet rename to sql/core/src/test/resources/test-data/dec-in-i64.parquet diff --git a/sql/core/src/test/resources/decimal.csv b/sql/core/src/test/resources/test-data/decimal.csv similarity index 100% rename from sql/core/src/test/resources/decimal.csv rename to sql/core/src/test/resources/test-data/decimal.csv diff --git a/sql/core/src/test/resources/disable_comments.csv b/sql/core/src/test/resources/test-data/disable_comments.csv similarity index 100% rename from sql/core/src/test/resources/disable_comments.csv rename to sql/core/src/test/resources/test-data/disable_comments.csv diff --git a/sql/core/src/test/resources/empty.csv b/sql/core/src/test/resources/test-data/empty.csv similarity index 100% rename from sql/core/src/test/resources/empty.csv rename to sql/core/src/test/resources/test-data/empty.csv diff --git a/sql/core/src/test/resources/nested-array-struct.parquet b/sql/core/src/test/resources/test-data/nested-array-struct.parquet similarity index 100% rename from sql/core/src/test/resources/nested-array-struct.parquet rename to sql/core/src/test/resources/test-data/nested-array-struct.parquet diff --git a/sql/core/src/test/resources/numbers.csv b/sql/core/src/test/resources/test-data/numbers.csv similarity index 100% rename from sql/core/src/test/resources/numbers.csv rename to sql/core/src/test/resources/test-data/numbers.csv diff --git a/sql/core/src/test/resources/old-repeated-int.parquet b/sql/core/src/test/resources/test-data/old-repeated-int.parquet similarity index 100% rename from sql/core/src/test/resources/old-repeated-int.parquet rename to sql/core/src/test/resources/test-data/old-repeated-int.parquet diff --git a/sql/core/src/test/resources/old-repeated-message.parquet b/sql/core/src/test/resources/test-data/old-repeated-message.parquet similarity index 100% rename from sql/core/src/test/resources/old-repeated-message.parquet rename to sql/core/src/test/resources/test-data/old-repeated-message.parquet diff --git a/sql/core/src/test/resources/parquet-thrift-compat.snappy.parquet b/sql/core/src/test/resources/test-data/parquet-thrift-compat.snappy.parquet similarity index 100% rename from sql/core/src/test/resources/parquet-thrift-compat.snappy.parquet rename to sql/core/src/test/resources/test-data/parquet-thrift-compat.snappy.parquet diff --git a/sql/core/src/test/resources/proto-repeated-string.parquet b/sql/core/src/test/resources/test-data/proto-repeated-string.parquet similarity index 100% rename from sql/core/src/test/resources/proto-repeated-string.parquet rename to sql/core/src/test/resources/test-data/proto-repeated-string.parquet diff --git a/sql/core/src/test/resources/proto-repeated-struct.parquet b/sql/core/src/test/resources/test-data/proto-repeated-struct.parquet similarity index 100% rename from sql/core/src/test/resources/proto-repeated-struct.parquet rename to sql/core/src/test/resources/test-data/proto-repeated-struct.parquet diff --git a/sql/core/src/test/resources/proto-struct-with-array-many.parquet b/sql/core/src/test/resources/test-data/proto-struct-with-array-many.parquet similarity index 100% rename from sql/core/src/test/resources/proto-struct-with-array-many.parquet rename to sql/core/src/test/resources/test-data/proto-struct-with-array-many.parquet diff --git a/sql/core/src/test/resources/proto-struct-with-array.parquet b/sql/core/src/test/resources/test-data/proto-struct-with-array.parquet similarity index 100% rename from sql/core/src/test/resources/proto-struct-with-array.parquet rename to sql/core/src/test/resources/test-data/proto-struct-with-array.parquet diff --git a/sql/core/src/test/resources/simple_sparse.csv b/sql/core/src/test/resources/test-data/simple_sparse.csv similarity index 100% rename from sql/core/src/test/resources/simple_sparse.csv rename to sql/core/src/test/resources/test-data/simple_sparse.csv diff --git a/sql/core/src/test/resources/text-partitioned/year=2014/data.txt b/sql/core/src/test/resources/test-data/text-partitioned/year=2014/data.txt similarity index 100% rename from sql/core/src/test/resources/text-partitioned/year=2014/data.txt rename to sql/core/src/test/resources/test-data/text-partitioned/year=2014/data.txt diff --git a/sql/core/src/test/resources/text-partitioned/year=2015/data.txt b/sql/core/src/test/resources/test-data/text-partitioned/year=2015/data.txt similarity index 100% rename from sql/core/src/test/resources/text-partitioned/year=2015/data.txt rename to sql/core/src/test/resources/test-data/text-partitioned/year=2015/data.txt diff --git a/sql/core/src/test/resources/text-suite.txt b/sql/core/src/test/resources/test-data/text-suite.txt similarity index 100% rename from sql/core/src/test/resources/text-suite.txt rename to sql/core/src/test/resources/test-data/text-suite.txt diff --git a/sql/core/src/test/resources/text-suite2.txt b/sql/core/src/test/resources/test-data/text-suite2.txt similarity index 100% rename from sql/core/src/test/resources/text-suite2.txt rename to sql/core/src/test/resources/test-data/text-suite2.txt diff --git a/sql/core/src/test/resources/unescaped-quotes.csv b/sql/core/src/test/resources/test-data/unescaped-quotes.csv similarity index 100% rename from sql/core/src/test/resources/unescaped-quotes.csv rename to sql/core/src/test/resources/test-data/unescaped-quotes.csv diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala index 92aa7b95434d..3454caff6b82 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala @@ -87,6 +87,16 @@ class DataFrameAggregateSuite extends QueryTest with SharedSQLContext { ) } + test("SPARK-17124 agg should be ordering preserving") { + val df = spark.range(2) + val ret = df.groupBy("id").agg("id" -> "sum", "id" -> "count", "id" -> "min") + assert(ret.schema.map(_.name) == Seq("id", "sum(id)", "count(id)", "min(id)")) + checkAnswer( + ret, + Row(0, 0, 1, 0) :: Row(1, 1, 1, 1) :: Nil + ) + } + test("rollup") { checkAnswer( courseSales.rollup("course", "year").sum("earnings"), @@ -467,6 +477,18 @@ class DataFrameAggregateSuite extends QueryTest with SharedSQLContext { assert(error.message.contains("collect_set() cannot have map type data")) } + test("SPARK-17641: collect functions should not collect null values") { + val df = Seq(("1", 2), (null, 2), ("1", 4)).toDF("a", "b") + checkAnswer( + df.select(collect_list($"a"), collect_list($"b")), + Seq(Row(Seq("1", "1"), Seq(2, 2, 4))) + ) + checkAnswer( + df.select(collect_set($"a"), collect_set($"b")), + Seq(Row(Seq("1"), Seq(2, 4))) + ) + } + test("SPARK-14664: Decimal sum/avg over window should work.") { checkAnswer( spark.sql("select sum(a) over () from values 1.0, 2.0, 3.0 T(a)"), @@ -475,4 +497,12 @@ class DataFrameAggregateSuite extends QueryTest with SharedSQLContext { spark.sql("select avg(a) over () from values 1.0, 2.0, 3.0 T(a)"), Row(2.0) :: Row(2.0) :: Row(2.0) :: Nil) } + + test("SPARK-17616: distinct aggregate combined with a non-partial aggregate") { + val df = Seq((1, 3, "a"), (1, 2, "b"), (3, 4, "c"), (3, 4, "c"), (3, 5, "d")) + .toDF("x", "y", "z") + checkAnswer( + df.groupBy($"x").agg(countDistinct($"y"), sort_array(collect_list($"z"))), + Seq(Row(1, 2, Seq("a", "b")), Row(3, 2, Seq("c", "c", "d")))) + } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameComplexTypeSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameComplexTypeSuite.scala index 72f676e6225e..1230b921aa27 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameComplexTypeSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameComplexTypeSuite.scala @@ -17,6 +17,7 @@ package org.apache.spark.sql +import org.apache.spark.sql.catalyst.DefinedByConstructorParams import org.apache.spark.sql.functions._ import org.apache.spark.sql.test.SharedSQLContext @@ -58,4 +59,43 @@ class DataFrameComplexTypeSuite extends QueryTest with SharedSQLContext { val nullIntRow = df.selectExpr("i[1]").collect()(0) assert(nullIntRow == org.apache.spark.sql.Row(null)) } + + test("SPARK-15285 Generated SpecificSafeProjection.apply method grows beyond 64KB") { + val ds100_5 = Seq(S100_5()).toDS() + ds100_5.rdd.count + } } + +class S100( + val s1: String = "1", val s2: String = "2", val s3: String = "3", val s4: String = "4", + val s5: String = "5", val s6: String = "6", val s7: String = "7", val s8: String = "8", + val s9: String = "9", val s10: String = "10", val s11: String = "11", val s12: String = "12", + val s13: String = "13", val s14: String = "14", val s15: String = "15", val s16: String = "16", + val s17: String = "17", val s18: String = "18", val s19: String = "19", val s20: String = "20", + val s21: String = "21", val s22: String = "22", val s23: String = "23", val s24: String = "24", + val s25: String = "25", val s26: String = "26", val s27: String = "27", val s28: String = "28", + val s29: String = "29", val s30: String = "30", val s31: String = "31", val s32: String = "32", + val s33: String = "33", val s34: String = "34", val s35: String = "35", val s36: String = "36", + val s37: String = "37", val s38: String = "38", val s39: String = "39", val s40: String = "40", + val s41: String = "41", val s42: String = "42", val s43: String = "43", val s44: String = "44", + val s45: String = "45", val s46: String = "46", val s47: String = "47", val s48: String = "48", + val s49: String = "49", val s50: String = "50", val s51: String = "51", val s52: String = "52", + val s53: String = "53", val s54: String = "54", val s55: String = "55", val s56: String = "56", + val s57: String = "57", val s58: String = "58", val s59: String = "59", val s60: String = "60", + val s61: String = "61", val s62: String = "62", val s63: String = "63", val s64: String = "64", + val s65: String = "65", val s66: String = "66", val s67: String = "67", val s68: String = "68", + val s69: String = "69", val s70: String = "70", val s71: String = "71", val s72: String = "72", + val s73: String = "73", val s74: String = "74", val s75: String = "75", val s76: String = "76", + val s77: String = "77", val s78: String = "78", val s79: String = "79", val s80: String = "80", + val s81: String = "81", val s82: String = "82", val s83: String = "83", val s84: String = "84", + val s85: String = "85", val s86: String = "86", val s87: String = "87", val s88: String = "88", + val s89: String = "89", val s90: String = "90", val s91: String = "91", val s92: String = "92", + val s93: String = "93", val s94: String = "94", val s95: String = "95", val s96: String = "96", + val s97: String = "97", val s98: String = "98", val s99: String = "99", val s100: String = "100") +extends DefinedByConstructorParams + +case class S100_5( + s1: S100 = new S100(), s2: S100 = new S100(), s3: S100 = new S100(), + s4: S100 = new S100(), s5: S100 = new S100()) + + diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameJoinSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameJoinSuite.scala index 4342c039aefc..4abf5e42b9c3 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameJoinSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameJoinSuite.scala @@ -225,4 +225,12 @@ class DataFrameJoinSuite extends QueryTest with SharedSQLContext { Row(1, null) :: Row(null, 2) :: Nil ) } + + test("SPARK-16991: Full outer join followed by inner join produces wrong results") { + val a = Seq((1, 2), (2, 3)).toDF("a", "b") + val b = Seq((2, 5), (3, 4)).toDF("a", "c") + val c = Seq((3, 1)).toDF("a", "d") + val ab = a.join(b, Seq("a"), "fullouter") + checkAnswer(ab.join(c, "a"), Row(3, null, 4, 1) :: Nil) + } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala index 5151532ed2e1..da5c538eace2 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala @@ -1550,4 +1550,19 @@ class DataFrameSuite extends QueryTest with SharedSQLContext { checkAnswer(joined, Row("x", null, null)) checkAnswer(joined.filter($"new".isNull), Row("x", null, null)) } + + test("SPARK-16664: persist with more than 200 columns") { + val size = 201L + val rdd = sparkContext.makeRDD(Seq(Row.fromSeq(Seq.range(0, size)))) + val schemas = List.range(0, size).map(a => StructField("name" + a, LongType, true)) + val df = spark.createDataFrame(rdd, StructType(schemas), false) + assert(df.persist.take(1).apply(0).toSeq(100).asInstanceOf[Long] == 100) + } + + test("copy results for sampling with replacement") { + val df = Seq((1, 0), (2, 0), (3, 0)).toDF("a", "b") + val sampleDf = df.sample(true, 2.00) + val d = sampleDf.withColumn("c", monotonically_increasing_id).select($"c").collect + assert(d.size == d.distinct.size) + } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala index 43cbc03b7aa0..f897cfb26d3f 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala @@ -422,6 +422,31 @@ class DatasetSuite extends QueryTest with SharedSQLContext { 3, 17, 27, 58, 62) } + test("SPARK-16686: Dataset.sample with seed results shouldn't depend on downstream usage") { + val simpleUdf = udf((n: Int) => { + require(n != 1, "simpleUdf shouldn't see id=1!") + 1 + }) + + val df = Seq( + (0, "string0"), + (1, "string1"), + (2, "string2"), + (3, "string3"), + (4, "string4"), + (5, "string5"), + (6, "string6"), + (7, "string7"), + (8, "string8"), + (9, "string9") + ).toDF("id", "stringData") + val sampleDF = df.sample(false, 0.7, 50) + // After sampling, sampleDF doesn't contain id=1. + assert(!sampleDF.select("id").collect.contains(1)) + // simpleUdf should not encounter id=1. + checkAnswer(sampleDF.select(simpleUdf($"id")), List.fill(sampleDF.count.toInt)(Row(1))) + } + test("SPARK-11436: we should rebind right encoder when join 2 datasets") { val ds1 = Seq("1", "2").toDS().as("a") val ds2 = Seq(2, 3).toDS().as("b") @@ -844,6 +869,19 @@ class DatasetSuite extends QueryTest with SharedSQLContext { val ds = spark.createDataset(data)(enc) checkDataset(ds, (("a", "b"), "c"), (null, "d")) } + + test("SPARK-16995: flat mapping on Dataset containing a column created with lit/expr") { + val df = Seq("1").toDF("a") + + import df.sparkSession.implicits._ + + checkDataset( + df.withColumn("b", lit(0)).as[ClassData] + .groupByKey(_.a).flatMapGroups { case (x, iter) => List[Int]() }) + checkDataset( + df.withColumn("b", expr("0")).as[ClassData] + .groupByKey(_.a).flatMapGroups { case (x, iter) => List[Int]() }) + } } case class Generic[T](id: T, value: Double) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/MetadataCacheSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/MetadataCacheSuite.scala index eacf254cd183..98aa447fc056 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/MetadataCacheSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/MetadataCacheSuite.scala @@ -20,6 +20,7 @@ package org.apache.spark.sql import java.io.File import org.apache.spark.SparkException +import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.test.SharedSQLContext /** @@ -85,4 +86,28 @@ class MetadataCacheSuite extends QueryTest with SharedSQLContext { assert(newCount > 0 && newCount < 100) }} } + + test("case sensitivity support in temporary view refresh") { + withSQLConf(SQLConf.CASE_SENSITIVE.key -> "false") { + withTempView("view_refresh") { + withTempPath { (location: File) => + // Create a Parquet directory + spark.range(start = 0, end = 100, step = 1, numPartitions = 3) + .write.parquet(location.getAbsolutePath) + + // Read the directory in + spark.read.parquet(location.getAbsolutePath).createOrReplaceTempView("view_refresh") + + // Delete a file + deleteOneFileInDirectory(location) + intercept[SparkException](sql("select count(*) from view_refresh").first()) + + // Refresh and we should be able to read it again. + spark.catalog.refreshTable("vIeW_reFrEsH") + val newCount = sql("select count(*) from view_refresh").first().getLong(0) + assert(newCount > 0 && newCount < 100) + } + } + } + } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/QueryTest.scala b/sql/core/src/test/scala/org/apache/spark/sql/QueryTest.scala index b15f38c2a71e..f96bd8cc8b5e 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/QueryTest.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/QueryTest.scala @@ -33,7 +33,7 @@ import org.apache.spark.sql.execution.aggregate.TypedAggregateExpression import org.apache.spark.sql.execution.columnar.InMemoryRelation import org.apache.spark.sql.execution.datasources.LogicalRelation import org.apache.spark.sql.execution.streaming.MemoryPlan -import org.apache.spark.sql.types.ObjectType +import org.apache.spark.sql.types.{Metadata, ObjectType} abstract class QueryTest extends PlanTest { @@ -242,9 +242,10 @@ abstract class QueryTest extends PlanTest { case p if p.getClass.getSimpleName == "MetastoreRelation" => return case _: MemoryPlan => return }.transformAllExpressions { - case a: ImperativeAggregate => return + case _: ImperativeAggregate => return case _: TypedAggregateExpression => return case Literal(_, _: ObjectType) => return + case _: UserDefinedGenerator => return } // bypass hive tests before we fix all corner cases in hive module. @@ -266,6 +267,14 @@ abstract class QueryTest extends PlanTest { val normalized1 = logicalPlan.transformAllExpressions { case udf: ScalaUDF => udf.copy(function = null) case gen: UserDefinedGenerator => gen.copy(function = null) + // After SPARK-17356: the JSON representation no longer has the Metadata. We need to remove + // the Metadata from the normalized plan so that we can compare this plan with the + // JSON-deserialzed plan. + case a @ Alias(child, name) if a.explicitMetadata.isDefined => + Alias(child, name)(a.exprId, a.qualifier, Some(Metadata.empty), a.isGenerated) + case a: AttributeReference if a.metadata != Metadata.empty => + AttributeReference(a.name, a.dataType, a.nullable, Metadata.empty)(a.exprId, a.qualifier, + a.isGenerated) } // RDDs/data are not serializable to JSON, so we need to collect LogicalPlans that contains @@ -394,6 +403,9 @@ object QueryTest { sameRows(expectedAnswer, sparkAnswer, isSorted).map { results => s""" |Results do not match for query: + |Timezone: ${TimeZone.getDefault} + |Timezone Env: ${sys.env("TZ")} + | |${df.queryExecution} |== Results == |$results diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala index d9659012fac5..cf250970c6c3 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala @@ -17,10 +17,11 @@ package org.apache.spark.sql +import java.io.File import java.math.MathContext -import java.sql.Timestamp +import java.sql.{Date, Timestamp} -import org.apache.spark.AccumulatorSuite +import org.apache.spark.{AccumulatorSuite, SparkException} import org.apache.spark.sql.catalyst.analysis.UnresolvedException import org.apache.spark.sql.catalyst.expressions.SortOrder import org.apache.spark.sql.catalyst.plans.logical.Aggregate @@ -38,26 +39,6 @@ class SQLQuerySuite extends QueryTest with SharedSQLContext { setupTestData() - test("having clause") { - withTempView("hav") { - Seq(("one", 1), ("two", 2), ("three", 3), ("one", 5)).toDF("k", "v") - .createOrReplaceTempView("hav") - checkAnswer( - sql("SELECT k, sum(v) FROM hav GROUP BY k HAVING sum(v) > 2"), - Row("one", 6) :: Row("three", 3) :: Nil) - } - } - - test("having condition contains grouping column") { - withTempView("hav") { - Seq(("one", 1), ("two", 2), ("three", 3), ("one", 5)).toDF("k", "v") - .createOrReplaceTempView("hav") - checkAnswer( - sql("SELECT count(k) FROM hav GROUP BY v + 1 HAVING v + 1 = 2"), - Row(1) :: Nil) - } - } - test("SPARK-8010: promote numeric to string") { val df = Seq((1, 1)).toDF("key", "value") df.createOrReplaceTempView("src") @@ -464,12 +445,6 @@ class SQLQuerySuite extends QueryTest with SharedSQLContext { Nil) } - test("index into array") { - checkAnswer( - sql("SELECT data, data[0], data[0] + data[1], data[0 + 1] FROM arrayData"), - arrayData.map(d => Row(d.data, d.data(0), d.data(0) + d.data(1), d.data(1))).collect()) - } - test("left semi greater than predicate") { withSQLConf(SQLConf.CROSS_JOINS_ENABLED.key -> "true") { checkAnswer( @@ -491,119 +466,12 @@ class SQLQuerySuite extends QueryTest with SharedSQLContext { ) } - test("index into array of arrays") { - checkAnswer( - sql( - "SELECT nestedData, nestedData[0][0], nestedData[0][0] + nestedData[0][1] FROM arrayData"), - arrayData.map(d => - Row(d.nestedData, - d.nestedData(0)(0), - d.nestedData(0)(0) + d.nestedData(0)(1))).collect().toSeq) - } - test("agg") { checkAnswer( sql("SELECT a, SUM(b) FROM testData2 GROUP BY a"), Seq(Row(1, 3), Row(2, 3), Row(3, 3))) } - test("Group By Ordinal - basic") { - checkAnswer( - sql("SELECT a, sum(b) FROM testData2 GROUP BY 1"), - sql("SELECT a, sum(b) FROM testData2 GROUP BY a")) - - // duplicate group-by columns - checkAnswer( - sql("SELECT a, 1, sum(b) FROM testData2 GROUP BY a, 1"), - sql("SELECT a, 1, sum(b) FROM testData2 GROUP BY a")) - - checkAnswer( - sql("SELECT a, 1, sum(b) FROM testData2 GROUP BY 1, 2"), - sql("SELECT a, 1, sum(b) FROM testData2 GROUP BY a")) - } - - test("Group By Ordinal - non aggregate expressions") { - checkAnswer( - sql("SELECT a, b + 2, count(2) FROM testData2 GROUP BY a, 2"), - sql("SELECT a, b + 2, count(2) FROM testData2 GROUP BY a, b + 2")) - - checkAnswer( - sql("SELECT a, b + 2 as c, count(2) FROM testData2 GROUP BY a, 2"), - sql("SELECT a, b + 2, count(2) FROM testData2 GROUP BY a, b + 2")) - } - - test("Group By Ordinal - non-foldable constant expression") { - checkAnswer( - sql("SELECT a, b, sum(b) FROM testData2 GROUP BY a, b, 1 + 0"), - sql("SELECT a, b, sum(b) FROM testData2 GROUP BY a, b")) - - checkAnswer( - sql("SELECT a, 1, sum(b) FROM testData2 GROUP BY a, 1 + 2"), - sql("SELECT a, 1, sum(b) FROM testData2 GROUP BY a")) - } - - test("Group By Ordinal - alias") { - checkAnswer( - sql("SELECT a, (b + 2) as c, count(2) FROM testData2 GROUP BY a, 2"), - sql("SELECT a, b + 2, count(2) FROM testData2 GROUP BY a, b + 2")) - - checkAnswer( - sql("SELECT a as b, b as a, sum(b) FROM testData2 GROUP BY 1, 2"), - sql("SELECT a, b, sum(b) FROM testData2 GROUP BY a, b")) - } - - test("Group By Ordinal - constants") { - checkAnswer( - sql("SELECT 1, 2, sum(b) FROM testData2 GROUP BY 1, 2"), - sql("SELECT 1, 2, sum(b) FROM testData2")) - } - - test("Group By Ordinal - negative cases") { - intercept[UnresolvedException[Aggregate]] { - sql("SELECT a, b FROM testData2 GROUP BY -1") - } - - intercept[UnresolvedException[Aggregate]] { - sql("SELECT a, b FROM testData2 GROUP BY 3") - } - - var e = intercept[UnresolvedException[Aggregate]]( - sql("SELECT SUM(a) FROM testData2 GROUP BY 1")) - assert(e.getMessage contains - "Invalid call to Group by position: the '1'th column in the select contains " + - "an aggregate function") - - e = intercept[UnresolvedException[Aggregate]]( - sql("SELECT SUM(a) + 1 FROM testData2 GROUP BY 1")) - assert(e.getMessage contains - "Invalid call to Group by position: the '1'th column in the select contains " + - "an aggregate function") - - var ae = intercept[AnalysisException]( - sql("SELECT a, rand(0), sum(b) FROM testData2 GROUP BY a, 2")) - assert(ae.getMessage contains - "nondeterministic expression rand(0) should not appear in grouping expression") - - ae = intercept[AnalysisException]( - sql("SELECT * FROM testData2 GROUP BY a, b, 1")) - assert(ae.getMessage contains - "Group by position: star is not allowed to use in the select list " + - "when using ordinals in group by") - } - - test("Group By Ordinal: spark.sql.groupByOrdinal=false") { - withSQLConf(SQLConf.GROUP_BY_ORDINAL.key -> "false") { - // If spark.sql.groupByOrdinal=false, ignore the position number. - intercept[AnalysisException] { - sql("SELECT a, sum(b) FROM testData2 GROUP BY 1") - } - // '*' is not allowed to use in the select list when users specify ordinals in group by - checkAnswer( - sql("SELECT * FROM testData2 GROUP BY a, b, 1"), - sql("SELECT * FROM testData2 GROUP BY a, b")) - } - } - test("aggregates with nulls") { checkAnswer( sql("SELECT SKEWNESS(a), KURTOSIS(a), MIN(a), MAX(a)," + @@ -670,51 +538,12 @@ class SQLQuerySuite extends QueryTest with SharedSQLContext { sortTest() } - test("limit") { - checkAnswer( - sql("SELECT * FROM testData LIMIT 9 + 1"), - testData.take(10).toSeq) - - checkAnswer( - sql("SELECT * FROM arrayData LIMIT CAST(1 AS Integer)"), - arrayData.collect().take(1).map(Row.fromTuple).toSeq) - - checkAnswer( - sql("SELECT * FROM mapData LIMIT 1"), - mapData.collect().take(1).map(Row.fromTuple).toSeq) - } - - test("non-foldable expressions in LIMIT") { - val e = intercept[AnalysisException] { - sql("SELECT * FROM testData LIMIT key > 3") - }.getMessage - assert(e.contains("The limit expression must evaluate to a constant value, " + - "but got (testdata.`key` > 3)")) - } - - test("Expressions in limit clause are not integer") { - var e = intercept[AnalysisException] { - sql("SELECT * FROM testData LIMIT true") - }.getMessage - assert(e.contains("The limit expression must be integer type, but got boolean")) - - e = intercept[AnalysisException] { - sql("SELECT * FROM testData LIMIT 'a'") - }.getMessage - assert(e.contains("The limit expression must be integer type, but got string")) - } - test("negative in LIMIT or TABLESAMPLE") { val expected = "The limit expression must be equal to or greater than 0, but got -1" var e = intercept[AnalysisException] { sql("SELECT * FROM testData TABLESAMPLE (-1 rows)") }.getMessage assert(e.contains(expected)) - - e = intercept[AnalysisException] { - sql("SELECT * FROM testData LIMIT -1") - }.getMessage - assert(e.contains(expected)) } test("CTE feature") { @@ -1339,134 +1168,12 @@ class SQLQuerySuite extends QueryTest with SharedSQLContext { checkAggregation("SELECT key + 1 + 1, COUNT(*) FROM testData GROUP BY key + 1", false) } - test("Test to check we can use Long.MinValue") { - checkAnswer( - sql(s"SELECT ${Long.MinValue} FROM testData ORDER BY key LIMIT 1"), Row(Long.MinValue) - ) - - checkAnswer( - sql(s"SELECT key FROM testData WHERE key > ${Long.MinValue}"), - (1 to 100).map(Row(_)).toSeq - ) - } - - test("Floating point number format") { - checkAnswer( - sql("SELECT 0.3"), Row(BigDecimal(0.3)) - ) - - checkAnswer( - sql("SELECT -0.8"), Row(BigDecimal(-0.8)) - ) - - checkAnswer( - sql("SELECT .5"), Row(BigDecimal(0.5)) - ) - - checkAnswer( - sql("SELECT -.18"), Row(BigDecimal(-0.18)) - ) - } - - test("Auto cast integer type") { - checkAnswer( - sql(s"SELECT ${Int.MaxValue + 1L}"), Row(Int.MaxValue + 1L) - ) - - checkAnswer( - sql(s"SELECT ${Int.MinValue - 1L}"), Row(Int.MinValue - 1L) - ) - - checkAnswer( - sql("SELECT 9223372036854775808"), Row(new java.math.BigDecimal("9223372036854775808")) - ) - - checkAnswer( - sql("SELECT -9223372036854775809"), Row(new java.math.BigDecimal("-9223372036854775809")) - ) - } - - test("Test to check we can apply sign to expression") { - - checkAnswer( - sql("SELECT -100"), Row(-100) - ) - - checkAnswer( - sql("SELECT +230"), Row(230) - ) - - checkAnswer( - sql("SELECT -5.2"), Row(BigDecimal(-5.2)) - ) - - checkAnswer( - sql("SELECT +6.8e0"), Row(6.8d) - ) - - checkAnswer( - sql("SELECT -key FROM testData WHERE key = 2"), Row(-2) - ) - - checkAnswer( - sql("SELECT +key FROM testData WHERE key = 3"), Row(3) - ) - - checkAnswer( - sql("SELECT -(key + 1) FROM testData WHERE key = 1"), Row(-2) - ) - - checkAnswer( - sql("SELECT - key + 1 FROM testData WHERE key = 10"), Row(-9) - ) - - checkAnswer( - sql("SELECT +(key + 5) FROM testData WHERE key = 5"), Row(10) - ) - - checkAnswer( - sql("SELECT -MAX(key) FROM testData"), Row(-100) - ) - - checkAnswer( - sql("SELECT +MAX(key) FROM testData"), Row(100) - ) - - checkAnswer( - sql("SELECT - (-10)"), Row(10) - ) - - checkAnswer( - sql("SELECT + (-key) FROM testData WHERE key = 32"), Row(-32) - ) - - checkAnswer( - sql("SELECT - (+Max(key)) FROM testData"), Row(-100) - ) - - checkAnswer( - sql("SELECT - - 3"), Row(3) - ) - - checkAnswer( - sql("SELECT - + 20"), Row(-20) - ) - - checkAnswer( - sql("SELEcT - + 45"), Row(-45) - ) - - checkAnswer( - sql("SELECT + + 100"), Row(100) - ) - - checkAnswer( - sql("SELECT - - Max(key) FROM testData"), Row(100) - ) - - checkAnswer( - sql("SELECT + - key FROM testData WHERE key = 33"), Row(-33) - ) + testQuietly( + "SPARK-16748: SparkExceptions during planning should not wrapped in TreeNodeException") { + intercept[SparkException] { + val df = spark.range(0, 5).map(x => (1 / x).toString).toDF("a").orderBy("a") + df.queryExecution.toRdd // force physical planning, but not execution of the plan + } } test("Multiple join") { @@ -1928,21 +1635,18 @@ class SQLQuerySuite extends QueryTest with SharedSQLContext { e = intercept[AnalysisException] { sql(s"select id from `com.databricks.spark.avro`.`file_path`") } - assert(e.message.contains("Failed to find data source: com.databricks.spark.avro. " + - "Please use Spark package http://spark-packages.org/package/databricks/spark-avro")) + assert(e.message.contains("Failed to find data source: com.databricks.spark.avro.")) // data source type is case insensitive e = intercept[AnalysisException] { sql(s"select id from Avro.`file_path`") } - assert(e.message.contains("Failed to find data source: avro. Please use Spark package " + - "http://spark-packages.org/package/databricks/spark-avro")) + assert(e.message.contains("Failed to find data source: avro.")) e = intercept[AnalysisException] { sql(s"select id from avro.`file_path`") } - assert(e.message.contains("Failed to find data source: avro. Please use Spark package " + - "http://spark-packages.org/package/databricks/spark-avro")) + assert(e.message.contains("Failed to find data source: avro.")) e = intercept[AnalysisException] { sql(s"select id from `org.apache.spark.sql.sources.HadoopFsRelationProvider`.`file_path`") @@ -1987,15 +1691,6 @@ class SQLQuerySuite extends QueryTest with SharedSQLContext { } } - test("SPARK-11032: resolve having correctly") { - withTempView("src") { - Seq(1 -> "a").toDF("i", "j").createOrReplaceTempView("src") - checkAnswer( - sql("SELECT MIN(t.i) FROM (SELECT * FROM src WHERE i > 0) t HAVING(COUNT(1) > 0)"), - Row(1)) - } - } - test("SPARK-11303: filter should not be pushed down into sample") { val df = spark.range(100) List(true, false).foreach { withReplacement => @@ -2495,70 +2190,6 @@ class SQLQuerySuite extends QueryTest with SharedSQLContext { } } - test("order by ordinal number") { - checkAnswer( - sql("SELECT * FROM testData2 ORDER BY 1 DESC"), - sql("SELECT * FROM testData2 ORDER BY a DESC")) - // If the position is not an integer, ignore it. - checkAnswer( - sql("SELECT * FROM testData2 ORDER BY 1 + 0 DESC, b ASC"), - sql("SELECT * FROM testData2 ORDER BY b ASC")) - checkAnswer( - sql("SELECT * FROM testData2 ORDER BY 1 DESC, b ASC"), - sql("SELECT * FROM testData2 ORDER BY a DESC, b ASC")) - checkAnswer( - sql("SELECT * FROM testData2 SORT BY 1 DESC, 2"), - sql("SELECT * FROM testData2 SORT BY a DESC, b ASC")) - checkAnswer( - sql("SELECT * FROM testData2 ORDER BY 1 ASC, b ASC"), - Seq(Row(1, 1), Row(1, 2), Row(2, 1), Row(2, 2), Row(3, 1), Row(3, 2))) - } - - test("order by ordinal number - negative cases") { - intercept[UnresolvedException[SortOrder]] { - sql("SELECT * FROM testData2 ORDER BY 0") - } - intercept[UnresolvedException[SortOrder]] { - sql("SELECT * FROM testData2 ORDER BY -1 DESC, b ASC") - } - intercept[UnresolvedException[SortOrder]] { - sql("SELECT * FROM testData2 ORDER BY 3 DESC, b ASC") - } - } - - test("order by ordinal number with conf spark.sql.orderByOrdinal=false") { - withSQLConf(SQLConf.ORDER_BY_ORDINAL.key -> "false") { - // If spark.sql.orderByOrdinal=false, ignore the position number. - checkAnswer( - sql("SELECT * FROM testData2 ORDER BY 1 DESC, b ASC"), - sql("SELECT * FROM testData2 ORDER BY b ASC")) - } - } - - test("natural join") { - val df1 = Seq(("one", 1), ("two", 2), ("three", 3)).toDF("k", "v1") - val df2 = Seq(("one", 1), ("two", 22), ("one", 5)).toDF("k", "v2") - withTempView("nt1", "nt2") { - df1.createOrReplaceTempView("nt1") - df2.createOrReplaceTempView("nt2") - checkAnswer( - sql("SELECT * FROM nt1 natural join nt2 where k = \"one\""), - Row("one", 1, 1) :: Row("one", 1, 5) :: Nil) - - checkAnswer( - sql("SELECT * FROM nt1 natural left join nt2 order by v1, v2"), - Row("one", 1, 1) :: Row("one", 1, 5) :: Row("two", 2, 22) :: Row("three", 3, null) :: Nil) - - checkAnswer( - sql("SELECT * FROM nt1 natural right join nt2 order by v1, v2"), - Row("one", 1, 1) :: Row("one", 1, 5) :: Row("two", 2, 22) :: Nil) - - checkAnswer( - sql("SELECT count(*) FROM nt1 natural full outer join nt2"), - Row(4) :: Nil) - } - } - test("join with using clause") { val df1 = Seq(("r1c1", "r1c2", "t1r1c3"), ("r2c1", "r2c2", "t1r2c3"), ("r3c1x", "r3c2", "t1r3c3")).toDF("c1", "c2", "c3") @@ -2942,6 +2573,14 @@ class SQLQuerySuite extends QueryTest with SharedSQLContext { Row(s"$expected") :: Nil) } + test("SPARK-16975: Column-partition path starting '_' should be handled correctly") { + withTempDir { dir => + val parquetDir = new File(dir, "parquet").getCanonicalPath + spark.range(10).withColumn("_col", $"id").write.partitionBy("_col").save(parquetDir) + spark.read.parquet(parquetDir) + } + } + test("SPARK-16644: Aggregate should not put aggregate expressions to constraints") { withTable("tbl") { sql("CREATE TABLE tbl(a INT, b INT) USING parquet") @@ -2973,4 +2612,24 @@ class SQLQuerySuite extends QueryTest with SharedSQLContext { data.selectExpr("`part.col1`", "`col.1`")) } } + + test("SPARK-17515: CollectLimit.execute() should perform per-partition limits") { + val numRecordsRead = spark.sparkContext.longAccumulator + spark.range(1, 100, 1, numPartitions = 10).map { x => + numRecordsRead.add(1) + x + }.limit(1).queryExecution.toRdd.count() + assert(numRecordsRead.value === 10) + } + + test("CREATE TABLE USING should not fail if a same-name temp view exists") { + withTable("same_name") { + withTempView("same_name") { + spark.range(10).createTempView("same_name") + sql("CREATE TABLE same_name(i int) USING json") + checkAnswer(spark.table("same_name"), spark.range(10).toDF()) + assert(spark.table("default.same_name").collect().isEmpty) + } + } + } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SQLQueryTestSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SQLQueryTestSuite.scala new file mode 100644 index 000000000000..55d5a56f1040 --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/SQLQueryTestSuite.scala @@ -0,0 +1,284 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql + +import java.io.File +import java.util.{Locale, TimeZone} + +import scala.util.control.NonFatal + +import org.apache.spark.sql.catalyst.planning.PhysicalOperation +import org.apache.spark.sql.catalyst.plans.logical._ +import org.apache.spark.sql.catalyst.rules.RuleExecutor +import org.apache.spark.sql.catalyst.util.{fileToString, stringToFile} +import org.apache.spark.sql.test.SharedSQLContext +import org.apache.spark.sql.types.StructType + +/** + * End-to-end test cases for SQL queries. + * + * Each case is loaded from a file in "spark/sql/core/src/test/resources/sql-tests/inputs". + * Each case has a golden result file in "spark/sql/core/src/test/resources/sql-tests/results". + * + * To run the entire test suite: + * {{{ + * build/sbt "sql/test-only *SQLQueryTestSuite" + * }}} + * + * To run a single test file upon change: + * {{{ + * build/sbt "~sql/test-only *SQLQueryTestSuite -- -z inline-table.sql" + * }}} + * + * To re-generate golden files, run: + * {{{ + * SPARK_GENERATE_GOLDEN_FILES=1 build/sbt "sql/test-only *SQLQueryTestSuite" + * }}} + * + * The format for input files is simple: + * 1. A list of SQL queries separated by semicolon. + * 2. Lines starting with -- are treated as comments and ignored. + * + * For example: + * {{{ + * -- this is a comment + * select 1, -1; + * select current_date; + * }}} + * + * The format for golden result files look roughly like: + * {{{ + * -- some header information + * + * -- !query 0 + * select 1, -1 + * -- !query 0 schema + * struct<...schema...> + * -- !query 0 output + * ... data row 1 ... + * ... data row 2 ... + * ... + * + * -- !query 1 + * ... + * }}} + */ +class SQLQueryTestSuite extends QueryTest with SharedSQLContext { + + private val regenerateGoldenFiles: Boolean = System.getenv("SPARK_GENERATE_GOLDEN_FILES") == "1" + + private val baseResourcePath = { + // If regenerateGoldenFiles is true, we must be running this in SBT and we use hard-coded + // relative path. Otherwise, we use classloader's getResource to find the location. + if (regenerateGoldenFiles) { + java.nio.file.Paths.get("src", "test", "resources", "sql-tests").toFile + } else { + val res = getClass.getClassLoader.getResource("sql-tests") + new File(res.getFile) + } + } + + private val inputFilePath = new File(baseResourcePath, "inputs").getAbsolutePath + private val goldenFilePath = new File(baseResourcePath, "results").getAbsolutePath + + /** List of test cases to ignore, in lower cases. */ + private val blackList = Set( + "blacklist.sql" // Do NOT remove this one. It is here to test the blacklist functionality. + ) + + // Create all the test cases. + listTestCases().foreach(createScalaTestCase) + + /** A test case. */ + private case class TestCase(name: String, inputFile: String, resultFile: String) + + /** A single SQL query's output. */ + private case class QueryOutput(sql: String, schema: String, output: String) { + def toString(queryIndex: Int): String = { + // We are explicitly not using multi-line string due to stripMargin removing "|" in output. + s"-- !query $queryIndex\n" + + sql + "\n" + + s"-- !query $queryIndex schema\n" + + schema + "\n" + + s"-- !query $queryIndex output\n" + + output + } + } + + private def createScalaTestCase(testCase: TestCase): Unit = { + if (blackList.contains(testCase.name.toLowerCase)) { + // Create a test case to ignore this case. + ignore(testCase.name) { /* Do nothing */ } + } else { + // Create a test case to run this case. + test(testCase.name) { runTest(testCase) } + } + } + + /** Run a test case. */ + private def runTest(testCase: TestCase): Unit = { + val input = fileToString(new File(testCase.inputFile)) + + // List of SQL queries to run + val queries: Seq[String] = { + val cleaned = input.split("\n").filterNot(_.startsWith("--")).mkString("\n") + // note: this is not a robust way to split queries using semicolon, but works for now. + cleaned.split("(?<=[^\\\\]);").map(_.trim).filter(_ != "").toSeq + } + + // Create a local SparkSession to have stronger isolation between different test cases. + // This does not isolate catalog changes. + val localSparkSession = spark.newSession() + loadTestData(localSparkSession) + + // Run the SQL queries preparing them for comparison. + val outputs: Seq[QueryOutput] = queries.map { sql => + val (schema, output) = getNormalizedResult(localSparkSession, sql) + // We might need to do some query canonicalization in the future. + QueryOutput( + sql = sql, + schema = schema.catalogString, + output = output.mkString("\n").trim) + } + + if (regenerateGoldenFiles) { + // Again, we are explicitly not using multi-line string due to stripMargin removing "|". + val goldenOutput = { + s"-- Automatically generated by ${getClass.getSimpleName}\n" + + s"-- Number of queries: ${outputs.size}\n\n\n" + + outputs.zipWithIndex.map{case (qr, i) => qr.toString(i)}.mkString("\n\n\n") + "\n" + } + stringToFile(new File(testCase.resultFile), goldenOutput) + } + + // Read back the golden file. + val expectedOutputs: Seq[QueryOutput] = { + val goldenOutput = fileToString(new File(testCase.resultFile)) + val segments = goldenOutput.split("-- !query.+\n") + + // each query has 3 segments, plus the header + assert(segments.size == outputs.size * 3 + 1, + s"Expected ${outputs.size * 3 + 1} blocks in result file but got ${segments.size}. " + + s"Try regenerate the result files.") + Seq.tabulate(outputs.size) { i => + QueryOutput( + sql = segments(i * 3 + 1).trim, + schema = segments(i * 3 + 2).trim, + output = segments(i * 3 + 3).trim + ) + } + } + + // Compare results. + assertResult(expectedOutputs.size, s"Number of queries should be ${expectedOutputs.size}") { + outputs.size + } + + outputs.zip(expectedOutputs).zipWithIndex.foreach { case ((output, expected), i) => + assertResult(expected.sql, s"SQL query did not match for query #$i\n${expected.sql}") { + output.sql + } + assertResult(expected.schema, s"Schema did not match for query #$i\n${expected.sql}") { + output.schema + } + assertResult(expected.output, s"Result dit not match for query #$i\n${expected.sql}") { + output.output + } + } + } + + /** Executes a query and returns the result as (schema of the output, normalized output). */ + private def getNormalizedResult(session: SparkSession, sql: String): (StructType, Seq[String]) = { + // Returns true if the plan is supposed to be sorted. + def isSorted(plan: LogicalPlan): Boolean = plan match { + case _: Join | _: Aggregate | _: Generate | _: Sample | _: Distinct => false + case PhysicalOperation(_, _, Sort(_, true, _)) => true + case _ => plan.children.iterator.exists(isSorted) + } + + try { + val df = session.sql(sql) + val schema = df.schema + val answer = df.queryExecution.hiveResultString() + + // If the output is not pre-sorted, sort it. + if (isSorted(df.queryExecution.analyzed)) (schema, answer) else (schema, answer.sorted) + + } catch { + case NonFatal(e) => + // If there is an exception, put the exception class followed by the message. + (StructType(Seq.empty), Seq(e.getClass.getName, e.getMessage)) + } + } + + private def listTestCases(): Seq[TestCase] = { + listFilesRecursively(new File(inputFilePath)).map { file => + val resultFile = file.getAbsolutePath.replace(inputFilePath, goldenFilePath) + ".out" + TestCase(file.getName, file.getAbsolutePath, resultFile) + } + } + + /** Returns all the files (not directories) in a directory, recursively. */ + private def listFilesRecursively(path: File): Seq[File] = { + val (dirs, files) = path.listFiles().partition(_.isDirectory) + files ++ dirs.flatMap(listFilesRecursively) + } + + /** Load built-in test tables into the SparkSession. */ + private def loadTestData(session: SparkSession): Unit = { + import session.implicits._ + + (1 to 100).map(i => (i, i.toString)).toDF("key", "value").createOrReplaceTempView("testdata") + + ((Seq(1, 2, 3), Seq(Seq(1, 2, 3))) :: (Seq(2, 3, 4), Seq(Seq(2, 3, 4))) :: Nil) + .toDF("arraycol", "nestedarraycol") + .createOrReplaceTempView("arraydata") + + (Tuple1(Map(1 -> "a1", 2 -> "b1", 3 -> "c1", 4 -> "d1", 5 -> "e1")) :: + Tuple1(Map(1 -> "a2", 2 -> "b2", 3 -> "c2", 4 -> "d2")) :: + Tuple1(Map(1 -> "a3", 2 -> "b3", 3 -> "c3")) :: + Tuple1(Map(1 -> "a4", 2 -> "b4")) :: + Tuple1(Map(1 -> "a5")) :: Nil) + .toDF("mapcol") + .createOrReplaceTempView("mapdata") + } + + private val originalTimeZone = TimeZone.getDefault + private val originalLocale = Locale.getDefault + + override def beforeAll(): Unit = { + super.beforeAll() + // Timezone is fixed to America/Los_Angeles for those timezone sensitive tests (timestamp_*) + TimeZone.setDefault(TimeZone.getTimeZone("America/Los_Angeles")) + // Add Locale setting + Locale.setDefault(Locale.US) + RuleExecutor.resetTime() + } + + override def afterAll(): Unit = { + try { + TimeZone.setDefault(originalTimeZone) + Locale.setDefault(originalLocale) + + // For debugging dump some statistics about how much time was spent in various optimizer rules + logWarning(RuleExecutor.dumpTimeSpent()) + } finally { + super.afterAll() + } + } +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SparkSessionBuilderSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SparkSessionBuilderSuite.scala index 418345b9ee8f..386d13d07a95 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/SparkSessionBuilderSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/SparkSessionBuilderSuite.scala @@ -100,6 +100,7 @@ class SparkSessionBuilderSuite extends SparkFunSuite { assert(session.conf.get("key2") == "value2") assert(session.sparkContext.conf.get("key1") == "value1") assert(session.sparkContext.conf.get("key2") == "value2") + assert(session.sparkContext.conf.get("spark.app.name") == "test") session.stop() } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/StringFunctionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/StringFunctionsSuite.scala index 7bea2f6ad0db..9be2de9c7d71 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/StringFunctionsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/StringFunctionsSuite.scala @@ -92,6 +92,18 @@ class StringFunctionsSuite extends QueryTest with SharedSQLContext { Row("300", "100") :: Row("400", "100") :: Row("400-400", "100") :: Nil) } + test("non-matching optional group") { + val df = Seq(Tuple1("aaaac")).toDF("s") + checkAnswer( + df.select(regexp_extract($"s", "(foo)", 1)), + Row("") + ) + checkAnswer( + df.select(regexp_extract($"s", "(a+)(b)?(c)", 2)), + Row("") + ) + } + test("string ascii function") { val df = Seq(("abc", "")).toDF("a", "b") checkAnswer( diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala index 13490c35679a..375da224aaa7 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala @@ -20,7 +20,7 @@ package org.apache.spark.sql.execution import org.apache.spark.rdd.RDD import org.apache.spark.sql.{execution, Row} import org.apache.spark.sql.catalyst.InternalRow -import org.apache.spark.sql.catalyst.expressions.{Ascending, Attribute, Literal, SortOrder} +import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.plans.Inner import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, Repartition} import org.apache.spark.sql.catalyst.plans.physical._ @@ -415,6 +415,44 @@ class PlannerSuite extends SharedSQLContext { } } + test("EnsureRequirements skips sort when required ordering is semantically equal to " + + "existing ordering") { + val exprId: ExprId = NamedExpression.newExprId + val attribute1 = + AttributeReference( + name = "col1", + dataType = LongType, + nullable = false + ) (exprId = exprId, + qualifier = Some("col1_qualifier") + ) + + val attribute2 = + AttributeReference( + name = "col1", + dataType = LongType, + nullable = false + ) (exprId = exprId) + + val orderingA1 = SortOrder(attribute1, Ascending) + val orderingA2 = SortOrder(attribute2, Ascending) + + assert(orderingA1 != orderingA2, s"$orderingA1 should NOT equal to $orderingA2") + assert(orderingA1.semanticEquals(orderingA2), + s"$orderingA1 should be semantically equal to $orderingA2") + + val inputPlan = DummySparkPlan( + children = DummySparkPlan(outputOrdering = Seq(orderingA1)) :: Nil, + requiredChildOrdering = Seq(Seq(orderingA2)), + requiredChildDistribution = Seq(UnspecifiedDistribution) + ) + val outputPlan = EnsureRequirements(spark.sessionState.conf).apply(inputPlan) + assertDistributionRequirementsAreSatisfied(outputPlan) + if (outputPlan.collect { case s: SortExec => true }.nonEmpty) { + fail(s"No sorts should have been added:\n$outputPlan") + } + } + // This is a regression test for SPARK-11135 test("EnsureRequirements adds sort when required ordering isn't a prefix of existing ordering") { val orderingA = SortOrder(Literal(1), Ascending) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/TakeOrderedAndProjectSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/TakeOrderedAndProjectSuite.scala index 3217e34bd8ad..7e317a4d8026 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/TakeOrderedAndProjectSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/TakeOrderedAndProjectSuite.scala @@ -59,7 +59,7 @@ class TakeOrderedAndProjectSuite extends SparkPlanTest with SharedSQLContext { checkThatPlansAgree( generateRandomInputData(), input => - noOpFilter(TakeOrderedAndProjectExec(limit, sortOrder, None, input)), + noOpFilter(TakeOrderedAndProjectExec(limit, sortOrder, input.output, input)), input => GlobalLimitExec(limit, LocalLimitExec(limit, @@ -74,7 +74,7 @@ class TakeOrderedAndProjectSuite extends SparkPlanTest with SharedSQLContext { generateRandomInputData(), input => noOpFilter( - TakeOrderedAndProjectExec(limit, sortOrder, Some(Seq(input.output.last)), input)), + TakeOrderedAndProjectExec(limit, sortOrder, Seq(input.output.last), input)), input => GlobalLimitExec(limit, LocalLimitExec(limit, diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/benchmark/BenchmarkWideTable.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/benchmark/BenchmarkWideTable.scala new file mode 100644 index 000000000000..9dcaca0ca93e --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/benchmark/BenchmarkWideTable.scala @@ -0,0 +1,53 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.benchmark + +import org.apache.spark.util.Benchmark + + +/** + * Benchmark to measure performance for wide table. + * To run this: + * build/sbt "sql/test-only *benchmark.BenchmarkWideTable" + * + * Benchmarks in this file are skipped in normal builds. + */ +class BenchmarkWideTable extends BenchmarkBase { + + ignore("project on wide table") { + val N = 1 << 20 + val df = sparkSession.range(N) + val columns = (0 until 400).map{ i => s"id as id$i"} + val benchmark = new Benchmark("projection on wide table", N) + benchmark.addCase("wide table", numIters = 5) { iter => + df.selectExpr(columns : _*).queryExecution.toRdd.count() + } + benchmark.run() + + /** + * Here are some numbers with different split threshold: + * + * Split threshold methods Rate(M/s) Per Row(ns) + * 10 400 0.4 2279 + * 100 200 0.6 1554 + * 1k 37 0.9 1116 + * 8k 5 0.5 2025 + * 64k 1 0.0 21649 + */ + } +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/columnar/InMemoryColumnarQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/columnar/InMemoryColumnarQuerySuite.scala index af3ed14c122d..937839644ad5 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/columnar/InMemoryColumnarQuerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/columnar/InMemoryColumnarQuerySuite.scala @@ -227,7 +227,8 @@ class InMemoryColumnarQuerySuite extends QueryTest with SharedSQLContext { val columnTypes1 = List.fill(length1)(IntegerType) val columnarIterator1 = GenerateColumnAccessor.generate(columnTypes1) - val length2 = 10000 + // SPARK-16664: the limit of janino is 8117 + val length2 = 8117 val columnTypes2 = List.fill(length2)(IntegerType) val columnarIterator2 = GenerateColumnAccessor.generate(columnTypes2) } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/command/DDLCommandSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/command/DDLCommandSuite.scala index 7b96f4c99ab5..8d74884df927 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/command/DDLCommandSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/command/DDLCommandSuite.scala @@ -564,6 +564,14 @@ class DDLCommandSuite extends PlanTest { comparePlans(parsed2, expected2) } + test("alter table: recover partitions") { + val sql = "ALTER TABLE table_name RECOVER PARTITIONS" + val parsed = parser.parsePlan(sql) + val expected = AlterTableRecoverPartitionsCommand( + TableIdentifier("table_name", None)) + comparePlans(parsed, expected) + } + test("alter view: add partition (not supported)") { assertUnsupported( """ diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/command/DDLSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/command/DDLSuite.scala index f2ec393c30ec..1f5492e8a0b9 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/command/DDLSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/command/DDLSuite.scala @@ -111,10 +111,6 @@ class DDLSuite extends QueryTest with SharedSQLContext with BeforeAndAfterEach { catalog.createPartitions(tableName, Seq(part), ignoreIfExists = false) } - private def appendTrailingSlash(path: String): String = { - if (!path.endsWith(File.separator)) path + File.separator else path - } - test("the qualified path of a database is stored in the catalog") { val catalog = spark.sessionState.catalog @@ -122,18 +118,19 @@ class DDLSuite extends QueryTest with SharedSQLContext with BeforeAndAfterEach { val path = tmpDir.toString // The generated temp path is not qualified. assert(!path.startsWith("file:/")) - sql(s"CREATE DATABASE db1 LOCATION '$path'") + val uri = tmpDir.toURI + sql(s"CREATE DATABASE db1 LOCATION '$uri'") val pathInCatalog = new Path(catalog.getDatabaseMetadata("db1").locationUri).toUri assert("file" === pathInCatalog.getScheme) - val expectedPath = if (path.endsWith(File.separator)) path.dropRight(1) else path - assert(expectedPath === pathInCatalog.getPath) + val expectedPath = new Path(path).toUri + assert(expectedPath.getPath === pathInCatalog.getPath) withSQLConf(SQLConf.WAREHOUSE_PATH.key -> path) { sql(s"CREATE DATABASE db2") - val pathInCatalog = new Path(catalog.getDatabaseMetadata("db2").locationUri).toUri - assert("file" === pathInCatalog.getScheme) - val expectedPath = appendTrailingSlash(spark.sessionState.conf.warehousePath) + "db2.db" - assert(expectedPath === pathInCatalog.getPath) + val pathInCatalog2 = new Path(catalog.getDatabaseMetadata("db2").locationUri).toUri + assert("file" === pathInCatalog2.getScheme) + val expectedPath2 = new Path(spark.sessionState.conf.warehousePath + "/" + "db2.db").toUri + assert(expectedPath2.getPath === pathInCatalog2.getPath) } sql("DROP DATABASE db1") @@ -141,6 +138,13 @@ class DDLSuite extends QueryTest with SharedSQLContext with BeforeAndAfterEach { } } + private def makeQualifiedPath(path: String): String = { + // copy-paste from SessionCatalog + val hadoopPath = new Path(path) + val fs = hadoopPath.getFileSystem(sparkContext.hadoopConfiguration) + fs.makeQualified(hadoopPath).toString + } + test("Create/Drop Database") { withTempDir { tmpDir => val path = tmpDir.toString @@ -154,8 +158,7 @@ class DDLSuite extends QueryTest with SharedSQLContext with BeforeAndAfterEach { sql(s"CREATE DATABASE $dbName") val db1 = catalog.getDatabaseMetadata(dbNameWithoutBackTicks) - val expectedLocation = - "file:" + appendTrailingSlash(path) + s"$dbNameWithoutBackTicks.db" + val expectedLocation = makeQualifiedPath(path + "/" + s"$dbNameWithoutBackTicks.db") assert(db1 == CatalogDatabase( dbNameWithoutBackTicks, "", @@ -181,8 +184,8 @@ class DDLSuite extends QueryTest with SharedSQLContext with BeforeAndAfterEach { sql(s"CREATE DATABASE $dbName") val db1 = catalog.getDatabaseMetadata(dbName) val expectedLocation = - "file:" + appendTrailingSlash(System.getProperty("user.dir")) + - s"spark-warehouse/$dbName.db" + makeQualifiedPath(s"${System.getProperty("user.dir")}/spark-warehouse" + + "/" + s"$dbName.db") assert(db1 == CatalogDatabase( dbName, "", @@ -200,17 +203,17 @@ class DDLSuite extends QueryTest with SharedSQLContext with BeforeAndAfterEach { val catalog = spark.sessionState.catalog val databaseNames = Seq("db1", "`database`") withTempDir { tmpDir => - val path = tmpDir.toString - val dbPath = "file:" + path + val path = new Path(tmpDir.toString).toUri.toString databaseNames.foreach { dbName => try { val dbNameWithoutBackTicks = cleanIdentifier(dbName) sql(s"CREATE DATABASE $dbName Location '$path'") val db1 = catalog.getDatabaseMetadata(dbNameWithoutBackTicks) + val expPath = makeQualifiedPath(tmpDir.toString) assert(db1 == CatalogDatabase( dbNameWithoutBackTicks, "", - if (dbPath.endsWith(File.separator)) dbPath.dropRight(1) else dbPath, + expPath, Map.empty)) sql(s"DROP DATABASE $dbName CASCADE") assert(!catalog.databaseExists(dbNameWithoutBackTicks)) @@ -233,8 +236,7 @@ class DDLSuite extends QueryTest with SharedSQLContext with BeforeAndAfterEach { val dbNameWithoutBackTicks = cleanIdentifier(dbName) sql(s"CREATE DATABASE $dbName") val db1 = catalog.getDatabaseMetadata(dbNameWithoutBackTicks) - val expectedLocation = - "file:" + appendTrailingSlash(path) + s"$dbNameWithoutBackTicks.db" + val expectedLocation = makeQualifiedPath(path + "/" + s"$dbNameWithoutBackTicks.db") assert(db1 == CatalogDatabase( dbNameWithoutBackTicks, "", @@ -275,7 +277,7 @@ class DDLSuite extends QueryTest with SharedSQLContext with BeforeAndAfterEach { databaseNames.foreach { dbName => try { val dbNameWithoutBackTicks = cleanIdentifier(dbName) - val location = "file:" + appendTrailingSlash(path) + s"$dbNameWithoutBackTicks.db" + val location = makeQualifiedPath(path + "/" + s"$dbNameWithoutBackTicks.db") sql(s"CREATE DATABASE $dbName") @@ -393,6 +395,17 @@ class DDLSuite extends QueryTest with SharedSQLContext with BeforeAndAfterEach { assert(catalog.getTableMetadata(tableIdent1) === expectedTable) } + test("Analyze in-memory cataloged tables(SimpleCatalogRelation)") { + withTable("tbl") { + sql("CREATE TABLE tbl(a INT, b INT) USING parquet") + val e = intercept[AnalysisException] { + sql("ANALYZE TABLE tbl COMPUTE STATISTICS") + }.getMessage + assert(e.contains("ANALYZE TABLE is only supported for Hive tables, " + + "but 'tbl' is a SimpleCatalogRelation")) + } + } + test("create table using") { val catalog = spark.sessionState.catalog withTable("tbl") { @@ -436,7 +449,8 @@ class DDLSuite extends QueryTest with SharedSQLContext with BeforeAndAfterEach { } test("create temporary view using") { - val csvFile = Thread.currentThread().getContextClassLoader.getResource("cars.csv").toString() + val csvFile = + Thread.currentThread().getContextClassLoader.getResource("test-data/cars.csv").toString withView("testview") { sql(s"CREATE OR REPLACE TEMPORARY VIEW testview (c1: String, c2: String) USING " + "org.apache.spark.sql.execution.datasources.csv.CSVFileFormat " + @@ -628,6 +642,64 @@ class DDLSuite extends QueryTest with SharedSQLContext with BeforeAndAfterEach { testAddPartitions(isDatasourceTable = true) } + test("alter table: recover partitions (sequential)") { + withSQLConf("spark.rdd.parallelListingThreshold" -> "10") { + testRecoverPartitions() + } + } + + test("alter table: recover partition (parallel)") { + withSQLConf("spark.rdd.parallelListingThreshold" -> "1") { + testRecoverPartitions() + } + } + + private def testRecoverPartitions() { + val catalog = spark.sessionState.catalog + // table to alter does not exist + intercept[AnalysisException] { + sql("ALTER TABLE does_not_exist RECOVER PARTITIONS") + } + + val tableIdent = TableIdentifier("tab1") + createTable(catalog, tableIdent) + val part1 = Map("a" -> "1", "b" -> "5") + createTablePartition(catalog, part1, tableIdent) + assert(catalog.listPartitions(tableIdent).map(_.spec).toSet == Set(part1)) + + val part2 = Map("a" -> "2", "b" -> "6") + val root = new Path(catalog.getTableMetadata(tableIdent).storage.locationUri.get) + val fs = root.getFileSystem(spark.sparkContext.hadoopConfiguration) + // valid + fs.mkdirs(new Path(new Path(root, "a=1"), "b=5")) + fs.createNewFile(new Path(new Path(root, "a=1/b=5"), "a.csv")) // file + fs.createNewFile(new Path(new Path(root, "a=1/b=5"), "_SUCCESS")) // file + fs.mkdirs(new Path(new Path(root, "A=2"), "B=6")) + fs.createNewFile(new Path(new Path(root, "A=2/B=6"), "b.csv")) // file + fs.createNewFile(new Path(new Path(root, "A=2/B=6"), "c.csv")) // file + fs.createNewFile(new Path(new Path(root, "A=2/B=6"), ".hiddenFile")) // file + fs.mkdirs(new Path(new Path(root, "A=2/B=6"), "_temporary")) + + // invalid + fs.mkdirs(new Path(new Path(root, "a"), "b")) // bad name + fs.mkdirs(new Path(new Path(root, "b=1"), "a=1")) // wrong order + fs.mkdirs(new Path(root, "a=4")) // not enough columns + fs.createNewFile(new Path(new Path(root, "a=1"), "b=4")) // file + fs.createNewFile(new Path(new Path(root, "a=1"), "_SUCCESS")) // _SUCCESS + fs.mkdirs(new Path(new Path(root, "a=1"), "_temporary")) // _temporary + fs.mkdirs(new Path(new Path(root, "a=1"), ".b=4")) // start with . + + try { + sql("ALTER TABLE tab1 RECOVER PARTITIONS") + assert(catalog.listPartitions(tableIdent).map(_.spec).toSet == + Set(part1, part2)) + assert(catalog.getPartition(tableIdent, part1).parameters("numFiles") == "1") + assert(catalog.getPartition(tableIdent, part2).parameters("numFiles") == "2") + } finally { + fs.delete(root, true) + } + } + test("alter table: add partition is not supported for views") { assertUnsupported("ALTER VIEW dbx.tab1 ADD IF NOT EXISTS PARTITION (b='2')") } @@ -1352,7 +1424,9 @@ class DDLSuite extends QueryTest with SharedSQLContext with BeforeAndAfterEach { (1 to 10).map { i => (i, i) }.toDF("a", "b").createTempView("my_temp_tab") sql(s"CREATE EXTERNAL TABLE my_ext_tab LOCATION '$path'") sql(s"CREATE VIEW my_view AS SELECT 1") - assertUnsupported("TRUNCATE TABLE my_temp_tab") + intercept[NoSuchTableException] { + sql("TRUNCATE TABLE my_temp_tab") + } assertUnsupported("TRUNCATE TABLE my_ext_tab") assertUnsupported("TRUNCATE TABLE my_view") } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/FileCatalogSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/FileCatalogSuite.scala index 0d9ea512729b..4f12df9c4985 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/FileCatalogSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/FileCatalogSuite.scala @@ -18,10 +18,12 @@ package org.apache.spark.sql.execution.datasources import java.io.File +import java.net.URI +import scala.collection.mutable import scala.language.reflectiveCalls -import org.apache.hadoop.fs.Path +import org.apache.hadoop.fs.{FileStatus, Path, RawLocalFileSystem} import org.apache.spark.sql.catalyst.util._ import org.apache.spark.sql.test.SharedSQLContext @@ -67,4 +69,57 @@ class FileCatalogSuite extends SharedSQLContext { } } + + test("ListingFileCatalog: folders that don't exist don't throw exceptions") { + withTempDir { dir => + val deletedFolder = new File(dir, "deleted") + assert(!deletedFolder.exists()) + val catalog1 = new ListingFileCatalog( + spark, Seq(new Path(deletedFolder.getCanonicalPath)), Map.empty, None, + ignoreFileNotFound = true) + // doesn't throw an exception + assert(catalog1.listLeafFiles(catalog1.paths).isEmpty) + } + } + + test("SPARK-17613 - PartitioningAwareFileCatalog: base path w/o '/' at end") { + class MockCatalog( + override val paths: Seq[Path]) extends PartitioningAwareFileCatalog(spark, Map.empty, None) { + + override def refresh(): Unit = {} + + override def leafFiles: mutable.LinkedHashMap[Path, FileStatus] = mutable.LinkedHashMap( + new Path("mockFs://some-bucket/file1.json") -> new FileStatus() + ) + + override def leafDirToChildrenFiles: Map[Path, Array[FileStatus]] = Map( + new Path("mockFs://some-bucket/") -> Array(new FileStatus()) + ) + + override def partitionSpec(): PartitionSpec = { + PartitionSpec.emptySpec + } + } + + withSQLConf( + "fs.mockFs.impl" -> classOf[FakeParentPathFileSystem].getName, + "fs.mockFs.impl.disable.cache" -> "true") { + val pathWithSlash = new Path("mockFs://some-bucket/") + assert(pathWithSlash.getParent === null) + val pathWithoutSlash = new Path("mockFs://some-bucket") + assert(pathWithoutSlash.getParent === null) + val catalog1 = new MockCatalog(Seq(pathWithSlash)) + val catalog2 = new MockCatalog(Seq(pathWithoutSlash)) + assert(catalog1.allFiles().nonEmpty) + assert(catalog2.allFiles().nonEmpty) + } + } +} + +class FakeParentPathFileSystem extends RawLocalFileSystem { + override def getScheme: String = "mockFs" + + override def getUri: URI = { + URI.create("mockFs://some-bucket") + } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/FileSourceStrategySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/FileSourceStrategySuite.scala index 8d8a18fa9332..2f1edb097492 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/FileSourceStrategySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/FileSourceStrategySuite.scala @@ -29,7 +29,7 @@ import org.apache.spark.sql._ import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions.{Expression, ExpressionSet, PredicateHelper} import org.apache.spark.sql.catalyst.util -import org.apache.spark.sql.execution.DataSourceScanExec +import org.apache.spark.sql.execution.{DataSourceScanExec, SparkPlan} import org.apache.spark.sql.functions._ import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.sources._ @@ -407,6 +407,39 @@ class FileSourceStrategySuite extends QueryTest with SharedSQLContext with Predi } } + test("[SPARK-16818] partition pruned file scans implement sameResult correctly") { + withTempPath { path => + val tempDir = path.getCanonicalPath + spark.range(100) + .selectExpr("id", "id as b") + .write + .partitionBy("id") + .parquet(tempDir) + val df = spark.read.parquet(tempDir) + def getPlan(df: DataFrame): SparkPlan = { + df.queryExecution.executedPlan + } + assert(getPlan(df.where("id = 2")).sameResult(getPlan(df.where("id = 2")))) + assert(!getPlan(df.where("id = 2")).sameResult(getPlan(df.where("id = 3")))) + } + } + + test("[SPARK-16818] exchange reuse respects differences in partition pruning") { + spark.conf.set("spark.sql.exchange.reuse", true) + withTempPath { path => + val tempDir = path.getCanonicalPath + spark.range(10) + .selectExpr("id % 2 as a", "id % 3 as b", "id as c") + .write + .partitionBy("a") + .parquet(tempDir) + val df = spark.read.parquet(tempDir) + val df1 = df.where("a = 0").groupBy("b").agg("c" -> "sum") + val df2 = df.where("a = 1").groupBy("b").agg("c" -> "sum") + checkAnswer(df1.join(df2, "b"), Row(0, 6, 12) :: Row(1, 4, 8) :: Row(2, 10, 5) :: Nil) + } + } + // Helpers for checking the arguments passed to the FileFormat. protected val checkPartitionSchema = @@ -474,7 +507,8 @@ class FileSourceStrategySuite extends QueryTest with SharedSQLContext with Predi val bucketed = df.queryExecution.analyzed transform { case l @ LogicalRelation(r: HadoopFsRelation, _, _) => l.copy(relation = - r.copy(bucketSpec = Some(BucketSpec(numBuckets = buckets, "c1" :: Nil, Nil)))) + r.copy(bucketSpec = + Some(BucketSpec(numBuckets = buckets, "c1" :: Nil, Nil)))(r.sparkSession)) } Dataset.ofRows(spark, bucketed) } else { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVInferSchemaSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVInferSchemaSuite.scala index dbe3af49c90c..5e00f669b859 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVInferSchemaSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVInferSchemaSuite.scala @@ -60,9 +60,9 @@ class CSVInferSchemaSuite extends SparkFunSuite { } test("Timestamp field types are inferred correctly via custom data format") { - var options = new CSVOptions(Map("dateFormat" -> "yyyy-mm")) + var options = new CSVOptions(Map("timestampFormat" -> "yyyy-mm")) assert(CSVInferSchema.inferField(TimestampType, "2015-08", options) == TimestampType) - options = new CSVOptions(Map("dateFormat" -> "yyyy")) + options = new CSVOptions(Map("timestampFormat" -> "yyyy")) assert(CSVInferSchema.inferField(TimestampType, "2015", options) == TimestampType) } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVSuite.scala index 311f1fa8d2af..8b6eb10d964d 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVSuite.scala @@ -22,34 +22,35 @@ import java.nio.charset.UnsupportedCharsetException import java.sql.{Date, Timestamp} import java.text.SimpleDateFormat +import org.apache.commons.lang3.time.FastDateFormat import org.apache.hadoop.io.SequenceFile.CompressionType import org.apache.hadoop.io.compress.GzipCodec import org.apache.spark.SparkException -import org.apache.spark.sql.{DataFrame, QueryTest, Row} +import org.apache.spark.sql.{DataFrame, QueryTest, Row, UDT} import org.apache.spark.sql.test.{SharedSQLContext, SQLTestUtils} import org.apache.spark.sql.types._ class CSVSuite extends QueryTest with SharedSQLContext with SQLTestUtils { import testImplicits._ - private val carsFile = "cars.csv" - private val carsMalformedFile = "cars-malformed.csv" - private val carsFile8859 = "cars_iso-8859-1.csv" - private val carsTsvFile = "cars.tsv" - private val carsAltFile = "cars-alternative.csv" - private val carsUnbalancedQuotesFile = "cars-unbalanced-quotes.csv" - private val carsNullFile = "cars-null.csv" - private val carsBlankColName = "cars-blank-column-name.csv" - private val emptyFile = "empty.csv" - private val commentsFile = "comments.csv" - private val disableCommentsFile = "disable_comments.csv" - private val boolFile = "bool.csv" - private val decimalFile = "decimal.csv" - private val simpleSparseFile = "simple_sparse.csv" - private val numbersFile = "numbers.csv" - private val datesFile = "dates.csv" - private val unescapedQuotesFile = "unescaped-quotes.csv" + private val carsFile = "test-data/cars.csv" + private val carsMalformedFile = "test-data/cars-malformed.csv" + private val carsFile8859 = "test-data/cars_iso-8859-1.csv" + private val carsTsvFile = "test-data/cars.tsv" + private val carsAltFile = "test-data/cars-alternative.csv" + private val carsUnbalancedQuotesFile = "test-data/cars-unbalanced-quotes.csv" + private val carsNullFile = "test-data/cars-null.csv" + private val carsBlankColName = "test-data/cars-blank-column-name.csv" + private val emptyFile = "test-data/empty.csv" + private val commentsFile = "test-data/comments.csv" + private val disableCommentsFile = "test-data/disable_comments.csv" + private val boolFile = "test-data/bool.csv" + private val decimalFile = "test-data/decimal.csv" + private val simpleSparseFile = "test-data/simple_sparse.csv" + private val numbersFile = "test-data/numbers.csv" + private val datesFile = "test-data/dates.csv" + private val unescapedQuotesFile = "test-data/unescaped-quotes.csv" private def testFile(fileName: String): String = { Thread.currentThread().getContextClassLoader.getResource(fileName).toString @@ -477,7 +478,7 @@ class CSVSuite extends QueryTest with SharedSQLContext with SQLTestUtils { val options = Map( "header" -> "true", "inferSchema" -> "true", - "dateFormat" -> "dd/MM/yyyy hh:mm") + "timestampFormat" -> "dd/MM/yyyy HH:mm") val results = spark.read .format("csv") .options(options) @@ -485,7 +486,7 @@ class CSVSuite extends QueryTest with SharedSQLContext with SQLTestUtils { .select("date") .collect() - val dateFormat = new SimpleDateFormat("dd/MM/yyyy hh:mm") + val dateFormat = new SimpleDateFormat("dd/MM/yyyy HH:mm") val expected = Seq(Seq(new Timestamp(dateFormat.parse("26/08/2015 18:00").getTime)), Seq(new Timestamp(dateFormat.parse("27/10/2014 18:30").getTime)), @@ -517,6 +518,7 @@ class CSVSuite extends QueryTest with SharedSQLContext with SQLTestUtils { case (expectedDate, date) => // As it truncates the hours, minutes and etc., we only check // if the dates (days, months and years) are the same via `toString()`. + println("Expected: "+expectedDate) assert(expectedDate.toString === date.toString) } } @@ -553,7 +555,7 @@ class CSVSuite extends QueryTest with SharedSQLContext with SQLTestUtils { verifyCars(cars, withHeader = true, checkValues = false) val results = cars.collect() - assert(results(0).toSeq === Array(2012, "Tesla", "S", "null", "null")) + assert(results(0).toSeq === Array(2012, "Tesla", "S", null, null)) assert(results(2).toSeq === Array(null, "Chevy", "Volt", null, null)) } @@ -679,6 +681,19 @@ class CSVSuite extends QueryTest with SharedSQLContext with SQLTestUtils { Seq((1, Array("Tesla", "Chevy", "Ford"))).toDF("id", "brands").write.csv(csvDir) }.getMessage assert(msg.contains("CSV data source does not support array data type")) + + msg = intercept[UnsupportedOperationException] { + Seq((1, new UDT.MyDenseVector(Array(0.25, 2.25, 4.25)))).toDF("id", "vectors") + .write.csv(csvDir) + }.getMessage + assert(msg.contains("CSV data source does not support array data type")) + + msg = intercept[SparkException] { + val schema = StructType(StructField("a", new UDT.MyDenseVectorUDT(), true) :: Nil) + spark.range(1).write.csv(csvDir) + spark.read.schema(schema).csv(csvDir).collect() + }.getCause.getMessage + assert(msg.contains("Unsupported type: array")) } } @@ -691,4 +706,155 @@ class CSVSuite extends QueryTest with SharedSQLContext with SQLTestUtils { verifyCars(cars, withHeader = true, checkValues = false) } + + test("Write timestamps correctly in ISO8601 format by default") { + withTempDir { dir => + val iso8601timestampsPath = s"${dir.getCanonicalPath}/iso8601timestamps.csv" + val timestamps = spark.read + .format("csv") + .option("inferSchema", "true") + .option("header", "true") + .option("timestampFormat", "dd/MM/yyyy HH:mm") + .load(testFile(datesFile)) + timestamps.write + .format("csv") + .option("header", "true") + .save(iso8601timestampsPath) + + // This will load back the timestamps as string. + val iso8601Timestamps = spark.read + .format("csv") + .option("header", "true") + .option("inferSchema", "false") + .load(iso8601timestampsPath) + + val iso8501 = FastDateFormat.getInstance("yyyy-MM-dd'T'HH:mm:ss.SSSZZ") + val expectedTimestamps = timestamps.collect().map { r => + // This should be ISO8601 formatted string. + Row(iso8501.format(r.toSeq.head)) + } + + checkAnswer(iso8601Timestamps, expectedTimestamps) + } + } + + test("Write dates correctly in ISO8601 format by default") { + withTempDir { dir => + val customSchema = new StructType(Array(StructField("date", DateType, true))) + val iso8601datesPath = s"${dir.getCanonicalPath}/iso8601dates.csv" + val dates = spark.read + .format("csv") + .schema(customSchema) + .option("header", "true") + .option("inferSchema", "false") + .option("dateFormat", "dd/MM/yyyy HH:mm") + .load(testFile(datesFile)) + dates.write + .format("csv") + .option("header", "true") + .save(iso8601datesPath) + + // This will load back the dates as string. + val iso8601dates = spark.read + .format("csv") + .option("header", "true") + .option("inferSchema", "false") + .load(iso8601datesPath) + + val iso8501 = FastDateFormat.getInstance("yyyy-MM-dd") + val expectedDates = dates.collect().map { r => + // This should be ISO8601 formatted string. + Row(iso8501.format(r.toSeq.head)) + } + + checkAnswer(iso8601dates, expectedDates) + } + } + + test("Roundtrip in reading and writing timestamps") { + withTempDir { dir => + val iso8601timestampsPath = s"${dir.getCanonicalPath}/iso8601timestamps.csv" + val timestamps = spark.read + .format("csv") + .option("header", "true") + .option("inferSchema", "true") + .load(testFile(datesFile)) + + timestamps.write + .format("csv") + .option("header", "true") + .save(iso8601timestampsPath) + + val iso8601timestamps = spark.read + .format("csv") + .option("header", "true") + .option("inferSchema", "true") + .load(iso8601timestampsPath) + + checkAnswer(iso8601timestamps, timestamps) + } + } + + test("Write dates correctly with dateFormat option") { + val customSchema = new StructType(Array(StructField("date", DateType, true))) + withTempDir { dir => + // With dateFormat option. + val datesWithFormatPath = s"${dir.getCanonicalPath}/datesWithFormat.csv" + val datesWithFormat = spark.read + .format("csv") + .schema(customSchema) + .option("header", "true") + .option("dateFormat", "dd/MM/yyyy HH:mm") + .load(testFile(datesFile)) + datesWithFormat.write + .format("csv") + .option("header", "true") + .option("dateFormat", "yyyy/MM/dd") + .save(datesWithFormatPath) + + // This will load back the dates as string. + val stringDatesWithFormat = spark.read + .format("csv") + .option("header", "true") + .option("inferSchema", "false") + .load(datesWithFormatPath) + val expectedStringDatesWithFormat = Seq( + Row("2015/08/26"), + Row("2014/10/27"), + Row("2016/01/28")) + + checkAnswer(stringDatesWithFormat, expectedStringDatesWithFormat) + } + } + + test("Write timestamps correctly with dateFormat option") { + withTempDir { dir => + // With dateFormat option. + val timestampsWithFormatPath = s"${dir.getCanonicalPath}/timestampsWithFormat.csv" + val timestampsWithFormat = spark.read + .format("csv") + .option("header", "true") + .option("inferSchema", "true") + .option("timestampFormat", "dd/MM/yyyy HH:mm") + .load(testFile(datesFile)) + timestampsWithFormat.write + .format("csv") + .option("header", "true") + .option("timestampFormat", "yyyy/MM/dd HH:mm") + .save(timestampsWithFormatPath) + + // This will load back the timestamps as string. + val stringTimestampsWithFormat = spark.read + .format("csv") + .option("header", "true") + .option("inferSchema", "false") + .load(timestampsWithFormatPath) + val expectedStringTimestampsWithFormat = Seq( + Row("2015/08/26 18:00"), + Row("2014/10/27 18:30"), + Row("2016/01/28 20:00")) + + checkAnswer(stringTimestampsWithFormat, expectedStringTimestampsWithFormat) + } + } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVTypeCastSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVTypeCastSuite.scala index 26b33b24efc3..dae92f626c22 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVTypeCastSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVTypeCastSuite.scala @@ -68,16 +68,46 @@ class CSVTypeCastSuite extends SparkFunSuite { } test("Nullable types are handled") { - assert(CSVTypeCast.castTo("", IntegerType, nullable = true, CSVOptions()) == null) + assertNull( + CSVTypeCast.castTo("-", ByteType, nullable = true, CSVOptions("nullValue", "-"))) + assertNull( + CSVTypeCast.castTo("-", ShortType, nullable = true, CSVOptions("nullValue", "-"))) + assertNull( + CSVTypeCast.castTo("-", IntegerType, nullable = true, CSVOptions("nullValue", "-"))) + assertNull( + CSVTypeCast.castTo("-", LongType, nullable = true, CSVOptions("nullValue", "-"))) + assertNull( + CSVTypeCast.castTo("-", FloatType, nullable = true, CSVOptions("nullValue", "-"))) + assertNull( + CSVTypeCast.castTo("-", DoubleType, nullable = true, CSVOptions("nullValue", "-"))) + assertNull( + CSVTypeCast.castTo("-", BooleanType, nullable = true, CSVOptions("nullValue", "-"))) + assertNull( + CSVTypeCast.castTo("-", DecimalType.DoubleDecimal, true, CSVOptions("nullValue", "-"))) + assertNull( + CSVTypeCast.castTo("-", TimestampType, nullable = true, CSVOptions("nullValue", "-"))) + assertNull( + CSVTypeCast.castTo("-", DateType, nullable = true, CSVOptions("nullValue", "-"))) + assertNull( + CSVTypeCast.castTo("-", StringType, nullable = true, CSVOptions("nullValue", "-"))) } - test("String type should always return the same as the input") { + test("String type should also respect `nullValue`") { + assertNull( + CSVTypeCast.castTo("", StringType, nullable = true, CSVOptions())) + assert( + CSVTypeCast.castTo("", StringType, nullable = false, CSVOptions()) == + UTF8String.fromString("")) + assert( - CSVTypeCast.castTo("", StringType, nullable = true, CSVOptions()) == + CSVTypeCast.castTo("", StringType, nullable = true, CSVOptions("nullValue", "null")) == UTF8String.fromString("")) assert( - CSVTypeCast.castTo("", StringType, nullable = false, CSVOptions()) == + CSVTypeCast.castTo("", StringType, nullable = false, CSVOptions("nullValue", "null")) == UTF8String.fromString("")) + + assertNull( + CSVTypeCast.castTo(null, StringType, nullable = true, CSVOptions("nullValue", "null"))) } test("Throws exception for empty string with non null type") { @@ -96,13 +126,18 @@ class CSVTypeCastSuite extends SparkFunSuite { assert(CSVTypeCast.castTo("1.00", DoubleType) == 1.0) assert(CSVTypeCast.castTo("true", BooleanType) == true) - val options = CSVOptions("dateFormat", "dd/MM/yyyy hh:mm") + val timestampsOptions = CSVOptions("timestampFormat", "dd/MM/yyyy hh:mm") val customTimestamp = "31/01/2015 00:00" - val expectedTime = options.dateFormat.parse("31/01/2015 00:00").getTime - assert(CSVTypeCast.castTo(customTimestamp, TimestampType, nullable = true, options) == - expectedTime * 1000L) - assert(CSVTypeCast.castTo(customTimestamp, DateType, nullable = true, options) == - DateTimeUtils.millisToDays(expectedTime)) + val expectedTime = timestampsOptions.timestampFormat.parse(customTimestamp).getTime + val castedTimestamp = + CSVTypeCast.castTo(customTimestamp, TimestampType, nullable = true, timestampsOptions) + assert(castedTimestamp == expectedTime * 1000L) + + val customDate = "31/01/2015" + val dateOptions = CSVOptions("dateFormat", "dd/MM/yyyy") + val expectedDate = dateOptions.dateFormat.parse(customDate).getTime + val castedDate = CSVTypeCast.castTo(customTimestamp, DateType, nullable = true, dateOptions) + assert(castedDate == DateTimeUtils.millisToDays(expectedDate)) val timestamp = "2015-01-01 00:00:00" assert(CSVTypeCast.castTo(timestamp, TimestampType) == @@ -165,20 +200,4 @@ class CSVTypeCastSuite extends SparkFunSuite { assert(doubleVal2 == Double.PositiveInfinity) } - test("Type-specific null values are used for casting") { - assertNull( - CSVTypeCast.castTo("-", ByteType, nullable = true, CSVOptions("nullValue", "-"))) - assertNull( - CSVTypeCast.castTo("-", ShortType, nullable = true, CSVOptions("nullValue", "-"))) - assertNull( - CSVTypeCast.castTo("-", IntegerType, nullable = true, CSVOptions("nullValue", "-"))) - assertNull( - CSVTypeCast.castTo("-", LongType, nullable = true, CSVOptions("nullValue", "-"))) - assertNull( - CSVTypeCast.castTo("-", FloatType, nullable = true, CSVOptions("nullValue", "-"))) - assertNull( - CSVTypeCast.castTo("-", DoubleType, nullable = true, CSVOptions("nullValue", "-"))) - assertNull( - CSVTypeCast.castTo("-", DecimalType.DoubleDecimal, true, CSVOptions("nullValue", "-"))) - } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/json/JsonSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/json/JsonSuite.scala index 0b0e64ac7273..1ba5b8123117 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/json/JsonSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/json/JsonSuite.scala @@ -64,9 +64,10 @@ class JsonSuite extends QueryTest with SharedSQLContext with TestJsonData { generator.flush() } + val dummyOption = new JSONOptions(Map.empty[String, String]) Utils.tryWithResource(factory.createParser(writer.toString)) { parser => parser.nextToken() - JacksonParser.convertRootField(factory, parser, dataType) + JacksonParser.convertRootField(factory, parser, dataType, dummyOption) } } @@ -99,15 +100,15 @@ class JsonSuite extends QueryTest with SharedSQLContext with TestJsonData { DateTimeUtils.fromJavaDate(Date.valueOf(strDate)), enforceCorrectType(strDate, DateType)) val ISO8601Time1 = "1970-01-01T01:00:01.0Z" - checkTypePromotion(DateTimeUtils.fromJavaTimestamp(new Timestamp(3601000)), - enforceCorrectType(ISO8601Time1, TimestampType)) - checkTypePromotion(DateTimeUtils.millisToDays(3601000), - enforceCorrectType(ISO8601Time1, DateType)) val ISO8601Time2 = "1970-01-01T02:00:01-01:00" + checkTypePromotion(DateTimeUtils.fromJavaTimestamp(new Timestamp(3601000)), + enforceCorrectType(ISO8601Time1, TimestampType)) checkTypePromotion(DateTimeUtils.fromJavaTimestamp(new Timestamp(10801000)), - enforceCorrectType(ISO8601Time2, TimestampType)) - checkTypePromotion(DateTimeUtils.millisToDays(10801000), - enforceCorrectType(ISO8601Time2, DateType)) + enforceCorrectType(ISO8601Time2, TimestampType)) + + val ISO8601Date = "1970-01-01" + checkTypePromotion(DateTimeUtils.millisToDays(32400000), + enforceCorrectType(ISO8601Date, DateType)) } test("Get compatible type") { @@ -1662,4 +1663,61 @@ class JsonSuite extends QueryTest with SharedSQLContext with TestJsonData { assert(df.schema.size === 2) df.collect() } + + test("Write dates correctly with dateFormat option") { + val customSchema = new StructType(Array(StructField("date", DateType, true))) + withTempDir { dir => + // With dateFormat option. + val datesWithFormatPath = s"${dir.getCanonicalPath}/datesWithFormat.json" + val datesWithFormat = spark.read + .schema(customSchema) + .option("dateFormat", "dd/MM/yyyy HH:mm") + .json(datesRecords) + + datesWithFormat.write + .format("json") + .option("dateFormat", "yyyy/MM/dd") + .save(datesWithFormatPath) + + // This will load back the dates as string. + val stringSchema = StructType(StructField("date", StringType, true) :: Nil) + val stringDatesWithFormat = spark.read + .schema(stringSchema) + .json(datesWithFormatPath) + val expectedStringDatesWithFormat = Seq( + Row("2015/08/26"), + Row("2014/10/27"), + Row("2016/01/28")) + + checkAnswer(stringDatesWithFormat, expectedStringDatesWithFormat) + } + } + + test("Write timestamps correctly with dateFormat option") { + val customSchema = new StructType(Array(StructField("date", TimestampType, true))) + withTempDir { dir => + // With dateFormat option. + val timestampsWithFormatPath = s"${dir.getCanonicalPath}/timestampsWithFormat.json" + val timestampsWithFormat = spark.read + .schema(customSchema) + .option("timestampFormat", "dd/MM/yyyy HH:mm") + .json(datesRecords) + timestampsWithFormat.write + .format("json") + .option("timestampFormat", "yyyy/MM/dd HH:mm") + .save(timestampsWithFormatPath) + + // This will load back the timestamps as string. + val stringSchema = StructType(StructField("date", StringType, true) :: Nil) + val stringTimestampsWithFormat = spark.read + .schema(stringSchema) + .json(timestampsWithFormatPath) + val expectedStringDatesWithFormat = Seq( + Row("2015/08/26 18:00"), + Row("2014/10/27 18:30"), + Row("2016/01/28 20:00")) + + checkAnswer(stringTimestampsWithFormat, expectedStringDatesWithFormat) + } + } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/json/TestJsonData.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/json/TestJsonData.scala index f4a333664386..d1d82fd5658b 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/json/TestJsonData.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/json/TestJsonData.scala @@ -222,6 +222,12 @@ private[json] trait TestJsonData { spark.sparkContext.parallelize( s"""{"a": 1${"0" * 38}, "b": 92233720368547758070}""" :: Nil) + def datesRecords: RDD[String] = + spark.sparkContext.parallelize( + """{"date": "26/08/2015 18:00"}""" :: + """{"date": "27/10/2014 18:30"}""" :: + """{"date": "28/01/2016 20:00"}""" :: Nil) + lazy val singleRow: RDD[String] = spark.sparkContext.parallelize("""{"a":123}""" :: Nil) def empty: RDD[String] = spark.sparkContext.parallelize(Seq[String]()) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFilterSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFilterSuite.scala index 2a89773cf534..ab9250045f5b 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFilterSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFilterSuite.scala @@ -32,6 +32,7 @@ import org.apache.spark.sql.functions._ import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.test.SharedSQLContext import org.apache.spark.sql.types._ +import org.apache.spark.util.{AccumulatorContext, LongAccumulator} /** * A test suite that tests Parquet filter2 API based filter pushdown optimization. @@ -370,73 +371,75 @@ class ParquetFilterSuite extends QueryTest with ParquetTest with SharedSQLContex test("SPARK-11103: Filter applied on merged Parquet schema with new column fails") { import testImplicits._ - - withSQLConf(SQLConf.PARQUET_FILTER_PUSHDOWN_ENABLED.key -> "true", - SQLConf.PARQUET_SCHEMA_MERGING_ENABLED.key -> "true") { - withTempPath { dir => - val pathOne = s"${dir.getCanonicalPath}/table1" - (1 to 3).map(i => (i, i.toString)).toDF("a", "b").write.parquet(pathOne) - val pathTwo = s"${dir.getCanonicalPath}/table2" - (1 to 3).map(i => (i, i.toString)).toDF("c", "b").write.parquet(pathTwo) - - // If the "c = 1" filter gets pushed down, this query will throw an exception which - // Parquet emits. This is a Parquet issue (PARQUET-389). - val df = spark.read.parquet(pathOne, pathTwo).filter("c = 1").selectExpr("c", "b", "a") - checkAnswer( - df, - Row(1, "1", null)) - - // The fields "a" and "c" only exist in one Parquet file. - assert(df.schema("a").metadata.getBoolean(StructType.metadataKeyForOptionalField)) - assert(df.schema("c").metadata.getBoolean(StructType.metadataKeyForOptionalField)) - - val pathThree = s"${dir.getCanonicalPath}/table3" - df.write.parquet(pathThree) - - // We will remove the temporary metadata when writing Parquet file. - val schema = spark.read.parquet(pathThree).schema - assert(schema.forall(!_.metadata.contains(StructType.metadataKeyForOptionalField))) - - val pathFour = s"${dir.getCanonicalPath}/table4" - val dfStruct = sparkContext.parallelize(Seq((1, 1))).toDF("a", "b") - dfStruct.select(struct("a").as("s")).write.parquet(pathFour) - - val pathFive = s"${dir.getCanonicalPath}/table5" - val dfStruct2 = sparkContext.parallelize(Seq((1, 1))).toDF("c", "b") - dfStruct2.select(struct("c").as("s")).write.parquet(pathFive) - - // If the "s.c = 1" filter gets pushed down, this query will throw an exception which - // Parquet emits. - val dfStruct3 = spark.read.parquet(pathFour, pathFive).filter("s.c = 1") - .selectExpr("s") - checkAnswer(dfStruct3, Row(Row(null, 1))) - - // The fields "s.a" and "s.c" only exist in one Parquet file. - val field = dfStruct3.schema("s").dataType.asInstanceOf[StructType] - assert(field("a").metadata.getBoolean(StructType.metadataKeyForOptionalField)) - assert(field("c").metadata.getBoolean(StructType.metadataKeyForOptionalField)) - - val pathSix = s"${dir.getCanonicalPath}/table6" - dfStruct3.write.parquet(pathSix) - - // We will remove the temporary metadata when writing Parquet file. - val forPathSix = spark.read.parquet(pathSix).schema - assert(forPathSix.forall(!_.metadata.contains(StructType.metadataKeyForOptionalField))) - - // sanity test: make sure optional metadata field is not wrongly set. - val pathSeven = s"${dir.getCanonicalPath}/table7" - (1 to 3).map(i => (i, i.toString)).toDF("a", "b").write.parquet(pathSeven) - val pathEight = s"${dir.getCanonicalPath}/table8" - (4 to 6).map(i => (i, i.toString)).toDF("a", "b").write.parquet(pathEight) - - val df2 = spark.read.parquet(pathSeven, pathEight).filter("a = 1").selectExpr("a", "b") - checkAnswer( - df2, - Row(1, "1")) - - // The fields "a" and "b" exist in both two Parquet files. No metadata is set. - assert(!df2.schema("a").metadata.contains(StructType.metadataKeyForOptionalField)) - assert(!df2.schema("b").metadata.contains(StructType.metadataKeyForOptionalField)) + Seq("true", "false").map { vectorized => + withSQLConf(SQLConf.PARQUET_FILTER_PUSHDOWN_ENABLED.key -> "true", + SQLConf.PARQUET_SCHEMA_MERGING_ENABLED.key -> "true", + SQLConf.PARQUET_VECTORIZED_READER_ENABLED.key -> vectorized) { + withTempPath { dir => + val pathOne = s"${dir.getCanonicalPath}/table1" + (1 to 3).map(i => (i, i.toString)).toDF("a", "b").write.parquet(pathOne) + val pathTwo = s"${dir.getCanonicalPath}/table2" + (1 to 3).map(i => (i, i.toString)).toDF("c", "b").write.parquet(pathTwo) + + // If the "c = 1" filter gets pushed down, this query will throw an exception which + // Parquet emits. This is a Parquet issue (PARQUET-389). + val df = spark.read.parquet(pathOne, pathTwo).filter("c = 1").selectExpr("c", "b", "a") + checkAnswer( + df, + Row(1, "1", null)) + + // The fields "a" and "c" only exist in one Parquet file. + assert(df.schema("a").metadata.getBoolean(StructType.metadataKeyForOptionalField)) + assert(df.schema("c").metadata.getBoolean(StructType.metadataKeyForOptionalField)) + + val pathThree = s"${dir.getCanonicalPath}/table3" + df.write.parquet(pathThree) + + // We will remove the temporary metadata when writing Parquet file. + val schema = spark.read.parquet(pathThree).schema + assert(schema.forall(!_.metadata.contains(StructType.metadataKeyForOptionalField))) + + val pathFour = s"${dir.getCanonicalPath}/table4" + val dfStruct = sparkContext.parallelize(Seq((1, 1))).toDF("a", "b") + dfStruct.select(struct("a").as("s")).write.parquet(pathFour) + + val pathFive = s"${dir.getCanonicalPath}/table5" + val dfStruct2 = sparkContext.parallelize(Seq((1, 1))).toDF("c", "b") + dfStruct2.select(struct("c").as("s")).write.parquet(pathFive) + + // If the "s.c = 1" filter gets pushed down, this query will throw an exception which + // Parquet emits. + val dfStruct3 = spark.read.parquet(pathFour, pathFive).filter("s.c = 1") + .selectExpr("s") + checkAnswer(dfStruct3, Row(Row(null, 1))) + + // The fields "s.a" and "s.c" only exist in one Parquet file. + val field = dfStruct3.schema("s").dataType.asInstanceOf[StructType] + assert(field("a").metadata.getBoolean(StructType.metadataKeyForOptionalField)) + assert(field("c").metadata.getBoolean(StructType.metadataKeyForOptionalField)) + + val pathSix = s"${dir.getCanonicalPath}/table6" + dfStruct3.write.parquet(pathSix) + + // We will remove the temporary metadata when writing Parquet file. + val forPathSix = spark.read.parquet(pathSix).schema + assert(forPathSix.forall(!_.metadata.contains(StructType.metadataKeyForOptionalField))) + + // sanity test: make sure optional metadata field is not wrongly set. + val pathSeven = s"${dir.getCanonicalPath}/table7" + (1 to 3).map(i => (i, i.toString)).toDF("a", "b").write.parquet(pathSeven) + val pathEight = s"${dir.getCanonicalPath}/table8" + (4 to 6).map(i => (i, i.toString)).toDF("a", "b").write.parquet(pathEight) + + val df2 = spark.read.parquet(pathSeven, pathEight).filter("a = 1").selectExpr("a", "b") + checkAnswer( + df2, + Row(1, "1")) + + // The fields "a" and "b" exist in both two Parquet files. No metadata is set. + assert(!df2.schema("a").metadata.contains(StructType.metadataKeyForOptionalField)) + assert(!df2.schema("b").metadata.contains(StructType.metadataKeyForOptionalField)) + } } } } @@ -559,4 +562,32 @@ class ParquetFilterSuite extends QueryTest with ParquetTest with SharedSQLContex assert(df.filter("_1 IS NOT NULL").count() === 4) } } + + test("Fiters should be pushed down for vectorized Parquet reader at row group level") { + import testImplicits._ + + withSQLConf(SQLConf.PARQUET_VECTORIZED_READER_ENABLED.key -> "true", + SQLConf.WHOLESTAGE_CODEGEN_ENABLED.key -> "false") { + withTempPath { dir => + val path = s"${dir.getCanonicalPath}/table" + (1 to 1024).map(i => (101, i)).toDF("a", "b").write.parquet(path) + + Seq(("true", (x: Long) => x == 0), ("false", (x: Long) => x > 0)).map { case (push, func) => + withSQLConf(SQLConf.PARQUET_FILTER_PUSHDOWN_ENABLED.key -> push) { + val accu = new LongAccumulator + accu.register(sparkContext, Some("numRowGroups")) + + val df = spark.read.parquet(path).filter("a < 100") + df.foreachPartition(_.foreach(v => accu.add(0))) + df.collect + + val numRowGroups = AccumulatorContext.lookForAccumulatorByName("numRowGroups") + assert(numRowGroups.isDefined) + assert(func(numRowGroups.get.asInstanceOf[LongAccumulator].value)) + AccumulatorContext.remove(accu.id) + } + } + } + } + } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetIOSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetIOSuite.scala index fc9ce6bb3041..46ccfa53bd79 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetIOSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetIOSuite.scala @@ -38,11 +38,12 @@ import org.apache.parquet.schema.{MessageType, MessageTypeParser} import org.apache.spark.SparkException import org.apache.spark.sql._ import org.apache.spark.sql.catalyst.{InternalRow, ScalaReflection} -import org.apache.spark.sql.catalyst.expressions.UnsafeRow +import org.apache.spark.sql.catalyst.expressions.{GenericMutableRow, UnsafeRow} import org.apache.spark.sql.catalyst.util.DateTimeUtils import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.test.SharedSQLContext import org.apache.spark.sql.types._ +import org.apache.spark.unsafe.types.UTF8String // Write support class for nested groups: ParquetWriter initializes GroupWriteSupport // with an empty configuration (it is after all not intended to be used in this way?) @@ -556,7 +557,7 @@ class ParquetIOSuite extends QueryTest with ParquetTest with SharedSQLContext { withSQLConf(SQLConf.PARQUET_VECTORIZED_READER_ENABLED.key -> vectorized) { checkAnswer( // Decimal column in this file is encoded using plain dictionary - readResourceParquetFile("dec-in-i32.parquet"), + readResourceParquetFile("test-data/dec-in-i32.parquet"), spark.range(1 << 4).select('id % 10 cast DecimalType(5, 2) as 'i32_dec)) } } @@ -567,7 +568,7 @@ class ParquetIOSuite extends QueryTest with ParquetTest with SharedSQLContext { withSQLConf(SQLConf.PARQUET_VECTORIZED_READER_ENABLED.key -> vectorized) { checkAnswer( // Decimal column in this file is encoded using plain dictionary - readResourceParquetFile("dec-in-i64.parquet"), + readResourceParquetFile("test-data/dec-in-i64.parquet"), spark.range(1 << 4).select('id % 10 cast DecimalType(10, 2) as 'i64_dec)) } } @@ -578,7 +579,7 @@ class ParquetIOSuite extends QueryTest with ParquetTest with SharedSQLContext { withSQLConf(SQLConf.PARQUET_VECTORIZED_READER_ENABLED.key -> vectorized) { checkAnswer( // Decimal column in this file is encoded using plain dictionary - readResourceParquetFile("dec-in-fixed-len.parquet"), + readResourceParquetFile("test-data/dec-in-fixed-len.parquet"), spark.range(1 << 4).select('id % 10 cast DecimalType(10, 2) as 'fixed_len_dec)) } } @@ -677,6 +678,52 @@ class ParquetIOSuite extends QueryTest with ParquetTest with SharedSQLContext { } } } + + test("VectorizedParquetRecordReader - partition column types") { + withTempPath { dir => + Seq(1).toDF().repartition(1).write.parquet(dir.getCanonicalPath) + + val dataTypes = + Seq(StringType, BooleanType, ByteType, ShortType, IntegerType, LongType, + FloatType, DoubleType, DecimalType(25, 5), DateType, TimestampType) + + val constantValues = + Seq( + UTF8String.fromString("a string"), + true, + 1.toByte, + 2.toShort, + 3, + Long.MaxValue, + 0.25.toFloat, + 0.75D, + Decimal("1234.23456"), + DateTimeUtils.fromJavaDate(java.sql.Date.valueOf("2015-01-01")), + DateTimeUtils.fromJavaTimestamp(java.sql.Timestamp.valueOf("2015-01-01 23:50:59.123"))) + + dataTypes.zip(constantValues).foreach { case (dt, v) => + val schema = StructType(StructField("pcol", dt) :: Nil) + val vectorizedReader = new VectorizedParquetRecordReader + val partitionValues = new GenericMutableRow(Array(v)) + val file = SpecificParquetRecordReaderBase.listDirectory(dir).get(0) + + try { + vectorizedReader.initialize(file, null) + vectorizedReader.initBatch(schema, partitionValues) + vectorizedReader.nextKeyValue() + val row = vectorizedReader.getCurrentValue.asInstanceOf[InternalRow] + + // Use `GenericMutableRow` by explicitly copying rather than `ColumnarBatch` + // in order to use get(...) method which is not implemented in `ColumnarBatch`. + val actual = row.copy().get(1, dt) + val expected = v + assert(actual == expected) + } finally { + vectorizedReader.close() + } + } + } + } } class JobCommitFailureParquetOutputCommitter(outputPath: Path, context: TaskAttemptContext) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetProtobufCompatibilitySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetProtobufCompatibilitySuite.scala index 98333e58cada..fa88019298a6 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetProtobufCompatibilitySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetProtobufCompatibilitySuite.scala @@ -22,12 +22,12 @@ import org.apache.spark.sql.test.SharedSQLContext class ParquetProtobufCompatibilitySuite extends ParquetCompatibilityTest with SharedSQLContext { test("unannotated array of primitive type") { - checkAnswer(readResourceParquetFile("old-repeated-int.parquet"), Row(Seq(1, 2, 3))) + checkAnswer(readResourceParquetFile("test-data/old-repeated-int.parquet"), Row(Seq(1, 2, 3))) } test("unannotated array of struct") { checkAnswer( - readResourceParquetFile("old-repeated-message.parquet"), + readResourceParquetFile("test-data/old-repeated-message.parquet"), Row( Seq( Row("First inner", null, null), @@ -35,14 +35,14 @@ class ParquetProtobufCompatibilitySuite extends ParquetCompatibilityTest with Sh Row(null, null, "Third inner")))) checkAnswer( - readResourceParquetFile("proto-repeated-struct.parquet"), + readResourceParquetFile("test-data/proto-repeated-struct.parquet"), Row( Seq( Row("0 - 1", "0 - 2", "0 - 3"), Row("1 - 1", "1 - 2", "1 - 3")))) checkAnswer( - readResourceParquetFile("proto-struct-with-array-many.parquet"), + readResourceParquetFile("test-data/proto-struct-with-array-many.parquet"), Seq( Row( Seq( @@ -60,13 +60,13 @@ class ParquetProtobufCompatibilitySuite extends ParquetCompatibilityTest with Sh test("struct with unannotated array") { checkAnswer( - readResourceParquetFile("proto-struct-with-array.parquet"), + readResourceParquetFile("test-data/proto-struct-with-array.parquet"), Row(10, 9, Seq.empty, null, Row(9), Seq(Row(9), Row(10)))) } test("unannotated array of struct with unannotated array") { checkAnswer( - readResourceParquetFile("nested-array-struct.parquet"), + readResourceParquetFile("test-data/nested-array-struct.parquet"), Seq( Row(2, Seq(Row(1, Seq(Row(3))))), Row(5, Seq(Row(4, Seq(Row(6))))), @@ -75,7 +75,7 @@ class ParquetProtobufCompatibilitySuite extends ParquetCompatibilityTest with Sh test("unannotated array of string") { checkAnswer( - readResourceParquetFile("proto-repeated-string.parquet"), + readResourceParquetFile("test-data/proto-repeated-string.parquet"), Seq( Row(Seq("hello", "world")), Row(Seq("good", "bye")), diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetThriftCompatibilitySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetThriftCompatibilitySuite.scala index ff5706999a6d..4157a5b46dc4 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetThriftCompatibilitySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetThriftCompatibilitySuite.scala @@ -23,8 +23,8 @@ import org.apache.spark.sql.test.SharedSQLContext class ParquetThriftCompatibilitySuite extends ParquetCompatibilityTest with SharedSQLContext { import ParquetCompatibilityTest._ - private val parquetFilePath = - Thread.currentThread().getContextClassLoader.getResource("parquet-thrift-compat.snappy.parquet") + private val parquetFilePath = Thread.currentThread().getContextClassLoader.getResource( + "test-data/parquet-thrift-compat.snappy.parquet") test("Read Parquet file generated by parquet-thrift") { logInfo( diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/text/TextSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/text/TextSuite.scala index 71d3da915840..d11c2acb815d 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/text/TextSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/text/TextSuite.scala @@ -66,7 +66,7 @@ class TextSuite extends QueryTest with SharedSQLContext { test("reading partitioned data using read.textFile()") { val partitionedData = Thread.currentThread().getContextClassLoader - .getResource("text-partitioned").toString + .getResource("test-data/text-partitioned").toString val ds = spark.read.textFile(partitionedData) val data = ds.collect() @@ -76,7 +76,7 @@ class TextSuite extends QueryTest with SharedSQLContext { test("support for partitioned reading using read.text()") { val partitionedData = Thread.currentThread().getContextClassLoader - .getResource("text-partitioned").toString + .getResource("test-data/text-partitioned").toString val df = spark.read.text(partitionedData) val data = df.filter("year = '2015'").select("value").collect() @@ -155,7 +155,7 @@ class TextSuite extends QueryTest with SharedSQLContext { } private def testFile: String = { - Thread.currentThread().getContextClassLoader.getResource("text-suite.txt").toString + Thread.currentThread().getContextClassLoader.getResource("test-data/text-suite.txt").toString } /** Verifies data and schema. */ diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/HashedRelationSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/HashedRelationSuite.scala index 40864c80ebc8..ede63fea9606 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/HashedRelationSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/HashedRelationSuite.scala @@ -19,6 +19,8 @@ package org.apache.spark.sql.execution.joins import java.io.{ByteArrayInputStream, ByteArrayOutputStream, ObjectInputStream, ObjectOutputStream} +import scala.util.Random + import org.apache.spark.{SparkConf, SparkFunSuite} import org.apache.spark.memory.{StaticMemoryManager, TaskMemoryManager} import org.apache.spark.serializer.KryoSerializer @@ -152,6 +154,105 @@ class HashedRelationSuite extends SparkFunSuite with SharedSQLContext { } } + test("LongToUnsafeRowMap with very wide range") { + val taskMemoryManager = new TaskMemoryManager( + new StaticMemoryManager( + new SparkConf().set("spark.memory.offHeap.enabled", "false"), + Long.MaxValue, + Long.MaxValue, + 1), + 0) + val unsafeProj = UnsafeProjection.create(Seq(BoundReference(0, LongType, false))) + + { + // SPARK-16740 + val keys = Seq(0L, Long.MaxValue, Long.MaxValue) + val map = new LongToUnsafeRowMap(taskMemoryManager, 1) + keys.foreach { k => + map.append(k, unsafeProj(InternalRow(k))) + } + map.optimize() + val row = unsafeProj(InternalRow(0L)).copy() + keys.foreach { k => + assert(map.getValue(k, row) eq row) + assert(row.getLong(0) === k) + } + map.free() + } + + + { + // SPARK-16802 + val keys = Seq(Long.MaxValue, Long.MaxValue - 10) + val map = new LongToUnsafeRowMap(taskMemoryManager, 1) + keys.foreach { k => + map.append(k, unsafeProj(InternalRow(k))) + } + map.optimize() + val row = unsafeProj(InternalRow(0L)).copy() + keys.foreach { k => + assert(map.getValue(k, row) eq row) + assert(row.getLong(0) === k) + } + assert(map.getValue(Long.MinValue, row) eq null) + map.free() + } + } + + test("LongToUnsafeRowMap with random keys") { + val taskMemoryManager = new TaskMemoryManager( + new StaticMemoryManager( + new SparkConf().set("spark.memory.offHeap.enabled", "false"), + Long.MaxValue, + Long.MaxValue, + 1), + 0) + val unsafeProj = UnsafeProjection.create(Seq(BoundReference(0, LongType, false))) + + val N = 1000000 + val rand = new Random + val keys = (0 to N).map(x => rand.nextLong()).toArray + + val map = new LongToUnsafeRowMap(taskMemoryManager, 10) + keys.foreach { k => + map.append(k, unsafeProj(InternalRow(k))) + } + map.optimize() + + val os = new ByteArrayOutputStream() + val out = new ObjectOutputStream(os) + map.writeExternal(out) + out.flush() + val in = new ObjectInputStream(new ByteArrayInputStream(os.toByteArray)) + val map2 = new LongToUnsafeRowMap(taskMemoryManager, 1) + map2.readExternal(in) + + val row = unsafeProj(InternalRow(0L)).copy() + keys.foreach { k => + val r = map2.get(k, row) + assert(r.hasNext) + var c = 0 + while (r.hasNext) { + val rr = r.next() + assert(rr.getLong(0) === k) + c += 1 + } + } + var i = 0 + while (i < N * 10) { + val k = rand.nextLong() + val r = map2.get(k, row) + if (r != null) { + assert(r.hasNext) + while (r.hasNext) { + assert(r.next().getLong(0) === k) + } + } + i += 1 + } + map.free() + } + test("Spark-14521") { val ser = new KryoSerializer( (new SparkConf).set("spark.kryo.referenceTracking", "false")).newInstance() diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/stat/ApproxQuantileSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/stat/ApproxQuantileSuite.scala index 0a989d026ce1..8bd6b3c5cdc8 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/stat/ApproxQuantileSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/stat/ApproxQuantileSuite.scala @@ -42,6 +42,20 @@ class ApproxQuantileSuite extends SparkFunSuite { summary.compress() } + /** + * Interleaves compression and insertions. + */ + private def buildCompressSummary( + data: Seq[Double], + epsi: Double, + threshold: Int): QuantileSummaries = { + var summary = new QuantileSummaries(threshold, epsi) + data.foreach { x => + summary = summary.insert(x).compress() + } + summary + } + private def checkQuantile(quant: Double, data: Seq[Double], summary: QuantileSummaries): Unit = { val approx = summary.query(quant) // The rank of the approximation. @@ -56,8 +70,8 @@ class ApproxQuantileSuite extends SparkFunSuite { for { (seq_name, data) <- Seq(increasing, decreasing, random) - epsi <- Seq(0.1, 0.0001) - compression <- Seq(1000, 10) + epsi <- Seq(0.1, 0.0001) // With a significant value and with full precision + compression <- Seq(1000, 10) // This interleaves n so that we test without and with compression } { test(s"Extremas with epsi=$epsi and seq=$seq_name, compression=$compression") { @@ -77,6 +91,17 @@ class ApproxQuantileSuite extends SparkFunSuite { checkQuantile(0.1, data, s) checkQuantile(0.001, data, s) } + + test(s"Some quantile values with epsi=$epsi and seq=$seq_name, compression=$compression " + + s"(interleaved)") { + val s = buildCompressSummary(data, epsi, compression) + assert(s.count == data.size, s"Found count=${s.count} but data size=${data.size}") + checkQuantile(0.9999, data, s) + checkQuantile(0.9, data, s) + checkQuantile(0.5, data, s) + checkQuantile(0.1, data, s) + checkQuantile(0.001, data, s) + } } // Tests for merging procedure diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/FileStreamSinkLogSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/FileStreamSinkLogSuite.scala index 39fd1f0cd37b..41a8cc2400df 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/FileStreamSinkLogSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/FileStreamSinkLogSuite.scala @@ -25,13 +25,14 @@ import org.apache.spark.sql.test.SharedSQLContext class FileStreamSinkLogSuite extends SparkFunSuite with SharedSQLContext { + import CompactibleFileStreamLog._ import FileStreamSinkLog._ test("getBatchIdFromFileName") { assert(1234L === getBatchIdFromFileName("1234")) assert(1234L === getBatchIdFromFileName("1234.compact")) intercept[NumberFormatException] { - FileStreamSinkLog.getBatchIdFromFileName("1234a") + getBatchIdFromFileName("1234a") } } @@ -83,22 +84,24 @@ class FileStreamSinkLogSuite extends SparkFunSuite with SharedSQLContext { } test("compactLogs") { - val logs = Seq( - newFakeSinkFileStatus("/a/b/x", FileStreamSinkLog.ADD_ACTION), - newFakeSinkFileStatus("/a/b/y", FileStreamSinkLog.ADD_ACTION), - newFakeSinkFileStatus("/a/b/z", FileStreamSinkLog.ADD_ACTION)) - assert(logs === compactLogs(logs)) + withFileStreamSinkLog { sinkLog => + val logs = Seq( + newFakeSinkFileStatus("/a/b/x", FileStreamSinkLog.ADD_ACTION), + newFakeSinkFileStatus("/a/b/y", FileStreamSinkLog.ADD_ACTION), + newFakeSinkFileStatus("/a/b/z", FileStreamSinkLog.ADD_ACTION)) + assert(logs === sinkLog.compactLogs(logs)) - val logs2 = Seq( - newFakeSinkFileStatus("/a/b/m", FileStreamSinkLog.ADD_ACTION), - newFakeSinkFileStatus("/a/b/n", FileStreamSinkLog.ADD_ACTION), - newFakeSinkFileStatus("/a/b/z", FileStreamSinkLog.DELETE_ACTION)) - assert(logs.dropRight(1) ++ logs2.dropRight(1) === compactLogs(logs ++ logs2)) + val logs2 = Seq( + newFakeSinkFileStatus("/a/b/m", FileStreamSinkLog.ADD_ACTION), + newFakeSinkFileStatus("/a/b/n", FileStreamSinkLog.ADD_ACTION), + newFakeSinkFileStatus("/a/b/z", FileStreamSinkLog.DELETE_ACTION)) + assert(logs.dropRight(1) ++ logs2.dropRight(1) === sinkLog.compactLogs(logs ++ logs2)) + } } test("serialize") { withFileStreamSinkLog { sinkLog => - val logs = Seq( + val logs = Array( SinkFileStatus( path = "/a/b/x", size = 100L, @@ -125,21 +128,21 @@ class FileStreamSinkLogSuite extends SparkFunSuite with SharedSQLContext { action = FileStreamSinkLog.ADD_ACTION)) // scalastyle:off - val expected = s"""${FileStreamSinkLog.VERSION} + val expected = s"""$VERSION |{"path":"/a/b/x","size":100,"isDir":false,"modificationTime":1000,"blockReplication":1,"blockSize":10000,"action":"add"} |{"path":"/a/b/y","size":200,"isDir":false,"modificationTime":2000,"blockReplication":2,"blockSize":20000,"action":"delete"} |{"path":"/a/b/z","size":300,"isDir":false,"modificationTime":3000,"blockReplication":3,"blockSize":30000,"action":"add"}""".stripMargin // scalastyle:on assert(expected === new String(sinkLog.serialize(logs), UTF_8)) - assert(FileStreamSinkLog.VERSION === new String(sinkLog.serialize(Nil), UTF_8)) + assert(VERSION === new String(sinkLog.serialize(Array()), UTF_8)) } } test("deserialize") { withFileStreamSinkLog { sinkLog => // scalastyle:off - val logs = s"""${FileStreamSinkLog.VERSION} + val logs = s"""$VERSION |{"path":"/a/b/x","size":100,"isDir":false,"modificationTime":1000,"blockReplication":1,"blockSize":10000,"action":"add"} |{"path":"/a/b/y","size":200,"isDir":false,"modificationTime":2000,"blockReplication":2,"blockSize":20000,"action":"delete"} |{"path":"/a/b/z","size":300,"isDir":false,"modificationTime":3000,"blockReplication":3,"blockSize":30000,"action":"add"}""".stripMargin @@ -173,7 +176,7 @@ class FileStreamSinkLogSuite extends SparkFunSuite with SharedSQLContext { assert(expected === sinkLog.deserialize(logs.getBytes(UTF_8))) - assert(Nil === sinkLog.deserialize(FileStreamSinkLog.VERSION.getBytes(UTF_8))) + assert(Nil === sinkLog.deserialize(VERSION.getBytes(UTF_8))) } } @@ -196,7 +199,7 @@ class FileStreamSinkLogSuite extends SparkFunSuite with SharedSQLContext { for (batchId <- 0 to 10) { sinkLog.add( batchId, - Seq(newFakeSinkFileStatus("/a/b/" + batchId, FileStreamSinkLog.ADD_ACTION))) + Array(newFakeSinkFileStatus("/a/b/" + batchId, FileStreamSinkLog.ADD_ACTION))) val expectedFiles = (0 to batchId).map { id => newFakeSinkFileStatus("/a/b/" + id, FileStreamSinkLog.ADD_ACTION) } @@ -230,17 +233,17 @@ class FileStreamSinkLogSuite extends SparkFunSuite with SharedSQLContext { }.toSet } - sinkLog.add(0, Seq(newFakeSinkFileStatus("/a/b/0", FileStreamSinkLog.ADD_ACTION))) + sinkLog.add(0, Array(newFakeSinkFileStatus("/a/b/0", FileStreamSinkLog.ADD_ACTION))) assert(Set("0") === listBatchFiles()) - sinkLog.add(1, Seq(newFakeSinkFileStatus("/a/b/1", FileStreamSinkLog.ADD_ACTION))) + sinkLog.add(1, Array(newFakeSinkFileStatus("/a/b/1", FileStreamSinkLog.ADD_ACTION))) assert(Set("0", "1") === listBatchFiles()) - sinkLog.add(2, Seq(newFakeSinkFileStatus("/a/b/2", FileStreamSinkLog.ADD_ACTION))) + sinkLog.add(2, Array(newFakeSinkFileStatus("/a/b/2", FileStreamSinkLog.ADD_ACTION))) assert(Set("2.compact") === listBatchFiles()) - sinkLog.add(3, Seq(newFakeSinkFileStatus("/a/b/3", FileStreamSinkLog.ADD_ACTION))) + sinkLog.add(3, Array(newFakeSinkFileStatus("/a/b/3", FileStreamSinkLog.ADD_ACTION))) assert(Set("2.compact", "3") === listBatchFiles()) - sinkLog.add(4, Seq(newFakeSinkFileStatus("/a/b/4", FileStreamSinkLog.ADD_ACTION))) + sinkLog.add(4, Array(newFakeSinkFileStatus("/a/b/4", FileStreamSinkLog.ADD_ACTION))) assert(Set("2.compact", "3", "4") === listBatchFiles()) - sinkLog.add(5, Seq(newFakeSinkFileStatus("/a/b/5", FileStreamSinkLog.ADD_ACTION))) + sinkLog.add(5, Array(newFakeSinkFileStatus("/a/b/5", FileStreamSinkLog.ADD_ACTION))) assert(Set("5.compact") === listBatchFiles()) } } @@ -263,7 +266,7 @@ class FileStreamSinkLogSuite extends SparkFunSuite with SharedSQLContext { private def withFileStreamSinkLog(f: FileStreamSinkLog => Unit): Unit = { withTempDir { file => - val sinkLog = new FileStreamSinkLog(spark, file.getCanonicalPath) + val sinkLog = new FileStreamSinkLog(FileStreamSinkLog.VERSION, spark, file.getCanonicalPath) f(sinkLog) } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/FileStreamSourceSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/FileStreamSourceSuite.scala new file mode 100644 index 000000000000..1793db0002af --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/FileStreamSourceSuite.scala @@ -0,0 +1,125 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.streaming + +import java.io.{File, FileNotFoundException} +import java.net.URI + +import scala.util.Random + +import org.apache.hadoop.fs.{FileStatus, Path, RawLocalFileSystem} + +import org.apache.spark.SparkFunSuite +import org.apache.spark.sql.execution.streaming.ExistsThrowsExceptionFileSystem._ +import org.apache.spark.sql.test.SharedSQLContext +import org.apache.spark.sql.types.StructType + +class FileStreamSourceSuite extends SparkFunSuite with SharedSQLContext { + + import FileStreamSource._ + + test("SeenFilesMap") { + val map = new SeenFilesMap(maxAgeMs = 10) + + map.add("a", 5) + assert(map.size == 1) + map.purge() + assert(map.size == 1) + + // Add a new entry and purge should be no-op, since the gap is exactly 10 ms. + map.add("b", 15) + assert(map.size == 2) + map.purge() + assert(map.size == 2) + + // Add a new entry that's more than 10 ms than the first entry. We should be able to purge now. + map.add("c", 16) + assert(map.size == 3) + map.purge() + assert(map.size == 2) + + // Override existing entry shouldn't change the size + map.add("c", 25) + assert(map.size == 2) + + // Not a new file because we have seen c before + assert(!map.isNewFile("c", 20)) + + // Not a new file because timestamp is too old + assert(!map.isNewFile("d", 5)) + + // Finally a new file: never seen and not too old + assert(map.isNewFile("e", 20)) + } + + test("SeenFilesMap should only consider a file old if it is earlier than last purge time") { + val map = new SeenFilesMap(maxAgeMs = 10) + + map.add("a", 20) + assert(map.size == 1) + + // Timestamp 5 should still considered a new file because purge time should be 0 + assert(map.isNewFile("b", 9)) + assert(map.isNewFile("b", 10)) + + // Once purge, purge time should be 10 and then b would be a old file if it is less than 10. + map.purge() + assert(!map.isNewFile("b", 9)) + assert(map.isNewFile("b", 10)) + } + + testWithUninterruptibleThread("do not recheck that files exist during getBatch") { + withTempDir { temp => + spark.conf.set( + s"fs.$scheme.impl", + classOf[ExistsThrowsExceptionFileSystem].getName) + // add the metadata entries as a pre-req + val dir = new File(temp, "dir") // use non-existent directory to test whether log make the dir + val metadataLog = + new FileStreamSourceLog(FileStreamSourceLog.VERSION, spark, dir.getAbsolutePath) + assert(metadataLog.add(0, Array(FileEntry(s"$scheme:///file1", 100L, 0)))) + + val newSource = new FileStreamSource(spark, s"$scheme:///", "parquet", StructType(Nil), + dir.getAbsolutePath, Map.empty) + // this method should throw an exception if `fs.exists` is called during resolveRelation + newSource.getBatch(None, LongOffset(1)) + } + } +} + +/** Fake FileSystem to test whether the method `fs.exists` is called during + * `DataSource.resolveRelation`. + */ +class ExistsThrowsExceptionFileSystem extends RawLocalFileSystem { + override def getUri: URI = { + URI.create(s"$scheme:///") + } + + override def exists(f: Path): Boolean = { + throw new IllegalArgumentException("Exists shouldn't have been called!") + } + + /** Simply return an empty file for now. */ + override def listStatus(file: Path): Array[FileStatus] = { + throw new FileNotFoundException("Folder was suddenly deleted but this should not make it fail!") + } +} + +object ExistsThrowsExceptionFileSystem { + val scheme = s"FileStreamSourceSuite${math.abs(Random.nextInt)}fs" +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/HDFSMetadataLogSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/HDFSMetadataLogSuite.scala index ab5a2d253b94..4259384f0bc6 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/HDFSMetadataLogSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/HDFSMetadataLogSuite.scala @@ -46,14 +46,14 @@ class HDFSMetadataLogSuite extends SparkFunSuite with SharedSQLContext { test("FileManager: FileContextManager") { withTempDir { temp => val path = new Path(temp.getAbsolutePath) - testManager(path, new FileContextManager(path, new Configuration)) + testFileManager(path, new FileContextManager(path, new Configuration)) } } test("FileManager: FileSystemManager") { withTempDir { temp => val path = new Path(temp.getAbsolutePath) - testManager(path, new FileSystemManager(path, new Configuration)) + testFileManager(path, new FileSystemManager(path, new Configuration)) } } @@ -103,6 +103,25 @@ class HDFSMetadataLogSuite extends SparkFunSuite with SharedSQLContext { } } + testWithUninterruptibleThread("HDFSMetadataLog: purge") { + withTempDir { temp => + val metadataLog = new HDFSMetadataLog[String](spark, temp.getAbsolutePath) + assert(metadataLog.add(0, "batch0")) + assert(metadataLog.add(1, "batch1")) + assert(metadataLog.add(2, "batch2")) + assert(metadataLog.get(0).isDefined) + assert(metadataLog.get(1).isDefined) + assert(metadataLog.get(2).isDefined) + assert(metadataLog.getLatest().get._1 == 2) + + metadataLog.purge(2) + assert(metadataLog.get(0).isEmpty) + assert(metadataLog.get(1).isEmpty) + assert(metadataLog.get(2).isDefined) + assert(metadataLog.getLatest().get._1 == 2) + } + } + testWithUninterruptibleThread("HDFSMetadataLog: restart") { withTempDir { temp => val metadataLog = new HDFSMetadataLog[String](spark, temp.getAbsolutePath) @@ -155,8 +174,8 @@ class HDFSMetadataLogSuite extends SparkFunSuite with SharedSQLContext { } } - - def testManager(basePath: Path, fm: FileManager): Unit = { + /** Basic test case for [[FileManager]] implementation. */ + private def testFileManager(basePath: Path, fm: FileManager): Unit = { // Mkdirs val dir = new Path(s"$basePath/dir/subdir/subsubdir") assert(!fm.exists(dir)) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/vectorized/ColumnarBatchSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/vectorized/ColumnarBatchSuite.scala index 100cc4daca87..e3943f31a48b 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/vectorized/ColumnarBatchSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/vectorized/ColumnarBatchSuite.scala @@ -802,8 +802,8 @@ class ColumnarBatchSuite extends SparkFunSuite { // Over-allocating beyond MAX_CAPACITY throws an exception column.appendBytes(10, 0.toByte) } - assert(ex.getMessage.contains(s"Cannot reserve more than ${column.MAX_CAPACITY} bytes in " + - s"the vectorized reader")) + assert(ex.getMessage.contains(s"Cannot reserve additional contiguous bytes in the " + + s"vectorized reader")) } } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/expressions/ReduceAggregatorSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/expressions/ReduceAggregatorSuite.scala new file mode 100644 index 000000000000..d826d3f54d92 --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/expressions/ReduceAggregatorSuite.scala @@ -0,0 +1,73 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.expressions + +import org.apache.spark.SparkFunSuite +import org.apache.spark.sql.Encoders +import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder + +class ReduceAggregatorSuite extends SparkFunSuite { + + test("zero value") { + val encoder: ExpressionEncoder[Int] = ExpressionEncoder() + val func = (v1: Int, v2: Int) => v1 + v2 + val aggregator: ReduceAggregator[Int] = new ReduceAggregator(func)(Encoders.scalaInt) + assert(aggregator.zero == (false, null)) + } + + test("reduce, merge and finish") { + val encoder: ExpressionEncoder[Int] = ExpressionEncoder() + val func = (v1: Int, v2: Int) => v1 + v2 + val aggregator: ReduceAggregator[Int] = new ReduceAggregator(func)(Encoders.scalaInt) + + val firstReduce = aggregator.reduce(aggregator.zero, 1) + assert(firstReduce == (true, 1)) + + val secondReduce = aggregator.reduce(firstReduce, 2) + assert(secondReduce == (true, 3)) + + val thirdReduce = aggregator.reduce(secondReduce, 3) + assert(thirdReduce == (true, 6)) + + val mergeWithZero1 = aggregator.merge(aggregator.zero, firstReduce) + assert(mergeWithZero1 == (true, 1)) + + val mergeWithZero2 = aggregator.merge(secondReduce, aggregator.zero) + assert(mergeWithZero2 == (true, 3)) + + val mergeTwoReduced = aggregator.merge(firstReduce, secondReduce) + assert(mergeTwoReduced == (true, 4)) + + assert(aggregator.finish(firstReduce)== 1) + assert(aggregator.finish(secondReduce) == 3) + assert(aggregator.finish(thirdReduce) == 6) + assert(aggregator.finish(mergeWithZero1) == 1) + assert(aggregator.finish(mergeWithZero2) == 3) + assert(aggregator.finish(mergeTwoReduced) == 4) + } + + test("requires at least one input row") { + val encoder: ExpressionEncoder[Int] = ExpressionEncoder() + val func = (v1: Int, v2: Int) => v1 + v2 + val aggregator: ReduceAggregator[Int] = new ReduceAggregator(func)(Encoders.scalaInt) + + intercept[IllegalStateException] { + aggregator.finish(aggregator.zero) + } + } +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/internal/CatalogSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/internal/CatalogSuite.scala index d75df56dd608..e62ae38cd35a 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/internal/CatalogSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/internal/CatalogSuite.scala @@ -304,6 +304,17 @@ class CatalogSuite columnFields.foreach { f => assert(columnString.contains(f.toString)) } } + test("dropTempView should not un-cache and drop metastore table if a same-name table exists") { + withTable("same_name") { + spark.range(10).write.saveAsTable("same_name") + sql("CACHE TABLE same_name") + assert(spark.catalog.isCached("default.same_name")) + spark.catalog.dropTempView("same_name") + assert(spark.sessionState.catalog.tableExists(TableIdentifier("same_name", Some("default")))) + assert(spark.catalog.isCached("default.same_name")) + } + } + // TODO: add tests for the rest of them } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/internal/SQLConfSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/internal/SQLConfSuite.scala index 2cd3f475b6c0..761bbe3576c7 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/internal/SQLConfSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/internal/SQLConfSuite.scala @@ -17,6 +17,8 @@ package org.apache.spark.sql.internal +import org.apache.hadoop.fs.Path + import org.apache.spark.sql.{QueryTest, Row, SparkSession, SQLContext} import org.apache.spark.sql.execution.WholeStageCodegenExec import org.apache.spark.sql.test.{SharedSQLContext, TestSQLContext} @@ -28,7 +30,7 @@ class SQLConfSuite extends QueryTest with SharedSQLContext { test("propagate from spark conf") { // We create a new context here to avoid order dependence with other tests that might call // clear(). - val newContext = new SQLContext(sparkContext) + val newContext = new SQLContext(SparkSession.builder().sparkContext(sparkContext).getOrCreate()) assert(newContext.getConf("spark.sql.testkey", "false") === "true") } @@ -214,7 +216,7 @@ class SQLConfSuite extends QueryTest with SharedSQLContext { // to get the default value, always unset it spark.conf.unset(SQLConf.WAREHOUSE_PATH.key) assert(spark.sessionState.conf.warehousePath - === s"file:${System.getProperty("user.dir")}/spark-warehouse") + === new Path(s"${System.getProperty("user.dir")}/spark-warehouse").toString) } finally { sql(s"set ${SQLConf.WAREHOUSE_PATH}=$original") } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCSuite.scala index 995b1200a229..1a6dba82b0e2 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCSuite.scala @@ -289,7 +289,7 @@ class JDBCSuite extends SparkFunSuite assert(names(2).equals("mary")) } - test("SELECT first field when fetchSize is two") { + test("SELECT first field when fetchsize is two") { val names = sql("SELECT NAME FROM fetchtwo").collect().map(x => x.getString(0)).sortWith(_ < _) assert(names.size === 3) assert(names(0).equals("fred")) @@ -305,7 +305,7 @@ class JDBCSuite extends SparkFunSuite assert(ids(2) === 3) } - test("SELECT second field when fetchSize is two") { + test("SELECT second field when fetchsize is two") { val ids = sql("SELECT THEID FROM fetchtwo").collect().map(x => x.getInt(0)).sortWith(_ < _) assert(ids.size === 3) assert(ids(0) === 1) @@ -352,7 +352,7 @@ class JDBCSuite extends SparkFunSuite urlWithUserAndPass, "TEST.PEOPLE", new Properties()).collect().length === 3) } - test("Basic API with illegal FetchSize") { + test("Basic API with illegal fetchsize") { val properties = new Properties() properties.setProperty(JdbcUtils.JDBC_BATCH_FETCH_SIZE, "-1") val e = intercept[SparkException] { @@ -770,4 +770,12 @@ class JDBCSuite extends SparkFunSuite val schema = JdbcUtils.schemaString(df, "jdbc:mysql://localhost:3306/temp") assert(schema.contains("`order` TEXT")) } + + test("SPARK-17673: Exchange reuse respects differences in output schema") { + val df = sql("SELECT * FROM inttypes WHERE a IS NOT NULL") + val df1 = df.groupBy("a").agg("c" -> "min") + val df2 = df.groupBy("a").agg("d" -> "min") + val res = df1.union(df2) + assert(res.distinct().count() == 2) // would be 1 if the exchange was incorrectly reused + } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/sources/DDLTestSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/sources/DDLTestSuite.scala index 5a7a9073fb3a..c2aedfff348a 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/sources/DDLTestSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/sources/DDLTestSuite.scala @@ -98,21 +98,21 @@ class DDLTestSuite extends DataSourceTest with SharedSQLContext { "describe ddlPeople", Seq( Row("intType", "int", "test comment test1"), - Row("stringType", "string", ""), - Row("dateType", "date", ""), - Row("timestampType", "timestamp", ""), - Row("doubleType", "double", ""), - Row("bigintType", "bigint", ""), - Row("tinyintType", "tinyint", ""), - Row("decimalType", "decimal(10,0)", ""), - Row("fixedDecimalType", "decimal(5,1)", ""), - Row("binaryType", "binary", ""), - Row("booleanType", "boolean", ""), - Row("smallIntType", "smallint", ""), - Row("floatType", "float", ""), - Row("mapType", "map", ""), - Row("arrayType", "array", ""), - Row("structType", "struct", "") + Row("stringType", "string", null), + Row("dateType", "date", null), + Row("timestampType", "timestamp", null), + Row("doubleType", "double", null), + Row("bigintType", "bigint", null), + Row("tinyintType", "tinyint", null), + Row("decimalType", "decimal(10,0)", null), + Row("fixedDecimalType", "decimal(5,1)", null), + Row("binaryType", "binary", null), + Row("booleanType", "boolean", null), + Row("smallIntType", "smallint", null), + Row("floatType", "float", null), + Row("mapType", "map", null), + Row("arrayType", "array", null), + Row("structType", "struct", null) )) test("SPARK-7686 DescribeCommand should have correct physical plan output attributes") { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/sources/ResolvedDataSourceSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/sources/ResolvedDataSourceSuite.scala index 5ea1f3243369..76ffb949f129 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/sources/ResolvedDataSourceSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/sources/ResolvedDataSourceSuite.scala @@ -74,16 +74,16 @@ class ResolvedDataSourceSuite extends SparkFunSuite { val error1 = intercept[AnalysisException] { getProvidingClass("avro") } - assert(error1.getMessage.contains("spark-packages")) + assert(error1.getMessage.contains("Failed to find data source: avro.")) val error2 = intercept[AnalysisException] { getProvidingClass("com.databricks.spark.avro") } - assert(error2.getMessage.contains("spark-packages")) + assert(error2.getMessage.contains("Failed to find data source: com.databricks.spark.avro.")) val error3 = intercept[ClassNotFoundException] { getProvidingClass("asfdwefasdfasdf") } - assert(error3.getMessage.contains("spark-packages")) + assert(error3.getMessage.contains("Failed to find data source: asfdwefasdfasdf.")) } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/FileStreamSourceSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/FileStreamSourceSuite.scala index 47260a23c7ee..55c95ae285c1 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/FileStreamSourceSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/FileStreamSourceSuite.scala @@ -18,7 +18,9 @@ package org.apache.spark.sql.streaming import java.io.File -import java.util.UUID + +import org.scalatest.PrivateMethodTester +import org.scalatest.time.SpanSugar._ import org.apache.spark.sql._ import org.apache.spark.sql.catalyst.util._ @@ -28,7 +30,7 @@ import org.apache.spark.sql.test.SharedSQLContext import org.apache.spark.sql.types._ import org.apache.spark.util.Utils -class FileStreamSourceTest extends StreamTest with SharedSQLContext { +class FileStreamSourceTest extends StreamTest with SharedSQLContext with PrivateMethodTester { import testImplicits._ @@ -104,12 +106,13 @@ class FileStreamSourceTest extends StreamTest with SharedSQLContext { def createFileStream( format: String, path: String, - schema: Option[StructType] = None): DataFrame = { + schema: Option[StructType] = None, + options: Map[String, String] = Map.empty): DataFrame = { val reader = if (schema.isDefined) { - spark.readStream.format(format).schema(schema.get) + spark.readStream.format(format).schema(schema.get).options(options) } else { - spark.readStream.format(format) + spark.readStream.format(format).options(options) } reader.load(path) } @@ -141,6 +144,8 @@ class FileStreamSourceSuite extends FileStreamSourceTest { import testImplicits._ + override val streamingTimeout = 20.seconds + /** Use `format` and `path` to create FileStreamSource via DataFrameReader */ private def createFileStreamSource( format: String, @@ -331,6 +336,42 @@ class FileStreamSourceSuite extends FileStreamSourceTest { } } + test("SPARK-17165 should not track the list of seen files indefinitely") { + // This test works by: + // 1. Create a file + // 2. Get it processed + // 3. Sleeps for a very short amount of time (larger than maxFileAge + // 4. Add another file (at this point the original file should have been purged + // 5. Test the size of the seenFiles internal data structure + + // Note that if we change maxFileAge to a very large number, the last step should fail. + withTempDirs { case (src, tmp) => + val textStream: DataFrame = + createFileStream("text", src.getCanonicalPath, options = Map("maxFileAge" -> "5ms")) + + testStream(textStream)( + AddTextFileData("a\nb", src, tmp), + CheckAnswer("a", "b"), + + // SLeeps longer than 5ms (maxFileAge) + // Unfortunately since a lot of file system does not have modification time granularity + // finer grained than 1 sec, we need to use 1 sec here. + AssertOnQuery { _ => Thread.sleep(1000); true }, + + AddTextFileData("c\nd", src, tmp), + CheckAnswer("a", "b", "c", "d"), + + AssertOnQuery("seen files should contain only one entry") { streamExecution => + val source = streamExecution.logicalPlan.collect { case e: StreamingExecutionRelation => + e.source.asInstanceOf[FileStreamSource] + }.head + assert(source.seenFiles.size == 1) + true + } + ) + } + } + // =============== JSON file stream tests ================ test("read from json files") { @@ -727,6 +768,137 @@ class FileStreamSourceSuite extends FileStreamSourceTest { } } } + + test("SPARK-17372 - write file names to WAL as Array[String]") { + // Note: If this test takes longer than the timeout, then its likely that this is actually + // running a Spark job with 10000 tasks. This test tries to avoid that by + // 1. Setting the threshold for parallel file listing to very high + // 2. Using a query that should use constant folding to eliminate reading of the files + + val numFiles = 10000 + + // This is to avoid running a spark job to list of files in parallel + // by the ListingFileCatalog. + spark.sessionState.conf.setConf(SQLConf.PARALLEL_PARTITION_DISCOVERY_THRESHOLD, numFiles * 2) + + withTempDirs { case (root, tmp) => + val src = new File(root, "a=1") + src.mkdirs() + + (1 to numFiles).map { _.toString }.foreach { i => + val tempFile = Utils.tempFileWith(new File(tmp, "text")) + val finalFile = new File(src, tempFile.getName) + stringToFile(finalFile, i) + } + assert(src.listFiles().size === numFiles) + + val files = spark.readStream.text(root.getCanonicalPath).as[String] + + // Note this query will use constant folding to eliminate the file scan. + // This is to avoid actually running a Spark job with 10000 tasks + val df = files.filter("1 == 0").groupBy().count() + + testStream(df, InternalOutputModes.Complete)( + AddTextFileData("0", src, tmp), + CheckAnswer(0) + ) + } + } + + test("compacat metadata log") { + val _sources = PrivateMethod[Seq[Source]]('sources) + val _metadataLog = PrivateMethod[FileStreamSourceLog]('metadataLog) + + def verify(execution: StreamExecution) + (batchId: Long, expectedBatches: Int): Boolean = { + import CompactibleFileStreamLog._ + + val fileSource = (execution invokePrivate _sources()).head.asInstanceOf[FileStreamSource] + val metadataLog = fileSource invokePrivate _metadataLog() + + if (isCompactionBatch(batchId, 2)) { + val path = metadataLog.batchIdToPath(batchId) + + // Assert path name should be ended with compact suffix. + assert(path.getName.endsWith(COMPACT_FILE_SUFFIX)) + + // Compacted batch should include all entries from start. + val entries = metadataLog.get(batchId) + assert(entries.isDefined) + assert(entries.get.length === metadataLog.allFiles().length) + assert(metadataLog.get(None, Some(batchId)).flatMap(_._2).length === entries.get.length) + } + + assert(metadataLog.allFiles().sortBy(_.batchId) === + metadataLog.get(None, Some(batchId)).flatMap(_._2).sortBy(_.batchId)) + + metadataLog.get(None, Some(batchId)).flatMap(_._2).length === expectedBatches + } + + withTempDirs { case (src, tmp) => + withSQLConf( + SQLConf.FILE_SOURCE_LOG_COMPACT_INTERVAL.key -> "2" + ) { + val fileStream = createFileStream("text", src.getCanonicalPath) + val filtered = fileStream.filter($"value" contains "keep") + + testStream(filtered)( + AddTextFileData("drop1\nkeep2\nkeep3", src, tmp), + CheckAnswer("keep2", "keep3"), + AssertOnQuery(verify(_)(0L, 1)), + AddTextFileData("drop4\nkeep5\nkeep6", src, tmp), + CheckAnswer("keep2", "keep3", "keep5", "keep6"), + AssertOnQuery(verify(_)(1L, 2)), + AddTextFileData("drop7\nkeep8\nkeep9", src, tmp), + CheckAnswer("keep2", "keep3", "keep5", "keep6", "keep8", "keep9"), + AssertOnQuery(verify(_)(2L, 3)), + StopStream, + StartStream(), + AssertOnQuery(verify(_)(2L, 3)), + AddTextFileData("drop10\nkeep11", src, tmp), + CheckAnswer("keep2", "keep3", "keep5", "keep6", "keep8", "keep9", "keep11"), + AssertOnQuery(verify(_)(3L, 4)), + AddTextFileData("drop12\nkeep13", src, tmp), + CheckAnswer("keep2", "keep3", "keep5", "keep6", "keep8", "keep9", "keep11", "keep13"), + AssertOnQuery(verify(_)(4L, 5)) + ) + } + } + } + + test("get arbitrary batch from FileStreamSource") { + withTempDirs { case (src, tmp) => + withSQLConf( + SQLConf.FILE_SOURCE_LOG_COMPACT_INTERVAL.key -> "2", + // Force deleting the old logs + SQLConf.FILE_SOURCE_LOG_CLEANUP_DELAY.key -> "1" + ) { + val fileStream = createFileStream("text", src.getCanonicalPath) + val filtered = fileStream.filter($"value" contains "keep") + + testStream(filtered)( + AddTextFileData("keep1", src, tmp), + CheckAnswer("keep1"), + AddTextFileData("keep2", src, tmp), + CheckAnswer("keep1", "keep2"), + AddTextFileData("keep3", src, tmp), + CheckAnswer("keep1", "keep2", "keep3"), + AssertOnQuery("check getBatch") { execution: StreamExecution => + val _sources = PrivateMethod[Seq[Source]]('sources) + val fileSource = + (execution invokePrivate _sources()).head.asInstanceOf[FileStreamSource] + assert(fileSource.getBatch(None, LongOffset(2)).as[String].collect() === + List("keep1", "keep2", "keep3")) + assert(fileSource.getBatch(Some(LongOffset(0)), LongOffset(2)).as[String].collect() === + List("keep2", "keep3")) + assert(fileSource.getBatch(Some(LongOffset(1)), LongOffset(2)).as[String].collect() === + List("keep3")) + true + } + ) + } + } + } } class FileStreamSourceStressTestSuite extends FileStreamSourceTest { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamTest.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamTest.scala index af2b58116b2a..6c5b170d9c7c 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamTest.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamTest.scala @@ -188,8 +188,8 @@ trait StreamTest extends QueryTest with SharedSQLContext with Timeouts { new AssertOnQuery(condition, message) } - def apply(message: String)(condition: StreamExecution => Unit): AssertOnQuery = { - new AssertOnQuery(s => { condition(s); true }, message) + def apply(message: String)(condition: StreamExecution => Boolean): AssertOnQuery = { + new AssertOnQuery(condition, message) } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingQueryListenerSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingQueryListenerSuite.scala index 7f4d28cf0598..831543a47420 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingQueryListenerSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingQueryListenerSuite.scala @@ -66,6 +66,7 @@ class StreamingQueryListenerSuite extends StreamTest with BeforeAndAfter { // No progress events or termination events assert(listener.progressStatuses.isEmpty) assert(listener.terminationStatus === null) + true }, AddDataMemory(input, Seq(1, 2, 3)), CheckAnswer(1, 2, 3), @@ -84,6 +85,7 @@ class StreamingQueryListenerSuite extends StreamTest with BeforeAndAfter { // No termination events assert(listener.terminationStatus === null) } + true }, StopStream, AssertOnQuery("Incorrect query status in onQueryTerminated") { query => @@ -94,10 +96,10 @@ class StreamingQueryListenerSuite extends StreamTest with BeforeAndAfter { assert(status.id === query.id) assert(status.sourceStatuses(0).offsetDesc === Some(LongOffset(0).toString)) assert(status.sinkStatus.offsetDesc === CompositeOffset.fill(LongOffset(0)).toString) - assert(listener.terminationStackTrace.isEmpty) assert(listener.terminationException === None) } listener.checkAsyncErrors() + true } ) } @@ -147,7 +149,7 @@ class StreamingQueryListenerSuite extends StreamTest with BeforeAndAfter { } } - test("exception should be reported in QueryTerminated") { + testQuietly("exception should be reported in QueryTerminated") { val listener = new QueryStatusCollector withListenerAdded(listener) { val input = MemoryStream[Int] @@ -159,8 +161,11 @@ class StreamingQueryListenerSuite extends StreamTest with BeforeAndAfter { spark.sparkContext.listenerBus.waitUntilEmpty(10000) assert(listener.terminationStatus !== null) assert(listener.terminationException.isDefined) + // Make sure that the exception message reported through listener + // contains the actual exception and relevant stack trace + assert(!listener.terminationException.get.contains("StreamingQueryException")) assert(listener.terminationException.get.contains("java.lang.ArithmeticException")) - assert(listener.terminationStackTrace.nonEmpty) + assert(listener.terminationException.get.contains("StreamingQueryListenerSuite")) } ) } @@ -205,8 +210,7 @@ class StreamingQueryListenerSuite extends StreamTest with BeforeAndAfter { val exception = new RuntimeException("exception") val queryQueryTerminated = new StreamingQueryListener.QueryTerminated( queryTerminatedInfo, - Some(exception.getMessage), - exception.getStackTrace) + Some(exception.getMessage)) val json = JsonProtocol.sparkEventToJson(queryQueryTerminated) val newQueryTerminated = JsonProtocol.sparkEventFromJson(json) @@ -262,7 +266,6 @@ class StreamingQueryListenerSuite extends StreamTest with BeforeAndAfter { @volatile var startStatus: StreamingQueryInfo = null @volatile var terminationStatus: StreamingQueryInfo = null @volatile var terminationException: Option[String] = null - @volatile var terminationStackTrace: Seq[StackTraceElement] = null val progressStatuses = new ConcurrentLinkedQueue[StreamingQueryInfo] @@ -296,7 +299,6 @@ class StreamingQueryListenerSuite extends StreamTest with BeforeAndAfter { assert(startStatus != null, "onQueryTerminated called before onQueryStarted") terminationStatus = queryTerminated.queryInfo terminationException = queryTerminated.exception - terminationStackTrace = queryTerminated.stackTrace } asyncTestWaiter.dismiss() } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingQuerySuite.scala index 9d58315c2003..88f1f188ab2a 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingQuerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingQuerySuite.scala @@ -125,6 +125,30 @@ class StreamingQuerySuite extends StreamTest with BeforeAndAfter { ) } + testQuietly("StreamExecution metadata garbage collection") { + val inputData = MemoryStream[Int] + val mapped = inputData.toDS().map(6 / _) + + // Run 3 batches, and then assert that only 1 metadata file is left at the end + // since the first 2 should have been purged. + testStream(mapped)( + AddData(inputData, 1, 2), + CheckAnswer(6, 3), + AddData(inputData, 1, 2), + CheckAnswer(6, 3, 6, 3), + AddData(inputData, 4, 6), + CheckAnswer(6, 3, 6, 3, 1, 1), + + AssertOnQuery("metadata log should contain only one file") { q => + val metadataLogDir = new java.io.File(q.offsetLog.metadataPath.toString) + val logFileNames = metadataLogDir.listFiles().toSeq.map(_.getName()) + val toTest = logFileNames.filter(! _.endsWith(".crc")) // Workaround for SPARK-17475 + assert(toTest.size == 1 && toTest.head == "2") + true + } + ) + } + /** * A [[StreamAction]] to test the behavior of `StreamingQuery.awaitTermination()`. * diff --git a/sql/core/src/test/scala/org/apache/spark/sql/test/DataFrameReaderWriterSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/test/DataFrameReaderWriterSuite.scala index 27a0a2a776c3..bba265e9c934 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/test/DataFrameReaderWriterSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/test/DataFrameReaderWriterSuite.scala @@ -22,6 +22,7 @@ import java.io.File import org.scalatest.BeforeAndAfter import org.apache.spark.sql._ +import org.apache.spark.sql.catalyst.TableIdentifier import org.apache.spark.sql.sources._ import org.apache.spark.sql.types.{StringType, StructField, StructType} import org.apache.spark.util.Utils @@ -424,6 +425,14 @@ class DataFrameReaderWriterSuite extends QueryTest with SharedSQLContext with Be spark.range(10).write.orc(dir) } + test("SPARK-17230: write out results of decimal calculation") { + val df = spark.range(99, 101) + .selectExpr("id", "cast(id as long) * cast('1.0' as decimal(38, 18)) as num") + df.write.mode(SaveMode.Overwrite).parquet(dir) + val df2 = spark.read.parquet(dir) + checkAnswer(df2, df) + } + private def testRead( df: => DataFrame, expectedResult: Seq[String], @@ -431,4 +440,79 @@ class DataFrameReaderWriterSuite extends QueryTest with SharedSQLContext with Be checkAnswer(df, spark.createDataset(expectedResult).toDF()) assert(df.schema === expectedSchema) } + + test("saveAsTable with mode Append should not fail if the table not exists " + + "but a same-name temp view exist") { + withTable("same_name") { + withTempView("same_name") { + spark.range(10).createTempView("same_name") + spark.range(20).write.mode(SaveMode.Append).saveAsTable("same_name") + assert( + spark.sessionState.catalog.tableExists(TableIdentifier("same_name", Some("default")))) + } + } + } + + test("saveAsTable with mode Append should not fail if the table already exists " + + "and a same-name temp view exist") { + withTable("same_name") { + withTempView("same_name") { + sql("CREATE TABLE same_name(id LONG) USING parquet") + spark.range(10).createTempView("same_name") + spark.range(20).write.mode(SaveMode.Append).saveAsTable("same_name") + checkAnswer(spark.table("same_name"), spark.range(10).toDF()) + checkAnswer(spark.table("default.same_name"), spark.range(20).toDF()) + } + } + } + + test("saveAsTable with mode ErrorIfExists should not fail if the table not exists " + + "but a same-name temp view exist") { + withTable("same_name") { + withTempView("same_name") { + spark.range(10).createTempView("same_name") + spark.range(20).write.mode(SaveMode.ErrorIfExists).saveAsTable("same_name") + assert( + spark.sessionState.catalog.tableExists(TableIdentifier("same_name", Some("default")))) + } + } + } + + test("saveAsTable with mode Overwrite should not drop the temp view if the table not exists " + + "but a same-name temp view exist") { + withTable("same_name") { + withTempView("same_name") { + spark.range(10).createTempView("same_name") + spark.range(20).write.mode(SaveMode.Overwrite).saveAsTable("same_name") + assert(spark.sessionState.catalog.getTempView("same_name").isDefined) + assert( + spark.sessionState.catalog.tableExists(TableIdentifier("same_name", Some("default")))) + } + } + } + + test("saveAsTable with mode Overwrite should not fail if the table already exists " + + "and a same-name temp view exist") { + withTable("same_name") { + withTempView("same_name") { + sql("CREATE TABLE same_name(id LONG) USING parquet") + spark.range(10).createTempView("same_name") + spark.range(20).write.mode(SaveMode.Overwrite).saveAsTable("same_name") + checkAnswer(spark.table("same_name"), spark.range(10).toDF()) + checkAnswer(spark.table("default.same_name"), spark.range(20).toDF()) + } + } + } + + test("saveAsTable with mode Ignore should create the table if the table not exists " + + "but a same-name temp view exist") { + withTable("same_name") { + withTempView("same_name") { + spark.range(10).createTempView("same_name") + spark.range(20).write.mode(SaveMode.Ignore).saveAsTable("same_name") + assert( + spark.sessionState.catalog.tableExists(TableIdentifier("same_name", Some("default")))) + } + } + } } diff --git a/sql/hive-thriftserver/build.gradle b/sql/hive-thriftserver/build.gradle new file mode 100644 index 000000000000..6dd72cdbd08a --- /dev/null +++ b/sql/hive-thriftserver/build.gradle @@ -0,0 +1,85 @@ +/* + * Copyright (c) 2016 SnappyData, Inc. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you + * may not use this file except in compliance with the License. You + * may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or + * implied. See the License for the specific language governing + * permissions and limitations under the License. See accompanying + * LICENSE file. + */ + +description = 'Spark Project Hive Thrift Server' + +dependencies { + compile project(subprojectBase + 'snappy-spark-hive_' + scalaBinaryVersion) + compile project(subprojectBase + 'snappy-spark-tags_' + scalaBinaryVersion) + + compile group: 'com.google.guava', name: 'guava', version: guavaVersion + compile(group: 'org.spark-project.hive', name: 'hive-cli', version: hiveVersion) { + exclude(group: 'org.spark-project.hive', module: 'hive-common') + exclude(group: 'org.spark-project.hive', module: 'hive-exec') + exclude(group: 'org.spark-project.hive', module: 'hive-jdbc') + exclude(group: 'org.spark-project.hive', module: 'hive-metastore') + exclude(group: 'org.spark-project.hive', module: 'hive-serde') + exclude(group: 'org.spark-project.hive', module: 'hive-service') + exclude(group: 'org.spark-project.hive', module: 'hive-shims') + exclude(group: 'org.apache.thrift', module: 'libthrift') + exclude(group: 'org.slf4j', module: 'slf4j-api') + exclude(group: 'org.slf4j', module: 'slf4j-log4j12') + exclude(group: 'log4j', module: 'log4j') + exclude(group: 'commons-logging', module: 'commons-logging') + } + compile(group: 'org.spark-project.hive', name: 'hive-beeline', version: hiveVersion) { + exclude(group: 'org.spark-project.hive', module: 'hive-common') + exclude(group: 'org.spark-project.hive', module: 'hive-exec') + exclude(group: 'org.spark-project.hive', module: 'hive-jdbc') + exclude(group: 'org.spark-project.hive', module: 'hive-metastore') + exclude(group: 'org.spark-project.hive', module: 'hive-service') + exclude(group: 'org.spark-project.hive', module: 'hive-shims') + exclude(group: 'org.apache.thrift', module: 'libthrift') + exclude(group: 'org.slf4j', module: 'slf4j-api') + exclude(group: 'org.slf4j', module: 'slf4j-log4j12') + exclude(group: 'log4j', module: 'log4j') + exclude(group: 'commons-logging', module: 'commons-logging') + } + compile(group: 'org.spark-project.hive', name: 'hive-jdbc', version: hiveVersion) { + exclude(group: 'org.spark-project.hive', module: 'hive-common') + exclude(group: 'org.spark-project.hive', module: 'hive-metastore') + exclude(group: 'org.spark-project.hive', module: 'hive-serde') + exclude(group: 'org.spark-project.hive', module: 'hive-service') + exclude(group: 'org.spark-project.hive', module: 'hive-shims') + exclude(group: 'org.spark-project.hive', module: 'httpclient') + exclude(group: 'org.apache.curator', module: 'curator-framework') + exclude(group: 'org.apache.thrift', module: 'libthrift') + exclude(group: 'org.apache.thrift', module: 'libfb303') + exclude(group: 'org.apache.zookeeper', module: 'zookeeper') + exclude(group: 'org.slf4j', module: 'slf4j-api') + exclude(group: 'org.slf4j', module: 'slf4j-log4j12') + exclude(group: 'log4j', module: 'log4j') + exclude(group: 'commons-logging', module: 'commons-logging') + exclude(group: 'org.codehaus.groovy', module: 'groovy-all') + } + compile(group: 'net.sf.jpam', name: 'jpam', version: jpamVersion) { + exclude(group: 'javax.servlet', module: 'servlet-api') + } + + testCompile project(path: subprojectBase + 'snappy-spark-core_' + scalaBinaryVersion, configuration: 'testOutput') + testCompile project(path: subprojectBase + 'snappy-spark-sql_' + scalaBinaryVersion, configuration: 'testOutput') + testCompile(group: 'org.seleniumhq.selenium', name: 'selenium-java', version: seleniumVersion) { + exclude(group: 'com.google.guava', module: 'guava') + exclude(group: 'io.netty', module: 'netty') + } + testCompile(group: 'org.seleniumhq.selenium', name: 'selenium-htmlunit-driver', version: seleniumVersion) { + exclude(group: 'com.google.guava', module: 'guava') + } +} + +// add generated sources +sourceSets.main.scala.srcDir 'src/gen/java' diff --git a/sql/hive-thriftserver/pom.xml b/sql/hive-thriftserver/pom.xml index 672425c86ecb..c47d5f028516 100644 --- a/sql/hive-thriftserver/pom.xml +++ b/sql/hive-thriftserver/pom.xml @@ -22,7 +22,7 @@ org.apache.spark spark-parent_2.11 - 2.0.1-SNAPSHOT + 2.0.1 ../../pom.xml diff --git a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkExecuteStatementOperation.scala b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkExecuteStatementOperation.scala index e8bcdd76efd7..b2717ec54e69 100644 --- a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkExecuteStatementOperation.scala +++ b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkExecuteStatementOperation.scala @@ -51,6 +51,7 @@ private[hive] class SparkExecuteStatementOperation( private var result: DataFrame = _ private var iter: Iterator[SparkRow] = _ + private var iterHeader: Iterator[SparkRow] = _ private var dataTypes: Array[DataType] = _ private var statementId: String = _ @@ -110,6 +111,14 @@ private[hive] class SparkExecuteStatementOperation( assertState(OperationState.FINISHED) setHasResultSet(true) val resultRowSet: RowSet = RowSetFactory.create(getResultSetSchema, getProtocolVersion) + + // Reset iter to header when fetching start from first row + if (order.equals(FetchOrientation.FETCH_FIRST)) { + val (ita, itb) = iterHeader.duplicate + iter = ita + iterHeader = itb + } + if (!iter.hasNext) { resultRowSet } else { @@ -228,6 +237,9 @@ private[hive] class SparkExecuteStatementOperation( result.collect().iterator } } + val (itra, itrb) = iter.duplicate + iterHeader = itra + iter = itrb dataTypes = result.queryExecution.analyzed.output.map(_.dataType).toArray } catch { case e: HiveSQLException => diff --git a/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/CliSuite.scala b/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/CliSuite.scala index d3cec11bd756..6a383592cff3 100644 --- a/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/CliSuite.scala +++ b/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/CliSuite.scala @@ -14,6 +14,24 @@ * See the License for the specific language governing permissions and * limitations under the License. */ +/* + * Changes for SnappyData data platform. + * + * Portions Copyright (c) 2016 SnappyData, Inc. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you + * may not use this file except in compliance with the License. You + * may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or + * implied. See the License for the specific language governing + * permissions and limitations under the License. See accompanying + * LICENSE file. + */ package org.apache.spark.sql.hive.thriftserver @@ -42,6 +60,8 @@ class CliSuite extends SparkFunSuite with BeforeAndAfterAll with Logging { val warehousePath = Utils.createTempDir() val metastorePath = Utils.createTempDir() val scratchDirPath = Utils.createTempDir() + val sparkHome = new File(sys.props.getOrElse("spark.test.home", + fail("spark.test.home is not set!"))) override def beforeAll(): Unit = { super.beforeAll() @@ -82,7 +102,7 @@ class CliSuite extends SparkFunSuite with BeforeAndAfterAll with Logging { val queriesString = queries.map(_ + "\n").mkString val command = { - val cliScript = "../../bin/spark-sql".split("/").mkString(File.separator) + val cliScript = "./bin/spark-sql".split("/").mkString(File.separator) val jdbcUrl = s"jdbc:derby:;databaseName=$metastorePath;create=true" s"""$cliScript | --master local @@ -123,7 +143,7 @@ class CliSuite extends SparkFunSuite with BeforeAndAfterAll with Logging { } } - val process = new ProcessBuilder(command: _*).start() + val process = new ProcessBuilder(command: _*).directory(sparkHome).start() val stdinWriter = new OutputStreamWriter(process.getOutputStream, StandardCharsets.UTF_8) stdinWriter.write(queriesString) @@ -200,8 +220,9 @@ class CliSuite extends SparkFunSuite with BeforeAndAfterAll with Logging { } test("Commands using SerDe provided in --jars") { - val jarFile = - "../hive/src/test/resources/hive-hcatalog-core-0.13.1.jar" + val jar = "hive/src/test/resources/hive-hcatalog-core-0.13.1.jar" + val jarFile = sys.props.get("spark.project.home").map( + _ + "/sql/" + jar).getOrElse("../" + jar) .split("/") .mkString(File.separator) diff --git a/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/HiveThriftServer2Suites.scala b/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/HiveThriftServer2Suites.scala index e388c2a082f1..fea18526044a 100644 --- a/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/HiveThriftServer2Suites.scala +++ b/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/HiveThriftServer2Suites.scala @@ -14,6 +14,24 @@ * See the License for the specific language governing permissions and * limitations under the License. */ +/* + * Changes for SnappyData data platform. + * + * Portions Copyright (c) 2016 SnappyData, Inc. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you + * may not use this file except in compliance with the License. You + * may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or + * implied. See the License for the specific language governing + * permissions and limitations under the License. See accompanying + * LICENSE file. + */ package org.apache.spark.sql.hive.thriftserver @@ -36,6 +54,8 @@ import org.apache.hive.service.auth.PlainSaslHelper import org.apache.hive.service.cli.GetInfoType import org.apache.hive.service.cli.thrift.TCLIService.Client import org.apache.hive.service.cli.thrift.ThriftCLIServiceClient +import org.apache.hive.service.cli.FetchOrientation +import org.apache.hive.service.cli.FetchType import org.apache.thrift.protocol.TBinaryProtocol import org.apache.thrift.transport.TSocket import org.scalatest.BeforeAndAfterAll @@ -91,6 +111,52 @@ class HiveThriftBinaryServerSuite extends HiveThriftJdbcTest { } } + test("SPARK-16563 ThriftCLIService FetchResults repeat fetching result") { + withCLIServiceClient { client => + val user = System.getProperty("user.name") + val sessionHandle = client.openSession(user, "") + + withJdbcStatement { statement => + val queries = Seq( + "DROP TABLE IF EXISTS test_16563", + "CREATE TABLE test_16563(key INT, val STRING)", + s"LOAD DATA LOCAL INPATH '${TestData.smallKv}' OVERWRITE INTO TABLE test_16563") + + queries.foreach(statement.execute) + val confOverlay = new java.util.HashMap[java.lang.String, java.lang.String] + val operationHandle = client.executeStatement( + sessionHandle, + "SELECT * FROM test_16563", + confOverlay) + + // Fetch result first time + assertResult(5, "Fetching result first time from next row") { + + val rows_next = client.fetchResults( + operationHandle, + FetchOrientation.FETCH_NEXT, + 1000, + FetchType.QUERY_OUTPUT) + + rows_next.numRows() + } + + // Fetch result second time from first row + assertResult(5, "Repeat fetching result from first row") { + + val rows_first = client.fetchResults( + operationHandle, + FetchOrientation.FETCH_FIRST, + 1000, + FetchType.QUERY_OUTPUT) + + rows_first.numRows() + } + statement.executeQuery("DROP TABLE IF EXISTS test_16563") + } + } + } + test("JDBC query execution") { withJdbcStatement { statement => val queries = Seq( @@ -433,8 +499,9 @@ class HiveThriftBinaryServerSuite extends HiveThriftJdbcTest { withMultipleConnectionJdbcStatement( { statement => - val jarFile = - "../hive/src/test/resources/hive-hcatalog-core-0.13.1.jar" + val jar = "hive/src/test/resources/hive-hcatalog-core-0.13.1.jar" + val jarFile = sys.props.get("spark.project.home").map( + _ + "/sql/" + jar).getOrElse("../" + jar) .split("/") .mkString(File.separator) @@ -689,8 +756,11 @@ abstract class HiveThriftServer2Test extends SparkFunSuite with BeforeAndAfterAl private val CLASS_NAME = HiveThriftServer2.getClass.getCanonicalName.stripSuffix("$") private val LOG_FILE_MARK = s"starting $CLASS_NAME, logging to " - protected val startScript = "../../sbin/start-thriftserver.sh".split("/").mkString(File.separator) - protected val stopScript = "../../sbin/stop-thriftserver.sh".split("/").mkString(File.separator) + protected val startScript = "./sbin/start-thriftserver.sh".split("/").mkString(File.separator) + protected val stopScript = "./sbin/stop-thriftserver.sh".split("/").mkString(File.separator) + + protected val sparkHome = sys.props.getOrElse("spark.test.home", + fail("spark.test.home is not set!")) private var listeningPort: Int = _ protected def serverPort: Int = listeningPort @@ -788,6 +858,7 @@ abstract class HiveThriftServer2Test extends SparkFunSuite with BeforeAndAfterAl logPath = { val lines = Utils.executeAndGetOutput( command = command, + workingDir = new File(sparkHome), extraEnvironment = Map( // Disables SPARK_TESTING to exclude log4j.properties in test directories. "SPARK_TESTING" -> "0", @@ -841,6 +912,7 @@ abstract class HiveThriftServer2Test extends SparkFunSuite with BeforeAndAfterAl // The `spark-daemon.sh' script uses kill, which is not synchronous, have to wait for a while. Utils.executeAndGetOutput( command = Seq(stopScript), + workingDir = new File(sparkHome), extraEnvironment = Map("SPARK_PID_DIR" -> pidDir.getCanonicalPath)) Thread.sleep(3.seconds.toMillis) diff --git a/sql/hive/build.gradle b/sql/hive/build.gradle new file mode 100644 index 000000000000..25f6d76d4d7e --- /dev/null +++ b/sql/hive/build.gradle @@ -0,0 +1,118 @@ +/* + * Copyright (c) 2016 SnappyData, Inc. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you + * may not use this file except in compliance with the License. You + * may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or + * implied. See the License for the specific language governing + * permissions and limitations under the License. See accompanying + * LICENSE file. + */ + +description = 'Spark Project Hive' + +dependencies { + compile project(subprojectBase + 'snappy-spark-core_' + scalaBinaryVersion) + compile project(subprojectBase + 'snappy-spark-sql_' + scalaBinaryVersion) + compile project(subprojectBase + 'snappy-spark-tags_' + scalaBinaryVersion) + + compile group: 'com.twitter', name: 'parquet-hadoop-bundle', version: hiveParquetVersion + compile group: 'org.apache.derby', name: 'derby', version: derbyVersion + compile(group: 'org.spark-project.hive', name: 'hive-exec', version: hiveVersion) { + exclude(group: 'org.spark-project.hive', module: 'hive-metastore') + exclude(group: 'org.spark-project.hive', module: 'hive-shims') + exclude(group: 'org.spark-project.hive', module: 'hive-ant') + exclude(group: 'org.spark-project.hive', module: 'spark-client') + exclude(group: 'org.apache.ant', module: 'ant') + exclude(group: 'com.esotericsoftware.kryo', module: 'kryo') + exclude(group: 'commons-codec', module: 'commons-codec') + exclude(group: 'commons-httpclient', module: 'commons-httpclient') + exclude(group: 'org.apache.avro', module: 'avro-mapred') + exclude(group: 'org.apache.calcite', module: 'calcite-core') + exclude(group: 'org.apache.curator', module: 'apache-curator') + exclude(group: 'org.apache.curator', module: 'curator-client') + exclude(group: 'org.apache.curator', module: 'curator-framework') + exclude(group: 'org.apache.thrift', module: 'libthrift') + exclude(group: 'org.apache.thrift', module: 'libfb303') + exclude(group: 'org.apache.zookeeper', module: 'zookeeper') + exclude(group: 'org.slf4j', module: 'slf4j-api') + exclude(group: 'org.slf4j', module: 'slf4j-log4j12') + exclude(group: 'log4j', module: 'log4j') + exclude(group: 'commons-logging', module: 'commons-logging') + exclude(group: 'org.codehaus.groovy', module: 'groovy-all') + exclude(group: 'jline', module: 'jline') + } + compile(group: 'org.spark-project.hive', name: 'hive-metastore', version: hiveVersion) { + exclude(group: 'org.spark-project.hive', module: 'hive-serde') + exclude(group: 'org.spark-project.hive', module: 'hive-shims') + exclude(group: 'org.apache.thrift', module: 'libfb303') + exclude(group: 'org.apache.thrift', module: 'libthrift') + exclude(group: 'javax.servlet', module: 'servlet-api') + exclude(group: 'com.google.guava', module: 'guava') + exclude(group: 'org.slf4j', module: 'slf4j-api') + exclude(group: 'org.slf4j', module: 'slf4j-log4j12') + exclude(group: 'org.apache.derby', module: 'derby') + } + + compile group: 'org.apache.avro', name: 'avro', version: avroVersion + compile(group: 'org.apache.avro', name: 'avro-ipc', version: avroVersion) { + exclude(group: 'org.jboss.netty', module: 'netty') + exclude(group: 'org.mortbay.jetty', module: 'jetty') + exclude(group: 'org.mortbay.jetty', module: 'jetty-util') + exclude(group: 'org.mortbay.jetty', module: 'servlet-api') + exclude(group: 'org.apache.velocity', module: 'velocity') + } + compile(group: 'org.apache.avro', name: 'avro-mapred', version: avroVersion, classifier: 'hadoop2') { + exclude(group: 'org.jboss.netty', module: 'netty') + exclude(group: 'org.mortbay.jetty', module: 'jetty') + exclude(group: 'org.mortbay.jetty', module: 'jetty-util') + exclude(group: 'org.mortbay.jetty', module: 'servlet-api') + exclude(group: 'org.apache.velocity', module: 'velocity') + exclude(group: 'org.apache.avro', module: 'avro-ipc') + } + compile group: 'commons-httpclient', name: 'commons-httpclient', version: '3.1' + compile(group: 'org.apache.calcite', name: 'calcite-avatica', version: '1.2.0-incubating') { + exclude(group: 'com.fasterxml.jackson.core', module: 'jackson-annotations') + exclude(group: 'com.fasterxml.jackson.core', module: 'jackson-core') + exclude(group: 'com.fasterxml.jackson.core', module: 'jackson-databind') + } + compile(group: 'org.apache.calcite', name: 'calcite-core', version: '1.2.0-incubating') { + exclude(group: 'com.fasterxml.jackson.core', module: 'jackson-annotations') + exclude(group: 'com.fasterxml.jackson.core', module: 'jackson-core') + exclude(group: 'com.fasterxml.jackson.core', module: 'jackson-databind') + exclude(group: 'com.google.guava', module: 'guava') + exclude(group: 'com.google.code.findbugs', module: 'jsr305') + exclude(group: 'org.codehaus.janino', module: 'janino') + exclude(group: 'org.hsqldb', module: 'hsqldb') + exclude(group: 'org.pentaho', module: 'pentaho-aggdesigner-algorithm') + } + compile group: 'org.apache.httpcomponents', name: 'httpclient', version: httpClientVersion + compile group: 'org.codehaus.jackson', name: 'jackson-mapper-asl', version: '1.9.13' + compile group: 'commons-codec', name: 'commons-codec', version: commonsCodecVersion + compile group: 'joda-time', name: 'joda-time', version: '2.9.4' + compile group: 'org.jodd', name: 'jodd-core', version: '3.5.2' + compile group: 'com.google.code.findbugs', name: 'jsr305', version: jsr305Version + compile group: 'org.datanucleus', name: 'datanucleus-core', version: '3.2.10' + compile(group: 'org.apache.thrift', name: 'libthrift', version: thriftVersion) { + exclude(group: 'org.slf4j', module: 'slf4j-api') + } + compile(group: 'org.apache.thrift', name: 'libfb303', version: thriftVersion) { + exclude(group: 'org.slf4j', module: 'slf4j-api') + } + + testCompile group: 'org.apache.avro', name: 'avro-ipc', version: avroVersion, classifier: 'tests' + + testCompile project(path: subprojectBase + 'snappy-spark-core_' + scalaBinaryVersion, configuration: 'testOutput') + testCompile project(path: subprojectBase + 'snappy-spark-sql_' + scalaBinaryVersion, configuration: 'testOutput') + testCompile project(path: subprojectBase + 'snappy-spark-catalyst_' + scalaBinaryVersion, configuration: 'testOutput') +} + +// fix scala+java test ordering +sourceSets.test.scala.srcDirs 'src/test/java', 'compatibility/src/test/scala' +sourceSets.test.java.srcDirs = [] diff --git a/sql/hive/compatibility/src/test/scala/org/apache/spark/sql/hive/execution/HiveCompatibilitySuite.scala b/sql/hive/compatibility/src/test/scala/org/apache/spark/sql/hive/execution/HiveCompatibilitySuite.scala index 13d18fdec0e9..a54d23487625 100644 --- a/sql/hive/compatibility/src/test/scala/org/apache/spark/sql/hive/execution/HiveCompatibilitySuite.scala +++ b/sql/hive/compatibility/src/test/scala/org/apache/spark/sql/hive/execution/HiveCompatibilitySuite.scala @@ -979,8 +979,8 @@ class HiveCompatibilitySuite extends HiveQueryFileTest with BeforeAndAfter { "udf_PI", "udf_acos", "udf_add", - "udf_array", - "udf_array_contains", + // "udf_array", -- done in array.sql + // "udf_array_contains", -- done in array.sql "udf_ascii", "udf_asin", "udf_atan", diff --git a/sql/hive/pom.xml b/sql/hive/pom.xml index add4375364b1..2388d5f6ef1d 100644 --- a/sql/hive/pom.xml +++ b/sql/hive/pom.xml @@ -22,7 +22,7 @@ org.apache.spark spark-parent_2.11 - 2.0.1-SNAPSHOT + 2.0.1 ../../pom.xml diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveExternalCatalog.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveExternalCatalog.scala index b8bc9ab900ad..88cf06fce6c6 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveExternalCatalog.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveExternalCatalog.scala @@ -339,31 +339,39 @@ private[spark] class HiveExternalCatalog(client: HiveClient, hadoopConf: Configu override def createFunction( db: String, funcDefinition: CatalogFunction): Unit = withClient { + requireDbExists(db) // Hive's metastore is case insensitive. However, Hive's createFunction does // not normalize the function name (unlike the getFunction part). So, // we are normalizing the function name. val functionName = funcDefinition.identifier.funcName.toLowerCase + requireFunctionNotExists(db, functionName) val functionIdentifier = funcDefinition.identifier.copy(funcName = functionName) client.createFunction(db, funcDefinition.copy(identifier = functionIdentifier)) } override def dropFunction(db: String, name: String): Unit = withClient { + requireFunctionExists(db, name) client.dropFunction(db, name) } override def renameFunction(db: String, oldName: String, newName: String): Unit = withClient { + requireFunctionExists(db, oldName) + requireFunctionNotExists(db, newName) client.renameFunction(db, oldName, newName) } override def getFunction(db: String, funcName: String): CatalogFunction = withClient { + requireFunctionExists(db, funcName) client.getFunction(db, funcName) } override def functionExists(db: String, funcName: String): Boolean = withClient { + requireDbExists(db) client.functionExists(db, funcName) } override def listFunctions(db: String, pattern: String): Seq[String] = withClient { + requireDbExists(db) client.listFunctions(db, pattern) } diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveInspectors.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveInspectors.scala index 585befe37825..9d56aec4a963 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveInspectors.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveInspectors.scala @@ -557,7 +557,8 @@ private[hive] trait HiveInspectors { // 1. create the pojo (most likely) object val result = x.create() var i = 0 - while (i < fieldRefs.size) { + val size = fieldRefs.size + while (i < size) { // 2. set the property for the pojo val tpe = structType(i).dataType x.setStructFieldData( @@ -574,7 +575,8 @@ private[hive] trait HiveInspectors { val row = a.asInstanceOf[InternalRow] val result = new java.util.ArrayList[AnyRef](fieldRefs.size) var i = 0 - while (i < fieldRefs.size) { + val size = fieldRefs.size + while (i < size) { val tpe = structType(i).dataType result.add(wrap(row.get(i, tpe), fieldRefs.get(i).getFieldObjectInspector, tpe)) i += 1 @@ -610,7 +612,8 @@ private[hive] trait HiveInspectors { cache: Array[AnyRef], dataTypes: Array[DataType]): Array[AnyRef] = { var i = 0 - while (i < inspectors.length) { + val length = inspectors.length + while (i < length) { cache(i) = wrap(row.get(i, dataTypes(i)), inspectors(i), dataTypes(i)) i += 1 } @@ -623,7 +626,8 @@ private[hive] trait HiveInspectors { cache: Array[AnyRef], dataTypes: Array[DataType]): Array[AnyRef] = { var i = 0 - while (i < inspectors.length) { + val length = inspectors.length + while (i < length) { cache(i) = wrap(row(i), inspectors(i), dataTypes(i)) i += 1 } diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala index 789f94aff303..bafb42277e33 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala @@ -300,13 +300,12 @@ private[hive] class HiveMetastoreCatalog(sparkSession: SparkSession) extends Log } val relation = HadoopFsRelation( - sparkSession = sparkSession, location = fileCatalog, partitionSchema = partitionSchema, dataSchema = inferredSchema, bucketSpec = bucketSpec, fileFormat = defaultSource, - options = options) + options = options)(sparkSession = sparkSession) val created = LogicalRelation( relation, diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveSessionCatalog.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveSessionCatalog.scala index c59ac3dcafea..1684e8debe3e 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveSessionCatalog.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveSessionCatalog.scala @@ -230,10 +230,8 @@ private[sql] class HiveSessionCatalog( // List of functions we are explicitly not supporting are: // compute_stats, context_ngrams, create_union, // current_user, ewah_bitmap, ewah_bitmap_and, ewah_bitmap_empty, ewah_bitmap_or, field, - // in_file, index, java_method, - // matchpath, ngrams, noop, noopstreaming, noopwithmap, noopwithmapstreaming, - // parse_url_tuple, posexplode, reflect2, - // str_to_map, windowingtablefunction. + // in_file, index, matchpath, ngrams, noop, noopstreaming, noopwithmap, + // noopwithmapstreaming, parse_url_tuple, reflect2, windowingtablefunction. private val hiveFunctions = Seq( "hash", "histogram_numeric", diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveUtils.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveUtils.scala index bdec611453b2..487cfd087aa0 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveUtils.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveUtils.scala @@ -280,7 +280,7 @@ private[spark] object HiveUtils extends Logging { throw new IllegalArgumentException( "Builtin jars can only be used when hive execution version == hive metastore version. " + s"Execution: $hiveExecutionVersion != Metastore: $hiveMetastoreVersion. " + - "Specify a vaild path to the correct hive jars using $HIVE_METASTORE_JARS " + + s"Specify a vaild path to the correct hive jars using $HIVE_METASTORE_JARS " + s"or change ${HIVE_METASTORE_VERSION.key} to $hiveExecutionVersion.") } @@ -394,6 +394,13 @@ private[spark] object HiveUtils extends Logging { // hive.metastore.uris is not set. propMap.put(ConfVars.METASTOREURIS.varname, "") + // The execution client will generate garbage events, therefore the listeners that are generated + // for the execution clients are useless. In order to not output garbage, we don't generate + // these listeners. + propMap.put(ConfVars.METASTORE_PRE_EVENT_LISTENERS.varname, "") + propMap.put(ConfVars.METASTORE_EVENT_LISTENERS.varname, "") + propMap.put(ConfVars.METASTORE_END_FUNCTION_LISTENERS.varname, "") + propMap.toMap } diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/MetastoreRelation.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/MetastoreRelation.scala index 58bca2059cac..10d71bef94d2 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/MetastoreRelation.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/MetastoreRelation.scala @@ -61,7 +61,7 @@ private[hive] case class MetastoreRelation( Objects.hashCode(databaseName, tableName, alias, output) } - override protected def otherCopyArgs: Seq[AnyRef] = catalogTable :: sparkSession :: Nil + override protected def otherCopyArgs: Seq[AnyRef] = catalogTable :: client :: sparkSession :: Nil private def toHiveColumn(c: CatalogColumn): FieldSchema = { new FieldSchema(c.name, c.dataType, c.comment.orNull) @@ -82,7 +82,6 @@ private[hive] case class MetastoreRelation( tTable.setTableType(catalogTable.tableType match { case CatalogTableType.EXTERNAL => HiveTableType.EXTERNAL_TABLE.toString case CatalogTableType.MANAGED => HiveTableType.MANAGED_TABLE.toString - case CatalogTableType.INDEX => HiveTableType.INDEX_TABLE.toString case CatalogTableType.VIEW => HiveTableType.VIRTUAL_VIEW.toString }) @@ -164,7 +163,13 @@ private[hive] case class MetastoreRelation( val sd = new org.apache.hadoop.hive.metastore.api.StorageDescriptor() tPartition.setSd(sd) - sd.setCols(catalogTable.schema.map(toHiveColumn).asJava) + + // Note: In Hive the schema and partition columns must be disjoint sets + val schema = catalogTable.schema.map(toHiveColumn).filter { c => + !catalogTable.partitionColumnNames.contains(c.getName) + } + sd.setCols(schema.asJava) + p.storage.locationUri.foreach(sd.setLocation) p.storage.inputFormat.foreach(sd.setInputFormat) p.storage.outputFormat.foreach(sd.setOutputFormat) diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/TableReader.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/TableReader.scala index e4cb33b28520..a768b9d6d71b 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/TableReader.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/TableReader.scala @@ -426,7 +426,8 @@ private[hive] object HadoopTableReader extends HiveInspectors with Logging { iterator.map { value => val raw = converter.convert(rawDeser.deserialize(value)) var i = 0 - while (i < fieldRefs.length) { + val length = fieldRefs.length + while (i < length) { val fieldValue = soi.getStructFieldData(raw, fieldRefs(i)) if (fieldValue == null) { mutableRow.setNullAt(fieldOrdinals(i)) diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/HiveClientImpl.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/HiveClientImpl.scala index 6cdf3ef54500..7db51d4b493a 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/HiveClientImpl.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/HiveClientImpl.scala @@ -139,14 +139,32 @@ private[hive] class HiveClientImpl( // so we should keep `conf` and reuse the existing instance of `CliSessionState`. originalState } else { - val hiveConf = new HiveConf(hadoopConf, classOf[SessionState]) + val hiveConf = new HiveConf(classOf[SessionState]) + // 1: we set all confs in the hadoopConf to this hiveConf. + // This hadoopConf contains user settings in Hadoop's core-site.xml file + // and Hive's hive-site.xml file. Note, we load hive-site.xml file manually in + // SharedState and put settings in this hadoopConf instead of relying on HiveConf + // to load user settings. Otherwise, HiveConf's initialize method will override + // settings in the hadoopConf. This issue only shows up when spark.sql.hive.metastore.jars + // is not set to builtin. When spark.sql.hive.metastore.jars is builtin, the classpath + // has hive-site.xml. So, HiveConf will use that to override its default values. + hadoopConf.iterator().asScala.foreach { entry => + val key = entry.getKey + val value = entry.getValue + if (key.toLowerCase.contains("password")) { + logDebug(s"Applying Hadoop and Hive config to Hive Conf: $key=xxx") + } else { + logDebug(s"Applying Hadoop and Hive config to Hive Conf: $key=$value") + } + hiveConf.set(key, value) + } // HiveConf is a Hadoop Configuration, which has a field of classLoader and // the initial value will be the current thread's context class loader // (i.e. initClassLoader at here). // We call initialConf.setClassLoader(initClassLoader) at here to make // this action explicit. hiveConf.setClassLoader(initClassLoader) - // First, we set all spark confs to this hiveConf. + // 2: we set all spark confs to this hiveConf. sparkConf.getAll.foreach { case (k, v) => if (k.toLowerCase.contains("password")) { logDebug(s"Applying Spark config to Hive Conf: $k=xxx") @@ -155,7 +173,7 @@ private[hive] class HiveClientImpl( } hiveConf.set(k, v) } - // Second, we set all entries in config to this hiveConf. + // 3: we set all entries in config to this hiveConf. extraConfig.foreach { case (k, v) => if (k.toLowerCase.contains("password")) { logDebug(s"Applying extra config to HiveConf: $k=xxx") @@ -360,8 +378,9 @@ private[hive] class HiveClientImpl( tableType = h.getTableType match { case HiveTableType.EXTERNAL_TABLE => CatalogTableType.EXTERNAL case HiveTableType.MANAGED_TABLE => CatalogTableType.MANAGED - case HiveTableType.INDEX_TABLE => CatalogTableType.INDEX case HiveTableType.VIRTUAL_VIEW => CatalogTableType.VIEW + case HiveTableType.INDEX_TABLE => + throw new AnalysisException("Hive index table is not supported.") }, schema = schema, partitionColumnNames = partCols.map(_.name), @@ -393,7 +412,10 @@ private[hive] class HiveClientImpl( serdeProperties = Option(h.getTTable.getSd.getSerdeInfo.getParameters) .map(_.asScala.toMap).orNull ), - properties = properties, + // For EXTERNAL_TABLE, the table properties has a particular field "EXTERNAL". This is added + // in the function toHiveTable. + properties = properties.filter(kv => kv._1 != "comment" && kv._1 != "EXTERNAL"), + comment = properties.get("comment"), viewOriginalText = Option(h.getViewOriginalText), viewText = Option(h.getViewExpandedText), unsupportedFeatures = unsupportedFeatures) @@ -742,7 +764,6 @@ private[hive] class HiveClientImpl( HiveTableType.EXTERNAL_TABLE case CatalogTableType.MANAGED => HiveTableType.MANAGED_TABLE - case CatalogTableType.INDEX => HiveTableType.INDEX_TABLE case CatalogTableType.VIEW => HiveTableType.VIRTUAL_VIEW }) // Note: In Hive the schema and partition columns must be disjoint sets @@ -817,6 +838,8 @@ private[hive] class HiveClientImpl( serde = Option(apiPartition.getSd.getSerdeInfo.getSerializationLib), compressed = apiPartition.getSd.isCompressed, serdeProperties = Option(apiPartition.getSd.getSerdeInfo.getParameters) - .map(_.asScala.toMap).orNull)) + .map(_.asScala.toMap).orNull), + parameters = + if (hp.getParameters() != null) hp.getParameters().asScala.toMap else Map.empty) } } diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/HiveShim.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/HiveShim.scala index 9df4a26d55a2..13a8741cdc57 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/HiveShim.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/HiveShim.scala @@ -251,6 +251,7 @@ private[client] class Shim_v0_12 extends Shim with Logging { val table = hive.getTable(database, tableName) parts.foreach { s => val location = s.storage.locationUri.map(new Path(table.getPath, _)).orNull + val params = if (s.parameters.nonEmpty) s.parameters.asJava else null val spec = s.spec.asJava if (hive.getPartition(table, spec, false) != null && ignoreIfExists) { // Ignore this partition since it already exists and ignoreIfExists == true @@ -264,7 +265,7 @@ private[client] class Shim_v0_12 extends Shim with Logging { table, spec, location, - null, // partParams + params, // partParams null, // inputFormat null, // outputFormat -1: JInteger, // numBuckets @@ -417,8 +418,11 @@ private[client] class Shim_v0_13 extends Shim_v0_12 { parts: Seq[CatalogTablePartition], ignoreIfExists: Boolean): Unit = { val addPartitionDesc = new AddPartitionDesc(db, table, ignoreIfExists) - parts.foreach { s => + parts.zipWithIndex.foreach { case (s, i) => addPartitionDesc.addPartition(s.spec.asJava, s.storage.locationUri.orNull) + if (s.parameters.nonEmpty) { + addPartitionDesc.getPartition(i).setPartParams(s.parameters.asJava) + } } hive.createPartitions(addPartitionDesc) } diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/CreateHiveTableAsSelectCommand.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/CreateHiveTableAsSelectCommand.scala index 15a5d79dcb08..a12b223b76ef 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/CreateHiveTableAsSelectCommand.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/CreateHiveTableAsSelectCommand.scala @@ -34,7 +34,6 @@ import org.apache.spark.sql.hive.MetastoreRelation * @param ignoreIfExists allow continue working if it's already exists, otherwise * raise exception */ -private[hive] case class CreateHiveTableAsSelectCommand( tableDesc: CatalogTable, query: LogicalPlan, @@ -43,7 +42,7 @@ case class CreateHiveTableAsSelectCommand( private val tableIdentifier = tableDesc.identifier - override def children: Seq[LogicalPlan] = Seq(query) + override def innerChildren: Seq[LogicalPlan] = Seq(query) override def run(sparkSession: SparkSession): Seq[Row] = { lazy val metastoreRelation: MetastoreRelation = { diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/HiveTableScanExec.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/HiveTableScanExec.scala index cc3e74b4e8cc..a716a3eab621 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/HiveTableScanExec.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/HiveTableScanExec.scala @@ -54,7 +54,7 @@ case class HiveTableScanExec( require(partitionPruningPred.isEmpty || relation.hiveQlTable.isPartitioned, "Partition pruning predicates only supported for partitioned tables.") - private[sql] override lazy val metrics = Map( + override lazy val metrics = Map( "numOutputRows" -> SQLMetrics.createMetric(sparkContext, "number of output rows")) override def producedAttributes: AttributeSet = outputSet ++ diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/InsertIntoHiveTable.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/InsertIntoHiveTable.scala index 3d58d490a51e..6a091aa6de1a 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/InsertIntoHiveTable.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/InsertIntoHiveTable.scala @@ -147,8 +147,7 @@ case class InsertIntoHiveTable( val hadoopConf = sessionState.newHadoopConf() val tmpLocation = getExternalTmpPath(tableLocation, hadoopConf) val fileSinkConf = new FileSinkDesc(tmpLocation.toString, tableDesc, false) - val isCompressed = - sessionState.conf.getConfString("hive.exec.compress.output", "false").toBoolean + val isCompressed = hadoopConf.get("hive.exec.compress.output", "false").toBoolean if (isCompressed) { // Please note that isCompressed, "mapred.output.compress", "mapred.output.compression.codec", @@ -182,15 +181,13 @@ case class InsertIntoHiveTable( // Validate partition spec if there exist any dynamic partitions if (numDynamicPartitions > 0) { // Report error if dynamic partitioning is not enabled - if (!sessionState.conf.getConfString("hive.exec.dynamic.partition", "true").toBoolean) { + if (!hadoopConf.get("hive.exec.dynamic.partition", "true").toBoolean) { throw new SparkException(ErrorMsg.DYNAMIC_PARTITION_DISABLED.getMsg) } // Report error if dynamic partition strict mode is on but no static partition is found if (numStaticPartitions == 0 && - sessionState.conf.getConfString( - "hive.exec.dynamic.partition.mode", "strict").equalsIgnoreCase("strict")) - { + hadoopConf.get("hive.exec.dynamic.partition.mode", "strict").equalsIgnoreCase("strict")) { throw new SparkException(ErrorMsg.DYNAMIC_PARTITION_STRICT_MODE.getMsg) } diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/ScriptTransformation.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/ScriptTransformation.scala index dfb12512a40f..9747abbf15a5 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/ScriptTransformation.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/ScriptTransformation.scala @@ -51,7 +51,6 @@ import org.apache.spark.util.{CircularBuffer, RedirectThread, SerializableConfig * @param script the command that should be executed. * @param output the attributes that are produced by the script. */ -private[hive] case class ScriptTransformation( input: Seq[Expression], script: String, @@ -336,7 +335,6 @@ private class ScriptTransformationWriterThread( } } -private[hive] object HiveScriptIOSchema { def apply(input: ScriptInputOutputSchema): HiveScriptIOSchema = { HiveScriptIOSchema( @@ -355,7 +353,6 @@ object HiveScriptIOSchema { /** * The wrapper class of Hive input and output schema properties */ -private[hive] case class HiveScriptIOSchema ( inputRowFormat: Seq[(String, String)], outputRowFormat: Seq[(String, String)], diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveUDFs.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveUDFs.scala index c53675694f62..a5f800d4c568 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveUDFs.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveUDFs.scala @@ -150,7 +150,8 @@ private[hive] case class HiveGenericUDF( returnInspector // Make sure initialized. var i = 0 - while (i < children.length) { + val length = children.length + while (i < length) { val idx = i deferredObjects(i).asInstanceOf[DeferredObjectAdapter] .set(() => children(idx).eval(input)) diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/orc/OrcFileFormat.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/orc/OrcFileFormat.scala index a2c8092e01bb..fc126b361697 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/orc/OrcFileFormat.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/orc/OrcFileFormat.scala @@ -31,6 +31,7 @@ import org.apache.hadoop.mapred.{InputFormat => MapRedInputFormat, JobConf, Outp import org.apache.hadoop.mapreduce._ import org.apache.hadoop.mapreduce.lib.input.{FileInputFormat, FileSplit} +import org.apache.spark.TaskContext import org.apache.spark.internal.Logging import org.apache.spark.rdd.{HadoopRDD, RDD} import org.apache.spark.sql.{Row, SparkSession} @@ -47,7 +48,7 @@ import org.apache.spark.util.SerializableConfiguration * [[FileFormat]] for reading ORC files. If this is moved or renamed, please update * [[DataSource]]'s backwardCompatibilityMap. */ -private[sql] class OrcFileFormat +class OrcFileFormat extends FileFormat with DataSourceRegister with Serializable { override def shortName(): String = "orc" @@ -150,12 +151,15 @@ private[sql] class OrcFileFormat new SparkOrcNewRecordReader(orcReader, conf, fileSplit.getStart, fileSplit.getLength) } + val recordsIterator = new RecordReaderIterator[OrcStruct](orcRecordReader) + Option(TaskContext.get()).foreach(_.addTaskCompletionListener(_ => recordsIterator.close())) + // Unwraps `OrcStruct`s to `UnsafeRow`s OrcRelation.unwrapOrcStructs( conf, requiredSchema, Some(orcRecordReader.getObjectInspector.asInstanceOf[StructObjectInspector]), - new RecordReaderIterator[OrcStruct](orcRecordReader)) + recordsIterator) } } } @@ -194,7 +198,8 @@ private[orc] class OrcSerializer(dataSchema: StructType, conf: Configuration) row: InternalRow): Unit = { val fieldRefs = oi.getAllStructFieldRefs var i = 0 - while (i < fieldRefs.size) { + val size = fieldRefs.size + while (i < size) { oi.setStructFieldData( struct, @@ -358,7 +363,8 @@ private[orc] object OrcRelation extends HiveInspectors { iterator.map { value => val raw = deserializer.deserialize(value) var i = 0 - while (i < fieldRefs.length) { + val length = fieldRefs.length + while (i < length) { val fieldValue = oi.getStructFieldData(raw, fieldRefs(i)) if (fieldValue == null) { mutableRow.setNullAt(fieldOrdinals(i)) diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/orc/OrcOptions.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/orc/OrcOptions.scala index 91cf0dc960d5..c2a126d3bf9c 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/orc/OrcOptions.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/orc/OrcOptions.scala @@ -20,8 +20,7 @@ package org.apache.spark.sql.hive.orc /** * Options for the ORC data source. */ -private[orc] class OrcOptions( - @transient private val parameters: Map[String, String]) +private[orc] class OrcOptions(@transient private val parameters: Map[String, String]) extends Serializable { import OrcOptions._ @@ -31,7 +30,14 @@ private[orc] class OrcOptions( * Acceptable values are defined in [[shortOrcCompressionCodecNames]]. */ val compressionCodec: String = { - val codecName = parameters.getOrElse("compression", "snappy").toLowerCase + // `orc.compress` is a ORC configuration. So, here we respect this as an option but + // `compression` has higher precedence than `orc.compress`. It means if both are set, + // we will use `compression`. + val orcCompressionConf = parameters.get(OrcRelation.ORC_COMPRESSION) + val codecName = parameters + .get("compression") + .orElse(orcCompressionConf) + .getOrElse("snappy").toLowerCase if (!shortOrcCompressionCodecNames.contains(codecName)) { val availableCodecs = shortOrcCompressionCodecNames.keys.map(_.toLowerCase) throw new IllegalArgumentException(s"Codec [$codecName] " + diff --git a/sql/hive/src/test/resources/sqlgen/broadcast_join_subquery.sql b/sql/hive/src/test/resources/sqlgen/broadcast_join_subquery.sql index 3e2111d58a3c..ec881a216e0b 100644 --- a/sql/hive/src/test/resources/sqlgen/broadcast_join_subquery.sql +++ b/sql/hive/src/test/resources/sqlgen/broadcast_join_subquery.sql @@ -5,4 +5,4 @@ FROM (SELECT x.key as key1, x.value as value1, y.key as key2, y.value as value2 JOIN srcpart z ON (subq.key1 = z.key and z.ds='2008-04-08' and z.hr=11) ORDER BY subq.key1, z.value -------------------------------------------------------------------------------- -SELECT `gen_attr_0` AS `key1`, `gen_attr_1` AS `value` FROM (SELECT `gen_attr_0`, `gen_attr_1` FROM (SELECT `gen_attr_5` AS `gen_attr_0`, `gen_attr_7` AS `gen_attr_6`, `gen_attr_9` AS `gen_attr_8`, `gen_attr_11` AS `gen_attr_10` FROM (SELECT `key` AS `gen_attr_5`, `value` AS `gen_attr_7` FROM `default`.`src1`) AS gen_subquery_0 INNER JOIN (SELECT `key` AS `gen_attr_9`, `value` AS `gen_attr_11` FROM `default`.`src`) AS gen_subquery_1 ON (`gen_attr_5` = `gen_attr_9`)) AS subq INNER JOIN (SELECT `key` AS `gen_attr_2`, `value` AS `gen_attr_1`, `ds` AS `gen_attr_3`, `hr` AS `gen_attr_4` FROM `default`.`srcpart`) AS gen_subquery_2 ON (((`gen_attr_0` = `gen_attr_2`) AND (`gen_attr_3` = "2008-04-08")) AND (CAST(`gen_attr_4` AS DOUBLE) = CAST(11 AS DOUBLE))) ORDER BY `gen_attr_0` ASC, `gen_attr_1` ASC) AS gen_subquery_3 +SELECT `gen_attr_0` AS `key1`, `gen_attr_1` AS `value` FROM (SELECT `gen_attr_0`, `gen_attr_1` FROM (SELECT `gen_attr_5` AS `gen_attr_0`, `gen_attr_7` AS `gen_attr_6`, `gen_attr_9` AS `gen_attr_8`, `gen_attr_11` AS `gen_attr_10` FROM (SELECT `key` AS `gen_attr_5`, `value` AS `gen_attr_7` FROM `default`.`src1`) AS gen_subquery_0 INNER JOIN (SELECT `key` AS `gen_attr_9`, `value` AS `gen_attr_11` FROM `default`.`src`) AS gen_subquery_1 ON (`gen_attr_5` = `gen_attr_9`)) AS subq INNER JOIN (SELECT `key` AS `gen_attr_2`, `value` AS `gen_attr_1`, `ds` AS `gen_attr_3`, `hr` AS `gen_attr_4` FROM `default`.`srcpart`) AS gen_subquery_2 ON (((`gen_attr_0` = `gen_attr_2`) AND (`gen_attr_3` = '2008-04-08')) AND (CAST(`gen_attr_4` AS DOUBLE) = CAST(11 AS DOUBLE))) ORDER BY `gen_attr_0` ASC, `gen_attr_1` ASC) AS gen_subquery_3 diff --git a/sql/hive/src/test/resources/sqlgen/case_with_key.sql b/sql/hive/src/test/resources/sqlgen/case_with_key.sql index dff65f10835f..e991ebafdc90 100644 --- a/sql/hive/src/test/resources/sqlgen/case_with_key.sql +++ b/sql/hive/src/test/resources/sqlgen/case_with_key.sql @@ -1,4 +1,4 @@ -- This file is automatically generated by LogicalPlanToSQLSuite. SELECT CASE id WHEN 0 THEN 'foo' WHEN 1 THEN 'bar' END FROM parquet_t0 -------------------------------------------------------------------------------- -SELECT `gen_attr_0` AS `CASE WHEN (id = CAST(0 AS BIGINT)) THEN foo WHEN (id = CAST(1 AS BIGINT)) THEN bar END` FROM (SELECT CASE WHEN (`gen_attr_1` = CAST(0 AS BIGINT)) THEN "foo" WHEN (`gen_attr_1` = CAST(1 AS BIGINT)) THEN "bar" END AS `gen_attr_0` FROM (SELECT `id` AS `gen_attr_1` FROM `default`.`parquet_t0`) AS gen_subquery_0) AS gen_subquery_1 +SELECT `gen_attr_0` AS `CASE WHEN (id = CAST(0 AS BIGINT)) THEN foo WHEN (id = CAST(1 AS BIGINT)) THEN bar END` FROM (SELECT CASE WHEN (`gen_attr_1` = CAST(0 AS BIGINT)) THEN 'foo' WHEN (`gen_attr_1` = CAST(1 AS BIGINT)) THEN 'bar' END AS `gen_attr_0` FROM (SELECT `id` AS `gen_attr_1` FROM `default`.`parquet_t0`) AS gen_subquery_0) AS gen_subquery_1 diff --git a/sql/hive/src/test/resources/sqlgen/case_with_key_and_else.sql b/sql/hive/src/test/resources/sqlgen/case_with_key_and_else.sql index af3e169b5431..492777e376ec 100644 --- a/sql/hive/src/test/resources/sqlgen/case_with_key_and_else.sql +++ b/sql/hive/src/test/resources/sqlgen/case_with_key_and_else.sql @@ -1,4 +1,4 @@ -- This file is automatically generated by LogicalPlanToSQLSuite. SELECT CASE id WHEN 0 THEN 'foo' WHEN 1 THEN 'bar' ELSE 'baz' END FROM parquet_t0 -------------------------------------------------------------------------------- -SELECT `gen_attr_0` AS `CASE WHEN (id = CAST(0 AS BIGINT)) THEN foo WHEN (id = CAST(1 AS BIGINT)) THEN bar ELSE baz END` FROM (SELECT CASE WHEN (`gen_attr_1` = CAST(0 AS BIGINT)) THEN "foo" WHEN (`gen_attr_1` = CAST(1 AS BIGINT)) THEN "bar" ELSE "baz" END AS `gen_attr_0` FROM (SELECT `id` AS `gen_attr_1` FROM `default`.`parquet_t0`) AS gen_subquery_0) AS gen_subquery_1 +SELECT `gen_attr_0` AS `CASE WHEN (id = CAST(0 AS BIGINT)) THEN foo WHEN (id = CAST(1 AS BIGINT)) THEN bar ELSE baz END` FROM (SELECT CASE WHEN (`gen_attr_1` = CAST(0 AS BIGINT)) THEN 'foo' WHEN (`gen_attr_1` = CAST(1 AS BIGINT)) THEN 'bar' ELSE 'baz' END AS `gen_attr_0` FROM (SELECT `id` AS `gen_attr_1` FROM `default`.`parquet_t0`) AS gen_subquery_0) AS gen_subquery_1 diff --git a/sql/hive/src/test/resources/sqlgen/inline_tables.sql b/sql/hive/src/test/resources/sqlgen/inline_tables.sql new file mode 100644 index 000000000000..18803a3ee59b --- /dev/null +++ b/sql/hive/src/test/resources/sqlgen/inline_tables.sql @@ -0,0 +1,4 @@ +-- This file is automatically generated by LogicalPlanToSQLSuite. +select * from values ("one", 1), ("two", 2), ("three", null) as data(a, b) where b > 1 +-------------------------------------------------------------------------------- +SELECT `gen_attr_0` AS `a`, `gen_attr_1` AS `b` FROM (SELECT `gen_attr_0`, `gen_attr_1` FROM (VALUES ('one', 1), ('two', 2), ('three', CAST(NULL AS INT)) AS gen_subquery_0(gen_attr_0, gen_attr_1)) AS data WHERE (`gen_attr_1` > 1)) AS data diff --git a/sql/hive/src/test/resources/sqlgen/json_tuple_generator_1.sql b/sql/hive/src/test/resources/sqlgen/json_tuple_generator_1.sql index 6f5562a20ccc..11e45a48f1b8 100644 --- a/sql/hive/src/test/resources/sqlgen/json_tuple_generator_1.sql +++ b/sql/hive/src/test/resources/sqlgen/json_tuple_generator_1.sql @@ -3,4 +3,4 @@ SELECT c0, c1, c2 FROM parquet_t3 LATERAL VIEW JSON_TUPLE(json, 'f1', 'f2', 'f3') jt -------------------------------------------------------------------------------- -SELECT `gen_attr_0` AS `c0`, `gen_attr_1` AS `c1`, `gen_attr_2` AS `c2` FROM (SELECT `gen_attr_0`, `gen_attr_1`, `gen_attr_2` FROM (SELECT `arr` AS `gen_attr_4`, `arr2` AS `gen_attr_5`, `json` AS `gen_attr_3`, `id` AS `gen_attr_6` FROM `default`.`parquet_t3`) AS gen_subquery_0 LATERAL VIEW json_tuple(`gen_attr_3`, "f1", "f2", "f3") gen_subquery_1 AS `gen_attr_0`, `gen_attr_1`, `gen_attr_2`) AS jt +SELECT `gen_attr_0` AS `c0`, `gen_attr_1` AS `c1`, `gen_attr_2` AS `c2` FROM (SELECT `gen_attr_0`, `gen_attr_1`, `gen_attr_2` FROM (SELECT `arr` AS `gen_attr_4`, `arr2` AS `gen_attr_5`, `json` AS `gen_attr_3`, `id` AS `gen_attr_6` FROM `default`.`parquet_t3`) AS gen_subquery_0 LATERAL VIEW json_tuple(`gen_attr_3`, 'f1', 'f2', 'f3') gen_subquery_1 AS `gen_attr_0`, `gen_attr_1`, `gen_attr_2`) AS jt diff --git a/sql/hive/src/test/resources/sqlgen/json_tuple_generator_2.sql b/sql/hive/src/test/resources/sqlgen/json_tuple_generator_2.sql index 0d4f67f18426..d86b39df5744 100644 --- a/sql/hive/src/test/resources/sqlgen/json_tuple_generator_2.sql +++ b/sql/hive/src/test/resources/sqlgen/json_tuple_generator_2.sql @@ -3,4 +3,4 @@ SELECT a, b, c FROM parquet_t3 LATERAL VIEW JSON_TUPLE(json, 'f1', 'f2', 'f3') jt AS a, b, c -------------------------------------------------------------------------------- -SELECT `gen_attr_0` AS `a`, `gen_attr_1` AS `b`, `gen_attr_2` AS `c` FROM (SELECT `gen_attr_0`, `gen_attr_1`, `gen_attr_2` FROM (SELECT `arr` AS `gen_attr_4`, `arr2` AS `gen_attr_5`, `json` AS `gen_attr_3`, `id` AS `gen_attr_6` FROM `default`.`parquet_t3`) AS gen_subquery_0 LATERAL VIEW json_tuple(`gen_attr_3`, "f1", "f2", "f3") gen_subquery_1 AS `gen_attr_0`, `gen_attr_1`, `gen_attr_2`) AS jt +SELECT `gen_attr_0` AS `a`, `gen_attr_1` AS `b`, `gen_attr_2` AS `c` FROM (SELECT `gen_attr_0`, `gen_attr_1`, `gen_attr_2` FROM (SELECT `arr` AS `gen_attr_4`, `arr2` AS `gen_attr_5`, `json` AS `gen_attr_3`, `id` AS `gen_attr_6` FROM `default`.`parquet_t3`) AS gen_subquery_0 LATERAL VIEW json_tuple(`gen_attr_3`, 'f1', 'f2', 'f3') gen_subquery_1 AS `gen_attr_0`, `gen_attr_1`, `gen_attr_2`) AS jt diff --git a/sql/hive/src/test/resources/sqlgen/not_like.sql b/sql/hive/src/test/resources/sqlgen/not_like.sql index da39a62225a5..22485045e212 100644 --- a/sql/hive/src/test/resources/sqlgen/not_like.sql +++ b/sql/hive/src/test/resources/sqlgen/not_like.sql @@ -1,4 +1,4 @@ -- This file is automatically generated by LogicalPlanToSQLSuite. SELECT id FROM t0 WHERE id + 5 NOT LIKE '1%' -------------------------------------------------------------------------------- -SELECT `gen_attr_0` AS `id` FROM (SELECT `gen_attr_0` FROM (SELECT `id` AS `gen_attr_0` FROM `default`.`t0`) AS gen_subquery_0 WHERE (NOT CAST((`gen_attr_0` + CAST(5 AS BIGINT)) AS STRING) LIKE "1%")) AS t0 +SELECT `gen_attr_0` AS `id` FROM (SELECT `gen_attr_0` FROM (SELECT `id` AS `gen_attr_0` FROM `default`.`t0`) AS gen_subquery_0 WHERE (NOT CAST((`gen_attr_0` + CAST(5 AS BIGINT)) AS STRING) LIKE '1%')) AS t0 diff --git a/sql/hive/src/test/resources/sqlgen/range.sql b/sql/hive/src/test/resources/sqlgen/range.sql new file mode 100644 index 000000000000..53c72ea71e6a --- /dev/null +++ b/sql/hive/src/test/resources/sqlgen/range.sql @@ -0,0 +1,4 @@ +-- This file is automatically generated by LogicalPlanToSQLSuite. +select * from range(100) +-------------------------------------------------------------------------------- +SELECT `gen_attr_0` AS `id` FROM (SELECT `gen_attr_0` FROM (SELECT id AS `gen_attr_0` FROM range(0, 100, 1)) AS gen_subquery_0) AS gen_subquery_1 diff --git a/sql/hive/src/test/resources/sqlgen/range_with_splits.sql b/sql/hive/src/test/resources/sqlgen/range_with_splits.sql new file mode 100644 index 000000000000..83d637d54a30 --- /dev/null +++ b/sql/hive/src/test/resources/sqlgen/range_with_splits.sql @@ -0,0 +1,4 @@ +-- This file is automatically generated by LogicalPlanToSQLSuite. +select * from range(1, 100, 20, 10) +-------------------------------------------------------------------------------- +SELECT `gen_attr_0` AS `id` FROM (SELECT `gen_attr_0` FROM (SELECT id AS `gen_attr_0` FROM range(1, 100, 20, 10)) AS gen_subquery_0) AS gen_subquery_1 diff --git a/sql/hive/src/test/resources/sqlgen/subquery_exists_1.sql b/sql/hive/src/test/resources/sqlgen/subquery_exists_1.sql index d598e4c036a2..bd28d8dca94c 100644 --- a/sql/hive/src/test/resources/sqlgen/subquery_exists_1.sql +++ b/sql/hive/src/test/resources/sqlgen/subquery_exists_1.sql @@ -5,4 +5,4 @@ where exists (select a.key from src a where b.value = a.value and a.key = b.key and a.value > 'val_9') -------------------------------------------------------------------------------- -SELECT `gen_attr_0` AS `key`, `gen_attr_1` AS `value` FROM (SELECT `gen_attr_0`, `gen_attr_1` FROM (SELECT `key` AS `gen_attr_0`, `value` AS `gen_attr_1` FROM `default`.`src`) AS gen_subquery_0 WHERE EXISTS(SELECT `gen_attr_4` AS `1` FROM (SELECT 1 AS `gen_attr_4` FROM (SELECT `gen_attr_3`, `gen_attr_2` FROM (SELECT `key` AS `gen_attr_3`, `value` AS `gen_attr_2` FROM `default`.`src`) AS gen_subquery_2 WHERE (`gen_attr_2` > "val_9")) AS gen_subquery_1 WHERE ((`gen_attr_1` = `gen_attr_2`) AND (`gen_attr_3` = `gen_attr_0`))) AS gen_subquery_3)) AS b +SELECT `gen_attr_0` AS `key`, `gen_attr_1` AS `value` FROM (SELECT `gen_attr_0`, `gen_attr_1` FROM (SELECT `key` AS `gen_attr_0`, `value` AS `gen_attr_1` FROM `default`.`src`) AS gen_subquery_0 WHERE EXISTS(SELECT `gen_attr_4` AS `1` FROM (SELECT 1 AS `gen_attr_4` FROM (SELECT `gen_attr_3`, `gen_attr_2` FROM (SELECT `key` AS `gen_attr_3`, `value` AS `gen_attr_2` FROM `default`.`src`) AS gen_subquery_2 WHERE (`gen_attr_2` > 'val_9')) AS gen_subquery_1 WHERE ((`gen_attr_1` = `gen_attr_2`) AND (`gen_attr_3` = `gen_attr_0`))) AS gen_subquery_3)) AS b diff --git a/sql/hive/src/test/resources/sqlgen/subquery_exists_2.sql b/sql/hive/src/test/resources/sqlgen/subquery_exists_2.sql index a353c33af21a..d2965fc0b9b7 100644 --- a/sql/hive/src/test/resources/sqlgen/subquery_exists_2.sql +++ b/sql/hive/src/test/resources/sqlgen/subquery_exists_2.sql @@ -6,4 +6,4 @@ from (select * from src a where b.value = a.value and a.key = b.key and a.value > 'val_9')) a -------------------------------------------------------------------------------- -SELECT `gen_attr_0` AS `key`, `gen_attr_1` AS `value` FROM (SELECT `gen_attr_0`, `gen_attr_1` FROM (SELECT `gen_attr_0`, `gen_attr_1` FROM (SELECT `key` AS `gen_attr_0`, `value` AS `gen_attr_1` FROM `default`.`src`) AS gen_subquery_0 WHERE EXISTS(SELECT `gen_attr_4` AS `1` FROM (SELECT 1 AS `gen_attr_4` FROM (SELECT `gen_attr_3`, `gen_attr_2` FROM (SELECT `key` AS `gen_attr_3`, `value` AS `gen_attr_2` FROM `default`.`src`) AS gen_subquery_2 WHERE (`gen_attr_2` > "val_9")) AS gen_subquery_1 WHERE ((`gen_attr_1` = `gen_attr_2`) AND (`gen_attr_3` = `gen_attr_0`))) AS gen_subquery_3)) AS a) AS a +SELECT `gen_attr_0` AS `key`, `gen_attr_1` AS `value` FROM (SELECT `gen_attr_0`, `gen_attr_1` FROM (SELECT `gen_attr_0`, `gen_attr_1` FROM (SELECT `key` AS `gen_attr_0`, `value` AS `gen_attr_1` FROM `default`.`src`) AS gen_subquery_0 WHERE EXISTS(SELECT `gen_attr_4` AS `1` FROM (SELECT 1 AS `gen_attr_4` FROM (SELECT `gen_attr_3`, `gen_attr_2` FROM (SELECT `key` AS `gen_attr_3`, `value` AS `gen_attr_2` FROM `default`.`src`) AS gen_subquery_2 WHERE (`gen_attr_2` > 'val_9')) AS gen_subquery_1 WHERE ((`gen_attr_1` = `gen_attr_2`) AND (`gen_attr_3` = `gen_attr_0`))) AS gen_subquery_3)) AS a) AS a diff --git a/sql/hive/src/test/resources/sqlgen/subquery_exists_having_1.sql b/sql/hive/src/test/resources/sqlgen/subquery_exists_having_1.sql index f6873d24e16e..93ce902b7599 100644 --- a/sql/hive/src/test/resources/sqlgen/subquery_exists_having_1.sql +++ b/sql/hive/src/test/resources/sqlgen/subquery_exists_having_1.sql @@ -6,4 +6,4 @@ having exists (select a.key from src a where a.key = b.key and a.value > 'val_9') -------------------------------------------------------------------------------- -SELECT `gen_attr_1` AS `key`, `gen_attr_2` AS `count(1)` FROM (SELECT `gen_attr_1`, count(1) AS `gen_attr_2` FROM (SELECT `key` AS `gen_attr_1`, `value` AS `gen_attr_3` FROM `default`.`src`) AS gen_subquery_0 GROUP BY `gen_attr_1` HAVING EXISTS(SELECT `gen_attr_4` AS `1` FROM (SELECT 1 AS `gen_attr_4` FROM (SELECT `gen_attr_0` FROM (SELECT `key` AS `gen_attr_0`, `value` AS `gen_attr_5` FROM `default`.`src`) AS gen_subquery_2 WHERE (`gen_attr_5` > "val_9")) AS gen_subquery_1 WHERE (`gen_attr_0` = `gen_attr_1`)) AS gen_subquery_3)) AS b +SELECT `gen_attr_1` AS `key`, `gen_attr_2` AS `count(1)` FROM (SELECT `gen_attr_1`, count(1) AS `gen_attr_2` FROM (SELECT `key` AS `gen_attr_1`, `value` AS `gen_attr_3` FROM `default`.`src`) AS gen_subquery_0 GROUP BY `gen_attr_1` HAVING EXISTS(SELECT `gen_attr_4` AS `1` FROM (SELECT 1 AS `gen_attr_4` FROM (SELECT `gen_attr_0` FROM (SELECT `key` AS `gen_attr_0`, `value` AS `gen_attr_5` FROM `default`.`src`) AS gen_subquery_2 WHERE (`gen_attr_5` > 'val_9')) AS gen_subquery_1 WHERE (`gen_attr_0` = `gen_attr_1`)) AS gen_subquery_3)) AS b diff --git a/sql/hive/src/test/resources/sqlgen/subquery_exists_having_2.sql b/sql/hive/src/test/resources/sqlgen/subquery_exists_having_2.sql index 8452ef946f61..411e073f0d28 100644 --- a/sql/hive/src/test/resources/sqlgen/subquery_exists_having_2.sql +++ b/sql/hive/src/test/resources/sqlgen/subquery_exists_having_2.sql @@ -7,4 +7,4 @@ from (select b.key, count(*) from src a where a.key = b.key and a.value > 'val_9')) a -------------------------------------------------------------------------------- -SELECT `gen_attr_0` AS `key`, `gen_attr_1` AS `count(1)` FROM (SELECT `gen_attr_0`, `gen_attr_1` FROM (SELECT `gen_attr_0`, count(1) AS `gen_attr_1` FROM (SELECT `key` AS `gen_attr_0`, `value` AS `gen_attr_3` FROM `default`.`src`) AS gen_subquery_0 GROUP BY `gen_attr_0` HAVING EXISTS(SELECT `gen_attr_4` AS `1` FROM (SELECT 1 AS `gen_attr_4` FROM (SELECT `gen_attr_2` FROM (SELECT `key` AS `gen_attr_2`, `value` AS `gen_attr_5` FROM `default`.`src`) AS gen_subquery_2 WHERE (`gen_attr_5` > "val_9")) AS gen_subquery_1 WHERE (`gen_attr_2` = `gen_attr_0`)) AS gen_subquery_3)) AS a) AS a +SELECT `gen_attr_0` AS `key`, `gen_attr_1` AS `count(1)` FROM (SELECT `gen_attr_0`, `gen_attr_1` FROM (SELECT `gen_attr_0`, count(1) AS `gen_attr_1` FROM (SELECT `key` AS `gen_attr_0`, `value` AS `gen_attr_3` FROM `default`.`src`) AS gen_subquery_0 GROUP BY `gen_attr_0` HAVING EXISTS(SELECT `gen_attr_4` AS `1` FROM (SELECT 1 AS `gen_attr_4` FROM (SELECT `gen_attr_2` FROM (SELECT `key` AS `gen_attr_2`, `value` AS `gen_attr_5` FROM `default`.`src`) AS gen_subquery_2 WHERE (`gen_attr_5` > 'val_9')) AS gen_subquery_1 WHERE (`gen_attr_2` = `gen_attr_0`)) AS gen_subquery_3)) AS a) AS a diff --git a/sql/hive/src/test/resources/sqlgen/subquery_exists_having_3.sql b/sql/hive/src/test/resources/sqlgen/subquery_exists_having_3.sql index 2ef38ce42944..b2ed0b0557af 100644 --- a/sql/hive/src/test/resources/sqlgen/subquery_exists_having_3.sql +++ b/sql/hive/src/test/resources/sqlgen/subquery_exists_having_3.sql @@ -6,4 +6,4 @@ having exists (select a.key from src a where a.value > 'val_9' and a.value = min(b.value)) -------------------------------------------------------------------------------- -SELECT `gen_attr_0` AS `key`, `gen_attr_1` AS `min(value)` FROM (SELECT `gen_attr_0`, `gen_attr_1` FROM (SELECT `gen_attr_0`, min(`gen_attr_4`) AS `gen_attr_1`, min(`gen_attr_4`) AS `gen_attr_3` FROM (SELECT `key` AS `gen_attr_0`, `value` AS `gen_attr_4` FROM `default`.`src`) AS gen_subquery_0 GROUP BY `gen_attr_0` HAVING EXISTS(SELECT `gen_attr_5` AS `1` FROM (SELECT 1 AS `gen_attr_5` FROM (SELECT `gen_attr_6`, `gen_attr_2` FROM (SELECT `key` AS `gen_attr_6`, `value` AS `gen_attr_2` FROM `default`.`src`) AS gen_subquery_3 WHERE (`gen_attr_2` > "val_9")) AS gen_subquery_2 WHERE (`gen_attr_2` = `gen_attr_3`)) AS gen_subquery_4)) AS gen_subquery_1) AS b +SELECT `gen_attr_0` AS `key`, `gen_attr_1` AS `min(value)` FROM (SELECT `gen_attr_0`, `gen_attr_1` FROM (SELECT `gen_attr_0`, min(`gen_attr_4`) AS `gen_attr_1`, min(`gen_attr_4`) AS `gen_attr_3` FROM (SELECT `key` AS `gen_attr_0`, `value` AS `gen_attr_4` FROM `default`.`src`) AS gen_subquery_0 GROUP BY `gen_attr_0` HAVING EXISTS(SELECT `gen_attr_5` AS `1` FROM (SELECT 1 AS `gen_attr_5` FROM (SELECT `gen_attr_6`, `gen_attr_2` FROM (SELECT `key` AS `gen_attr_6`, `value` AS `gen_attr_2` FROM `default`.`src`) AS gen_subquery_3 WHERE (`gen_attr_2` > 'val_9')) AS gen_subquery_2 WHERE (`gen_attr_2` = `gen_attr_3`)) AS gen_subquery_4)) AS gen_subquery_1) AS b diff --git a/sql/hive/src/test/resources/sqlgen/subquery_in_having_1.sql b/sql/hive/src/test/resources/sqlgen/subquery_in_having_1.sql index bfa58211b12f..9894f5ab39c7 100644 --- a/sql/hive/src/test/resources/sqlgen/subquery_in_having_1.sql +++ b/sql/hive/src/test/resources/sqlgen/subquery_in_having_1.sql @@ -5,4 +5,4 @@ group by key having count(*) in (select count(*) from src s1 where s1.key = '90' group by s1.key) order by key -------------------------------------------------------------------------------- -SELECT `gen_attr_0` AS `key`, `gen_attr_1` AS `count(1)` FROM (SELECT `gen_attr_0`, `gen_attr_1` FROM (SELECT `gen_attr_0`, count(1) AS `gen_attr_1`, count(1) AS `gen_attr_2` FROM (SELECT `key` AS `gen_attr_0`, `value` AS `gen_attr_4` FROM `default`.`src`) AS gen_subquery_0 GROUP BY `gen_attr_0` HAVING (`gen_attr_2` IN (SELECT `gen_attr_5` AS `_c0` FROM (SELECT `gen_attr_3` AS `gen_attr_5` FROM (SELECT count(1) AS `gen_attr_3` FROM (SELECT `key` AS `gen_attr_6`, `value` AS `gen_attr_7` FROM `default`.`src`) AS gen_subquery_3 WHERE (CAST(`gen_attr_6` AS DOUBLE) = CAST("90" AS DOUBLE)) GROUP BY `gen_attr_6`) AS gen_subquery_2) AS gen_subquery_4))) AS gen_subquery_1 ORDER BY `gen_attr_0` ASC) AS src +SELECT `gen_attr_0` AS `key`, `gen_attr_1` AS `count(1)` FROM (SELECT `gen_attr_0`, `gen_attr_1` FROM (SELECT `gen_attr_0`, count(1) AS `gen_attr_1`, count(1) AS `gen_attr_2` FROM (SELECT `key` AS `gen_attr_0`, `value` AS `gen_attr_4` FROM `default`.`src`) AS gen_subquery_0 GROUP BY `gen_attr_0` HAVING (`gen_attr_2` IN (SELECT `gen_attr_5` AS `_c0` FROM (SELECT `gen_attr_3` AS `gen_attr_5` FROM (SELECT count(1) AS `gen_attr_3` FROM (SELECT `key` AS `gen_attr_6`, `value` AS `gen_attr_7` FROM `default`.`src`) AS gen_subquery_3 WHERE (CAST(`gen_attr_6` AS DOUBLE) = CAST('90' AS DOUBLE)) GROUP BY `gen_attr_6`) AS gen_subquery_2) AS gen_subquery_4))) AS gen_subquery_1 ORDER BY `gen_attr_0` ASC) AS src diff --git a/sql/hive/src/test/resources/sqlgen/subquery_in_having_2.sql b/sql/hive/src/test/resources/sqlgen/subquery_in_having_2.sql index f7503bce068f..c3a122aa889b 100644 --- a/sql/hive/src/test/resources/sqlgen/subquery_in_having_2.sql +++ b/sql/hive/src/test/resources/sqlgen/subquery_in_having_2.sql @@ -7,4 +7,4 @@ having b.key in (select a.key where a.value > 'val_9' and a.value = min(b.value)) order by b.key -------------------------------------------------------------------------------- -SELECT `gen_attr_0` AS `key`, `gen_attr_1` AS `min(value)` FROM (SELECT `gen_attr_0`, `gen_attr_1` FROM (SELECT `gen_attr_0`, min(`gen_attr_5`) AS `gen_attr_1`, min(`gen_attr_5`) AS `gen_attr_4` FROM (SELECT `key` AS `gen_attr_0`, `value` AS `gen_attr_5` FROM `default`.`src`) AS gen_subquery_0 GROUP BY `gen_attr_0` HAVING (struct(`gen_attr_0`, `gen_attr_4`) IN (SELECT `gen_attr_6` AS `_c0`, `gen_attr_7` AS `_c1` FROM (SELECT `gen_attr_2` AS `gen_attr_6`, `gen_attr_3` AS `gen_attr_7` FROM (SELECT `gen_attr_2`, `gen_attr_3` FROM (SELECT `key` AS `gen_attr_2`, `value` AS `gen_attr_3` FROM `default`.`src`) AS gen_subquery_3 WHERE (`gen_attr_3` > "val_9")) AS gen_subquery_2) AS gen_subquery_4))) AS gen_subquery_1 ORDER BY `gen_attr_0` ASC) AS b +SELECT `gen_attr_0` AS `key`, `gen_attr_1` AS `min(value)` FROM (SELECT `gen_attr_0`, `gen_attr_1` FROM (SELECT `gen_attr_0`, min(`gen_attr_5`) AS `gen_attr_1`, min(`gen_attr_5`) AS `gen_attr_4` FROM (SELECT `key` AS `gen_attr_0`, `value` AS `gen_attr_5` FROM `default`.`src`) AS gen_subquery_0 GROUP BY `gen_attr_0` HAVING (struct(`gen_attr_0`, `gen_attr_4`) IN (SELECT `gen_attr_6` AS `_c0`, `gen_attr_7` AS `_c1` FROM (SELECT `gen_attr_2` AS `gen_attr_6`, `gen_attr_3` AS `gen_attr_7` FROM (SELECT `gen_attr_2`, `gen_attr_3` FROM (SELECT `key` AS `gen_attr_2`, `value` AS `gen_attr_3` FROM `default`.`src`) AS gen_subquery_3 WHERE (`gen_attr_3` > 'val_9')) AS gen_subquery_2) AS gen_subquery_4))) AS gen_subquery_1 ORDER BY `gen_attr_0` ASC) AS b diff --git a/sql/hive/src/test/resources/sqlgen/subquery_not_exists_1.sql b/sql/hive/src/test/resources/sqlgen/subquery_not_exists_1.sql index 54a38ec0edb4..eed20a5d311f 100644 --- a/sql/hive/src/test/resources/sqlgen/subquery_not_exists_1.sql +++ b/sql/hive/src/test/resources/sqlgen/subquery_not_exists_1.sql @@ -5,4 +5,4 @@ where not exists (select a.key from src a where b.value = a.value and a.key = b.key and a.value > 'val_2') -------------------------------------------------------------------------------- -SELECT `gen_attr_0` AS `key`, `gen_attr_1` AS `value` FROM (SELECT `gen_attr_0`, `gen_attr_1` FROM (SELECT `key` AS `gen_attr_0`, `value` AS `gen_attr_1` FROM `default`.`src`) AS gen_subquery_0 WHERE (NOT EXISTS(SELECT `gen_attr_4` AS `1` FROM (SELECT 1 AS `gen_attr_4` FROM (SELECT `gen_attr_3`, `gen_attr_2` FROM (SELECT `key` AS `gen_attr_3`, `value` AS `gen_attr_2` FROM `default`.`src`) AS gen_subquery_2 WHERE (`gen_attr_2` > "val_2")) AS gen_subquery_1 WHERE ((`gen_attr_1` = `gen_attr_2`) AND (`gen_attr_3` = `gen_attr_0`))) AS gen_subquery_3))) AS b +SELECT `gen_attr_0` AS `key`, `gen_attr_1` AS `value` FROM (SELECT `gen_attr_0`, `gen_attr_1` FROM (SELECT `key` AS `gen_attr_0`, `value` AS `gen_attr_1` FROM `default`.`src`) AS gen_subquery_0 WHERE (NOT EXISTS(SELECT `gen_attr_4` AS `1` FROM (SELECT 1 AS `gen_attr_4` FROM (SELECT `gen_attr_3`, `gen_attr_2` FROM (SELECT `key` AS `gen_attr_3`, `value` AS `gen_attr_2` FROM `default`.`src`) AS gen_subquery_2 WHERE (`gen_attr_2` > 'val_2')) AS gen_subquery_1 WHERE ((`gen_attr_1` = `gen_attr_2`) AND (`gen_attr_3` = `gen_attr_0`))) AS gen_subquery_3))) AS b diff --git a/sql/hive/src/test/resources/sqlgen/subquery_not_exists_2.sql b/sql/hive/src/test/resources/sqlgen/subquery_not_exists_2.sql index c05bb5d991b4..7040e106e7ba 100644 --- a/sql/hive/src/test/resources/sqlgen/subquery_not_exists_2.sql +++ b/sql/hive/src/test/resources/sqlgen/subquery_not_exists_2.sql @@ -5,4 +5,4 @@ where not exists (select a.key from src a where b.value = a.value and a.value > 'val_2') -------------------------------------------------------------------------------- -SELECT `gen_attr_0` AS `key`, `gen_attr_1` AS `value` FROM (SELECT `gen_attr_0`, `gen_attr_1` FROM (SELECT `key` AS `gen_attr_0`, `value` AS `gen_attr_1` FROM `default`.`src`) AS gen_subquery_0 WHERE (NOT EXISTS(SELECT `gen_attr_3` AS `1` FROM (SELECT 1 AS `gen_attr_3` FROM (SELECT `gen_attr_4`, `gen_attr_2` FROM (SELECT `key` AS `gen_attr_4`, `value` AS `gen_attr_2` FROM `default`.`src`) AS gen_subquery_2 WHERE (`gen_attr_2` > "val_2")) AS gen_subquery_1 WHERE (`gen_attr_1` = `gen_attr_2`)) AS gen_subquery_3))) AS b +SELECT `gen_attr_0` AS `key`, `gen_attr_1` AS `value` FROM (SELECT `gen_attr_0`, `gen_attr_1` FROM (SELECT `key` AS `gen_attr_0`, `value` AS `gen_attr_1` FROM `default`.`src`) AS gen_subquery_0 WHERE (NOT EXISTS(SELECT `gen_attr_3` AS `1` FROM (SELECT 1 AS `gen_attr_3` FROM (SELECT `gen_attr_4`, `gen_attr_2` FROM (SELECT `key` AS `gen_attr_4`, `value` AS `gen_attr_2` FROM `default`.`src`) AS gen_subquery_2 WHERE (`gen_attr_2` > 'val_2')) AS gen_subquery_1 WHERE (`gen_attr_1` = `gen_attr_2`)) AS gen_subquery_3))) AS b diff --git a/sql/hive/src/test/resources/sqlgen/subquery_not_exists_having_1.sql b/sql/hive/src/test/resources/sqlgen/subquery_not_exists_having_1.sql index d6047c52f20f..3c0e90ed4222 100644 --- a/sql/hive/src/test/resources/sqlgen/subquery_not_exists_having_1.sql +++ b/sql/hive/src/test/resources/sqlgen/subquery_not_exists_having_1.sql @@ -6,4 +6,4 @@ having not exists (select a.key from src a where b.value = a.value and a.key = b.key and a.value > 'val_12') -------------------------------------------------------------------------------- -SELECT `gen_attr_3` AS `key`, `gen_attr_0` AS `value` FROM (SELECT `gen_attr_3`, `gen_attr_0` FROM (SELECT `key` AS `gen_attr_3`, `value` AS `gen_attr_0` FROM `default`.`src`) AS gen_subquery_0 GROUP BY `gen_attr_3`, `gen_attr_0` HAVING (NOT EXISTS(SELECT `gen_attr_4` AS `1` FROM (SELECT 1 AS `gen_attr_4` FROM (SELECT `gen_attr_2`, `gen_attr_1` FROM (SELECT `key` AS `gen_attr_2`, `value` AS `gen_attr_1` FROM `default`.`src`) AS gen_subquery_2 WHERE (`gen_attr_1` > "val_12")) AS gen_subquery_1 WHERE ((`gen_attr_0` = `gen_attr_1`) AND (`gen_attr_2` = `gen_attr_3`))) AS gen_subquery_3))) AS b +SELECT `gen_attr_3` AS `key`, `gen_attr_0` AS `value` FROM (SELECT `gen_attr_3`, `gen_attr_0` FROM (SELECT `key` AS `gen_attr_3`, `value` AS `gen_attr_0` FROM `default`.`src`) AS gen_subquery_0 GROUP BY `gen_attr_3`, `gen_attr_0` HAVING (NOT EXISTS(SELECT `gen_attr_4` AS `1` FROM (SELECT 1 AS `gen_attr_4` FROM (SELECT `gen_attr_2`, `gen_attr_1` FROM (SELECT `key` AS `gen_attr_2`, `value` AS `gen_attr_1` FROM `default`.`src`) AS gen_subquery_2 WHERE (`gen_attr_1` > 'val_12')) AS gen_subquery_1 WHERE ((`gen_attr_0` = `gen_attr_1`) AND (`gen_attr_2` = `gen_attr_3`))) AS gen_subquery_3))) AS b diff --git a/sql/hive/src/test/resources/sqlgen/subquery_not_exists_having_2.sql b/sql/hive/src/test/resources/sqlgen/subquery_not_exists_having_2.sql index 8b5402d8aa77..0c16f9e58b9b 100644 --- a/sql/hive/src/test/resources/sqlgen/subquery_not_exists_having_2.sql +++ b/sql/hive/src/test/resources/sqlgen/subquery_not_exists_having_2.sql @@ -6,4 +6,4 @@ having not exists (select distinct a.key from src a where b.value = a.value and a.value > 'val_12') -------------------------------------------------------------------------------- -SELECT `gen_attr_2` AS `key`, `gen_attr_0` AS `value` FROM (SELECT `gen_attr_2`, `gen_attr_0` FROM (SELECT `key` AS `gen_attr_2`, `value` AS `gen_attr_0` FROM `default`.`src`) AS gen_subquery_0 GROUP BY `gen_attr_2`, `gen_attr_0` HAVING (NOT EXISTS(SELECT `gen_attr_3` AS `1` FROM (SELECT 1 AS `gen_attr_3` FROM (SELECT DISTINCT `gen_attr_4`, `gen_attr_1` FROM (SELECT `key` AS `gen_attr_4`, `value` AS `gen_attr_1` FROM `default`.`src`) AS gen_subquery_2 WHERE (`gen_attr_1` > "val_12")) AS gen_subquery_1 WHERE (`gen_attr_0` = `gen_attr_1`)) AS gen_subquery_3))) AS b +SELECT `gen_attr_2` AS `key`, `gen_attr_0` AS `value` FROM (SELECT `gen_attr_2`, `gen_attr_0` FROM (SELECT `key` AS `gen_attr_2`, `value` AS `gen_attr_0` FROM `default`.`src`) AS gen_subquery_0 GROUP BY `gen_attr_2`, `gen_attr_0` HAVING (NOT EXISTS(SELECT `gen_attr_3` AS `1` FROM (SELECT 1 AS `gen_attr_3` FROM (SELECT DISTINCT `gen_attr_4`, `gen_attr_1` FROM (SELECT `key` AS `gen_attr_4`, `value` AS `gen_attr_1` FROM `default`.`src`) AS gen_subquery_2 WHERE (`gen_attr_1` > 'val_12')) AS gen_subquery_1 WHERE (`gen_attr_0` = `gen_attr_1`)) AS gen_subquery_3))) AS b diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/catalyst/ExpressionSQLBuilderSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/catalyst/ExpressionSQLBuilderSuite.scala index fef726c5d801..43a218b4d14b 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/catalyst/ExpressionSQLBuilderSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/catalyst/ExpressionSQLBuilderSuite.scala @@ -24,14 +24,22 @@ import org.apache.spark.sql.catalyst.expressions.{If, Literal, SpecifiedWindowFr class ExpressionSQLBuilderSuite extends SQLBuilderTest { test("literal") { - checkSQL(Literal("foo"), "\"foo\"") - checkSQL(Literal("\"foo\""), "\"\\\"foo\\\"\"") + checkSQL(Literal("foo"), "'foo'") + checkSQL(Literal("\"foo\""), "'\"foo\"'") + checkSQL(Literal("'foo'"), "'\\'foo\\''") checkSQL(Literal(1: Byte), "1Y") checkSQL(Literal(2: Short), "2S") checkSQL(Literal(4: Int), "4") checkSQL(Literal(8: Long), "8L") checkSQL(Literal(1.5F), "CAST(1.5 AS FLOAT)") + checkSQL(Literal(Float.PositiveInfinity), "CAST('Infinity' AS FLOAT)") + checkSQL(Literal(Float.NegativeInfinity), "CAST('-Infinity' AS FLOAT)") + checkSQL(Literal(Float.NaN), "CAST('NaN' AS FLOAT)") checkSQL(Literal(2.5D), "2.5D") + checkSQL(Literal(Double.PositiveInfinity), "CAST('Infinity' AS DOUBLE)") + checkSQL(Literal(Double.NegativeInfinity), "CAST('-Infinity' AS DOUBLE)") + checkSQL(Literal(Double.NaN), "CAST('NaN' AS DOUBLE)") + checkSQL(Literal(BigDecimal("10.0000000").underlying), "10.0000000BD") checkSQL( Literal(Timestamp.valueOf("2016-01-01 00:00:00")), "TIMESTAMP('2016-01-01 00:00:00.0')") // TODO tests for decimals @@ -75,8 +83,8 @@ class ExpressionSQLBuilderSuite extends SQLBuilderTest { checkSQL('a.int / 'b.int, "(`a` / `b`)") checkSQL('a.int % 'b.int, "(`a` % `b`)") - checkSQL(-'a.int, "(-`a`)") - checkSQL(-('a.int + 'b.int), "(-(`a` + `b`))") + checkSQL(-'a.int, "(- `a`)") + checkSQL(-('a.int + 'b.int), "(- (`a` + `b`))") } test("window specification") { diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/catalyst/LogicalPlanToSQLSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/catalyst/LogicalPlanToSQLSuite.scala index d8ab864ca6fc..ef2f756a4bde 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/catalyst/LogicalPlanToSQLSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/catalyst/LogicalPlanToSQLSuite.scala @@ -23,7 +23,10 @@ import java.nio.file.{Files, NoSuchFileException, Paths} import scala.util.control.NonFatal import org.apache.spark.sql.Column +import org.apache.spark.sql.catalyst.analysis.MultiInstanceRelation +import org.apache.spark.sql.catalyst.expressions.Attribute import org.apache.spark.sql.catalyst.parser.ParseException +import org.apache.spark.sql.catalyst.plans.logical.LeafNode import org.apache.spark.sql.functions._ import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.test.SQLTestUtils @@ -41,15 +44,14 @@ class LogicalPlanToSQLSuite extends SQLBuilderTest with SQLTestUtils { import testImplicits._ // Used for generating new query answer files by saving - private val regenerateGoldenFiles: Boolean = - Option(System.getenv("SPARK_GENERATE_GOLDEN_FILES")) == Some("1") + private val regenerateGoldenFiles: Boolean = System.getenv("SPARK_GENERATE_GOLDEN_FILES") == "1" private val goldenSQLPath = "src/test/resources/sqlgen/" protected override def beforeAll(): Unit = { super.beforeAll() - sql("DROP TABLE IF EXISTS parquet_t0") - sql("DROP TABLE IF EXISTS parquet_t1") - sql("DROP TABLE IF EXISTS parquet_t2") + (0 to 3).foreach { i => + sql(s"DROP TABLE IF EXISTS parquet_t$i") + } sql("DROP TABLE IF EXISTS t0") spark.range(10).write.saveAsTable("parquet_t0") @@ -85,10 +87,9 @@ class LogicalPlanToSQLSuite extends SQLBuilderTest with SQLTestUtils { override protected def afterAll(): Unit = { try { - sql("DROP TABLE IF EXISTS parquet_t0") - sql("DROP TABLE IF EXISTS parquet_t1") - sql("DROP TABLE IF EXISTS parquet_t2") - sql("DROP TABLE IF EXISTS parquet_t3") + (0 to 3).foreach { i => + sql(s"DROP TABLE IF EXISTS parquet_t$i") + } sql("DROP TABLE IF EXISTS t0") } finally { super.afterAll() @@ -181,7 +182,11 @@ class LogicalPlanToSQLSuite extends SQLBuilderTest with SQLTestUtils { } test("Test should fail if the SQL query cannot be regenerated") { - spark.range(10).createOrReplaceTempView("not_sql_gen_supported_table_so_far") + case class Unsupported() extends LeafNode with MultiInstanceRelation { + override def newInstance(): Unsupported = copy() + override def output: Seq[Attribute] = Nil + } + Unsupported().createOrReplaceTempView("not_sql_gen_supported_table_so_far") sql("select * from not_sql_gen_supported_table_so_far") val m3 = intercept[org.scalatest.exceptions.TestFailedException] { checkSQL("select * from not_sql_gen_supported_table_so_far", "in") @@ -197,6 +202,11 @@ class LogicalPlanToSQLSuite extends SQLBuilderTest with SQLTestUtils { } } + test("range") { + checkSQL("select * from range(100)", "range") + checkSQL("select * from range(1, 100, 20, 10)", "range_with_splits") + } + test("in") { checkSQL("SELECT id FROM parquet_t0 WHERE id IN (1, 2, 3)", "in") } @@ -1103,4 +1113,12 @@ class LogicalPlanToSQLSuite extends SQLBuilderTest with SQLTestUtils { checkSQL("select * from orc_t", "select_orc_table") } } + + test("inline tables") { + checkSQL( + """ + |select * from values ("one", 1), ("two", 2), ("three", null) as data(a, b) where b > 1 + """.stripMargin, + "inline_tables") + } } diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveDDLCommandSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveDDLCommandSuite.scala index 867aadb5f556..54009d4b4130 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveDDLCommandSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveDDLCommandSuite.scala @@ -18,6 +18,7 @@ package org.apache.spark.sql.hive import org.apache.spark.sql.AnalysisException +import org.apache.spark.sql.catalyst.TableIdentifier import org.apache.spark.sql.catalyst.analysis.UnresolvedAttribute import org.apache.spark.sql.catalyst.catalog.{CatalogColumn, CatalogTable, CatalogTableType} import org.apache.spark.sql.catalyst.dsl.expressions._ @@ -520,8 +521,13 @@ class HiveDDLCommandSuite extends PlanTest { } } - test("MSCK repair table (not supported)") { - assertUnsupported("MSCK REPAIR TABLE tab1") + test("MSCK REPAIR table") { + val sql = "MSCK REPAIR TABLE tab1" + val parsed = parser.parsePlan(sql) + val expected = AlterTableRecoverPartitionsCommand( + TableIdentifier("tab1", None), + "MSCK REPAIR TABLE") + comparePlans(parsed, expected) } test("create table like") { diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveSparkSubmitSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveSparkSubmitSuite.scala index 9bca720a9473..0c7a0446ffb8 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveSparkSubmitSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveSparkSubmitSuite.scala @@ -14,6 +14,24 @@ * See the License for the specific language governing permissions and * limitations under the License. */ +/* + * Changes for SnappyData data platform. + * + * Portions Copyright (c) 2016 SnappyData, Inc. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you + * may not use this file except in compliance with the License. You + * may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or + * implied. See the License for the specific language governing + * permissions and limitations under the License. See accompanying + * LICENSE file. + */ package org.apache.spark.sql.hive @@ -153,12 +171,14 @@ class HiveSparkSubmitSuite case x => throw new Exception(s"Unsupported Scala Version: $x") } val testJar = s"sql/hive/src/test/resources/regression-test-SPARK-8489/test-$version.jar" + val testJarPath = sys.props.get("spark.project.home").map( + _ + '/' + testJar).getOrElse(testJar) val args = Seq( "--conf", "spark.ui.enabled=false", "--conf", "spark.master.rest.enabled=false", "--driver-java-options", "-Dderby.system.durability=test", "--class", "Main", - testJar) + testJarPath) runSparkSubmit(args) } @@ -253,6 +273,47 @@ class HiveSparkSubmitSuite runSparkSubmit(args) } + test("SPARK-16901: set javax.jdo.option.ConnectionURL") { + // In this test, we set javax.jdo.option.ConnectionURL and set metastore version to + // 0.13. This test will make sure that javax.jdo.option.ConnectionURL will not be + // overridden by hive's default settings when we create a HiveConf object inside + // HiveClientImpl. Please see SPARK-16901 for more details. + + val metastoreLocation = Utils.createTempDir() + metastoreLocation.delete() + val metastoreURL = + s"jdbc:derby:memory:;databaseName=${metastoreLocation.getAbsolutePath};create=true" + val hiveSiteXmlContent = + s""" + | + | + | javax.jdo.option.ConnectionURL + | $metastoreURL + | + | + """.stripMargin + + // Write a hive-site.xml containing a setting of hive.metastore.warehouse.dir. + val hiveSiteDir = Utils.createTempDir() + val file = new File(hiveSiteDir.getCanonicalPath, "hive-site.xml") + val bw = new BufferedWriter(new FileWriter(file)) + bw.write(hiveSiteXmlContent) + bw.close() + + val unusedJar = TestUtils.createJarWithClasses(Seq.empty) + val args = Seq( + "--class", SetMetastoreURLTest.getClass.getName.stripSuffix("$"), + "--name", "SetMetastoreURLTest", + "--master", "local[1]", + "--conf", "spark.ui.enabled=false", + "--conf", "spark.master.rest.enabled=false", + "--conf", s"spark.sql.test.expectedMetastoreURL=$metastoreURL", + "--conf", s"spark.driver.extraClassPath=${hiveSiteDir.getCanonicalPath}", + "--driver-java-options", "-Dderby.system.durability=test", + unusedJar.toString) + runSparkSubmit(args) + } + // NOTE: This is an expensive operation in terms of time (10 seconds+). Use sparingly. // This is copied from org.apache.spark.deploy.SparkSubmitSuite private def runSparkSubmit(args: Seq[String]): Unit = { @@ -313,6 +374,45 @@ class HiveSparkSubmitSuite } } +object SetMetastoreURLTest extends Logging { + def main(args: Array[String]): Unit = { + Utils.configTestLog4j("INFO") + + val sparkConf = new SparkConf(loadDefaults = true) + val builder = SparkSession.builder() + .config(sparkConf) + .config("spark.ui.enabled", "false") + .config("spark.sql.hive.metastore.version", "0.13.1") + // The issue described in SPARK-16901 only appear when + // spark.sql.hive.metastore.jars is not set to builtin. + .config("spark.sql.hive.metastore.jars", "maven") + .enableHiveSupport() + + val spark = builder.getOrCreate() + val expectedMetastoreURL = + spark.conf.get("spark.sql.test.expectedMetastoreURL") + logInfo(s"spark.sql.test.expectedMetastoreURL is $expectedMetastoreURL") + + if (expectedMetastoreURL == null) { + throw new Exception( + s"spark.sql.test.expectedMetastoreURL should be set.") + } + + // HiveSharedState is used when Hive support is enabled. + val actualMetastoreURL = + spark.sharedState.asInstanceOf[HiveSharedState] + .metadataHive + .getConf("javax.jdo.option.ConnectionURL", "this_is_a_wrong_URL") + logInfo(s"javax.jdo.option.ConnectionURL is $actualMetastoreURL") + + if (actualMetastoreURL != expectedMetastoreURL) { + throw new Exception( + s"Expected value of javax.jdo.option.ConnectionURL is $expectedMetastoreURL. But, " + + s"the actual value is $actualMetastoreURL") + } + } +} + object SetWarehouseLocationTest extends Logging { def main(args: Array[String]): Unit = { Utils.configTestLog4j("INFO") diff --git a/core/src/main/scala/org/apache/spark/storage/BlockFetchException.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveUtilsSuite.scala similarity index 52% rename from core/src/main/scala/org/apache/spark/storage/BlockFetchException.scala rename to sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveUtilsSuite.scala index f6e46ae9a481..667a7ddd8bb6 100644 --- a/core/src/main/scala/org/apache/spark/storage/BlockFetchException.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveUtilsSuite.scala @@ -15,10 +15,22 @@ * limitations under the License. */ -package org.apache.spark.storage +package org.apache.spark.sql.hive -import org.apache.spark.SparkException +import org.apache.hadoop.hive.conf.HiveConf.ConfVars -private[spark] -case class BlockFetchException(messages: String, throwable: Throwable) - extends SparkException(messages, throwable) +import org.apache.spark.sql.hive.test.TestHiveSingleton +import org.apache.spark.sql.test.SQLTestUtils +import org.apache.spark.sql.QueryTest + +class HiveUtilsSuite extends QueryTest with SQLTestUtils with TestHiveSingleton { + + test("newTemporaryConfiguration overwrites listener configurations") { + Seq(true, false).foreach { useInMemoryDerby => + val conf = HiveUtils.newTemporaryConfiguration(useInMemoryDerby) + assert(conf(ConfVars.METASTORE_PRE_EVENT_LISTENERS.varname) === "") + assert(conf(ConfVars.METASTORE_EVENT_LISTENERS.varname) === "") + assert(conf(ConfVars.METASTORE_END_FUNCTION_LISTENERS.varname) === "") + } + } +} diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/MetastoreDataSourcesSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/MetastoreDataSourcesSuite.scala index af071f95e69f..28c8139072ea 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/MetastoreDataSourcesSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/MetastoreDataSourcesSuite.scala @@ -333,7 +333,7 @@ class MetastoreDataSourcesSuite extends QueryTest with SQLTestUtils with TestHiv }.getMessage assert( - message.contains("Table ctasJsonTable already exists."), + message.contains("Table default.ctasJsonTable already exists."), "We should complain that ctasJsonTable already exists") // The following statement should be fine if it has IF NOT EXISTS. @@ -748,7 +748,7 @@ class MetastoreDataSourcesSuite extends QueryTest with SQLTestUtils with TestHiv assert(schema === actualSchema) // Checks the DESCRIBE output. - checkAnswer(sql("DESCRIBE spark6655"), Row("int", "int", "") :: Nil) + checkAnswer(sql("DESCRIBE spark6655"), Row("int", "int", null) :: Nil) } } @@ -901,7 +901,8 @@ class MetastoreDataSourcesSuite extends QueryTest with SQLTestUtils with TestHiv val e = intercept[AnalysisException] { createDF(10, 19).write.mode(SaveMode.Append).format("orc").saveAsTable("appendOrcToParquet") } - assert(e.getMessage.contains("The file format of the existing table `appendOrcToParquet` " + + assert(e.getMessage.contains( + "The file format of the existing table default.appendOrcToParquet " + "is `org.apache.spark.sql.execution.datasources.parquet.ParquetFileFormat`. " + "It doesn't match the specified format `orc`")) } @@ -912,7 +913,8 @@ class MetastoreDataSourcesSuite extends QueryTest with SQLTestUtils with TestHiv createDF(10, 19).write.mode(SaveMode.Append).format("parquet") .saveAsTable("appendParquetToJson") } - assert(e.getMessage.contains("The file format of the existing table `appendParquetToJson` " + + assert(e.getMessage.contains( + "The file format of the existing table default.appendParquetToJson " + "is `org.apache.spark.sql.execution.datasources.json.JsonFileFormat`. " + "It doesn't match the specified format `parquet`")) } @@ -923,7 +925,8 @@ class MetastoreDataSourcesSuite extends QueryTest with SQLTestUtils with TestHiv createDF(10, 19).write.mode(SaveMode.Append).format("text") .saveAsTable("appendTextToJson") } - assert(e.getMessage.contains("The file format of the existing table `appendTextToJson` is " + + assert(e.getMessage.contains( + "The file format of the existing table default.appendTextToJson is " + "`org.apache.spark.sql.execution.datasources.json.JsonFileFormat`. " + "It doesn't match the specified format `text`")) } diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/MetastoreRelationSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/MetastoreRelationSuite.scala new file mode 100644 index 000000000000..eec60b440720 --- /dev/null +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/MetastoreRelationSuite.scala @@ -0,0 +1,38 @@ +/* +* Licensed to the Apache Software Foundation (ASF) under one or more +* contributor license agreements. See the NOTICE file distributed with +* this work for additional information regarding copyright ownership. +* The ASF licenses this file to You under the Apache License, Version 2.0 +* (the "License"); you may not use this file except in compliance with +* the License. You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*/ + +package org.apache.spark.sql.hive + +import org.apache.spark.SparkFunSuite +import org.apache.spark.sql.catalyst.TableIdentifier +import org.apache.spark.sql.catalyst.catalog.{CatalogColumn, CatalogStorageFormat, CatalogTable, CatalogTableType} + +class MetastoreRelationSuite extends SparkFunSuite { + test("makeCopy and toJSON should work") { + val table = CatalogTable( + identifier = TableIdentifier("test", Some("db")), + tableType = CatalogTableType.VIEW, + storage = CatalogStorageFormat.empty, + schema = Seq.empty[CatalogColumn]) + val relation = MetastoreRelation("db", "test", None)(table, null, null) + + // No exception should be thrown + relation.makeCopy(Array("db", "test", None)) + // No exception should be thrown + relation.toJSON + } +} diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveCommandSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveCommandSuite.scala index 5d510197c4d9..446029fdc6e4 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveCommandSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveCommandSuite.scala @@ -404,25 +404,24 @@ class HiveCommandSuite extends QueryTest with SQLTestUtils with TestHiveSingleto |USING org.apache.spark.sql.parquet.DefaultSource """.stripMargin) // An empty sequence of row is returned for session temporary table. - val message1 = intercept[AnalysisException] { + intercept[NoSuchTableException] { sql("SHOW PARTITIONS parquet_temp") - }.getMessage - assert(message1.contains("is not allowed on a temporary table")) + } - val message2 = intercept[AnalysisException] { + val message1 = intercept[AnalysisException] { sql("SHOW PARTITIONS parquet_tab3") }.getMessage - assert(message2.contains("not allowed on a table that is not partitioned")) + assert(message1.contains("not allowed on a table that is not partitioned")) - val message3 = intercept[AnalysisException] { + val message2 = intercept[AnalysisException] { sql("SHOW PARTITIONS parquet_tab4 PARTITION(abcd=2015, xyz=1)") }.getMessage - assert(message3.contains("Non-partitioning column(s) [abcd, xyz] are specified")) + assert(message2.contains("Non-partitioning column(s) [abcd, xyz] are specified")) - val message4 = intercept[AnalysisException] { + val message3 = intercept[AnalysisException] { sql("SHOW PARTITIONS parquet_view1") }.getMessage - assert(message4.contains("is not allowed on a view or index table")) + assert(message3.contains("is not allowed on a view")) } } diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveDDLSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveDDLSuite.scala index 92282420214d..0a6ccbed8493 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveDDLSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveDDLSuite.scala @@ -24,8 +24,11 @@ import org.scalatest.BeforeAndAfterEach import org.apache.spark.internal.config._ import org.apache.spark.sql.{AnalysisException, QueryTest, Row, SaveMode} -import org.apache.spark.sql.catalyst.catalog.{CatalogDatabase, CatalogTableType} +import org.apache.spark.sql.catalyst.catalog.{CatalogDatabase, CatalogTable, CatalogTableType} import org.apache.spark.sql.catalyst.TableIdentifier +import org.apache.spark.sql.execution.command.{CreateDataSourceTableUtils, DDLUtils} +import org.apache.spark.sql.execution.command.CreateDataSourceTableUtils._ +import org.apache.spark.sql.execution.datasources.CaseInsensitiveMap import org.apache.spark.sql.hive.test.TestHiveSingleton import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.test.SQLTestUtils @@ -135,8 +138,11 @@ class HiveDDLSuite sql(s"CREATE VIEW $viewName COMMENT 'no comment' AS SELECT * FROM $tabName") val tableMetadata = catalog.getTableMetadata(TableIdentifier(tabName, Some("default"))) val viewMetadata = catalog.getTableMetadata(TableIdentifier(viewName, Some("default"))) - assert(tableMetadata.properties.get("comment") == Option("BLABLA")) - assert(viewMetadata.properties.get("comment") == Option("no comment")) + assert(tableMetadata.comment == Option("BLABLA")) + assert(viewMetadata.comment == Option("no comment")) + // Ensure that `comment` is removed from the table property + assert(tableMetadata.properties.get("comment").isEmpty) + assert(viewMetadata.properties.get("comment").isEmpty) } } } @@ -286,7 +292,7 @@ class HiveDDLSuite sql(s"ALTER VIEW $viewName UNSET TBLPROPERTIES ('p')") }.getMessage assert(message.contains( - "Attempted to unset non-existent property 'p' in table '`view1`'")) + "Attempted to unset non-existent property 'p' in table '`default`.`view1`'")) } } } @@ -368,6 +374,44 @@ class HiveDDLSuite expectedSerdeProps) } + test("MSCK REPAIR RABLE") { + val catalog = spark.sessionState.catalog + val tableIdent = TableIdentifier("tab1") + sql("CREATE TABLE tab1 (height INT, length INT) PARTITIONED BY (a INT, b INT)") + val part1 = Map("a" -> "1", "b" -> "5") + val part2 = Map("a" -> "2", "b" -> "6") + val root = new Path(catalog.getTableMetadata(tableIdent).storage.locationUri.get) + val fs = root.getFileSystem(spark.sparkContext.hadoopConfiguration) + // valid + fs.mkdirs(new Path(new Path(root, "a=1"), "b=5")) + fs.createNewFile(new Path(new Path(root, "a=1/b=5"), "a.csv")) // file + fs.createNewFile(new Path(new Path(root, "a=1/b=5"), "_SUCCESS")) // file + fs.mkdirs(new Path(new Path(root, "A=2"), "B=6")) + fs.createNewFile(new Path(new Path(root, "A=2/B=6"), "b.csv")) // file + fs.createNewFile(new Path(new Path(root, "A=2/B=6"), "c.csv")) // file + fs.createNewFile(new Path(new Path(root, "A=2/B=6"), ".hiddenFile")) // file + fs.mkdirs(new Path(new Path(root, "A=2/B=6"), "_temporary")) + + // invalid + fs.mkdirs(new Path(new Path(root, "a"), "b")) // bad name + fs.mkdirs(new Path(new Path(root, "b=1"), "a=1")) // wrong order + fs.mkdirs(new Path(root, "a=4")) // not enough columns + fs.createNewFile(new Path(new Path(root, "a=1"), "b=4")) // file + fs.createNewFile(new Path(new Path(root, "a=1"), "_SUCCESS")) // _SUCCESS + fs.mkdirs(new Path(new Path(root, "a=1"), "_temporary")) // _temporary + fs.mkdirs(new Path(new Path(root, "a=1"), ".b=4")) // start with . + + try { + sql("MSCK REPAIR TABLE tab1") + assert(catalog.listPartitions(tableIdent).map(_.spec).toSet == + Set(part1, part2)) + assert(catalog.getPartition(tableIdent, part1).parameters("numFiles") == "1") + assert(catalog.getPartition(tableIdent, part2).parameters("numFiles") == "2") + } finally { + fs.delete(root, true) + } + } + test("drop table using drop view") { withTable("tab1") { sql("CREATE TABLE tab1(c1 int)") @@ -431,6 +475,22 @@ class HiveDDLSuite } } + test("desc table for Hive table - partitioned table") { + withTable("tbl") { + sql("CREATE TABLE tbl(a int) PARTITIONED BY (b int)") + + assert(sql("DESC tbl").collect().containsSlice( + Seq( + Row("a", "int", null), + Row("b", "int", null), + Row("# Partition Information", "", ""), + Row("# col_name", "data_type", "comment"), + Row("b", "int", null) + ) + )) + } + } + test("desc table for data source table using Hive Metastore") { assume(spark.sparkContext.conf.get(CATALOG_IMPLEMENTATION) == "hive") val tabName = "tab1" @@ -594,6 +654,248 @@ class HiveDDLSuite } } + + test("CREATE TABLE LIKE a temporary view") { + val sourceViewName = "tab1" + val targetTabName = "tab2" + withTempView(sourceViewName) { + withTable(targetTabName) { + spark.range(10).select('id as 'a, 'id as 'b, 'id as 'c, 'id as 'd) + .createTempView(sourceViewName) + sql(s"CREATE TABLE $targetTabName LIKE $sourceViewName") + + val sourceTable = spark.sessionState.catalog.getTempViewOrPermanentTableMetadata( + TableIdentifier(sourceViewName)) + val targetTable = spark.sessionState.catalog.getTableMetadata( + TableIdentifier(targetTabName, Some("default"))) + + checkCreateTableLike(sourceTable, targetTable) + } + } + } + + test("CREATE TABLE LIKE a data source table") { + val sourceTabName = "tab1" + val targetTabName = "tab2" + withTable(sourceTabName, targetTabName) { + spark.range(10).select('id as 'a, 'id as 'b, 'id as 'c, 'id as 'd) + .write.format("json").saveAsTable(sourceTabName) + sql(s"CREATE TABLE $targetTabName LIKE $sourceTabName") + + val sourceTable = + spark.sessionState.catalog.getTableMetadata(TableIdentifier(sourceTabName, Some("default"))) + val targetTable = + spark.sessionState.catalog.getTableMetadata(TableIdentifier(targetTabName, Some("default"))) + // The table type of the source table should be a Hive-managed data source table + assert(DDLUtils.isDatasourceTable(sourceTable)) + assert(sourceTable.tableType == CatalogTableType.MANAGED) + + checkCreateTableLike(sourceTable, targetTable) + } + } + + test("CREATE TABLE LIKE an external data source table") { + val sourceTabName = "tab1" + val targetTabName = "tab2" + withTable(sourceTabName, targetTabName) { + withTempPath { dir => + val path = dir.getCanonicalPath + spark.range(10).select('id as 'a, 'id as 'b, 'id as 'c, 'id as 'd) + .write.format("parquet").save(path) + sql(s"CREATE TABLE $sourceTabName USING parquet OPTIONS (PATH '$path')") + sql(s"CREATE TABLE $targetTabName LIKE $sourceTabName") + + // The source table should be an external data source table + val sourceTable = spark.sessionState.catalog.getTableMetadata( + TableIdentifier(sourceTabName, Some("default"))) + val targetTable = spark.sessionState.catalog.getTableMetadata( + TableIdentifier(targetTabName, Some("default"))) + // The table type of the source table should be an external data source table + assert(DDLUtils.isDatasourceTable(sourceTable)) + assert(sourceTable.tableType == CatalogTableType.EXTERNAL) + + checkCreateTableLike(sourceTable, targetTable) + } + } + } + + test("CREATE TABLE LIKE a managed Hive serde table") { + val catalog = spark.sessionState.catalog + val sourceTabName = "tab1" + val targetTabName = "tab2" + withTable(sourceTabName, targetTabName) { + sql(s"CREATE TABLE $sourceTabName TBLPROPERTIES('prop1'='value1') AS SELECT 1 key, 'a'") + sql(s"CREATE TABLE $targetTabName LIKE $sourceTabName") + + val sourceTable = catalog.getTableMetadata(TableIdentifier(sourceTabName, Some("default"))) + assert(sourceTable.tableType == CatalogTableType.MANAGED) + assert(sourceTable.properties.get("prop1").nonEmpty) + val targetTable = catalog.getTableMetadata(TableIdentifier(targetTabName, Some("default"))) + + checkCreateTableLike(sourceTable, targetTable) + } + } + + test("CREATE TABLE LIKE an external Hive serde table") { + val catalog = spark.sessionState.catalog + withTempDir { tmpDir => + val basePath = tmpDir.getCanonicalPath + val sourceTabName = "tab1" + val targetTabName = "tab2" + withTable(sourceTabName, targetTabName) { + assert(tmpDir.listFiles.isEmpty) + sql( + s""" + |CREATE EXTERNAL TABLE $sourceTabName (key INT comment 'test', value STRING) + |COMMENT 'Apache Spark' + |PARTITIONED BY (ds STRING, hr STRING) + |LOCATION '$basePath' + """.stripMargin) + for (ds <- Seq("2008-04-08", "2008-04-09"); hr <- Seq("11", "12")) { + sql( + s""" + |INSERT OVERWRITE TABLE $sourceTabName + |partition (ds='$ds',hr='$hr') + |SELECT 1, 'a' + """.stripMargin) + } + sql(s"CREATE TABLE $targetTabName LIKE $sourceTabName") + + val sourceTable = catalog.getTableMetadata(TableIdentifier(sourceTabName, Some("default"))) + assert(sourceTable.tableType == CatalogTableType.EXTERNAL) + assert(sourceTable.comment == Option("Apache Spark")) + val targetTable = catalog.getTableMetadata(TableIdentifier(targetTabName, Some("default"))) + + checkCreateTableLike(sourceTable, targetTable) + } + } + } + + test("CREATE TABLE LIKE a view") { + val sourceTabName = "tab1" + val sourceViewName = "view" + val targetTabName = "tab2" + withTable(sourceTabName, targetTabName) { + withView(sourceViewName) { + spark.range(10).select('id as 'a, 'id as 'b, 'id as 'c, 'id as 'd) + .write.format("json").saveAsTable(sourceTabName) + sql(s"CREATE VIEW $sourceViewName AS SELECT * FROM $sourceTabName") + sql(s"CREATE TABLE $targetTabName LIKE $sourceViewName") + + val sourceView = spark.sessionState.catalog.getTableMetadata( + TableIdentifier(sourceViewName, Some("default"))) + // The original source should be a VIEW with an empty path + assert(sourceView.tableType == CatalogTableType.VIEW) + assert(sourceView.viewText.nonEmpty && sourceView.viewOriginalText.nonEmpty) + val targetTable = spark.sessionState.catalog.getTableMetadata( + TableIdentifier(targetTabName, Some("default"))) + + checkCreateTableLike(sourceView, targetTable) + } + } + } + + private def getTablePath(table: CatalogTable): Option[String] = { + if (DDLUtils.isDatasourceTable(table)) { + new CaseInsensitiveMap(table.storage.serdeProperties).get("path") + } else { + table.storage.locationUri + } + } + + private def checkCreateTableLike(sourceTable: CatalogTable, targetTable: CatalogTable): Unit = { + // The created table should be a MANAGED table with empty view text and original text. + assert(targetTable.tableType == CatalogTableType.MANAGED, + "the created table must be a Hive managed table") + assert(targetTable.viewText.isEmpty && targetTable.viewOriginalText.isEmpty, + "the view text and original text in the created table must be empty") + assert(targetTable.comment.isEmpty, + "the comment in the created table must be empty") + assert(targetTable.unsupportedFeatures.isEmpty, + "the unsupportedFeatures in the create table must be empty") + + val metastoreGeneratedProperties = Seq( + "CreateTime", + "transient_lastDdlTime", + "grantTime", + "lastUpdateTime", + "last_modified_by", + "last_modified_time", + "Owner:", + "COLUMN_STATS_ACCURATE", + "numFiles", + "numRows", + "rawDataSize", + "totalSize", + "totalNumberFiles", + "maxFileSize", + "minFileSize" + ) + assert(targetTable.properties.filterKeys { key => + !metastoreGeneratedProperties.contains(key) && !key.startsWith(DATASOURCE_PREFIX) + }.isEmpty, + "the table properties of source tables should not be copied in the created table") + + if (DDLUtils.isDatasourceTable(sourceTable) || + sourceTable.tableType == CatalogTableType.VIEW) { + assert(DDLUtils.isDatasourceTable(targetTable), + "the target table should be a data source table") + } else { + assert(!DDLUtils.isDatasourceTable(targetTable), + "the target table should be a Hive serde table") + } + + if (sourceTable.tableType == CatalogTableType.VIEW) { + // Source table is a temporary/permanent view, which does not have a provider. The created + // target table uses the default data source format + assert(targetTable.properties(CreateDataSourceTableUtils.DATASOURCE_PROVIDER) == + spark.sessionState.conf.defaultDataSourceName) + } else if (DDLUtils.isDatasourceTable(sourceTable)) { + assert(targetTable.properties(CreateDataSourceTableUtils.DATASOURCE_PROVIDER) == + sourceTable.properties(CreateDataSourceTableUtils.DATASOURCE_PROVIDER)) + } + + val sourceTablePath = getTablePath(sourceTable) + val targetTablePath = getTablePath(targetTable) + assert(targetTablePath.nonEmpty, "target table path should not be empty") + assert(sourceTablePath != targetTablePath, + "source table/view path should be different from target table path") + + // The source table contents should not been seen in the target table. + assert(spark.table(sourceTable.identifier).count() != 0, "the source table should be nonempty") + assert(spark.table(targetTable.identifier).count() == 0, "the target table should be empty") + + // Their schema should be identical + checkAnswer( + sql(s"DESC ${sourceTable.identifier}").select("col_name", "data_type"), + sql(s"DESC ${targetTable.identifier}").select("col_name", "data_type")) + + withSQLConf("hive.exec.dynamic.partition.mode" -> "nonstrict") { + // Check whether the new table can be inserted using the data from the original table + sql(s"INSERT INTO TABLE ${targetTable.identifier} SELECT * FROM ${sourceTable.identifier}") + } + + // After insertion, the data should be identical + checkAnswer( + sql(s"SELECT * FROM ${sourceTable.identifier}"), + sql(s"SELECT * FROM ${targetTable.identifier}")) + } + + test("Analyze data source tables(LogicalRelation)") { + withTable("t1") { + withTempPath { dir => + val path = dir.getCanonicalPath + spark.range(1).write.format("parquet").save(path) + sql(s"CREATE TABLE t1 USING parquet OPTIONS (PATH '$path')") + val e = intercept[AnalysisException] { + sql("ANALYZE TABLE t1 COMPUTE STATISTICS") + }.getMessage + assert(e.contains("ANALYZE TABLE is only supported for Hive tables, " + + "but 't1' is a LogicalRelation")) + } + } + } + test("desc table for data source table") { withTable("tab1") { val tabName = "tab1" @@ -621,7 +923,7 @@ class HiveDDLSuite val desc = sql("DESC FORMATTED t1").collect().toSeq - assert(desc.contains(Row("id", "bigint", ""))) + assert(desc.contains(Row("id", "bigint", null))) } } } @@ -638,13 +940,13 @@ class HiveDDLSuite assert(formattedDesc.containsSlice( Seq( - Row("a", "bigint", ""), - Row("b", "bigint", ""), - Row("c", "bigint", ""), - Row("d", "bigint", ""), + Row("a", "bigint", null), + Row("b", "bigint", null), + Row("c", "bigint", null), + Row("d", "bigint", null), Row("# Partition Information", "", ""), - Row("# col_name", "", ""), - Row("d", "", ""), + Row("# col_name", "data_type", "comment"), + Row("d", "bigint", null), Row("", "", ""), Row("# Detailed Table Information", "", ""), Row("Database:", "default", "") diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveExplainSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveExplainSuite.scala index 98afd99a203a..ec3328c0ae2d 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveExplainSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveExplainSuite.scala @@ -77,7 +77,7 @@ class HiveExplainSuite extends QueryTest with SQLTestUtils with TestHiveSingleto "src") } - test("SPARK-6212: The EXPLAIN output of CTAS only shows the analyzed plan") { + test("SPARK-17230: The EXPLAIN output of CTAS only shows the analyzed plan") { withTempView("jt") { val rdd = sparkContext.parallelize((1 to 10).map(i => s"""{"a":$i, "b":"str$i"}""")) spark.read.json(rdd).createOrReplaceTempView("jt") @@ -98,8 +98,8 @@ class HiveExplainSuite extends QueryTest with SQLTestUtils with TestHiveSingleto } val physicalIndex = outputs.indexOf("== Physical Plan ==") - assert(!outputs.substring(physicalIndex).contains("Subquery"), - "Physical Plan should not contain Subquery since it's eliminated by optimizer") + assert(outputs.substring(physicalIndex).contains("SubqueryAlias"), + "Physical Plan should contain SubqueryAlias since the query should not be optimized") } } diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveQuerySuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveQuerySuite.scala index f8c55ec45650..183a57f22cb7 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveQuerySuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveQuerySuite.scala @@ -26,16 +26,17 @@ import scala.util.Try import org.apache.hadoop.hive.conf.HiveConf.ConfVars import org.scalatest.BeforeAndAfter -import org.apache.spark.{SparkException, SparkFiles} -import org.apache.spark.sql.{AnalysisException, DataFrame, Row} +import org.apache.spark.SparkFiles +import org.apache.spark.sql.{AnalysisException, DataFrame, Row, SparkSession} import org.apache.spark.sql.catalyst.expressions.Cast import org.apache.spark.sql.catalyst.parser.ParseException import org.apache.spark.sql.catalyst.plans.logical.Project import org.apache.spark.sql.execution.joins.BroadcastNestedLoopJoinExec import org.apache.spark.sql.hive._ -import org.apache.spark.sql.hive.test.{TestHive, TestHiveContext} +import org.apache.spark.sql.hive.test.TestHive import org.apache.spark.sql.hive.test.TestHive._ import org.apache.spark.sql.internal.SQLConf +import org.apache.spark.sql.test.SQLTestUtils case class TestData(a: Int, b: String) @@ -43,7 +44,7 @@ case class TestData(a: Int, b: String) * A set of test cases expressed in Hive QL that are not covered by the tests * included in the hive distribution. */ -class HiveQuerySuite extends HiveComparisonTest with BeforeAndAfter { +class HiveQuerySuite extends HiveComparisonTest with SQLTestUtils with BeforeAndAfter { private val originalTimeZone = TimeZone.getDefault private val originalLocale = Locale.getDefault @@ -51,6 +52,8 @@ class HiveQuerySuite extends HiveComparisonTest with BeforeAndAfter { private val originalCrossJoinEnabled = TestHive.conf.crossJoinEnabled + def spark: SparkSession = sparkSession + override def beforeAll() { super.beforeAll() TestHive.setCacheTables(true) @@ -216,15 +219,6 @@ class HiveQuerySuite extends HiveComparisonTest with BeforeAndAfter { assert(new Timestamp(1000) == r1.getTimestamp(0)) } - createQueryTest("constant array", - """ - |SELECT sort_array( - | sort_array( - | array("hadoop distributed file system", - | "enterprise databases", "hadoop map-reduce"))) - |FROM src LIMIT 1; - """.stripMargin) - createQueryTest("null case", "SELECT case when(true) then 1 else null end FROM src LIMIT 1") @@ -834,8 +828,8 @@ class HiveQuerySuite extends HiveComparisonTest with BeforeAndAfter { assertResult( Array( - Row("a", "int", ""), - Row("b", "string", "")) + Row("a", "int", null), + Row("b", "string", null)) ) { sql("DESCRIBE test_describe_commands2") .select('col_name, 'data_type, 'comment) @@ -1212,6 +1206,27 @@ class HiveQuerySuite extends HiveComparisonTest with BeforeAndAfter { } assertUnsupportedFeature { sql("DROP TEMPORARY MACRO SIGMOID") } } + + test("dynamic partitioning is allowed when hive.exec.dynamic.partition.mode is nonstrict") { + val modeConfKey = "hive.exec.dynamic.partition.mode" + withTable("with_parts") { + sql("CREATE TABLE with_parts(key INT) PARTITIONED BY (p INT)") + + withSQLConf(modeConfKey -> "nonstrict") { + sql("INSERT OVERWRITE TABLE with_parts partition(p) select 1, 2") + assert(spark.table("with_parts").filter($"p" === 2).collect().head == Row(1, 2)) + } + + val originalValue = spark.sparkContext.hadoopConfiguration.get(modeConfKey, "nonstrict") + try { + spark.sparkContext.hadoopConfiguration.set(modeConfKey, "nonstrict") + sql("INSERT OVERWRITE TABLE with_parts partition(p) select 3, 4") + assert(spark.table("with_parts").filter($"p" === 4).collect().head == Row(3, 4)) + } finally { + spark.sparkContext.hadoopConfiguration.set(modeConfKey, originalValue) + } + } + } } // for SPARK-2180 test diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala index c4d9e0aee911..51d49469b3f4 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala @@ -642,19 +642,35 @@ class SQLQuerySuite extends QueryTest with SQLTestUtils with TestHiveSingleton { } test("specifying the column list for CTAS") { - Seq((1, "111111"), (2, "222222")).toDF("key", "value").createOrReplaceTempView("mytable1") + withTempView("mytable1") { + Seq((1, "111111"), (2, "222222")).toDF("key", "value").createOrReplaceTempView("mytable1") + withTable("gen__tmp") { + sql("create table gen__tmp as select key as a, value as b from mytable1") + checkAnswer( + sql("SELECT a, b from gen__tmp"), + sql("select key, value from mytable1").collect()) + } - sql("create table gen__tmp as select key as a, value as b from mytable1") - checkAnswer( - sql("SELECT a, b from gen__tmp"), - sql("select key, value from mytable1").collect()) - sql("DROP TABLE gen__tmp") + withTable("gen__tmp") { + val e = intercept[AnalysisException] { + sql("create table gen__tmp(a int, b string) as select key, value from mytable1") + }.getMessage + assert(e.contains("Schema may not be specified in a Create Table As Select (CTAS)")) + } - intercept[AnalysisException] { - sql("create table gen__tmp(a int, b string) as select key, value from mytable1") + withTable("gen__tmp") { + val e = intercept[AnalysisException] { + sql( + """ + |CREATE TABLE gen__tmp + |PARTITIONED BY (key string) + |AS SELECT key, value FROM mytable1 + """.stripMargin) + }.getMessage + assert(e.contains("A Create Table As Select (CTAS) statement is not allowed to " + + "create a partitioned table using Hive's file formats")) + } } - - sql("drop table mytable1") } test("command substitution") { @@ -1684,6 +1700,27 @@ class SQLQuerySuite extends QueryTest with SQLTestUtils with TestHiveSingleton { } } + test("SPARK-17354: Partitioning by dates/timestamps works with Parquet vectorized reader") { + withSQLConf(SQLConf.PARQUET_VECTORIZED_READER_ENABLED.key -> "true") { + sql( + """CREATE TABLE order(id INT) + |PARTITIONED BY (pd DATE, pt TIMESTAMP) + |STORED AS PARQUET + """.stripMargin) + + sql("set hive.exec.dynamic.partition.mode=nonstrict") + sql( + """INSERT INTO TABLE order PARTITION(pd, pt) + |SELECT 1 AS id, CAST('1990-02-24' AS DATE) AS pd, CAST('1990-02-24' AS TIMESTAMP) AS pt + """.stripMargin) + val actual = sql("SELECT * FROM order") + val expected = sql( + "SELECT 1 AS id, CAST('1990-02-24' AS DATE) AS pd, CAST('1990-02-24' AS TIMESTAMP) AS pt") + checkAnswer(actual, expected) + sql("DROP TABLE order") + } + } + def testCommandAvailable(command: String): Boolean = { Try(Process(command) !!).isSuccess } diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLViewSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLViewSuite.scala index 39846f145c42..490fea27de2f 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLViewSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLViewSuite.scala @@ -18,6 +18,8 @@ package org.apache.spark.sql.hive.execution import org.apache.spark.sql.{AnalysisException, QueryTest, Row, SaveMode} +import org.apache.spark.sql.catalyst.TableIdentifier +import org.apache.spark.sql.catalyst.analysis.NoSuchTableException import org.apache.spark.sql.hive.test.TestHiveSingleton import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.test.SQLTestUtils @@ -55,6 +57,56 @@ class SQLViewSuite extends QueryTest with SQLTestUtils with TestHiveSingleton { } } + test("Issue exceptions for ALTER VIEW on the temporary view") { + val viewName = "testView" + withTempView(viewName) { + spark.range(10).createTempView(viewName) + assertNoSuchTable(s"ALTER VIEW $viewName SET TBLPROPERTIES ('p' = 'an')") + assertNoSuchTable(s"ALTER VIEW $viewName UNSET TBLPROPERTIES ('p')") + } + } + + test("Issue exceptions for ALTER TABLE on the temporary view") { + val viewName = "testView" + withTempView(viewName) { + spark.range(10).createTempView(viewName) + assertNoSuchTable(s"ALTER TABLE $viewName SET SERDE 'whatever'") + assertNoSuchTable(s"ALTER TABLE $viewName PARTITION (a=1, b=2) SET SERDE 'whatever'") + assertNoSuchTable(s"ALTER TABLE $viewName SET SERDEPROPERTIES ('p' = 'an')") + assertNoSuchTable(s"ALTER TABLE $viewName SET LOCATION '/path/to/your/lovely/heart'") + assertNoSuchTable(s"ALTER TABLE $viewName PARTITION (a='4') SET LOCATION '/path/to/home'") + assertNoSuchTable(s"ALTER TABLE $viewName ADD IF NOT EXISTS PARTITION (a='4', b='8')") + assertNoSuchTable(s"ALTER TABLE $viewName DROP PARTITION (a='4', b='8')") + assertNoSuchTable(s"ALTER TABLE $viewName PARTITION (a='4') RENAME TO PARTITION (a='5')") + assertNoSuchTable(s"ALTER TABLE $viewName RECOVER PARTITIONS") + } + } + + test("Issue exceptions for other table DDL on the temporary view") { + val viewName = "testView" + withTempView(viewName) { + spark.range(10).createTempView(viewName) + + val e = intercept[AnalysisException] { + sql(s"INSERT INTO TABLE $viewName SELECT 1") + }.getMessage + assert(e.contains("Inserting into an RDD-based table is not allowed")) + + val testData = hiveContext.getHiveFile("data/files/employee.dat").getCanonicalPath + assertNoSuchTable(s"""LOAD DATA LOCAL INPATH "$testData" INTO TABLE $viewName""") + assertNoSuchTable(s"TRUNCATE TABLE $viewName") + assertNoSuchTable(s"SHOW CREATE TABLE $viewName") + assertNoSuchTable(s"SHOW PARTITIONS $viewName") + assertNoSuchTable(s"ANALYZE TABLE $viewName COMPUTE STATISTICS") + } + } + + private def assertNoSuchTable(query: String): Unit = { + intercept[NoSuchTableException] { + sql(query) + } + } + test("error handling: fail if the view sql itself is invalid") { // A table that does not exist intercept[AnalysisException] { @@ -174,6 +226,70 @@ class SQLViewSuite extends QueryTest with SQLTestUtils with TestHiveSingleton { } } + test("should not allow ALTER VIEW AS when the view does not exist") { + assertNoSuchTable("ALTER VIEW testView AS SELECT 1, 2") + assertNoSuchTable("ALTER VIEW default.testView AS SELECT 1, 2") + } + + test("ALTER VIEW AS should try to alter temp view first if view name has no database part") { + withView("test_view") { + withTempView("test_view") { + sql("CREATE VIEW test_view AS SELECT 1 AS a, 2 AS b") + sql("CREATE TEMP VIEW test_view AS SELECT 1 AS a, 2 AS b") + + sql("ALTER VIEW test_view AS SELECT 3 AS i, 4 AS j") + + // The temporary view should be updated. + checkAnswer(spark.table("test_view"), Row(3, 4)) + + // The permanent view should stay same. + checkAnswer(spark.table("default.test_view"), Row(1, 2)) + } + } + } + + test("ALTER VIEW AS should alter permanent view if view name has database part") { + withView("test_view") { + withTempView("test_view") { + sql("CREATE VIEW test_view AS SELECT 1 AS a, 2 AS b") + sql("CREATE TEMP VIEW test_view AS SELECT 1 AS a, 2 AS b") + + sql("ALTER VIEW default.test_view AS SELECT 3 AS i, 4 AS j") + + // The temporary view should stay same. + checkAnswer(spark.table("test_view"), Row(1, 2)) + + // The permanent view should be updated. + checkAnswer(spark.table("default.test_view"), Row(3, 4)) + } + } + } + + test("ALTER VIEW AS should keep the previous table properties, comment, create_time, etc.") { + withView("test_view") { + sql( + """ + |CREATE VIEW test_view + |COMMENT 'test' + |TBLPROPERTIES ('key' = 'a') + |AS SELECT 1 AS a, 2 AS b + """.stripMargin) + + val catalog = spark.sessionState.catalog + val viewMeta = catalog.getTableMetadata(TableIdentifier("test_view")) + assert(viewMeta.comment == Some("test")) + assert(viewMeta.properties("key") == "a") + + sql("ALTER VIEW test_view AS SELECT 3 AS i, 4 AS j") + val updatedViewMeta = catalog.getTableMetadata(TableIdentifier("test_view")) + assert(updatedViewMeta.comment == Some("test")) + assert(updatedViewMeta.properties("key") == "a") + assert(updatedViewMeta.createTime == viewMeta.createTime) + // The view should be updated. + checkAnswer(spark.table("test_view"), Row(3, 4)) + } + } + Seq(true, false).foreach { enabled => val prefix = (if (enabled) "With" else "Without") + " canonical native view: " test(s"$prefix correctly handle CREATE OR REPLACE VIEW") { diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/OrcQuerySuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/OrcQuerySuite.scala index 46595ee8186a..8f8768217788 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/OrcQuerySuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/OrcQuerySuite.scala @@ -160,6 +160,29 @@ class OrcQuerySuite extends QueryTest with BeforeAndAfterAll with OrcTest { } } + test("SPARK-16610: Respect orc.compress option when compression is unset") { + // Respect `orc.compress`. + withTempPath { file => + spark.range(0, 10).write + .option("orc.compress", "ZLIB") + .orc(file.getCanonicalPath) + val expectedCompressionKind = + OrcFileOperator.getFileReader(file.getCanonicalPath).get.getCompression + assert("ZLIB" === expectedCompressionKind.name()) + } + + // `compression` overrides `orc.compress`. + withTempPath { file => + spark.range(0, 10).write + .option("compression", "ZLIB") + .option("orc.compress", "SNAPPY") + .orc(file.getCanonicalPath) + val expectedCompressionKind = + OrcFileOperator.getFileReader(file.getCanonicalPath).get.getCompression + assert("ZLIB" === expectedCompressionKind.name()) + } + } + // Hive supports zlib, snappy and none for Hive 1.2.1. test("Compression options for writing to an ORC file (SNAPPY, ZLIB and NONE)") { withTempPath { file => diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/parquetSuites.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/parquetSuites.scala index 31b6197d56fc..e92bbdea75a7 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/parquetSuites.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/parquetSuites.scala @@ -589,6 +589,13 @@ class ParquetMetastoreSuite extends ParquetPartitioningTest { } } } + + test("self-join") { + val table = spark.table("normal_parquet") + val selfJoin = table.as("t1").join(table.as("t2")) + checkAnswer(selfJoin, + sql("SELECT * FROM normal_parquet x JOIN normal_parquet y")) + } } /** diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/sources/HadoopFsRelationTest.scala b/sql/hive/src/test/scala/org/apache/spark/sql/sources/HadoopFsRelationTest.scala index 047b08c4ccf6..97f2b23823d7 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/sources/HadoopFsRelationTest.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/sources/HadoopFsRelationTest.scala @@ -337,9 +337,8 @@ abstract class HadoopFsRelationTest extends QueryTest with SQLTestUtils with Tes } test("saveAsTable()/load() - non-partitioned table - ErrorIfExists") { - Seq.empty[(Int, String)].toDF().createOrReplaceTempView("t") - - withTempView("t") { + withTable("t") { + sql("CREATE TABLE t(i INT) USING parquet") intercept[AnalysisException] { testDF.write.format(dataSourceName).mode(SaveMode.ErrorIfExists).saveAsTable("t") } @@ -347,9 +346,8 @@ abstract class HadoopFsRelationTest extends QueryTest with SQLTestUtils with Tes } test("saveAsTable()/load() - non-partitioned table - Ignore") { - Seq.empty[(Int, String)].toDF().createOrReplaceTempView("t") - - withTempView("t") { + withTable("t") { + sql("CREATE TABLE t(i INT) USING parquet") testDF.write.format(dataSourceName).mode(SaveMode.Ignore).saveAsTable("t") assert(spark.table("t").collect().isEmpty) } diff --git a/streaming/build.gradle b/streaming/build.gradle new file mode 100644 index 000000000000..bcfeecb82f8f --- /dev/null +++ b/streaming/build.gradle @@ -0,0 +1,44 @@ +/* + * Copyright (c) 2016 SnappyData, Inc. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you + * may not use this file except in compliance with the License. You + * may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or + * implied. See the License for the specific language governing + * permissions and limitations under the License. See accompanying + * LICENSE file. + */ + +description = 'Spark Project Streaming' + +dependencies { + compile project(subprojectBase + 'snappy-spark-core_' + scalaBinaryVersion) + compile project(subprojectBase + 'snappy-spark-tags_' + scalaBinaryVersion) + + compile group: 'com.google.guava', name: 'guava', version: guavaVersion + compile group: 'org.eclipse.jetty', name: 'jetty-server', version: jettyVersion + compile group: 'org.eclipse.jetty', name: 'jetty-plus', version: jettyVersion + compile group: 'org.eclipse.jetty', name: 'jetty-util', version: jettyVersion + compile group: 'org.eclipse.jetty', name: 'jetty-http', version: jettyVersion + compile group: 'org.eclipse.jetty', name: 'jetty-servlet', version: jettyVersion + compile group: 'org.eclipse.jetty', name: 'jetty-servlets', version: jettyVersion + + testCompile project(path: subprojectBase + 'snappy-spark-core_' + scalaBinaryVersion, configuration: 'testOutput') + testCompile(group: 'org.seleniumhq.selenium', name: 'selenium-java', version: seleniumVersion) { + exclude(group: 'com.google.guava', module: 'guava') + exclude(group: 'io.netty', module: 'netty') + } + testCompile(group: 'org.seleniumhq.selenium', name: 'selenium-htmlunit-driver', version: seleniumVersion) { + exclude(group: 'com.google.guava', module: 'guava') + } +} + +// fix scala+java mix to use scala first for tests +sourceSets.test.scala.srcDir 'src/test/java' +sourceSets.test.java.srcDirs = [] diff --git a/streaming/pom.xml b/streaming/pom.xml index e7614fe66110..4e0f5b1013f9 100644 --- a/streaming/pom.xml +++ b/streaming/pom.xml @@ -21,7 +21,7 @@ org.apache.spark spark-parent_2.11 - 2.0.1-SNAPSHOT + 2.0.1 ../pom.xml diff --git a/streaming/src/main/scala/org/apache/spark/streaming/StreamingSource.scala b/streaming/src/main/scala/org/apache/spark/streaming/StreamingSource.scala index 9697437dd2fe..0b306a28d1a5 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/StreamingSource.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/StreamingSource.scala @@ -87,11 +87,11 @@ private[streaming] class StreamingSource(ssc: StreamingContext) extends Source { // Gauge for last received batch, useful for monitoring the streaming job's running status, // displayed data -1 for any abnormal condition. registerGaugeWithOption("lastReceivedBatch_submissionTime", - _.lastCompletedBatch.map(_.submissionTime), -1L) + _.lastReceivedBatch.map(_.submissionTime), -1L) registerGaugeWithOption("lastReceivedBatch_processingStartTime", - _.lastCompletedBatch.flatMap(_.processingStartTime), -1L) + _.lastReceivedBatch.flatMap(_.processingStartTime), -1L) registerGaugeWithOption("lastReceivedBatch_processingEndTime", - _.lastCompletedBatch.flatMap(_.processingEndTime), -1L) + _.lastReceivedBatch.flatMap(_.processingEndTime), -1L) // Gauge for last received batch records. registerGauge("lastReceivedBatch_records", _.lastReceivedBatchRecords.values.sum, 0L) diff --git a/streaming/src/main/scala/org/apache/spark/streaming/api/python/PythonDStream.scala b/streaming/src/main/scala/org/apache/spark/streaming/api/python/PythonDStream.scala index aeff4d7a98e7..46bfc6085645 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/api/python/PythonDStream.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/api/python/PythonDStream.scala @@ -24,11 +24,14 @@ import java.util.{ArrayList => JArrayList, List => JList} import scala.collection.JavaConverters._ import scala.language.existentials +import py4j.Py4JException + import org.apache.spark.SparkException import org.apache.spark.api.java._ +import org.apache.spark.internal.Logging import org.apache.spark.rdd.RDD import org.apache.spark.storage.StorageLevel -import org.apache.spark.streaming.{Duration, Interval, Time} +import org.apache.spark.streaming.{Duration, Interval, StreamingContext, Time} import org.apache.spark.streaming.api.java._ import org.apache.spark.streaming.dstream._ import org.apache.spark.util.Utils @@ -157,7 +160,7 @@ private[python] object PythonTransformFunctionSerializer { /** * Helper functions, which are called from Python via Py4J. */ -private[python] object PythonDStream { +private[streaming] object PythonDStream { /** * can not access PythonTransformFunctionSerializer.register() via Py4j @@ -184,6 +187,32 @@ private[python] object PythonDStream { rdds.asScala.foreach(queue.add) queue } + + /** + * Stop [[StreamingContext]] if the Python process crashes (E.g., OOM) in case the user cannot + * stop it in the Python side. + */ + def stopStreamingContextIfPythonProcessIsDead(e: Throwable): Unit = { + // These two special messages are from: + // scalastyle:off + // https://github.com/bartdag/py4j/blob/5cbb15a21f857e8cf334ce5f675f5543472f72eb/py4j-java/src/main/java/py4j/CallbackClient.java#L218 + // https://github.com/bartdag/py4j/blob/5cbb15a21f857e8cf334ce5f675f5543472f72eb/py4j-java/src/main/java/py4j/CallbackClient.java#L340 + // scalastyle:on + if (e.isInstanceOf[Py4JException] && + ("Cannot obtain a new communication channel" == e.getMessage || + "Error while obtaining a new communication channel" == e.getMessage)) { + // Start a new thread to stop StreamingContext to avoid deadlock. + new Thread("Stop-StreamingContext") with Logging { + setDaemon(true) + + override def run(): Unit = { + logError( + "Cannot connect to Python process. It's probably dead. Stopping StreamingContext.", e) + StreamingContext.getActive().foreach(_.stop(stopSparkContext = false)) + } + }.start() + } + } } /** diff --git a/streaming/src/main/scala/org/apache/spark/streaming/dstream/DStream.scala b/streaming/src/main/scala/org/apache/spark/streaming/dstream/DStream.scala index fa15a0bf65ab..f3080c730fe0 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/dstream/DStream.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/dstream/DStream.scala @@ -14,13 +14,31 @@ * See the License for the specific language governing permissions and * limitations under the License. */ +/* + * Changes for SnappyData data platform. + * + * Portions Copyright (c) 2016 SnappyData, Inc. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you + * may not use this file except in compliance with the License. You + * may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or + * implied. See the License for the specific language governing + * permissions and limitations under the License. See accompanying + * LICENSE file. + */ package org.apache.spark.streaming.dstream import java.io.{IOException, ObjectInputStream, ObjectOutputStream} +import java.util.concurrent.ConcurrentHashMap -import scala.collection.mutable.HashMap import scala.language.implicitConversions import scala.reflect.ClassTag import scala.util.matching.Regex @@ -81,9 +99,17 @@ abstract class DStream[T: ClassTag] ( // Methods and fields available on all DStreams // ======================================================================= + import scala.collection.JavaConverters._ // RDDs generated, marked as private[streaming] so that testsuites can access it @transient - private[streaming] var generatedRDDs = new HashMap[Time, RDD[T]]() + // private[streaming] var generatedRDDs = new HashMap[Time, RDD[T]]() + private[streaming] var generatedRDDs: scala.collection.mutable.Map[Time, RDD[T]] = _ + + initGeneratedRDDs() + + def initGeneratedRDDs(): Unit = { + generatedRDDs = new ConcurrentHashMap[Time, RDD[T]]().asScala + } // Time zero for the DStream private[streaming] var zeroTime: Time = null @@ -189,6 +215,18 @@ abstract class DStream[T: ClassTag] ( * its parent DStreams. */ private[streaming] def initialize(time: Time) { + initialize(time, skipInitialized = false) + } + + /** + * Initialize the DStream by setting the "zero" time, based on which + * the validity of future times is calculated. This method also recursively initializes + * its parent DStreams. + */ + private[streaming] def initialize(time: Time, skipInitialized: Boolean) { + if (skipInitialized && isInitialized) { + return + } if (zeroTime != null && zeroTime != time) { throw new SparkException(s"ZeroTime is already initialized to $zeroTime" + s", cannot initialize it again to $time") @@ -212,7 +250,7 @@ abstract class DStream[T: ClassTag] ( } // Initialize the dependencies - dependencies.foreach(_.initialize(zeroTime)) + dependencies.foreach(_.initialize(zeroTime, skipInitialized)) } private def validateAtInit(): Unit = { @@ -220,9 +258,11 @@ abstract class DStream[T: ClassTag] ( case StreamingContextState.INITIALIZED => // good to go case StreamingContextState.ACTIVE => + /* throw new IllegalStateException( "Adding new inputs, transformations, and output operations after " + "starting a context is not supported") + */ case StreamingContextState.STOPPED => throw new IllegalStateException( "Adding new inputs, transformations, and output operations after " + @@ -534,7 +574,8 @@ abstract class DStream[T: ClassTag] ( private def readObject(ois: ObjectInputStream): Unit = Utils.tryOrIOException { logDebug(s"${this.getClass().getSimpleName}.readObject used") ois.defaultReadObject() - generatedRDDs = new HashMap[Time, RDD[T]]() + // generatedRDDs = new HashMap[Time, RDD[T]]() + initGeneratedRDDs() } // ======================================================================= @@ -650,8 +691,12 @@ abstract class DStream[T: ClassTag] ( private def foreachRDD( foreachFunc: (RDD[T], Time) => Unit, displayInnerRDDOps: Boolean): Unit = { - new ForEachDStream(this, - context.sparkContext.clean(foreachFunc, false), displayInnerRDDOps).register() + val dStream = new ForEachDStream(this, + context.sparkContext.clean(foreachFunc, false), displayInnerRDDOps) + if (ssc.getState() == StreamingContextState.ACTIVE) { + dStream.initialize(ssc.graph.zeroTime, skipInitialized = true) + } + dStream.register() } /** diff --git a/streaming/src/main/scala/org/apache/spark/streaming/dstream/FileInputDStream.scala b/streaming/src/main/scala/org/apache/spark/streaming/dstream/FileInputDStream.scala index ed9305875cb7..bf4be782bc00 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/dstream/FileInputDStream.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/dstream/FileInputDStream.scala @@ -14,6 +14,24 @@ * See the License for the specific language governing permissions and * limitations under the License. */ +/* + * Changes for SnappyData data platform. + * + * Portions Copyright (c) 2016 SnappyData, Inc. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you + * may not use this file except in compliance with the License. You + * may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or + * implied. See the License for the specific language governing + * permissions and limitations under the License. See accompanying + * LICENSE file. + */ package org.apache.spark.streaming.dstream @@ -315,7 +333,8 @@ class FileInputDStream[K, V, F <: NewInputFormat[K, V]]( private def readObject(ois: ObjectInputStream): Unit = Utils.tryOrIOException { logDebug(this.getClass().getSimpleName + ".readObject used") ois.defaultReadObject() - generatedRDDs = new mutable.HashMap[Time, RDD[(K, V)]]() + // generatedRDDs = new mutable.HashMap[Time, RDD[(K, V)]]() + initGeneratedRDDs() batchTimeToSelectedFiles = new mutable.HashMap[Time, Array[String]] recentlySelectedFiles = new mutable.HashSet[String]() fileToModTime = new TimeStampedHashMap[String, Long](true) diff --git a/streaming/src/main/scala/org/apache/spark/streaming/rdd/WriteAheadLogBackedBlockRDD.scala b/streaming/src/main/scala/org/apache/spark/streaming/rdd/WriteAheadLogBackedBlockRDD.scala index 53fccd8d5e6e..c68dbb72ab67 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/rdd/WriteAheadLogBackedBlockRDD.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/rdd/WriteAheadLogBackedBlockRDD.scala @@ -14,11 +14,28 @@ * See the License for the specific language governing permissions and * limitations under the License. */ +/* + * Changes for SnappyData data platform. + * + * Portions Copyright (c) 2016 SnappyData, Inc. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you + * may not use this file except in compliance with the License. You + * may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or + * implied. See the License for the specific language governing + * permissions and limitations under the License. See accompanying + * LICENSE file. + */ + package org.apache.spark.streaming.rdd -import java.io.File import java.nio.ByteBuffer -import java.util.UUID import scala.reflect.ClassTag import scala.util.control.NonFatal @@ -27,7 +44,7 @@ import org.apache.spark._ import org.apache.spark.rdd.BlockRDD import org.apache.spark.storage.{BlockId, StorageLevel} import org.apache.spark.streaming.util._ -import org.apache.spark.util.SerializableConfiguration +import org.apache.spark.util.{SerializableConfiguration, Utils} import org.apache.spark.util.io.ChunkedByteBuffer /** @@ -120,7 +137,7 @@ class WriteAheadLogBackedBlockRDD[T: ClassTag]( val blockId = partition.blockId def getBlockFromBlockManager(): Option[Iterator[T]] = { - blockManager.get(blockId).map(_.data.asInstanceOf[Iterator[T]]) + blockManager.get[T](blockId).map(_.data.asInstanceOf[Iterator[T]]) } def getBlockFromWriteAheadLog(): Iterator[T] = { @@ -135,8 +152,8 @@ class WriteAheadLogBackedBlockRDD[T: ClassTag]( // FileBasedWriteAheadLog will not create any file or directory at that path. Also, // this dummy directory should not already exist otherwise the WAL will try to recover // past events from the directory and throw errors. - val nonExistentDirectory = new File( - System.getProperty("java.io.tmpdir"), UUID.randomUUID().toString).getAbsolutePath + val nonExistentDirectory = Utils.tempFileWith( + System.getProperty("java.io.tmpdir"), prefix = null).getAbsolutePath writeAheadLog = WriteAheadLogUtils.createLogForReceiver( SparkEnv.get.conf, nonExistentDirectory, hadoopConf) dataRead = writeAheadLog.read(partition.walRecordHandle) @@ -163,7 +180,8 @@ class WriteAheadLogBackedBlockRDD[T: ClassTag]( dataRead.rewind() } serializerManager - .dataDeserializeStream(blockId, new ChunkedByteBuffer(dataRead).toInputStream()) + .dataDeserializeStream( + blockId, new ChunkedByteBuffer(dataRead).toInputStream())(elementClassTag) .asInstanceOf[Iterator[T]] } diff --git a/streaming/src/main/scala/org/apache/spark/streaming/scheduler/JobGenerator.scala b/streaming/src/main/scala/org/apache/spark/streaming/scheduler/JobGenerator.scala index 19c88f1ee011..4489a5334d17 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/scheduler/JobGenerator.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/scheduler/JobGenerator.scala @@ -22,6 +22,7 @@ import scala.util.{Failure, Success, Try} import org.apache.spark.internal.Logging import org.apache.spark.rdd.RDD import org.apache.spark.streaming.{Checkpoint, CheckpointWriter, Time} +import org.apache.spark.streaming.api.python.PythonDStream import org.apache.spark.streaming.util.RecurringTimer import org.apache.spark.util.{Clock, EventLoop, ManualClock, Utils} @@ -252,6 +253,7 @@ class JobGenerator(jobScheduler: JobScheduler) extends Logging { jobScheduler.submitJobSet(JobSet(time, jobs, streamIdToInputInfos)) case Failure(e) => jobScheduler.reportError("Error generating jobs for time " + time, e) + PythonDStream.stopStreamingContextIfPythonProcessIsDead(e) } eventLoop.post(DoCheckpoint(time, clearCheckpointDataLater = false)) } diff --git a/streaming/src/main/scala/org/apache/spark/streaming/scheduler/JobScheduler.scala b/streaming/src/main/scala/org/apache/spark/streaming/scheduler/JobScheduler.scala index 79d6254eb372..f5ba5edad9e2 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/scheduler/JobScheduler.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/scheduler/JobScheduler.scala @@ -27,6 +27,7 @@ import org.apache.commons.lang3.SerializationUtils import org.apache.spark.internal.Logging import org.apache.spark.rdd.{PairRDDFunctions, RDD} import org.apache.spark.streaming._ +import org.apache.spark.streaming.api.python.PythonDStream import org.apache.spark.streaming.ui.UIUtils import org.apache.spark.util.{EventLoop, ThreadUtils} @@ -210,6 +211,7 @@ class JobScheduler(val ssc: StreamingContext) extends Logging { private def handleError(msg: String, e: Throwable) { logError(msg, e) ssc.waiter.notifyError(e) + PythonDStream.stopStreamingContextIfPythonProcessIsDead(e) } private class JobHandler(job: Job) extends Runnable with Logging { diff --git a/streaming/src/main/scala/org/apache/spark/streaming/ui/StreamingJobProgressListener.scala b/streaming/src/main/scala/org/apache/spark/streaming/ui/StreamingJobProgressListener.scala index c086df47d983..61f852a0d31a 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/ui/StreamingJobProgressListener.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/ui/StreamingJobProgressListener.scala @@ -259,7 +259,7 @@ private[streaming] class StreamingJobProgressListener(ssc: StreamingContext) // We use an Iterable rather than explicitly converting to a seq so that updates // will propagate val outputOpIdToSparkJobIds: Iterable[OutputOpIdAndSparkJobId] = - Option(batchTimeToOutputOpIdSparkJobIdPair.get(batchTime).asScala) + Option(batchTimeToOutputOpIdSparkJobIdPair.get(batchTime)).map(_.asScala) .getOrElse(Seq.empty) _batchUIData.outputOpIdSparkJobIdPairs = outputOpIdToSparkJobIds } diff --git a/streaming/src/test/scala/org/apache/spark/streaming/ReceivedBlockHandlerSuite.scala b/streaming/src/test/scala/org/apache/spark/streaming/ReceivedBlockHandlerSuite.scala index e97427991bf9..7e665454a540 100644 --- a/streaming/src/test/scala/org/apache/spark/streaming/ReceivedBlockHandlerSuite.scala +++ b/streaming/src/test/scala/org/apache/spark/streaming/ReceivedBlockHandlerSuite.scala @@ -23,6 +23,7 @@ import java.nio.ByteBuffer import scala.collection.mutable.ArrayBuffer import scala.concurrent.duration._ import scala.language.postfixOps +import scala.reflect.ClassTag import org.apache.hadoop.conf.Configuration import org.scalatest.{BeforeAndAfter, Matchers} @@ -47,6 +48,7 @@ class ReceivedBlockHandlerSuite extends SparkFunSuite with BeforeAndAfter with Matchers + with LocalSparkContext with Logging { import WriteAheadLogBasedBlockHandler._ @@ -77,8 +79,10 @@ class ReceivedBlockHandlerSuite rpcEnv = RpcEnv.create("test", "localhost", 0, conf, securityMgr) conf.set("spark.driver.port", rpcEnv.address.port.toString) + sc = new SparkContext("local", "test", conf) blockManagerMaster = new BlockManagerMaster(rpcEnv.setupEndpoint("blockmanager", - new BlockManagerMasterEndpoint(rpcEnv, true, conf, new LiveListenerBus)), conf, true) + new BlockManagerMasterEndpoint(rpcEnv, true, conf, + new LiveListenerBus(sc))), conf, true) storageLevel = StorageLevel.MEMORY_ONLY_SER blockManager = createBlockManager(blockManagerSize, conf) @@ -160,7 +164,7 @@ class ReceivedBlockHandlerSuite val bytes = reader.read(fileSegment) reader.close() serializerManager.dataDeserializeStream( - generateBlockId(), new ChunkedByteBuffer(bytes).toInputStream()).toList + generateBlockId(), new ChunkedByteBuffer(bytes).toInputStream())(ClassTag.Any).toList } loggedData shouldEqual data } diff --git a/streaming/src/test/scala/org/apache/spark/streaming/ui/StreamingJobProgressListenerSuite.scala b/streaming/src/test/scala/org/apache/spark/streaming/ui/StreamingJobProgressListenerSuite.scala index 26b757cc2d53..46ab3ac8de3d 100644 --- a/streaming/src/test/scala/org/apache/spark/streaming/ui/StreamingJobProgressListenerSuite.scala +++ b/streaming/src/test/scala/org/apache/spark/streaming/ui/StreamingJobProgressListenerSuite.scala @@ -68,6 +68,7 @@ class StreamingJobProgressListenerSuite extends TestSuiteBase with Matchers { listener.waitingBatches should be (List(BatchUIData(batchInfoSubmitted))) listener.runningBatches should be (Nil) listener.retainedCompletedBatches should be (Nil) + listener.lastReceivedBatch should be (Some(BatchUIData(batchInfoSubmitted))) listener.lastCompletedBatch should be (None) listener.numUnprocessedBatches should be (1) listener.numTotalCompletedBatches should be (0) @@ -81,6 +82,7 @@ class StreamingJobProgressListenerSuite extends TestSuiteBase with Matchers { listener.waitingBatches should be (Nil) listener.runningBatches should be (List(BatchUIData(batchInfoStarted))) listener.retainedCompletedBatches should be (Nil) + listener.lastReceivedBatch should be (Some(BatchUIData(batchInfoStarted))) listener.lastCompletedBatch should be (None) listener.numUnprocessedBatches should be (1) listener.numTotalCompletedBatches should be (0) @@ -123,6 +125,7 @@ class StreamingJobProgressListenerSuite extends TestSuiteBase with Matchers { listener.waitingBatches should be (Nil) listener.runningBatches should be (Nil) listener.retainedCompletedBatches should be (List(BatchUIData(batchInfoCompleted))) + listener.lastReceivedBatch should be (Some(BatchUIData(batchInfoCompleted))) listener.lastCompletedBatch should be (Some(BatchUIData(batchInfoCompleted))) listener.numUnprocessedBatches should be (0) listener.numTotalCompletedBatches should be (1) diff --git a/tools/build.gradle b/tools/build.gradle new file mode 100644 index 000000000000..05b48719a0d3 --- /dev/null +++ b/tools/build.gradle @@ -0,0 +1,25 @@ +/* + * Copyright (c) 2016 SnappyData, Inc. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you + * may not use this file except in compliance with the License. You + * may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or + * implied. See the License for the specific language governing + * permissions and limitations under the License. See accompanying + * LICENSE file. + */ + +description = 'Spark Project Tools' + +dependencies { + compile group: 'org.scala-lang', name: 'scala-compiler', version: scalaVersion + compile group: 'org.clapper', name: 'classutil_' + scalaBinaryVersion, version: '1.0.12' +} + +// TODO: anything special required for deploy, install and source plugins in maven? diff --git a/tools/pom.xml b/tools/pom.xml index 3f4cce1ca354..ddc1b091fec8 100644 --- a/tools/pom.xml +++ b/tools/pom.xml @@ -20,7 +20,7 @@ org.apache.spark spark-parent_2.11 - 2.0.1-SNAPSHOT + 2.0.1 ../pom.xml diff --git a/yarn/build.gradle b/yarn/build.gradle new file mode 100644 index 000000000000..5fb389cba38c --- /dev/null +++ b/yarn/build.gradle @@ -0,0 +1,138 @@ +/* + * Copyright (c) 2016 SnappyData, Inc. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you + * may not use this file except in compliance with the License. You + * may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or + * implied. See the License for the specific language governing + * permissions and limitations under the License. See accompanying + * LICENSE file. + */ + +description = 'Spark Project YARN' + +dependencies { + compile project(subprojectBase + 'snappy-spark-core_' + scalaBinaryVersion) + compile project(subprojectBase + 'snappy-spark-tags_' + scalaBinaryVersion) + + compile(group: 'org.apache.hadoop', name: 'hadoop-yarn-api', version: hadoopVersion) { + exclude(group: 'javax.servlet', module: 'servlet-api') + exclude(group: 'asm', module: 'asm') + exclude(group: 'org.ow2.asm', module: 'asm') + exclude(group: 'org.jboss.netty', module: 'netty') + exclude(group: 'commons-logging', module: 'commons-logging') + exclude(group: 'com.sun.jersey') + exclude(group: 'com.sun.jersey.jersey-test-framework') + exclude(group: 'com.sun.jersey.contribs') + } + compile(group: 'org.apache.hadoop', name: 'hadoop-yarn-common', version: hadoopVersion) { + exclude(group: 'asm', module: 'asm') + exclude(group: 'org.ow2.asm', module: 'asm') + exclude(group: 'org.jboss.netty', module: 'netty') + exclude(group: 'javax.servlet', module: 'servlet-api') + exclude(group: 'commons-logging', module: 'commons-logging') + exclude(group: 'com.sun.jersey') + exclude(group: 'com.sun.jersey.jersey-test-framework') + exclude(group: 'com.sun.jersey.contribs') + } + compile(group: 'org.apache.hadoop', name: 'hadoop-yarn-server-web-proxy', version: hadoopVersion) { + exclude(group: 'asm', module: 'asm') + exclude(group: 'org.ow2.asm', module: 'asm') + exclude(group: 'org.jboss.netty', module: 'netty') + exclude(group: 'javax.servlet', module: 'servlet-api') + exclude(group: 'commons-logging', module: 'commons-logging') + exclude(group: 'com.sun.jersey') + exclude(group: 'com.sun.jersey.jersey-test-framework') + exclude(group: 'com.sun.jersey.contribs') + } + compile(group: 'org.apache.hadoop', name: 'hadoop-yarn-client', version: hadoopVersion) { + exclude(group: 'asm', module: 'asm') + exclude(group: 'org.ow2.asm', module: 'asm') + exclude(group: 'org.jboss.netty', module: 'netty') + exclude(group: 'javax.servlet', module: 'servlet-api') + exclude(group: 'commons-logging', module: 'commons-logging') + exclude(group: 'com.sun.jersey') + exclude(group: 'com.sun.jersey.jersey-test-framework') + exclude(group: 'com.sun.jersey.contribs') + } + compile group: 'com.google.guava', name: 'guava', version: guavaVersion + compile group: 'org.eclipse.jetty', name: 'jetty-server', version: jettyVersion + compile group: 'org.eclipse.jetty', name: 'jetty-plus', version: jettyVersion + compile group: 'org.eclipse.jetty', name: 'jetty-util', version: jettyVersion + compile group: 'org.eclipse.jetty', name: 'jetty-http', version: jettyVersion + compile group: 'org.eclipse.jetty', name: 'jetty-servlet', version: jettyVersion + compile group: 'org.eclipse.jetty', name: 'jetty-servlets', version: jettyVersion + compile group: 'org.apache.derby', name: 'derby', version: derbyVersion + compile(group: 'org.spark-project.hive', name: 'hive-exec', version: hiveVersion) { + exclude(group: 'org.spark-project.hive', module: 'hive-metastore') + exclude(group: 'org.spark-project.hive', module: 'hive-shims') + exclude(group: 'org.spark-project.hive', module: 'hive-ant') + exclude(group: 'org.spark-project.hive', module: 'spark-client') + exclude(group: 'org.apache.ant', module: 'ant') + exclude(group: 'com.esotericsoftware.kryo', module: 'kryo') + exclude(group: 'commons-codec', module: 'commons-codec') + exclude(group: 'commons-httpclient', module: 'commons-httpclient') + exclude(group: 'org.apache.avro', module: 'avro-mapred') + exclude(group: 'org.apache.calcite', module: 'calcite-core') + exclude(group: 'org.apache.curator', module: 'apache-curator') + exclude(group: 'org.apache.curator', module: 'curator-client') + exclude(group: 'org.apache.curator', module: 'curator-framework') + exclude(group: 'org.apache.thrift', module: 'libthrift') + exclude(group: 'org.apache.thrift', module: 'libfb303') + exclude(group: 'org.apache.zookeeper', module: 'zookeeper') + exclude(group: 'org.slf4j', module: 'slf4j-api') + exclude(group: 'org.slf4j', module: 'slf4j-log4j12') + exclude(group: 'log4j', module: 'log4j') + exclude(group: 'commons-logging', module: 'commons-logging') + exclude(group: 'org.codehaus.groovy', module: 'groovy-all') + exclude(group: 'jline', module: 'jline') + } + compile(group: 'org.spark-project.hive', name: 'hive-metastore', version: hiveVersion) { + exclude(group: 'org.spark-project.hive', module: 'hive-serde') + exclude(group: 'org.spark-project.hive', module: 'hive-shims') + exclude(group: 'org.apache.thrift', module: 'libfb303') + exclude(group: 'org.apache.thrift', module: 'libthrift') + exclude(group: 'javax.servlet', module: 'servlet-api') + exclude(group: 'com.google.guava', module: 'guava') + exclude(group: 'org.slf4j', module: 'slf4j-api') + exclude(group: 'org.slf4j', module: 'slf4j-log4j12') + exclude(group: 'org.apache.derby', module: 'derby') + } + compile(group: 'org.apache.thrift', name: 'libthrift', version: thriftVersion) { + exclude(group: 'org.slf4j', module: 'slf4j-api') + } + compile(group: 'org.apache.thrift', name: 'libfb303', version: thriftVersion) { + exclude(group: 'org.slf4j', module: 'slf4j-api') + } + + testCompile project(subprojectBase + 'snappy-spark-network-yarn_' + scalaBinaryVersion) + testCompile project(path: subprojectBase + 'snappy-spark-core_' + scalaBinaryVersion, configuration: 'testOutput') + + testCompile group: 'org.eclipse.jetty.orbit', name: 'javax.servlet.jsp', version: '2.2.0.v201112011158' + testCompile group: 'org.eclipse.jetty.orbit', name: 'javax.servlet.jsp.jstl', version: '1.2.0.v201105211821' + testCompile(group: 'org.apache.hadoop', name: 'hadoop-yarn-server-tests', version: hadoopVersion, classifier:'tests') { + exclude(group: 'asm', module: 'asm') + exclude(group: 'org.ow2.asm', module: 'asm') + exclude(group: 'org.jboss.netty', module: 'netty') + exclude(group: 'javax.servlet', module: 'servlet-api') + exclude(group: 'commons-logging', module: 'commons-logging') + exclude(group: 'com.sun.jersey') + exclude(group: 'com.sun.jersey.jersey-test-framework') + exclude(group: 'com.sun.jersey.contribs') + } + testCompile(group: 'org.mortbay.jetty', name: 'jetty', version: '6.1.26') { + exclude(group: 'org.mortbay.jetty', module: 'servlet-api') + } + testCompile group: 'com.sun.jersey', name: 'jersey-core', version: sunJerseyVersion + testCompile group: 'com.sun.jersey', name: 'jersey-server', version: sunJerseyVersion + testCompile(group: 'com.sun.jersey', name: 'jersey-json', version: sunJerseyVersion) { + exclude(group: 'stax', module: 'stax-api') + } + testCompile group: 'com.sun.jersey.contribs', name: 'jersey-guice', version: sunJerseyVersion +} diff --git a/yarn/pom.xml b/yarn/pom.xml index 7dba1a829fb9..b676b6d21a2c 100644 --- a/yarn/pom.xml +++ b/yarn/pom.xml @@ -20,7 +20,7 @@ org.apache.spark spark-parent_2.11 - 2.0.1-SNAPSHOT + 2.0.1 ../pom.xml diff --git a/yarn/src/main/scala/org/apache/spark/deploy/yarn/Client.scala b/yarn/src/main/scala/org/apache/spark/deploy/yarn/Client.scala index 01aa12a3c9a7..a47a64cccfa5 100644 --- a/yarn/src/main/scala/org/apache/spark/deploy/yarn/Client.scala +++ b/yarn/src/main/scala/org/apache/spark/deploy/yarn/Client.scala @@ -1148,10 +1148,10 @@ private[spark] class Client( val pyLibPath = Seq(sys.env("SPARK_HOME"), "python", "lib").mkString(File.separator) val pyArchivesFile = new File(pyLibPath, "pyspark.zip") require(pyArchivesFile.exists(), - "pyspark.zip not found; cannot run pyspark application in YARN mode.") - val py4jFile = new File(pyLibPath, "py4j-0.10.1-src.zip") + s"$pyArchivesFile not found; cannot run pyspark application in YARN mode.") + val py4jFile = new File(pyLibPath, "py4j-0.10.3-src.zip") require(py4jFile.exists(), - "py4j-0.10.1-src.zip not found; cannot run pyspark application in YARN mode.") + s"$py4jFile not found; cannot run pyspark application in YARN mode.") Seq(pyArchivesFile.getAbsolutePath(), py4jFile.getAbsolutePath()) } } diff --git a/yarn/src/main/scala/org/apache/spark/deploy/yarn/YarnAllocator.scala b/yarn/src/main/scala/org/apache/spark/deploy/yarn/YarnAllocator.scala index 1b800716495a..b321901e765e 100644 --- a/yarn/src/main/scala/org/apache/spark/deploy/yarn/YarnAllocator.scala +++ b/yarn/src/main/scala/org/apache/spark/deploy/yarn/YarnAllocator.scala @@ -483,7 +483,6 @@ private[yarn] class YarnAllocator( def updateInternalState(): Unit = synchronized { numExecutorsRunning += 1 - assert(numExecutorsRunning <= targetNumExecutors) executorIdToContainer(executorId) = container containerIdToExecutorId(container.getId) = executorId @@ -493,39 +492,44 @@ private[yarn] class YarnAllocator( allocatedContainerToHostMap.put(containerId, executorHostname) } - if (launchContainers) { - logInfo("Launching ExecutorRunnable. driverUrl: %s, executorHostname: %s".format( - driverUrl, executorHostname)) - - launcherPool.execute(new Runnable { - override def run(): Unit = { - try { - new ExecutorRunnable( - container, - conf, - sparkConf, - driverUrl, - executorId, - executorHostname, - executorMemory, - executorCores, - appAttemptId.getApplicationId.toString, - securityMgr, - localResources - ).run() - updateInternalState() - } catch { - case NonFatal(e) => - logError(s"Failed to launch executor $executorId on container $containerId", e) - // Assigned container should be released immediately to avoid unnecessary resource - // occupation. - amClient.releaseAssignedContainer(containerId) + if (numExecutorsRunning < targetNumExecutors) { + if (launchContainers) { + logInfo("Launching ExecutorRunnable. driverUrl: %s, executorHostname: %s".format( + driverUrl, executorHostname)) + + launcherPool.execute(new Runnable { + override def run(): Unit = { + try { + new ExecutorRunnable( + container, + conf, + sparkConf, + driverUrl, + executorId, + executorHostname, + executorMemory, + executorCores, + appAttemptId.getApplicationId.toString, + securityMgr, + localResources + ).run() + updateInternalState() + } catch { + case NonFatal(e) => + logError(s"Failed to launch executor $executorId on container $containerId", e) + // Assigned container should be released immediately to avoid unnecessary resource + // occupation. + amClient.releaseAssignedContainer(containerId) + } } - } - }) + }) + } else { + // For test only + updateInternalState() + } } else { - // For test only - updateInternalState() + logInfo(("Skip launching executorRunnable as runnning Excecutors count: %d " + + "reached target Executors count: %d.").format(numExecutorsRunning, targetNumExecutors)) } } } diff --git a/yarn/src/main/scala/org/apache/spark/scheduler/cluster/YarnSchedulerBackend.scala b/yarn/src/main/scala/org/apache/spark/scheduler/cluster/YarnSchedulerBackend.scala index 6b3c831e6047..2f9ea1911fd6 100644 --- a/yarn/src/main/scala/org/apache/spark/scheduler/cluster/YarnSchedulerBackend.scala +++ b/yarn/src/main/scala/org/apache/spark/scheduler/cluster/YarnSchedulerBackend.scala @@ -18,6 +18,7 @@ package org.apache.spark.scheduler.cluster import scala.concurrent.{ExecutionContext, Future} +import scala.util.{Failure, Success} import scala.util.control.NonFatal import org.apache.hadoop.yarn.api.records.{ApplicationAttemptId, ApplicationId} @@ -124,16 +125,16 @@ private[spark] abstract class YarnSchedulerBackend( * Request executors from the ApplicationMaster by specifying the total number desired. * This includes executors already pending or running. */ - override def doRequestTotalExecutors(requestedTotal: Int): Boolean = { - yarnSchedulerEndpointRef.askWithRetry[Boolean]( + override def doRequestTotalExecutors(requestedTotal: Int): Future[Boolean] = { + yarnSchedulerEndpointRef.ask[Boolean]( RequestExecutors(requestedTotal, localityAwareTasks, hostToLocalTaskCount)) } /** * Request that the ApplicationMaster kill the specified executors. */ - override def doKillExecutors(executorIds: Seq[String]): Boolean = { - yarnSchedulerEndpointRef.askWithRetry[Boolean](KillExecutors(executorIds)) + override def doKillExecutors(executorIds: Seq[String]): Future[Boolean] = { + yarnSchedulerEndpointRef.ask[Boolean](KillExecutors(executorIds)) } override def sufficientResourcesRegistered(): Boolean = { @@ -211,35 +212,35 @@ private[spark] abstract class YarnSchedulerBackend( extends ThreadSafeRpcEndpoint with Logging { private var amEndpoint: Option[RpcEndpointRef] = None - private val askAmThreadPool = - ThreadUtils.newDaemonCachedThreadPool("yarn-scheduler-ask-am-thread-pool") - implicit val askAmExecutor = ExecutionContext.fromExecutor(askAmThreadPool) - private[YarnSchedulerBackend] def handleExecutorDisconnectedFromDriver( executorId: String, executorRpcAddress: RpcAddress): Unit = { - amEndpoint match { + val removeExecutorMessage = amEndpoint match { case Some(am) => val lossReasonRequest = GetExecutorLossReason(executorId) - val future = am.ask[ExecutorLossReason](lossReasonRequest, askTimeout) - future onSuccess { - case reason: ExecutorLossReason => - driverEndpoint.askWithRetry[Boolean](RemoveExecutor(executorId, reason)) - } - future onFailure { - case NonFatal(e) => - logWarning(s"Attempted to get executor loss reason" + - s" for executor id ${executorId} at RPC address ${executorRpcAddress}," + - s" but got no response. Marking as slave lost.", e) - driverEndpoint.askWithRetry[Boolean](RemoveExecutor(executorId, SlaveLost())) - case t => throw t - } + am.ask[ExecutorLossReason](lossReasonRequest, askTimeout) + .map { reason => RemoveExecutor(executorId, reason) }(ThreadUtils.sameThread) + .recover { + case NonFatal(e) => + logWarning(s"Attempted to get executor loss reason" + + s" for executor id ${executorId} at RPC address ${executorRpcAddress}," + + s" but got no response. Marking as slave lost.", e) + RemoveExecutor(executorId, SlaveLost()) + }(ThreadUtils.sameThread) case None => logWarning("Attempted to check for an executor loss reason" + " before the AM has registered!") - driverEndpoint.askWithRetry[Boolean]( - RemoveExecutor(executorId, SlaveLost("AM is not yet registered."))) + Future.successful(RemoveExecutor(executorId, SlaveLost("AM is not yet registered."))) } + + removeExecutorMessage + .flatMap { message => + driverEndpoint.ask[Boolean](message) + }(ThreadUtils.sameThread) + .onFailure { + case NonFatal(e) => logError( + s"Error requesting driver to remove executor $executorId after disconnection.", e) + }(ThreadUtils.sameThread) } override def receive: PartialFunction[Any, Unit] = { @@ -257,9 +258,13 @@ private[spark] abstract class YarnSchedulerBackend( case AddWebUIFilter(filterName, filterParams, proxyBase) => addWebUIFilter(filterName, filterParams, proxyBase) - case RemoveExecutor(executorId, reason) => + case r @ RemoveExecutor(executorId, reason) => logWarning(reason.toString) - removeExecutor(executorId, reason) + driverEndpoint.ask[Boolean](r).onFailure { + case e => + logError("Error requesting driver to remove executor" + + s" $executorId for reason $reason", e) + }(ThreadUtils.sameThread) } @@ -267,13 +272,12 @@ private[spark] abstract class YarnSchedulerBackend( case r: RequestExecutors => amEndpoint match { case Some(am) => - Future { - context.reply(am.askWithRetry[Boolean](r)) - } onFailure { - case NonFatal(e) => + am.ask[Boolean](r).andThen { + case Success(b) => context.reply(b) + case Failure(NonFatal(e)) => logError(s"Sending $r to AM was unsuccessful", e) context.sendFailure(e) - } + }(ThreadUtils.sameThread) case None => logWarning("Attempted to request executors before the AM has registered!") context.reply(false) @@ -282,13 +286,12 @@ private[spark] abstract class YarnSchedulerBackend( case k: KillExecutors => amEndpoint match { case Some(am) => - Future { - context.reply(am.askWithRetry[Boolean](k)) - } onFailure { - case NonFatal(e) => + am.ask[Boolean](k).andThen { + case Success(b) => context.reply(b) + case Failure(NonFatal(e)) => logError(s"Sending $k to AM was unsuccessful", e) context.sendFailure(e) - } + }(ThreadUtils.sameThread) case None => logWarning("Attempted to kill executors before the AM has registered!") context.reply(false) @@ -304,10 +307,6 @@ private[spark] abstract class YarnSchedulerBackend( amEndpoint = None } } - - override def onStop(): Unit = { - askAmThreadPool.shutdownNow() - } } } diff --git a/yarn/src/test/scala/org/apache/spark/deploy/yarn/YarnAllocatorSuite.scala b/yarn/src/test/scala/org/apache/spark/deploy/yarn/YarnAllocatorSuite.scala index 207dbf56d360..f8351c03e561 100644 --- a/yarn/src/test/scala/org/apache/spark/deploy/yarn/YarnAllocatorSuite.scala +++ b/yarn/src/test/scala/org/apache/spark/deploy/yarn/YarnAllocatorSuite.scala @@ -136,6 +136,25 @@ class YarnAllocatorSuite extends SparkFunSuite with Matchers with BeforeAndAfter size should be (0) } + test("container should not be created if requested number if met") { + // request a single container and receive it + val handler = createAllocator(1) + handler.updateResourceRequests() + handler.getNumExecutorsRunning should be (0) + handler.getPendingAllocate.size should be (1) + + val container = createContainer("host1") + handler.handleAllocatedContainers(Array(container)) + + handler.getNumExecutorsRunning should be (1) + handler.allocatedContainerToHostMap.get(container.getId).get should be ("host1") + handler.allocatedHostToContainersMap.get("host1").get should contain (container.getId) + + val container2 = createContainer("host2") + handler.handleAllocatedContainers(Array(container2)) + handler.getNumExecutorsRunning should be (1) + } + test("some containers allocated") { // request a few containers and receive some of them val handler = createAllocator(4) diff --git a/yarn/src/test/scala/org/apache/spark/deploy/yarn/YarnClusterSuite.scala b/yarn/src/test/scala/org/apache/spark/deploy/yarn/YarnClusterSuite.scala index 1ccd7e5993f5..34f9aa64518c 100644 --- a/yarn/src/test/scala/org/apache/spark/deploy/yarn/YarnClusterSuite.scala +++ b/yarn/src/test/scala/org/apache/spark/deploy/yarn/YarnClusterSuite.scala @@ -216,7 +216,7 @@ class YarnClusterSuite extends BaseYarnClusterSuite { // needed locations. val sparkHome = sys.props("spark.test.home") val pythonPath = Seq( - s"$sparkHome/python/lib/py4j-0.10.1-src.zip", + s"$sparkHome/python/lib/py4j-0.10.3-src.zip", s"$sparkHome/python") val extraEnv = Map( "PYSPARK_ARCHIVES_PATH" -> pythonPath.map("local:" + _).mkString(File.pathSeparator), diff --git a/yarn/src/test/scala/org/apache/spark/network/yarn/YarnShuffleServiceSuite.scala b/yarn/src/test/scala/org/apache/spark/network/yarn/YarnShuffleServiceSuite.scala index c33a9e6bbe25..0c2e0204356b 100644 --- a/yarn/src/test/scala/org/apache/spark/network/yarn/YarnShuffleServiceSuite.scala +++ b/yarn/src/test/scala/org/apache/spark/network/yarn/YarnShuffleServiceSuite.scala @@ -27,6 +27,7 @@ import org.apache.hadoop.yarn.conf.YarnConfiguration import org.apache.hadoop.yarn.server.api.{ApplicationInitializationContext, ApplicationTerminationContext} import org.scalatest.{BeforeAndAfterEach, Matchers} +import org.apache.spark.SecurityManager import org.apache.spark.SparkFunSuite import org.apache.spark.network.shuffle.ShuffleTestAccessor import org.apache.spark.network.shuffle.protocol.ExecutorShuffleInfo @@ -77,6 +78,8 @@ class YarnShuffleServiceSuite extends SparkFunSuite with Matchers with BeforeAnd test("executor state kept across NM restart") { s1 = new YarnShuffleService + // set auth to true to test the secrets recovery + yarnConfig.setBoolean(SecurityManager.SPARK_AUTH_CONF, true) s1.init(yarnConfig) val app1Id = ApplicationId.newInstance(0, 1) val app1Data: ApplicationInitializationContext = @@ -89,6 +92,8 @@ class YarnShuffleServiceSuite extends SparkFunSuite with Matchers with BeforeAnd val execStateFile = s1.registeredExecutorFile execStateFile should not be (null) + val secretsFile = s1.secretsFile + secretsFile should not be (null) val shuffleInfo1 = new ExecutorShuffleInfo(Array("/foo", "/bar"), 3, SORT_MANAGER) val shuffleInfo2 = new ExecutorShuffleInfo(Array("/bippy"), 5, SORT_MANAGER) @@ -118,6 +123,7 @@ class YarnShuffleServiceSuite extends SparkFunSuite with Matchers with BeforeAnd s1.stop() s2 = new YarnShuffleService s2.init(yarnConfig) + s2.secretsFile should be (secretsFile) s2.registeredExecutorFile should be (execStateFile) val handler2 = s2.blockHandler @@ -135,6 +141,7 @@ class YarnShuffleServiceSuite extends SparkFunSuite with Matchers with BeforeAnd s3 = new YarnShuffleService s3.init(yarnConfig) s3.registeredExecutorFile should be (execStateFile) + s3.secretsFile should be (secretsFile) val handler3 = s3.blockHandler val resolver3 = ShuffleTestAccessor.getBlockResolver(handler3) @@ -148,7 +155,10 @@ class YarnShuffleServiceSuite extends SparkFunSuite with Matchers with BeforeAnd test("removed applications should not be in registered executor file") { s1 = new YarnShuffleService + yarnConfig.setBoolean(SecurityManager.SPARK_AUTH_CONF, false) s1.init(yarnConfig) + val secretsFile = s1.secretsFile + secretsFile should be (null) val app1Id = ApplicationId.newInstance(0, 1) val app1Data: ApplicationInitializationContext = new ApplicationInitializationContext("user", app1Id, null)